"""Replay a YAM joint trajectory recorded by yam_record.py.

Streams via raiden's `smooth_move_joints` at `control_hz × speed`. Same logic
as raiden's `rd replay` but minus the camera/SVO2/IK overhead — pure joint streaming.

Usage:
    source ~/cameron/raiden_fork/.venv/bin/activate
    python ~/cameron/yam_control/yam_replay.py \\
        ~/cameron/yam_control/recordings/20260528_120000.npz \\
        --speed 1.0 \\
        --control_hz 150
"""
import argparse
import json
import threading
import time
from pathlib import Path

import numpy as np

from raiden.robot.controller import RobotController, smooth_move_joints


def _resample(joints_native, t_native, control_hz):
    """Linearly interpolate joints sampled at irregular t_native onto a uniform
    control_hz grid spanning the original time range."""
    if len(t_native) < 2:
        return joints_native.copy(), t_native.copy()
    n_out = int(round((t_native[-1] - t_native[0]) * control_hz)) + 1
    t_uniform = t_native[0] + np.arange(n_out) / control_hz
    out = np.empty((n_out, joints_native.shape[1]), dtype=np.float32)
    for j in range(joints_native.shape[1]):
        out[:, j] = np.interp(t_uniform, t_native, joints_native[:, j])
    return out, t_uniform


def main():
    p = argparse.ArgumentParser()
    p.add_argument("trajectory", type=str, help="Path to .npz from yam_record.py")
    p.add_argument("--speed", type=float, default=1.0,
                   help="Playback speed multiplier. 0.5 = half speed; 2.0 = double.")
    p.add_argument("--control_hz", type=int, default=150,
                   help="Command rate to the followers (Hz). Smoother → higher.")
    p.add_argument("--start_t", type=float, default=0.0,
                   help="Seconds into the trajectory to start replay at.")
    p.add_argument("--end_t", type=float, default=-1.0,
                   help="Seconds into the trajectory to stop. -1 = run to end.")
    p.add_argument("--use_right_follower", type=int, default=1)
    p.add_argument("--use_left_follower", type=int, default=1)
    p.add_argument("--smooth_move_to_start_s", type=float, default=3.0,
                   help="Smooth-move time from current proprio to trajectory[0] (seconds).")
    args = p.parse_args()

    path = Path(args.trajectory)
    z = np.load(path, allow_pickle=False)
    joints = z["joints"]                # (N, 14)
    t_native = z["timestamps"]          # (N,)
    metadata = bytes(z["metadata"]).decode() if "metadata" in z.files else "{}"
    print(f"Loaded {path}: {joints.shape[0]} samples, {t_native[-1] - t_native[0]:.2f}s")
    try:
        meta = json.loads(metadata)
        print(f"  metadata: {meta}")
    except Exception:
        pass

    # Trim to [start_t, end_t]
    end_t = t_native[-1] if args.end_t < 0 else args.end_t
    mask = (t_native >= args.start_t) & (t_native <= end_t)
    joints = joints[mask]
    t_native = t_native[mask] - t_native[mask].min()
    print(f"  trimmed to [{args.start_t}, {end_t}]: {joints.shape[0]} samples")

    # Resample to control_hz so the streaming loop sees uniformly-spaced frames.
    joints_u, _ = _resample(joints, t_native, args.control_hz)
    n_frames = joints_u.shape[0]
    print(f"  resampled to {args.control_hz} Hz: {n_frames} frames")

    rc = RobotController(
        use_right_leader=False, use_left_leader=False,
        use_right_follower=bool(args.use_right_follower),
        use_left_follower=bool(args.use_left_follower),
    )
    rc.check_can_interfaces()
    rc.initialize_robots(gravity_comp_mode=False)

    DOF = 7
    try:
        # Smooth-move to start pose in parallel for both arms.
        print(f"\nMoving to start ({args.smooth_move_to_start_s:.1f}s smooth)...")
        threads = []
        if rc.follower_l is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_l, joints_u[0, :DOF]),
                kwargs={"time_interval_s": args.smooth_move_to_start_s, "steps": 100},
            ))
        if rc.follower_r is not None:
            threads.append(threading.Thread(
                target=smooth_move_joints,
                args=(rc.follower_r, joints_u[0, DOF:DOF*2]),
                kwargs={"time_interval_s": args.smooth_move_to_start_s, "steps": 100},
            ))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        # Stream at control_hz × speed.
        print(f"Streaming at {args.control_hz} Hz × {args.speed}x speed "
              f"(~{n_frames / (args.control_hz * args.speed):.1f}s)\n"
              "(Ctrl-C to stop)\n")
        dt_s = 1.0 / (args.control_hz * args.speed)
        t_start = time.monotonic()
        for i in range(n_frames):
            t_target = t_start + i * dt_s
            sleep = t_target - time.monotonic()
            if sleep > 0:
                time.sleep(sleep)
            if rc.follower_l is not None:
                rc.follower_l.command_joint_pos(joints_u[i, :DOF])
            if rc.follower_r is not None:
                rc.follower_r.command_joint_pos(joints_u[i, DOF:DOF*2])
            if i % args.control_hz == 0:
                print(f"  t={i / args.control_hz:6.2f}s / {n_frames / args.control_hz:.2f}s")
        print("Replay complete.")
    finally:
        try:
            rc.return_to_home()
        except Exception as e:
            print(f"return_to_home failed: {e}")


if __name__ == "__main__":
    main()
