"""Custom `rd record` variant with two extras vs. `rd record`:

  1. `--skip-wrist-cameras`  Omit wrist cameras (role left_wrist/right_wrist)
                             from this session — no SVO2 written, smaller
                             per-episode footprint, faster startup.
  2. `--rerun-viewer`        Stream every grabbed camera frame to a live
                             Rerun web viewer. Useful for monitoring scene
                             framing / gripper view during teleop+record.

Internals:
  - The camera filter: clone `~/.config/raiden/camera.json`, strip wrist cams,
    write to a temp file, pass to raiden's `run_recording`.
  - The live viewer: monkey-patch `DemonstrationRecorder._camera_loop` to also
    `rr.log()` each frame (the original loop only calls grab(), so we add a
    get_frame() + log step). Rerun runs in web mode so you can view from a
    browser through an SSH tunnel.

Usage (on a YAM, with raiden_fork venv active):
    python ~/cameron/yam_control/record_skip_wrist.py \\
        --skip-wrist-cameras --rerun-viewer
"""
import os
# MUST be set before ANY import that could pull in mujoco — mujoco binds its
# GL backend on first import. EGL is the only headless option that works
# without an X server. Setting via setdefault preserves user override.
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")

import argparse
import json
import tempfile
import threading
from pathlib import Path

import numpy as np

from raiden._config import CAMERA_CONFIG
from raiden.control import build_interface
from raiden.recorder import run_recording


_WRIST_ROLES = {"left_wrist", "right_wrist"}


def _patch_robot_controller_arms(use_right: bool, use_left: bool) -> None:
    """Monkey-patch RobotController.__init__ to enforce arm selection regardless
    of caller. raiden's run_recording hardcodes use_left=True, so we override
    at the RobotController level instead of forking run_recording."""
    from raiden.robot.controller import RobotController

    _orig_init = RobotController.__init__

    def _patched_init(self, *args, **kwargs):
        if not use_right:
            kwargs["use_right_leader"] = False
            kwargs["use_right_follower"] = False
        if not use_left:
            kwargs["use_left_leader"] = False
            kwargs["use_left_follower"] = False
        _orig_init(self, *args, **kwargs)

    RobotController.__init__ = _patched_init
    side = (
        "both" if (use_right and use_left)
        else "right_only" if use_right else "left_only"
    )
    print(f"  --arms {side}: patched RobotController.__init__")


def _filter_camera_config(src_path: str, skip_wrist: bool) -> str:
    """Return a path to a (possibly filtered) camera.json suitable for run_recording."""
    if not skip_wrist:
        return src_path

    with open(src_path) as f:
        cfg = json.load(f)

    kept = {}
    dropped = []
    for name, entry in cfg.items():
        role = entry.get("role")
        if role in _WRIST_ROLES:
            dropped.append((name, role))
            continue
        kept[name] = entry

    if dropped:
        print(f"  --skip-wrist-cameras: dropping {len(dropped)} wrist camera(s):")
        for name, role in dropped:
            print(f"     - {name} (role={role})")

    tmp = tempfile.NamedTemporaryFile(
        mode="w", suffix="_camera.json", delete=False, prefix="record_skip_wrist_"
    )
    json.dump(kept, tmp, indent=2)
    tmp.close()
    print(f"  filtered camera config written to {tmp.name}")
    return tmp.name


def _start_rerun_viewer(web_port: int) -> None:
    """Start a Rerun web viewer + gRPC server, print the URL + SSH tunnel hint."""
    import rerun as rr
    from urllib.parse import quote

    rr.init("raiden_record")
    grpc_port = web_port + 1
    server_uri = rr.serve_grpc(grpc_port=grpc_port)
    rr.serve_web_viewer(web_port=web_port, open_browser=False)
    viewer_url = f"http://localhost:{web_port}?url={quote(server_uri, safe='')}"
    print()
    print("=" * 60)
    print(f"  Rerun viewer: {viewer_url}")
    print(f"  SSH tunnel:   ssh -L {web_port}:localhost:{web_port} "
          f"-L {grpc_port}:localhost:{grpc_port} <host>")
    print("=" * 60)
    print()


