"""Velocity-paced variant of deploy_joint_state.py.

Same execution pattern (parallel smooth_move_joints + thread.join per pose),
but the per-pose duration is computed from a max-joint-velocity budget and the
ACTUAL distance to travel — not a fixed wall-clock time. Long hops take longer,
short hops are quick, peak joint speed stays bounded.

For each pose, both arms get the SAME time_interval_s — sized by whichever arm
has the larger max-per-joint delta — so they finish simultaneously.

Usage (on a YAM, with raiden_fork venv active):
    source ~/cameron/raiden_fork.venv/bin/activate
    python ~/cameron/yam_control/deploy_joint_state_velocity.py \\
        ~/cameron/yam_control/recordings/test_record_20260528_120000.npz \\
        --max_joint_vel_rad_s 0.5
"""
import argparse
import json
import threading
import time
from pathlib import Path

import numpy as np

from raiden.robot.controller import RobotController, smooth_move_joints


DOF = 7  # 6 arm + 1 gripper


def _confirm(prompt: str) -> bool:
    try:
        return input(prompt).strip().lower() in ("y", "yes")
    except EOFError:
        return False


def main():
    p = argparse.ArgumentParser()
    p.add_argument("trajectory", type=str,
                   help="Path to .npz from test_record_joint_seq.py")
    p.add_argument("--max_joint_vel_rad_s", type=float, default=0.5,
                   help="Peak per-joint angular velocity budget (rad/s). The "
                        "smooth_move duration is sized to keep every joint at "
                        "or below this. Default: 0.5 rad/s (a 1 rad excursion "
                        "takes 2 s).")
    p.add_argument("--min_move_s", type=float, default=0.2,
                   help="Floor on per-pose duration so tiny moves don't try "
                        "to execute in milliseconds. Default: 0.2 s.")
    p.add_argument("--steps", type=int, default=100,
                   help="Sub-steps inside each smooth_move. Default: 100.")
    p.add_argument("--pause_between_s", type=float, default=0.5,
                   help="Hold time at each pose before starting the next move.")
    p.add_argument("--first_move_vel_rad_s", type=float, default=0.25,
                   help="Slower velocity budget for the FIRST move (current "
                        "proprio → pose[0]) since the arms may be far from "
                        "the first recorded pose. Default: 0.25 rad/s.")
    p.add_argument("--max_delta_rad", type=float, default=2.0,
                   help="Abort if any per-joint jump between consecutive poses "
                        "exceeds this (rad). Safety guardrail.")
    p.add_argument("--use_right_follower", type=int, default=1)
    p.add_argument("--use_left_follower", type=int, default=1)
    p.add_argument("--start_idx", type=int, default=0)
    p.add_argument("--end_idx", type=int, default=-1)
    p.add_argument("--yes", action="store_true",
                   help="Skip the pre-flight confirmation prompt.")
    args = p.parse_args()

    path = Path(args.trajectory)
    z = np.load(path, allow_pickle=False)
    joints = z["joints"].astype(np.float32)
    metadata = (bytes(z["metadata"]).decode()
                if "metadata" in z.files else "{}")
    print(f"Loaded {path.name}: {joints.shape[0]} poses")
    try:
        print(f"  metadata: {json.loads(metadata)}")
    except Exception:
        pass

    end_idx = joints.shape[0] - 1 if args.end_idx < 0 else args.end_idx
    sel = joints[args.start_idx : end_idx + 1]
    if len(sel) == 0:
        print("No poses in selected range — nothing to deploy.")
        return
    print(f"  deploying poses [{args.start_idx}..{end_idx}] = {len(sel)} pose(s)")

    if len(sel) > 1:
        max_jump = float(np.max(np.abs(np.diff(sel, axis=0))))
        print(f"  max per-joint jump between consecutive poses: {max_jump:.3f} rad")
        if max_jump > args.max_delta_rad:
            print(f"  ⚠ exceeds --max_delta_rad {args.max_delta_rad}. Aborting.")
            return

    rc = RobotController(
        use_right_leader=False, use_left_leader=False,
        use_right_follower=bool(args.use_right_follower),
        use_left_follower=bool(args.use_left_follower),
    )
    rc.check_can_interfaces()
    rc.initialize_robots(gravity_comp_mode=False)

    q_now_l = rc.follower_l.get_joint_pos() if rc.follower_l else np.zeros(DOF)
    q_now_r = rc.follower_r.get_joint_pos() if rc.follower_r else np.zeros(DOF)
    q_now = np.concatenate([q_now_l, q_now_r]).astype(np.float32)
    first_jump = float(np.max(np.abs(sel[0] - q_now)))
    first_t = max(args.min_move_s, first_jump / max(args.first_move_vel_rad_s, 1e-6))
    print(f"\nCurrent proprio:")
    print(f"  L = {np.array2string(q_now_l, precision=3, suppress_small=True)}")
    print(f"  R = {np.array2string(q_now_r, precision=3, suppress_small=True)}")
    print(f"  max delta to first pose: {first_jump:.3f} rad "
          f"@ {args.first_move_vel_rad_s} rad/s → {first_t:.2f} s")
    if first_jump > args.max_delta_rad:
        print(f"  ⚠ exceeds --max_delta_rad {args.max_delta_rad}.")
        if not args.yes and not _confirm("  Continue anyway? [y/N] "):
            rc.shutdown()
            return

    if not args.yes:
        if not _confirm(f"\nDeploy {len(sel)} pose(s) "
                        f"@ {args.max_joint_vel_rad_s} rad/s? [y/N] "):
            print("Aborted.")
            rc.shutdown()
            return

    def move_to(target14, time_s):
        """Smooth-move both followers to target14 in parallel. Block until done.
        Both arms use the SAME time_s so they finish simultaneously."""
        threads = []
        if rc.follower_l is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_l, target14[:DOF]),
                kwargs={"time_interval_s": time_s, "steps": args.steps},
            ))
        if rc.follower_r is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_r, target14[DOF:DOF * 2]),
                kwargs={"time_interval_s": time_s, "steps": args.steps},
            ))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

    def duration_for(current14, target14, vel_rad_s):
        """Compute smooth_move duration from worst-case per-joint distance and
        the velocity budget. Both arms considered to keep them in sync."""
        delta = float(np.max(np.abs(np.asarray(target14) - np.asarray(current14))))
        t = delta / max(vel_rad_s, 1e-6)
        return max(args.min_move_s, t), delta

    try:
        print(f"\n[1/{len(sel)}] first move (Δmax={first_jump:.3f} rad → {first_t:.2f}s)")
        move_to(sel[0], first_t)
        if args.pause_between_s > 0:
            time.sleep(args.pause_between_s)

        prev = sel[0]
        for i in range(1, len(sel)):
            t_s, delta = duration_for(prev, sel[i], args.max_joint_vel_rad_s)
            qstr = np.array2string(sel[i], precision=3, suppress_small=True)
            print(f"[{i+1}/{len(sel)}] Δmax={delta:.3f} rad → {t_s:.2f}s  → {qstr}")
            move_to(sel[i], t_s)
            if args.pause_between_s > 0:
                time.sleep(args.pause_between_s)
            prev = sel[i]

        print("\nAll poses deployed.")
    except KeyboardInterrupt:
        print("\n[Ctrl-C] aborting — returning to home")
    finally:
        try:
            rc.return_to_home()
        except Exception as e:
            print(f"return_to_home failed: {e}")
        rc.shutdown()


if __name__ == "__main__":
    main()
