#!/usr/bin/env python3
"""Deploy IK-recovered joint states to the real Panda robot.

Handles the FULL startup sequence:
  1. Ensures rosbridge is running
  2. Starts publishing to /gello/joint_states IMMEDIATELY
  3. THEN launches the controller (so it sees targets from the start)
  4. Waits for joint states, syncs pub_q to actual position
  5. Settles, then ramps to IK targets

Usage:
  cd /data/cameron/para/panda_streaming
  MUJOCO_GL=egl python deploy_ik_sequence.py \
    --data_dir /data/cameron/panda_data/single_demo_sanity \
    --episode 0 --stride 1
"""
import argparse
import json
import os
import sys
import signal
import time
import threading
import subprocess

import numpy as np
import roslibpy
import mujoco
from scipy.spatial.transform import Rotation as Rot

sys.path.insert(0, os.path.dirname(__file__))
from data_panda_para import T_CAM_WORLD, CAM_K, N_ARM_JOINTS, GRIPPER_POS_MAX

FIXED_EEF_ROT = np.diag([1.0, -1.0, -1.0])
HOME_Q = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785])
JOINT_NAMES = [f"fr3_joint{j}" for j in range(1, 8)]
PUBLISH_HZ = 30.0
MAX_JOINT_VEL = 0.05  # rad/s (~2.9 deg/s)
SETTLE_SEC = 5.0
WAYPOINT_DWELL_SEC = 0.3

ROBOT_IP = "100.126.97.121"  # Tailscale IP


def mujoco_ik(mj_model, mj_data, target_pos, target_rot, q_init, hand_id,
              max_iter=300, damping=1e-4):
    mj_data.qpos[:7] = q_init.copy()
    mujoco.mj_forward(mj_model, mj_data)
    for _ in range(max_iter):
        cur_pos = mj_data.xpos[hand_id].copy()
        cur_rot = mj_data.xmat[hand_id].reshape(3, 3).copy()
        pos_err = target_pos - cur_pos
        R_err = target_rot @ cur_rot.T
        angle = np.arccos(np.clip((np.trace(R_err) - 1) / 2, -1, 1))
        if angle < 1e-6:
            rot_err = np.zeros(3)
        else:
            rot_err = angle / (2 * np.sin(angle + 1e-10)) * np.array([
                R_err[2, 1] - R_err[1, 2], R_err[0, 2] - R_err[2, 0],
                R_err[1, 0] - R_err[0, 1]])
        if np.linalg.norm(pos_err) < 1e-4 and np.linalg.norm(rot_err) < 1e-3:
            break
        jacp, jacr = np.zeros((3, mj_model.nv)), np.zeros((3, mj_model.nv))
        mujoco.mj_jacBody(mj_model, mj_data, jacp, jacr, hand_id)
        J = np.vstack([jacp[:, :7], 0.3 * jacr[:, :7]])
        err = np.concatenate([pos_err, 0.3 * rot_err])
        dq = np.linalg.solve(J.T @ J + damping * np.eye(7), J.T @ err)
        mj_data.qpos[:7] += np.clip(dq, -0.1, 0.1)
        mujoco.mj_forward(mj_model, mj_data)
    return mj_data.qpos[:7].copy()


def _arm_joint_index(name):
    for j in range(1, 8):
        if name in (f"panda_joint{j}", f"joint{j}", f"fr3_joint{j}", f"fr3v2_joint{j}"):
            return j - 1
    return None


