"""Bimanual teleop + record-on-button-press.

Same as `rd teleop`, but with an extra hook: while teleoperating, press the
WHITE (bottom) button on either leader to capture the current 14-D follower
joint state. Press the YELLOW (top) button on either leader to return arms to
home and exit.

Recorded states are also saved at exit to a timestamped .npz under
~/cameron/yam_control/recordings/test_record_*.npz with the same layout as
yam_record.py (joints[N,14], timestamps[N], metadata).

Usage (on a YAM, with raiden_fork venv active):
    source ~/cameron/raiden_fork.venv/bin/activate
    python ~/cameron/yam_control/test_record_joint_seq.py
"""
import argparse
import json
import os
import signal
import socket
import sys
import time
from datetime import datetime
from pathlib import Path

import numpy as np

from raiden.control import build_interface
from raiden.robot.controller import RobotController


def _read_joints14(rc):
    dof = 7
    ql = rc.follower_l.get_joint_pos() if rc.follower_l is not None else np.zeros(dof)
    qr = rc.follower_r.get_joint_pos() if rc.follower_r is not None else np.zeros(dof)
    return np.concatenate([ql, qr]).astype(np.float32)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--bilateral_kp", type=float, default=0.0)
    p.add_argument("--arms", type=str, default="bimanual", choices=["bimanual", "left"])
    p.add_argument("--out", type=str, default="",
                   help="Optional .npz output path. Default: "
                        "~/cameron/yam_control/recordings/test_record_<timestamp>.npz")
    p.add_argument("--notes", type=str, default="")
    args = p.parse_args()

    if not args.out:
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        args.out = str(Path.home() / f"cameron/yam_control/recordings/test_record_{ts}.npz")
    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    use_right = args.arms == "bimanual"
    use_left = True

    interface = build_interface("leader")  # YAMInterface (leader-follower)

    rc = RobotController(
        use_right_leader=interface.uses_leaders and use_right,
        use_left_leader=interface.uses_leaders and use_left,
        use_right_follower=use_right,
        use_left_follower=use_left,
    )

    interface.open()

    recorded = []  # list of (t_rel, joints14)
    t_start = None

    try:
        def signal_handler(signum, frame):
            rc.emergency_stop()
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

        rc.setup_for_teleop_recording()
        rc.enable_estop()

        interface.setup(rc)
        interface.start(rc)

        print(interface.banner)
        print("  WHITE / BOTTOM leader button → record current joint state")
        print("  YELLOW / TOP   leader button → return to home + exit")
        print()

        t_start = time.monotonic()

        while True:
            if rc.session_estop_requested:
                print("\n[FootPedal] Returning arms to home and exiting.")
                break

            # YELLOW / top button (or footpedal) → exit
            if interface.poll(rc):
                time.sleep(0.5)  # debounce
                break

            # WHITE / bottom button → record current state
            if rc.check_failure_button():
                q14 = _read_joints14(rc)
                t_rel = time.monotonic() - t_start
                recorded.append((t_rel, q14))
                qstr = np.array2string(q14, precision=3, suppress_small=True)
                print(f"recorded joint state  [#{len(recorded):03d}  t={t_rel:6.2f}s]  {qstr}")

            time.sleep(0.02)  # 50 Hz poll — fast enough to catch button presses

        interface.stop(rc)
        rc.shutdown()

        print("\nTeleoperation session ended.")

    except Exception as e:
        print(f"\nError: {e}")
        if rc.has_robots():
            rc.emergency_stop()
    finally:
        interface.close()
        # Save whatever we recorded
        if recorded:
            ts_arr = np.asarray([t for t, _ in recorded], dtype=np.float64)
            j_arr = np.stack([q for _, q in recorded], axis=0)
            meta = {
                "notes": args.notes,
                "datetime": datetime.now().isoformat(timespec="seconds"),
                "hostname": socket.gethostname(),
                "n_samples": int(len(recorded)),
                "source": "test_record_joint_seq.py",
            }
            np.savez(
                out_path,
                joints=j_arr,
                timestamps=ts_arr,
                metadata=np.frombuffer(json.dumps(meta).encode(), dtype=np.uint8),
            )
            print(f"Saved {len(recorded)} recorded states → {out_path}")
        else:
            print("No states recorded — nothing saved.")
        os._exit(0)


if __name__ == "__main__":
    main()
