#!/usr/bin/env python3
"""Stream Franka Emika Panda joint states via rosbridge and visualize them in MuJoCo.

Subscribes to /joint_states over roslibpy (rosbridge websocket) and drives
a MuJoCo passive viewer with the live joint positions.

Prerequisites:
  - rosbridge_websocket running on the robot
  - SSH tunnel:  ssh -L 9090:localhost:9090 user@robot -p <port>

Usage:
  python3 stream_panda_with_vis.py
  python3 stream_panda_with_vis.py --host localhost --port 9090 --topic /joint_states
"""
import argparse
import signal
import sys
import os
from typing import Optional
import threading
import time

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import mujoco
import numpy as np
import roslibpy

from ExoConfigs.panda_exo import PANDA_BASE_ONLY_CONFIG
from exo_utils import position_exoskeleton_meshes, get_link_poses_from_robot

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04  # finger joint range 0–0.04 m (each finger)

MAX_JUMP_RAD = 0.5  # reject messages where any joint jumps more than this in one step

latest_positions = np.zeros(N_ARM_JOINTS)
latest_gripper_width = 1.0  # 0=closed, 1=open
_lock = threading.Lock()
_initialized = False


def _arm_joint_index_for_name(name: str) -> Optional[int]:
    """Return 0-based arm joint index (0..6) for standard Panda / FR3 joint_state names."""
    for j in range(1, N_ARM_JOINTS + 1):
        if name in (
            f"panda_joint{j}",
            f"joint{j}",
            f"fr3_joint{j}",
            f"fr3v2_joint{j}",
        ):
            return j - 1
    return None


def _is_gripper_finger_joint(name: str) -> bool:
    return name in ("panda_finger_joint1", "finger_joint1", "fr3_finger_joint1")


def on_joint_state(msg: dict) -> None:
    global latest_positions, latest_gripper_width, _initialized
    names = msg.get("name") or []
    positions = msg.get("position") or []

    arm_pos = np.zeros(N_ARM_JOINTS)
    filled = [False] * N_ARM_JOINTS
    gripper_width = latest_gripper_width

    for i, name in enumerate(names):
        if i >= len(positions):
            break
        idx = _arm_joint_index_for_name(name)
        if idx is not None:
            arm_pos[idx] = positions[i]
            filled[idx] = True
        if _is_gripper_finger_joint(name):
            gripper_width = positions[i] / GRIPPER_POS_MAX

    if not all(filled):
        return

    with _lock:
        if _initialized and np.any(np.abs(arm_pos - latest_positions) > MAX_JUMP_RAD):
            return
        latest_positions = arm_pos
        latest_gripper_width = gripper_width
        _initialized = True


def main() -> int:
    p = argparse.ArgumentParser(description="Stream Panda joint states into MuJoCo viewer via rosbridge.")
    p.add_argument("--host", default="localhost", help="rosbridge host")
    p.add_argument("--port", type=int, default=9090, help="rosbridge websocket port")
    p.add_argument("--topic", default="/joint_states", help="JointState topic name")
    args = p.parse_args()

    robot_config = PANDA_BASE_ONLY_CONFIG
    if hasattr(robot_config, "exo_link_alpha"):
        robot_config.exo_link_alpha = 1

    print(f"Model: {robot_config.base_xml_path}")
    model = mujoco.MjModel.from_xml_string(robot_config.xml)
    data = mujoco.MjData(model)

    n_arm = min(N_ARM_JOINTS, data.qpos.size)
    has_gripper = data.ctrl.size > N_ARM_JOINTS and data.qpos.size >= N_ARM_JOINTS + 2

    client = roslibpy.Ros(host=args.host, port=args.port)
    sub = roslibpy.Topic(client, args.topic, "sensor_msgs/msg/JointState")
    sub.subscribe(on_joint_state)

    print(f"Connecting to ws://{args.host}:{args.port} ...")
    client.run()

    deadline = time.time() + 5.0
    while not client.is_connected and time.time() < deadline:
        time.sleep(0.05)

    if not client.is_connected:
        print("Failed to connect to rosbridge.")
        return 1

    print(f"Connected! Subscribing to {args.topic}")
    print("Launching MuJoCo viewer...\n")

    viewer = mujoco.viewer.launch_passive(model, data, show_left_ui=False, show_right_ui=False)

    def shutdown(*_):
        try:
            sub.unsubscribe()
            client.terminate()
        except Exception:
            pass

    signal.signal(signal.SIGINT, shutdown)
    signal.signal(signal.SIGTERM, shutdown)

    while viewer.is_running():
        with _lock:
            pos = latest_positions.copy()
            gw = latest_gripper_width

        data.qpos[:n_arm] = data.ctrl[:n_arm] = pos

        if has_gripper:
            g_pos_m = gw * GRIPPER_POS_MAX
            data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = g_pos_m

        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(
            robot_config, model, data, get_link_poses_from_robot(robot_config, model, data)
        )
        viewer.sync()

    viewer.close()
    shutdown()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