def _load_scene_calibration():
    """Return (K_native, T_cam2world, T_left_from_right, scene_cam_name) from
    raiden's calibration_results.json. T_left_from_right comes from
    bimanual_transform.right_base_to_left_base. None entries are missing."""
    import json
    from raiden._config import CALIBRATION_FILE
    with open(CALIBRATION_FILE) as f:
        cal = json.load(f)
    cams = cal.get("cameras", {})
    # Try the conventional scene camera name first.
    scene_name = "scene_camera" if "scene_camera" in cams else None
    if scene_name is None:
        # Fall back: pick the first camera whose name starts with "scene".
        for n in cams:
            if n.startswith("scene"):
                scene_name = n
                break
    if scene_name is None:
        raise RuntimeError("No scene camera found in calibration_results.json")
    sc = cams[scene_name]
    K = np.asarray(sc["intrinsics"]["camera_matrix"], dtype=np.float64)
    R = np.asarray(sc["extrinsics"]["rotation_matrix"], dtype=np.float64)
    t = np.asarray(sc["extrinsics"]["translation_vector"], dtype=np.float64).reshape(3)
    T_cam2world = np.eye(4, dtype=np.float64)
    T_cam2world[:3, :3] = R
    T_cam2world[:3, 3] = t
    # raiden's `bimanual_transform.right_base_to_left_base` JSON key is
    # *misnamed* — the matrix actually maps left_base → right_base. raiden's
    # own converter (converter.py:1021) inverts it before passing it through
    # the rest of the pipeline as `T_left_from_right`. Match that.
    M = np.asarray(
        cal["bimanual_transform"]["right_base_to_left_base"], dtype=np.float64
    )
    T_lfr = np.linalg.inv(M)
    return K, T_cam2world, T_lfr, scene_name