def ssh_cmd(cmd):
    """Run a command on the robot box via SSH."""
    full = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no cameron@{ROBOT_IP} \"{cmd}\""
    return subprocess.run(full, shell=True, capture_output=True, text=True, timeout=30)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", required=True)
    p.add_argument("--episode", type=int, default=0)
    p.add_argument("--stride", type=int, default=1)
    p.add_argument("--max_vel", type=float, default=MAX_JOINT_VEL)
    p.add_argument("--dry_run", action="store_true")
    args = p.parse_args()

    # Load episodes
    with open(os.path.join(args.data_dir, "episodes.json")) as f:
        episodes = json.load(f)["episodes"]
    ep = episodes[args.episode]
    print(f"Episode {args.episode}: frames {ep['start']}-{ep['end']}")

    # MuJoCo for FK + IK
    from ExoConfigs.panda_exo_handeye_4x2 import PANDA_HANDEYE_4X2_CONFIG
    mj_model = mujoco.MjModel.from_xml_string(PANDA_HANDEYE_4X2_CONFIG.xml)
    mj_data = mujoco.MjData(mj_model)
    hand_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "hand")

    # Compute IK targets
    frame_indices = list(range(ep["start"], ep["end"] + 1, args.stride))
    print(f"Computing IK for {len(frame_indices)} frames (stride={args.stride})...")

    targets = []
    prev_q = HOME_Q.copy()
    for frame_idx in frame_indices:
        js = np.load(os.path.join(args.data_dir, f"{frame_idx:06d}.npy")).astype(np.float64)
        gw = js[7] if len(js) > 7 else 1.0
        mj_data.qpos[:7] = js[:7]
        mj_data.qpos[7] = mj_data.qpos[8] = gw * GRIPPER_POS_MAX
        mujoco.mj_forward(mj_model, mj_data)
        target_pos = mj_data.xpos[hand_id].copy()
        ik_q = mujoco_ik(mj_model, mj_data, target_pos, FIXED_EEF_ROT, prev_q, hand_id)
        mj_data.qpos[:7] = ik_q
        mujoco.mj_forward(mj_model, mj_data)
        err_mm = np.linalg.norm(mj_data.xpos[hand_id] - target_pos) * 1000
        targets.append((ik_q.copy(), gw, frame_idx, err_mm))
        prev_q = ik_q.copy()
        print(f"  Frame {frame_idx:06d}: err={err_mm:.2f}mm, gripper={gw:.2f}")

    print(f"\nMax IK error: {max(t[3] for t in targets):.2f}mm across {len(targets)} targets")
    if args.dry_run:
        print("Dry run — not sending to robot.")
        return 0

    # ── Startup: fixed_publisher → controller → connect → deploy ──
    print("\n=== Starting robot stack ===")

    # Step 1: Kill old processes, restart container
    print("Restarting container...")
    ssh_cmd("docker exec gello-ros2-cameron bash -c 'pkill -9 -f ros2; pkill -9 -f franka; pkill -9 -f gello; pkill -9 -f realsense; pkill -9 -f sync_recorder' 2>/dev/null")
    time.sleep(2)
    ssh_cmd("docker stop gello-ros2-cameron; docker start gello-ros2-cameron")
    time.sleep(2)

    # Step 2: Start rosbridge
    print("Starting rosbridge...")
    ssh_cmd("docker exec -d gello-ros2-cameron bash -c 'source /opt/ros/humble/setup.bash && source /workspace/ros2/install/setup.bash && ros2 launch rosbridge_server rosbridge_websocket_launch.xml'")
    time.sleep(5)

    # Step 3: Start fixed_gello_state_publisher FIRST
    # This publishes a safe default position at 30Hz so the controller
    # has targets immediately when it activates (prevents robot drop)
    print("Starting fixed position publisher...")
    ssh_cmd("docker exec -d gello-ros2-cameron bash -c 'source /opt/ros/humble/setup.bash && source /workspace/ros2/install/setup.bash && python3 /workspace/ros2/scripts/fixed_gello_state_publisher.py > /tmp/fixed_pub.log 2>&1'")
    time.sleep(3)

    # Step 4: Start controller (fixed publisher is already sending targets)
    print("Starting controller...")
    ssh_cmd("docker exec -d gello-ros2-cameron bash -c 'source /opt/ros/humble/setup.bash && source /workspace/ros2/install/setup.bash && ros2 launch gello_launcher main_system.launch.py main_config_file:=/workspace/ros2/src/gello_launcher/config/blue_fixed_gello_command_config.yaml'")
    print("Waiting 25s for controller to S-curve to default position...")
    time.sleep(25)

    # Step 5: Set up tunnel and connect
    print("Setting up tunnel...")
    subprocess.run("kill $(ps aux | grep 'ssh.*-L 9090' | grep -v grep | awk '{print $2}') 2>/dev/null",
                   shell=True, capture_output=True)
    tunnel = subprocess.Popen(
        f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no -N -L 9090:localhost:9090 cameron@{ROBOT_IP}",
        shell=True)
    time.sleep(3)

    client = roslibpy.Ros(host='localhost', port=9090)
    joint_pub = roslibpy.Topic(client, "/gello/joint_states",
                                "sensor_msgs/msg/JointState", queue_length=1)
    gripper_pub = roslibpy.Topic(client, "/gripper/gripper_client/target_gripper_width_percent",
                                  "std_msgs/msg/Float32", queue_length=1)

    current_q = np.zeros(7)
    got_state = threading.Event()
    state_lock = threading.Lock()

    def on_joint_state(msg):
        names, positions = msg.get("name", []), msg.get("position", [])
        q = np.zeros(7)
        filled = 0
        for n, pv in zip(names, positions):
            idx = _arm_joint_index(n)
            if idx is not None:
                q[idx] = pv
                filled += 1
        if filled == 7:
            with state_lock:
                current_q[:] = q
            got_state.set()

    joint_sub = roslibpy.Topic(client, "/joint_states",
                                "sensor_msgs/msg/JointState", queue_length=1)
    joint_sub.subscribe(on_joint_state)
    client.run()

    print("Waiting for joint states...")
    if not got_state.wait(timeout=10.0):
        print("ERROR: No joint states! Is FCI enabled in Desk UI?")
        return 1

    with state_lock:
        start_q = current_q.copy()
    print(f"Current position: {['%.3f' % v for v in start_q]}")

    # Step 6: Start our velocity-limited publish loop (takes over from fixed publisher)
    pub_q = start_q.copy()
    goal_q = start_q.copy()
    goal_grip = 1.0
    pub_lock = threading.Lock()
    stop_event = threading.Event()

    def publish_loop():
        period = 1.0 / PUBLISH_HZ
        max_step = args.max_vel / PUBLISH_HZ
        while not stop_event.is_set():
            with pub_lock:
                diff = goal_q - pub_q
                step = np.clip(diff, -max_step, max_step)
                pub_q[:] += step
                q = pub_q.copy()
                grip = goal_grip

            secs = int(time.time())
            nsecs = int((time.time() - secs) * 1e9)
            joint_pub.publish(roslibpy.Message({
                "header": {"stamp": {"sec": secs, "nanosec": nsecs}, "frame_id": "fr3_link0"},
                "name": JOINT_NAMES,
                "position": q.tolist(),
            }))
            gripper_pub.publish(roslibpy.Message({"data": grip}))
            time.sleep(period)

    pub_thread = threading.Thread(target=publish_loop, daemon=True)
    pub_thread.start()

    # Kill the fixed publisher now that we're publishing
    print("Taking over from fixed publisher...")
    ssh_cmd("docker exec gello-ros2-cameron bash -c 'pkill -f fixed_gello_state_publisher' 2>/dev/null")

    # Settle — continuously track actual position
    print(f"Settling for {SETTLE_SEC:.0f}s...")
    settle_end = time.time() + SETTLE_SEC
    while time.time() < settle_end:
        with state_lock:
            actual = current_q.copy()
        with pub_lock:
            pub_q[:] = actual
            goal_q[:] = actual
        time.sleep(1.0 / PUBLISH_HZ)

    with state_lock:
        settled = current_q.copy()
    print(f"Settled at: {['%.3f' % v for v in settled]}")

    running = True
    def shutdown(*_):
        nonlocal running
        running = False
        print("\nStopping! Robot will hold current position.")
    signal.signal(signal.SIGINT, shutdown)

    # ── Step 6: Execute trajectory ──
    try:
        print(f"\n{'='*60}")
        print(f"Moving through {len(targets)} waypoints at {np.degrees(args.max_vel):.1f} deg/s")
        print(f"{'='*60}\n")

        for i, (ik_q, gw, frame_idx, err_mm) in enumerate(targets):
            if not running:
                break

            with pub_lock:
                dist = np.max(np.abs(ik_q - pub_q))
            eta = dist / args.max_vel if args.max_vel > 0 else 0

            print(f"[{i+1}/{len(targets)}] Frame {frame_idx:06d} "
                  f"(max travel: {np.degrees(dist):.1f}°, ETA ~{eta:.1f}s, "
                  f"gripper={gw:.2f})", flush=True)

            with pub_lock:
                goal_q[:] = ik_q
                goal_grip = gw

            # Wait for pub_q to reach goal
            while running:
                with pub_lock:
                    remaining = np.max(np.abs(goal_q - pub_q))
                if remaining < 0.005:
                    break
                time.sleep(0.1)

            if running:
                print(f"  Reached.", flush=True)
                time.sleep(WAYPOINT_DWELL_SEC)

        print(f"\nSequence complete!")

    finally:
        stop_event.set()
        pub_thread.join(timeout=1.0)
        joint_sub.unsubscribe()
        joint_pub.unadvertise()
        gripper_pub.unadvertise()
        client.terminate()
        tunnel.terminate()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
