"""Sequentially deploy joint states recorded by test_record_joint_seq.py.

Loads a .npz with `joints[N,14]` (left arm 7 + right arm 7, gripper-last per arm).
For each recorded state, both followers smooth-move IN PARALLEL to the target
pose. We wait for BOTH arms to finish before moving on to the next pose.

Usage (on a YAM, with raiden_fork venv active):
    source ~/cameron/raiden_fork.venv/bin/activate
    python ~/cameron/yam_control/deploy_joint_state.py \\
        ~/cameron/yam_control/recordings/test_record_20260528_120000.npz \\
        --time_per_move_s 3.0 --pause_between_s 0.5
"""
import argparse
import json
import threading
import time
from pathlib import Path

import numpy as np

from raiden.robot.controller import RobotController, smooth_move_joints


DOF = 7  # 6 arm + 1 gripper


def _confirm(prompt: str) -> bool:
    try:
        return input(prompt).strip().lower() in ("y", "yes")
    except EOFError:
        return False


def main():
    p = argparse.ArgumentParser()
    p.add_argument("trajectory", type=str,
                   help="Path to .npz from test_record_joint_seq.py")
    p.add_argument("--time_per_move_s", type=float, default=3.0,
                   help="Seconds each smooth_move from pose i → i+1 should take.")
    p.add_argument("--steps", type=int, default=100,
                   help="Sub-steps inside each smooth_move (raiden's default is 100).")
    p.add_argument("--pause_between_s", type=float, default=0.5,
                   help="Hold time at each pose before starting the next move.")
    p.add_argument("--first_move_s", type=float, default=4.0,
                   help="Longer settle time for the FIRST move (current proprio → "
                        "pose[0]) since the arms may be far from home.")
    p.add_argument("--max_delta_rad", type=float, default=2.0,
                   help="Abort if any per-joint jump between consecutive poses "
                        "exceeds this (rad). Safety guardrail; set high to disable.")
    p.add_argument("--use_right_follower", type=int, default=1)
    p.add_argument("--use_left_follower", type=int, default=1)
    p.add_argument("--start_idx", type=int, default=0,
                   help="Skip poses before this index.")
    p.add_argument("--end_idx", type=int, default=-1,
                   help="Stop after this index (inclusive). -1 = run to end.")
    p.add_argument("--yes", action="store_true",
                   help="Skip the pre-flight confirmation prompt.")
    args = p.parse_args()

    path = Path(args.trajectory)
    z = np.load(path, allow_pickle=False)
    joints = z["joints"].astype(np.float32)          # (N, 14)
    timestamps = z["timestamps"] if "timestamps" in z.files else None
    metadata = (bytes(z["metadata"]).decode()
                if "metadata" in z.files else "{}")
    print(f"Loaded {path.name}: {joints.shape[0]} poses")
    try:
        print(f"  metadata: {json.loads(metadata)}")
    except Exception:
        pass

    end_idx = joints.shape[0] - 1 if args.end_idx < 0 else args.end_idx
    sel = joints[args.start_idx : end_idx + 1]
    if len(sel) == 0:
        print("No poses in selected range — nothing to deploy.")
        return
    print(f"  deploying poses [{args.start_idx}..{end_idx}] = {len(sel)} pose(s)")

    # Safety check: largest jump between consecutive poses
    if len(sel) > 1:
        max_jump = float(np.max(np.abs(np.diff(sel, axis=0))))
        print(f"  max per-joint jump between consecutive poses: {max_jump:.3f} rad")
        if max_jump > args.max_delta_rad:
            print(f"  ⚠ exceeds --max_delta_rad {args.max_delta_rad}. Aborting.")
            return

    rc = RobotController(
        use_right_leader=False, use_left_leader=False,
        use_right_follower=bool(args.use_right_follower),
        use_left_follower=bool(args.use_left_follower),
    )
    rc.check_can_interfaces()
    rc.initialize_robots(gravity_comp_mode=False)

    # Show how far the first move will be from the current proprio
    q_now_l = rc.follower_l.get_joint_pos() if rc.follower_l else np.zeros(DOF)
    q_now_r = rc.follower_r.get_joint_pos() if rc.follower_r else np.zeros(DOF)
    q_now = np.concatenate([q_now_l, q_now_r]).astype(np.float32)
    first_jump = float(np.max(np.abs(sel[0] - q_now)))
    print(f"\nCurrent proprio:")
    print(f"  L = {np.array2string(q_now_l, precision=3, suppress_small=True)}")
    print(f"  R = {np.array2string(q_now_r, precision=3, suppress_small=True)}")
    print(f"  max delta to first pose: {first_jump:.3f} rad "
          f"(over {args.first_move_s:.1f}s)")
    if first_jump > args.max_delta_rad:
        print(f"  ⚠ exceeds --max_delta_rad {args.max_delta_rad}.")
        if not args.yes and not _confirm("  Continue anyway? [y/N] "):
            rc.shutdown()
            return

    if not args.yes:
        if not _confirm(f"\nDeploy {len(sel)} pose(s)? [y/N] "):
            print("Aborted.")
            rc.shutdown()
            return

    def move_to(target14, time_s):
        """Smooth-move both followers to target14 in parallel. Block until done."""
        threads = []
        if rc.follower_l is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_l, target14[:DOF]),
                kwargs={"time_interval_s": time_s, "steps": args.steps},
            ))
        if rc.follower_r is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_r, target14[DOF:DOF * 2]),
                kwargs={"time_interval_s": time_s, "steps": args.steps},
            ))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

    try:
        print(f"\n[1/{len(sel)}] moving to first pose ({args.first_move_s:.1f}s)...")
        move_to(sel[0], args.first_move_s)
        if args.pause_between_s > 0:
            time.sleep(args.pause_between_s)

        for i in range(1, len(sel)):
            qstr = np.array2string(sel[i], precision=3, suppress_small=True)
            jump = float(np.max(np.abs(sel[i] - sel[i - 1])))
            print(f"[{i+1}/{len(sel)}] move ({args.time_per_move_s:.1f}s, "
                  f"Δmax={jump:.3f} rad)  → {qstr}")
            move_to(sel[i], args.time_per_move_s)
            if args.pause_between_s > 0:
                time.sleep(args.pause_between_s)

        print("\nAll poses deployed.")
    except KeyboardInterrupt:
        print("\n[Ctrl-C] aborting — returning to home")
    finally:
        try:
            rc.return_to_home()
        except Exception as e:
            print(f"return_to_home failed: {e}")
        rc.shutdown()


if __name__ == "__main__":
    main()
