#!/usr/bin/env python3
"""Command the Franka to each hand-eye calibration pose via rosbridge.

Reads current joint state from /joint_states, holds it for a settling
period, then smoothly interpolates to each target with velocity limiting.
Publishes on /gello/joint_states at 30Hz for the JointImpedanceController.

Press Enter to advance to the next pose, 'b' to go back, 'q' to quit.

Prerequisites:
  - Docker container running with blue_fixed_gello_command_config (no GELLO publisher)
  - rosbridge accessible at localhost:9090

Usage:
  python hand_eye_calib/command_calib_poses.py
  python hand_eye_calib/command_calib_poses.py --duration 8
"""
import argparse
import threading
import time
import sys

import numpy as np
import roslibpy

N_ARM_JOINTS = 7
JOINT_NAMES = [f"fr3_joint{j}" for j in range(1, N_ARM_JOINTS + 1)]
PUBLISH_HZ = 30.0
MOVE_DURATION_SEC = 6.0
SETTLE_SEC = 3.0  # publish current position for this long before first move
MAX_JOINT_VEL = 0.15  # rad/s max per joint — very conservative

# Same calibration poses as calibrate.py
CALIB_POSES = [
    [0.0, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [0.0, -0.785, 0.0, -2.356, -0.8,  2.3,   2.3],
    [0.0, -0.785, 0.0, -2.2,   -0.1,  1.4,   2.7],
    [ 0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [-0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [ 0.3, -0.6,   0.0, -2.0,   -0.5,  1.8,   1.5],
    [-0.3, -1.0,   0.0, -2.5,    0.3,  1.2,   2.8],
    [ 0.0, -0.4,   0.0, -2.0,    0.0,  1.571, 2.3],
    [ 0.0, -1.2,   0.0, -2.5,    0.0,  1.571, 2.3],
    [ 0.0, -0.785, 0.5, -2.356,  0.0,  1.571, 2.3],
    [ 0.0, -0.785,-0.5, -2.356,  0.0,  1.571, 2.3],
    [ 0.2, -0.6,   0.3, -1.8,   -0.4,  2.0,   1.8],
    [-0.2, -0.9,  -0.3, -2.6,    0.5,  1.0,   0.5],
    [ 0.5, -0.5,   0.2, -2.1,    0.8,  1.8,   0.8],
    [-0.5, -1.1,  -0.2, -2.4,   -0.6,  2.5,   1.2],
    [ 0.1, -0.7,   0.4, -1.9,    0.3,  0.8,   2.0],
    [-0.1, -0.85, -0.4, -2.7,   -0.3,  1.6,  -0.5],
]


def _arm_joint_index(name):
    for j in range(1, N_ARM_JOINTS + 1):
        if name in (f"panda_joint{j}", f"joint{j}", f"fr3_joint{j}", f"fr3v2_joint{j}"):
            return j - 1
    return None


def smooth_interp(t):
    """Smooth step: 6t^5 - 15t^4 + 10t^3 (zero velocity AND acceleration at endpoints)."""
    t = max(0.0, min(1.0, t))
    return t * t * t * (t * (t * 6.0 - 15.0) + 10.0)


def main() -> int:
    p = argparse.ArgumentParser(description="Command Franka to calibration poses via rosbridge.")
    p.add_argument("--host", default="localhost", help="rosbridge host")
    p.add_argument("--port", type=int, default=9090, help="rosbridge websocket port")
    p.add_argument("--start", type=int, default=0, help="Starting pose index (0-based)")
    p.add_argument("--duration", type=float, default=MOVE_DURATION_SEC,
                   help="Seconds to interpolate between poses")
    p.add_argument("--max-vel", type=float, default=MAX_JOINT_VEL,
                   help="Max joint velocity in rad/s")
    p.add_argument("--gripper-width-percent", type=float, default=1.0,
                   help="Gripper width percent (0=closed, 1=open)")
    args = p.parse_args()

    client = roslibpy.Ros(host=args.host, port=args.port)

    joint_pub = roslibpy.Topic(client, "/gello/joint_states",
                               "sensor_msgs/msg/JointState", queue_length=1)
    gripper_pub = roslibpy.Topic(client, "/gripper/gripper_client/target_gripper_width_percent",
                                 "std_msgs/msg/Float32", queue_length=1)

    # Subscribe to /joint_states to read current robot position
    current_q = np.zeros(N_ARM_JOINTS)
    got_state = threading.Event()
    state_lock = threading.Lock()

    def on_joint_state(msg):
        names = msg.get("name", [])
        positions = msg.get("position", [])
        q = np.zeros(N_ARM_JOINTS)
        filled = 0
        for n, p_val in zip(names, positions):
            idx = _arm_joint_index(n)
            if idx is not None:
                q[idx] = p_val
                filled += 1
        if filled == N_ARM_JOINTS:
            with state_lock:
                current_q[:] = q
            got_state.set()

    joint_sub = roslibpy.Topic(client, "/joint_states",
                               "sensor_msgs/msg/JointState", queue_length=1)
    joint_sub.subscribe(on_joint_state)

    print(f"Connecting to ws://{args.host}:{args.port} ...")
    client.run()

    deadline = time.time() + 5.0
    while not client.is_connected and time.time() < deadline:
        time.sleep(0.05)
    if not client.is_connected:
        print("Failed to connect to rosbridge.")
        return 1
    print("Connected!")

    # Wait for first joint state
    print("Waiting for joint states from robot...")
    if not got_state.wait(timeout=10.0):
        print("No joint states received. Is the Franka driver running?")
        return 1

    with state_lock:
        start_q = current_q.copy()
    print(f"Current joints: {['%.3f' % v for v in start_q]}")

    # The published position — updated by publish_loop via velocity-limited stepping
    pub_q = start_q.copy()
    goal_q = start_q.copy()
    pub_lock = threading.Lock()
    stop_event = threading.Event()

    def publish_loop():
        """Publish at PUBLISH_HZ, moving pub_q toward goal_q with velocity limit."""
        period = 1.0 / PUBLISH_HZ
        max_step = args.max_vel * period  # max rad per tick

        while not stop_event.is_set():
            with pub_lock:
                diff = goal_q - pub_q
                # Clamp each joint's step to max_step
                step = np.clip(diff, -max_step, max_step)
                pub_q[:] += step
                q = pub_q.copy()

            secs = int(time.time())
            nsecs = int((time.time() - secs) * 1e9)
            joint_pub.publish(roslibpy.Message({
                "header": {
                    "stamp": {"sec": secs, "nanosec": nsecs},
                    "frame_id": "fr3_link0",
                },
                "name": JOINT_NAMES,
                "position": q.tolist(),
                "velocity": [],
                "effort": [],
            }))
            gripper_pub.publish(roslibpy.Message({
                "data": args.gripper_width_percent,
            }))
            time.sleep(period)

    # Settle: publish current position for a few seconds
    print(f"Settling: publishing current position for {SETTLE_SEC:.0f}s...")
    pub_thread = threading.Thread(target=publish_loop, daemon=True)
    pub_thread.start()
    time.sleep(SETTLE_SEC)
    print("Settled. Ready to command poses.\n")

    def set_goal(target):
        with pub_lock:
            goal_q[:] = np.array(target, dtype=np.float64)

    pose_idx = max(0, min(args.start, len(CALIB_POSES) - 1))

    def show_pose():
        q_str = "  ".join(f"{v:+.3f}" for v in CALIB_POSES[pose_idx])
        with pub_lock:
            dist = np.max(np.abs(goal_q - pub_q))
        eta = dist / args.max_vel if args.max_vel > 0 else 0
        print(f"\n[Pose {pose_idx + 1}/{len(CALIB_POSES)}]  {q_str}")
        print(f"  Max joint travel: {np.degrees(dist):.1f}° — ETA ~{eta:.1f}s at {np.degrees(args.max_vel):.1f}°/s")
        print("  Enter=next  b=back  q=quit  <number>=jump to pose")

    # Move to first pose
    set_goal(CALIB_POSES[pose_idx])
    show_pose()

    try:
        while True:
            cmd = input("> ").strip().lower()
            if cmd == "q":
                break
            elif cmd == "b":
                pose_idx = max(0, pose_idx - 1)
            elif cmd == "":
                if pose_idx >= len(CALIB_POSES) - 1:
                    print("  (already at last pose)")
                    continue
                pose_idx += 1
            else:
                try:
                    idx = int(cmd) - 1
                    if 0 <= idx < len(CALIB_POSES):
                        pose_idx = idx
                    else:
                        print(f"  Pose number must be 1-{len(CALIB_POSES)}")
                        continue
                except ValueError:
                    print("  Unknown command")
                    continue

            set_goal(CALIB_POSES[pose_idx])
            show_pose()

    except (KeyboardInterrupt, EOFError):
        pass
    finally:
        print("\nStopping publisher...")
        stop_event.set()
        pub_thread.join(timeout=1.0)
        joint_sub.unsubscribe()
        joint_pub.unadvertise()
        gripper_pub.unadvertise()
        client.terminate()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