class _OverlayWorker(threading.Thread):
    """Renders bimanual silhouettes on a fixed period in a daemon thread.

    Keeps the most recent (mask_l, mask_r) pair under a lock. The camera loop
    overlays whichever masks are current — at most period_s stale.
    """

    def __init__(
        self,
        robot_controller,
        K_render,
        T_cam2world,
        T_left_from_right,
        render_w: int,
        render_h: int,
        period_s: float,
    ):
        super().__init__(daemon=True, name="overlay-worker")
        self.rc = robot_controller
        self.K = K_render
        self.T_cam2world = T_cam2world
        self.T_lfr = T_left_from_right
        self.W = render_w
        self.H = render_h
        self.period_s = period_s
        self._stop = threading.Event()
        self._latest = (None, None)
        self._lock = threading.Lock()

    def stop(self):
        self._stop.set()

    def run(self):
        import sys
        sys.path.insert(0, str(Path(__file__).parent))
        from yam_overlay_render import render_bimanual

        while not self._stop.is_set():
            t0 = __import__("time").monotonic()
            try:
                ql = self.rc.follower_l.get_joint_pos() if self.rc.follower_l else None
                qr = self.rc.follower_r.get_joint_pos() if self.rc.follower_r else None
                # Pad missing side with zeros so render still works.
                ql = np.asarray(ql) if ql is not None else np.zeros(7)
                qr = np.asarray(qr) if qr is not None else np.zeros(7)
                joints14 = np.concatenate([ql, qr])
                ml, mr = render_bimanual(
                    joints14, self.T_cam2world, self.T_lfr, self.K, self.W, self.H
                )
                with self._lock:
                    self._latest = (ml, mr)
            except Exception as e:
                print(f"  [render_robots] render failed: {e}")
            # Sleep the remainder of the period
            self._stop.wait(max(0.0, self.period_s - (__import__("time").monotonic() - t0)))

    def get_masks_resized(self, target_w: int, target_h: int):
        import cv2
        with self._lock:
            ml, mr = self._latest
        if ml is None:
            return None, None
        ml_r = cv2.resize(ml, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
        mr_r = cv2.resize(mr, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
        return ml_r, mr_r


def _patch_camera_loop_for_rerun(viz_hz: float, viz_width: int, jpeg_quality: int,
                                 render_robots: bool, render_factor: int,
                                 render_period_s: float) -> None:
    """Monkey-patch DemonstrationRecorder._camera_loop to also log frames to rerun.

    Throttled to viz_hz Hz and resized to viz_width (preserving aspect ratio) to
    keep the gRPC pipe from backing up — the grab loop runs at native 30 Hz
    regardless, so the SVO2 recording is unaffected.
    """
    import time as _time
    import cv2
    import rerun as rr
    from raiden.recorder import DemonstrationRecorder

    min_dt = 1.0 / max(viz_hz, 0.1)
    _last_log_t: dict = {}

    # Lazy-init overlay worker state (shared across cameras, only used for scene).
    _overlay_state = {"worker": None, "K_render": None, "render_w": 0, "render_h": 0,
                      "scene_name": None}
    if render_robots:
        K_native, T_cam2world, T_lfr, scene_name = _load_scene_calibration()
        # K must be scaled to match the render resolution. We delay computing
        # render_w/h until we see the first scene-camera frame (we need native size).
        _overlay_state.update({
            "K_native": K_native, "T_cam2world": T_cam2world,
            "T_lfr": T_lfr, "scene_name": scene_name,
        })
        print(f"  --render_robots: scene cam='{scene_name}', factor=1/{render_factor}, "
              f"period={render_period_s:.1f}s")

    def _camera_loop_with_rerun(self, camera, stop_event):
        # Lazily start the overlay worker once we know the scene camera's native size.
        if (render_robots
                and camera.name == _overlay_state.get("scene_name")
                and _overlay_state["worker"] is None):
            # Probe one frame to get native res
            if camera.grab():
                f0 = camera.get_frame()
                nat_h, nat_w = f0.color.shape[:2]
                rw = max(1, nat_w // render_factor)
                rh = max(1, nat_h // render_factor)
                K_r = _overlay_state["K_native"].copy()
                K_r[0, 0] *= rw / nat_w  # fx
                K_r[1, 1] *= rh / nat_h  # fy
                K_r[0, 2] *= rw / nat_w  # cx
                K_r[1, 2] *= rh / nat_h  # cy
                _overlay_state["K_render"] = K_r
                _overlay_state["render_w"] = rw
                _overlay_state["render_h"] = rh
                w = _OverlayWorker(
                    robot_controller=self.robot_controller,
                    K_render=K_r,
                    T_cam2world=_overlay_state["T_cam2world"],
                    T_left_from_right=_overlay_state["T_lfr"],
                    render_w=rw, render_h=rh,
                    period_s=render_period_s,
                )
                w.start()
                _overlay_state["worker"] = w
                print(f"  [render_robots] worker started at {rw}x{rh} "
                      f"(native {nat_w}x{nat_h})")
                # We already grabbed a frame — fall through to potentially log it.

        while not stop_event.is_set():
            if not camera.grab():
                continue

            now = _time.monotonic()
            last = _last_log_t.get(camera.name, 0.0)
            if now - last < min_dt:
                continue
            _last_log_t[camera.name] = now

            try:
                frame = camera.get_frame()
                img = frame.color  # HxWx3 BGR

                if viz_width > 0 and img.shape[1] > viz_width:
                    h, w = img.shape[:2]
                    new_h = int(round(h * viz_width / w))
                    img = cv2.resize(img, (viz_width, new_h),
                                     interpolation=cv2.INTER_AREA)

                # Overlay silhouettes if this is the scene cam and we have masks.
                worker = _overlay_state["worker"]
                if (render_robots and camera.name == _overlay_state["scene_name"]
                        and worker is not None):
                    ml, mr = worker.get_masks_resized(img.shape[1], img.shape[0])
                    if ml is not None:
                        from yam_overlay_render import overlay_contours
                        img = overlay_contours(img, ml, mr)

                ok, buf = cv2.imencode(".jpg", img,
                                       [cv2.IMWRITE_JPEG_QUALITY, jpeg_quality])
                if not ok:
                    continue
                rr.set_time("camera_time", timestamp=frame.timestamp_ns / 1e9)
                rr.log(
                    f"cameras/{camera.name}",
                    rr.EncodedImage(contents=bytes(buf), media_type="image/jpeg"),
                )
            except Exception as e:
                print(f"  [rerun] {camera.name}: log failed: {e}")

    DemonstrationRecorder._camera_loop = _camera_loop_with_rerun
    print(f"  --rerun-viewer: patched _camera_loop "
          f"(viz_hz={viz_hz}, viz_width={viz_width}, jpeg_q={jpeg_quality}, "
          f"render_robots={render_robots})")


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--skip-wrist-cameras", action="store_true",
                   help="Omit wrist cameras (role=left_wrist/right_wrist) from this session.")
    p.add_argument("--rerun-viewer", action="store_true",
                   help="Stream camera frames to a live Rerun web viewer.")
    p.add_argument("--rerun-web-port", type=int, default=9090,
                   help="Web port for the Rerun viewer (default: 9090).")
    p.add_argument("--rerun-viz-hz", type=float, default=10.0,
                   help="Max log rate per camera (default: 10 Hz). The grab "
                        "loop still runs at native 30 Hz; this only throttles viz.")
    p.add_argument("--rerun-viz-width", type=int, default=640,
                   help="Resize width before logging (default: 640 px, aspect "
                        "preserved). 0 = native resolution.")
    p.add_argument("--rerun-jpeg-quality", type=int, default=70,
                   help="JPEG quality for viz frames (default: 70).")
    p.add_argument("--render_robots", action="store_true",
                   help="Overlay MuJoCo robot silhouette on the scene-camera "
                        "viz stream using current extrinsics + live joints. "
                        "Useful for visually checking calibration quality.")
    p.add_argument("--render_robots_factor", type=int, default=4,
                   help="Render robot silhouette at 1/N of native scene-cam "
                        "resolution (default: 4 — i.e. 320x180 for 1280x720).")
    p.add_argument("--render_robots_period_s", type=float, default=1.0,
                   help="Re-render the silhouette every N seconds (default: 1).")
    p.add_argument("--control", default="leader", choices=["leader", "spacemouse"])
    p.add_argument("--arms", default="right_only",
                   choices=["right_only", "left_only", "both"],
                   help="Which arms to record from (default: right_only).")
    p.add_argument("--data_dir", default="data",
                   help="Root data directory (default: ./data).")
    p.add_argument("--s3_bucket", default=None)
    p.add_argument("--s3_prefix", default="demonstrations")
    p.add_argument("--camera_config", default=CAMERA_CONFIG,
                   help=f"Source camera config (default: {CAMERA_CONFIG}).")
    args = p.parse_args()

    cam_cfg = _filter_camera_config(args.camera_config, args.skip_wrist_cameras)

    use_right = args.arms in ("right_only", "both")
    use_left = args.arms in ("left_only", "both")
    _patch_robot_controller_arms(use_right=use_right, use_left=use_left)

    if args.rerun_viewer or args.render_robots:
        _start_rerun_viewer(args.rerun_web_port)
        _patch_camera_loop_for_rerun(
            viz_hz=args.rerun_viz_hz,
            viz_width=args.rerun_viz_width,
            jpeg_quality=args.rerun_jpeg_quality,
            render_robots=args.render_robots,
            render_factor=args.render_robots_factor,
            render_period_s=args.render_robots_period_s,
        )

    try:
        # `arms="bimanual"` keeps run_recording's internal use_right=True; our
        # RobotController patch then disables whichever side we don't want.
        run_recording(
            s3_bucket=args.s3_bucket,
            s3_prefix=args.s3_prefix,
            interface=build_interface(args.control),
            camera_config_file=cam_cfg,
            arms="bimanual",
            data_dir=args.data_dir,
        )
    finally:
        if cam_cfg != args.camera_config and os.path.exists(cam_cfg):
            try:
                os.unlink(cam_cfg)
            except OSError:
                pass


if __name__ == "__main__":
    main()
