#!/usr/bin/env python3
"""Simple RGB and kinematics dataset recording script for the Panda arm.

Records RGB images and streamed Panda joint states (including gripper width)
at a fixed framerate, with an optional rendering overlay like simple_dataset_record.py.

Joint states are received via roslibpy (rosbridge websocket) — same as
stream_panda_with_cam.py. Requires the SSH tunnel + rosbridge to be running
(use scripts/start_robot_server.sh).
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import threading
import mujoco
import cv2
import numpy as np
import time
import random
import string
import argparse
import signal
from pathlib import Path

import roslibpy

from ExoConfigs.panda_exo import PANDA_BASE_ONLY_CONFIG, VIRTUAL_GRIPPER_BODY_NAME
from exo_utils import (
    detect_and_set_link_poses,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    get_link_poses_from_robot,
    get_body_pose_in_world,
)

# Configuration
fps = 5  # Recording framerate
frame_interval = 1.0 / fps

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04
MAX_JUMP_RAD = 0.5

latest_positions = np.zeros(N_ARM_JOINTS, dtype=np.float64)
latest_gripper_width = 1.0  # [0,1], default open
_lock = threading.Lock()
_initialized = False


def on_joint_state(msg: dict) -> None:
    global latest_positions, latest_gripper_width, _initialized
    names = msg.get("name") or []
    positions = msg.get("position") or []

    arm_pos = np.zeros(N_ARM_JOINTS)
    found_arm_joints = 0
    gripper_width = latest_gripper_width

    for i, name in enumerate(names):
        if i >= len(positions):
            break
        for j in range(1, N_ARM_JOINTS + 1):
            if name == f"fr3_joint{j}":
                arm_pos[j - 1] = positions[i]
                found_arm_joints += 1
                break
        if name == "fr3_finger_joint1":
            gripper_width = positions[i] / GRIPPER_POS_MAX

    if found_arm_joints < N_ARM_JOINTS:
        return

    with _lock:
        if _initialized and np.any(np.abs(arm_pos - latest_positions) > MAX_JUMP_RAD):
            return
        latest_positions = arm_pos
        latest_gripper_width = gripper_width
        _initialized = True


parser = argparse.ArgumentParser(description="Record RGB and Panda joint state dataset via rosbridge")
parser.add_argument("--host", type=str, default="localhost", help="rosbridge host (default: localhost via SSH tunnel)")
parser.add_argument("--port", type=int, default=9090, help="rosbridge websocket port (default: 9090)")
parser.add_argument("--topic", type=str, default="/joint_states", help="JointState topic (default: /joint_states)")
parser.add_argument("--camera", type=int, default=0, help="Camera device ID (default: 0)")
parser.add_argument("--render", action="store_true", help="Enable rendering visualization during recording")
parser.add_argument("--show_rgb", action="store_true", help="Show raw RGB stream during recording")
parser.add_argument("--dont_save", action="store_true", help="Don't save images and joint states")
args = parser.parse_args()


def make_unique_id() -> str:
    return "".join(random.choices(string.ascii_letters + string.digits, k=6))


unique_id = make_unique_id()
output_dir = Path(f"scratch/rgb_joints_capture_panda_{unique_id}")
if not args.dont_save:
    output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")

# Exoskeleton / model config (Panda base-only exo)
robot_config = PANDA_BASE_ONLY_CONFIG
if hasattr(robot_config, "exo_alpha"):
    robot_config.exo_alpha = 0.2
if hasattr(robot_config, "aruco_alpha"):
    robot_config.aruco_alpha = 0.8

camera_device = args.camera

print(f"Using Panda exoskeleton config: {robot_config.name}")
print(f"Recording at {fps} fps")
print(f"Initializing camera device {camera_device}...")

# Load model from config
model = mujoco.MjModel.from_xml_string(robot_config.xml)
data = mujoco.MjData(model)
n_arm = min(N_ARM_JOINTS, data.qpos.size)
has_gripper = data.ctrl.size > N_ARM_JOINTS and data.qpos.size >= N_ARM_JOINTS + 2

# Connect to rosbridge
print(f"Connecting to rosbridge at ws://{args.host}:{args.port} ...")
client = roslibpy.Ros(host=args.host, port=args.port)
sub = roslibpy.Topic(client, args.topic, "sensor_msgs/msg/JointState")
sub.subscribe(on_joint_state)
client.run()

deadline = time.time() + 5.0
while not client.is_connected and time.time() < deadline:
    time.sleep(0.05)
if not client.is_connected:
    raise RuntimeError("Failed to connect to rosbridge.")
print(f"Connected! Subscribing to {args.topic}\n")

# Initialize camera
cap = cv2.VideoCapture(camera_device)
if not cap.isOpened():
    raise RuntimeError(f"Failed to open camera device {camera_device}")

ret, frame = cap.read()
while not ret:
    ret, frame = cap.read()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
height, width = rgb.shape[:2]
print(f"Camera resolution: {width}x{height}")

# Downsampled resolution (half size for saving, lower-res for visualization)
ds_height_save, ds_width_save = height // 2, width // 2
vis_res_ds_factor = 5
ds_height_vis, ds_width_vis = height // vis_res_ds_factor, width // vis_res_ds_factor

# Initialize renderer if rendering is enabled
renderer = None
if args.render:
    from mujoco.renderer import Renderer
    renderer = Renderer(model, height=ds_height_vis, width=ds_width_vis)
    print("Rendering enabled")

# Use same initial intrinsics as stream_panda_with_cam.py for stability
cam_K = np.array(
    [
        [1.58847596e03, 0.0, 9.59500000e02],
        [0.0, 1.58847596e03, 5.39500000e02],
        [0.0, 0.0, 1.0],
    ],
    dtype=np.float64,
)
timestep = 0

running = True

def shutdown(*_):
    global running
    running = False

signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)

print("\n" + "=" * 60)
print("Recording started. Press 'q' to quit.")
print("=" * 60)

while running:
    frame_start_time = time.time()

    # Read latest joint state from rosbridge callback
    with _lock:
        pos = latest_positions.copy()
        gw = latest_gripper_width

    # Capture frame
    ret, frame = cap.read()
    if not ret:
        print("Failed to read frame from camera")
        continue
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Apply latest joint state to sim (kinematics only)
    data.qpos[:n_arm] = data.ctrl[:n_arm] = pos
    if has_gripper:
        g_pos_m = gw * GRIPPER_POS_MAX
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = g_pos_m
    mujoco.mj_forward(model, data)

    # Get camera pose via ArUco
    camera_pose_world = None
    try:
        link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
            rgb, model, data, robot_config, cam_K=cam_K
        )
        position_exoskeleton_meshes(robot_config, model, data, link_poses)
        mujoco.mj_forward(model, data)
    except Exception as e:
        print(f"Error detecting link poses: {e}")
        camera_pose_world = None

    if cv2.waitKey(1) & 0xFF == ord("q"):
        print("quitting")
        break

    # Downsample image
    rgb_downsampled_save = cv2.resize(
        rgb, (ds_width_save, ds_height_save), interpolation=cv2.INTER_LINEAR
    )
    rgb_downsampled_vis = cv2.resize(
        rgb, (ds_width_vis, ds_height_vis), interpolation=cv2.INTER_LINEAR
    )

    # Save image and joint state (+ gripper width)
    timestep_str = f"{timestep:06d}"
    image_path = output_dir / f"{timestep_str}.png"
    joint_path = output_dir / f"{timestep_str}.npy"
    gripper_pose_path = output_dir / f"{timestep_str}_gripper_pose.npy"
    camera_pose_path = output_dir / f"{timestep_str}_camera_pose.npy"
    cam_K_path = output_dir / f"{timestep_str}_cam_K_norm.npy"

    # Virtual EE pose from MJCF body `virtual_gripper_keypoint` (hand frame offset in panda.xml).
    try:
        gripper_pose_4x4 = get_body_pose_in_world(
            model, data, VIRTUAL_GRIPPER_BODY_NAME
        ).astype(np.float32)
    except ValueError:
        try:
            gripper_pose_4x4 = get_link_poses_from_robot(robot_config, model, data)[
                "larger_base"
            ]
        except Exception:
            gripper_pose_4x4 = np.eye(4, dtype=np.float32)

    if not args.dont_save:
        joint_state_with_gripper = np.concatenate(
            [pos.astype(np.float32), np.array([gw], dtype=np.float32)]
        )
        cv2.imwrite(
            str(image_path),
            cv2.cvtColor(rgb_downsampled_save, cv2.COLOR_RGB2BGR),
        )
        np.save(joint_path, joint_state_with_gripper)
        np.save(gripper_pose_path, gripper_pose_4x4)
        if camera_pose_world is not None and cam_K is not None:
            np.save(camera_pose_path, camera_pose_world)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0] /= rgb.shape[1]
            cam_K_norm[1] /= rgb.shape[0]
            np.save(cam_K_path, cam_K_norm)
        print(
            f"Saved timestep {timestep_str}: {image_path.name}, {joint_path.name} "
            f"(gripper_width={gw:.3f})"
        )

    # Optional rendering
    if args.render and camera_pose_world is not None and cam_K is not None:
        cam_K_low_res = cam_K.copy()
        cam_K_low_res[0] = cam_K_low_res[0] / vis_res_ds_factor
        cam_K_low_res[1] = cam_K_low_res[1] / vis_res_ds_factor
        rendered = render_from_camera_pose(
            model, data, camera_pose_world, cam_K_low_res, *rgb_downsampled_vis.shape[:2]
        )
        overlay = (
            rgb_downsampled_vis.astype(float) * 0.5 + rendered.astype(float) * 0.5
        ).astype(np.uint8)
        display = np.hstack([rgb_downsampled_vis, rendered, overlay])
        display = cv2.resize(
            display, np.array(display.shape[:2][::-1]) * 2, interpolation=cv2.INTER_NEAREST
        )
        cv2.imshow("Recording Panda", cv2.cvtColor(display, cv2.COLOR_RGB2BGR))
        if cv2.waitKey(1) & 0xFF == ord("q"):
            print("quitting")
            break
    elif args.show_rgb:
        cv2.imshow("Recording Panda", cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
        if cv2.waitKey(1) & 0xFF == ord("q"):
            print("quitting")
            break

    timestep += 1

    # Maintain framerate
    elapsed = time.time() - frame_start_time
    sleep_time = max(0, frame_interval - elapsed)
    if sleep_time > 0:
        time.sleep(sleep_time)

cap.release()
cv2.destroyAllWindows()
try:
    sub.unsubscribe()
    client.terminate()
except Exception:
    pass
print("Done!")
