#!/usr/bin/env python3
"""Stream Franka Emika Panda joint states via rosbridge and render onto live camera feed.

Subscribes to /joint_states over roslibpy for joint positions, captures
camera frames, detects ArUco markers for camera pose estimation, then
renders the MuJoCo robot overlaid on the live video.

Prerequisites:
  - rosbridge_websocket running on the robot
  - SSH tunnel:  ssh -L 9090:localhost:9090 user@robot -p <port>

Usage:
  python3 stream_panda_with_cam.py
  python3 stream_panda_with_cam.py --camera 0 --host localhost --port 9090
"""
import argparse
import signal
import sys
import os
from typing import Optional
import threading
import time

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import mujoco
import cv2
import numpy as np
import roslibpy
from mujoco.renderer import Renderer

from ExoConfigs.panda_exo import PANDA_BASE_ONLY_CONFIG as robot_config
from exo_utils import detect_and_set_link_poses, render_from_camera_pose

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04
MAX_JUMP_RAD = 0.5

latest_positions = np.zeros(N_ARM_JOINTS)
latest_gripper_width = 1.0
_lock = threading.Lock()
_initialized = False


def _arm_joint_index_for_name(name: str) -> Optional[int]:
    """Return 0-based arm joint index (0..6) for standard Panda / FR3 joint_state names."""
    for j in range(1, N_ARM_JOINTS + 1):
        if name in (
            f"panda_joint{j}",
            f"joint{j}",
            f"fr3_joint{j}",
            f"fr3v2_joint{j}",
        ):
            return j - 1
    return None


def _is_gripper_finger_joint(name: str) -> bool:
    return name in ("panda_finger_joint1", "finger_joint1", "fr3_finger_joint1")


def on_joint_state(msg: dict) -> None:
    global latest_positions, latest_gripper_width, _initialized
    names = msg.get("name") or []
    positions = msg.get("position") or []

    arm_pos = np.zeros(N_ARM_JOINTS)
    filled = [False] * N_ARM_JOINTS
    gripper_width = latest_gripper_width

    for i, name in enumerate(names):
        if i >= len(positions):
            break
        idx = _arm_joint_index_for_name(name)
        if idx is not None:
            arm_pos[idx] = positions[i]
            filled[idx] = True
        if _is_gripper_finger_joint(name):
            gripper_width = positions[i] / GRIPPER_POS_MAX

    if not all(filled):
        return

    with _lock:
        if _initialized and np.any(np.abs(arm_pos - latest_positions) > MAX_JUMP_RAD):
            return
        latest_positions = arm_pos
        latest_gripper_width = gripper_width
        _initialized = True


def main() -> int:
    p = argparse.ArgumentParser(description="Stream Panda joints via rosbridge, render onto camera feed.")
    p.add_argument("--host", default="localhost", help="rosbridge host")
    p.add_argument("--port", type=int, default=9090, help="rosbridge websocket port")
    p.add_argument("--topic", default="/joint_states", help="JointState topic name")
    p.add_argument("--camera", type=int, default=0, help="Camera device ID (default: 0)")
    args = p.parse_args()

    print(f"Using exoskeleton config: {robot_config.name})")

    model = mujoco.MjModel.from_xml_string(robot_config.xml)
    data = mujoco.MjData(model)
    n_arm = min(N_ARM_JOINTS, data.qpos.size)
    has_gripper = data.ctrl.size > N_ARM_JOINTS and data.qpos.size >= N_ARM_JOINTS + 2

    # Initialize camera
    print(f"Initializing camera device {args.camera}...")
    cap = cv2.VideoCapture(args.camera)
    if not cap.isOpened():
        print(f"Failed to open camera device {args.camera}")
        return 1

    for _ in range(10):
        ret, frame = cap.read()
    if not ret:
        print("Failed to read from camera")
        return 1

    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    height, width = rgb.shape[:2]
    print(f"Camera resolution: {width}x{height}")

    renderer = Renderer(model, height=height, width=width)

    # Connect to rosbridge
    client = roslibpy.Ros(host=args.host, port=args.port)
    sub = roslibpy.Topic(client, args.topic, "sensor_msgs/msg/JointState")
    sub.subscribe(on_joint_state)

    print(f"Connecting to ws://{args.host}:{args.port} ...")
    client.run()

    deadline = time.time() + 5.0
    while not client.is_connected and time.time() < deadline:
        time.sleep(0.05)

    if not client.is_connected:
        print("Failed to connect to rosbridge.")
        cap.release()
        return 1

    print(f"Connected! Subscribing to {args.topic}")
    print("Starting camera overlay loop...\n")

    running = True

    def shutdown(*_):
        nonlocal running
        running = False

    signal.signal(signal.SIGINT, shutdown)
    signal.signal(signal.SIGTERM, shutdown)

    #cam_K = None
    cam_K = np.array(
    [   
        [1.58847596e03, 0.0, 9.59500000e02],
        [0.0, 1.58847596e03, 5.39500000e02],
        [0.0, 0.0, 1.0],
    ],  
    dtype=np.float64,
    )


    while running:
        ret, frame = cap.read()
        if not ret:
            print("Failed to read frame from camera")
            continue
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Set joint positions from rosbridge data
        with _lock:
            pos = latest_positions.copy()
            gw = latest_gripper_width

        data.qpos[:n_arm] = data.ctrl[:n_arm] = pos

        if has_gripper:
            g_pos_m = gw * GRIPPER_POS_MAX
            data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = g_pos_m

        mujoco.mj_forward(model, data)

        # Detect markers and get camera pose
        try:
            link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = (
                detect_and_set_link_poses(rgb, model, data, robot_config, cam_K=cam_K)
            )
        except Exception as e:
            print(e)
            cv2.imshow("display", rgb[..., ::-1])
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            continue

        # Render robot from camera perspective and overlay
        rendered = render_from_camera_pose(model, data, camera_pose_world, cam_K, *rgb.shape[:2])
        overlay = (rgb.astype(float) * 0.5 + rendered.astype(float) * 0.5).astype(np.uint8)
        display = np.hstack([corners_vis, rendered, overlay])

        cv2.imshow("display", display[..., ::-1])
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    try:
        sub.unsubscribe()
        client.terminate()
    except Exception:
        pass
    print("Done!")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
