"""Script to run dual spacemouse teleoperation with Isaac Lab manipulation environments."""

"""Launch Isaac Sim Simulator first."""

from dataclasses import dataclass

import tyro

from isaaclab.app import AppLauncher


@dataclass
class TeleopArgs:
    """Dual spacemouse teleoperation for Isaac Lab environments."""
    
    task: str
    """Name of the task."""

    num_envs: int = 1
    """Number of environments to simulate."""

    scene_path: str | None = None
    """Path to config JSON file."""

    library_dir: str | None = None
    """Path to shared USD model library."""


    pos_sensitivity: float = 0.15
    """Position sensitivity factor."""

    rot_sensitivity: float = 0.25
    """Rotation sensitivity factor."""



args_cli, remaining = tyro.cli(TeleopArgs, return_unknown_args=True)

# Parse remaining args for AppLauncher (it expects argparse.Namespace, not a list)
import argparse
parser = argparse.ArgumentParser()
AppLauncher.add_app_launcher_args(parser)
app_args = parser.parse_args(remaining)

app_launcher = AppLauncher(app_args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import carb
import gymnasium as gym
import hid
import json
import omni
import signal
import time
import torch
import weakref
from pathlib import Path

import isaaclab.sim as sim_utils
from isaaclab.assets import RigidObject
from isaaclab.devices import Se3SpaceMouse, Se3SpaceMouseCfg
from isaaclab.markers import VisualizationMarkers, VisualizationMarkersCfg
from isaaclab.utils import math as math_utils
from isaaclab_tasks.utils import parse_env_cfg
import sim_improvement.environments  # noqa: F401
from sim_improvement.environments.mdp.recorders.recorders_config import DatasetRecorderManagerCfg
from sim_improvement.environments.mdp.recorders.hdf5_append_handler import HDF5AppendDatasetFileHandler
from isaaclab.managers.recorder_manager import DatasetExportMode


SPACEMOUSE_PRODUCT_STRINGS = {"SpaceMouse Compact", "SpaceMouse Wireless", "SpaceMouse Wireless BT", "3Dconnexion Universal Receiver"}


class Se3SpaceMouseIndexed(Se3SpaceMouse):
    """Se3SpaceMouse that opens a specific device by index (0-based) among all connected spacemice.

    Modified so both buttons toggle gripper (instead of right button triggering reset).
    """

    def __init__(self, cfg: Se3SpaceMouseCfg, device_index: int = 0):
        self._device_index = device_index
        super().__init__(cfg)

    def _find_device(self):
        """Find the Nth connected spacemouse by enumerating HID paths."""
        found_paths = []
        for _ in range(5):
            seen_paths = set()
            for dev_info in hid.enumerate():
                if dev_info["product_string"] in SPACEMOUSE_PRODUCT_STRINGS:
                    path = dev_info["path"]
                    if path not in seen_paths:
                        seen_paths.add(path)
                        found_paths.append(dev_info)
            if len(found_paths) > self._device_index:
                break
            found_paths.clear()
            time.sleep(1.0)

        if len(found_paths) <= self._device_index:
            raise OSError(
                f"SpaceMouse device_index={self._device_index} not found. "
                f"Found {len(found_paths)} device(s): {[d['product_string'] for d in found_paths]}"
            )

        target = found_paths[self._device_index]
        self._device.close()
        self._device.open_path(target["path"])
        self._device_name = target["product_string"]
        print(f"[SpaceMouse {self._device_index}] Opened: {target['product_string']} @ {target['path']}")

    def _run_device(self):
        """Listener thread - both buttons toggle gripper."""
        from isaaclab.devices.spacemouse.utils import convert_buffer
        while True:
            if self._device_name == "3Dconnexion Universal Receiver":
                data = self._device.read(7 + 6)
            else:
                data = self._device.read(7)
            if data is not None:
                # readings from 6-DoF sensor
                if self._device_name == "3Dconnexion Universal Receiver":
                    if data[0] == 1:
                        self._delta_pos[1] = self.pos_sensitivity * convert_buffer(data[1], data[2])
                        self._delta_pos[0] = self.pos_sensitivity * convert_buffer(data[3], data[4])
                        self._delta_pos[2] = self.pos_sensitivity * convert_buffer(data[5], data[6]) * -1.0
                        self._delta_rot[1] = self.rot_sensitivity * convert_buffer(data[1 + 6], data[2 + 6])
                        self._delta_rot[0] = self.rot_sensitivity * convert_buffer(data[3 + 6], data[4 + 6])
                        self._delta_rot[2] = self.rot_sensitivity * convert_buffer(data[5 + 6], data[6 + 6]) * -1.0
                else:
                    if data[0] == 1:
                        self._delta_pos[1] = self.pos_sensitivity * convert_buffer(data[1], data[2])
                        self._delta_pos[0] = self.pos_sensitivity * convert_buffer(data[3], data[4])
                        self._delta_pos[2] = self.pos_sensitivity * convert_buffer(data[5], data[6]) * -1.0
                    elif data[0] == 2 and not self._read_rotation:
                        self._delta_rot[1] = self.rot_sensitivity * convert_buffer(data[1], data[2])
                        self._delta_rot[0] = self.rot_sensitivity * convert_buffer(data[3], data[4])
                        self._delta_rot[2] = self.rot_sensitivity * convert_buffer(data[5], data[6]) * -1.0
                # readings from the side buttons - BOTH toggle gripper
                if data[0] == 3:
                    if data[1] == 1:  # left button
                        self._close_gripper = not self._close_gripper
                        if "L" in self._additional_callbacks:
                            self._additional_callbacks["L"]()
                    if data[1] == 2:  # right button - now also toggles gripper
                        self._close_gripper = not self._close_gripper
                        if "R" in self._additional_callbacks:
                            self._additional_callbacks["R"]()
                    if data[1] == 3:
                        self._read_rotation = not self._read_rotation


def main() -> None:
    # --- Environment setup (matching rollout_pretrained) ---
    env_cfg = parse_env_cfg(
        args_cli.task,
        device="cuda",
        num_envs=args_cli.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(  # type: ignore
        scene_path=args_cli.scene_path,
        library_dir=args_cli.library_dir,
    )

    # infinite time 
    env_cfg.episode_length_s = 50000.0

    # Record trajectory - use scene name for output folder
    scene_name = Path(args_cli.scene_path).stem if args_cli.scene_path else "default"
    output_folder = Path("runs") / scene_name
    output_folder.mkdir(parents=True, exist_ok=True)
    env_cfg.recorders = DatasetRecorderManagerCfg()
    env_cfg.recorders.dataset_file_handler_class_type = HDF5AppendDatasetFileHandler
    env_cfg.recorders.dataset_export_dir_path = str(output_folder)
    env_cfg.recorders.dataset_filename = "teleop"
    env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_SUCCEEDED_FAILED_IN_SEPARATE_FILES

    env = gym.make(args_cli.task, cfg=env_cfg)  # type: ignore
    unwrapped_env = env.unwrapped

    # --- Frame debug visualization ---
    frame_markers = None
    frame_vis_specs: list[tuple[str, torch.Tensor]] = []  # (obj_name, local_offset in mesh frame)
    if args_cli.scene_path:
        with open(args_cli.scene_path) as f:
            spawn_config = json.load(f)
        for obj in spawn_config.get("objects", []):
            obj_name = obj["name"]
            needs_yup = bool(obj.get("needs_yup", False))
            for frame_name, frame_data in obj.get("frames", {}).items():
                x, y, z = frame_data.get("translation", [0, 0, 0])
                if needs_yup:
                    # Frame offsets are Z-up (SDF convention) but root_quat_w
                    # includes Rx(90°) which expects Y-up mesh coords.
                    # Convert Z-up -> Y-up: [x, y, z] -> [x, z, -y]
                    offset = torch.tensor([x, z, -y], dtype=torch.float32)
                else:
                    offset = torch.tensor([x, y, z], dtype=torch.float32)
                frame_vis_specs.append((obj_name, offset))
        if frame_vis_specs:
            frame_markers = VisualizationMarkers(VisualizationMarkersCfg(
                prim_path="/World/Visuals/frame_debug",
                markers={
                    "sphere": sim_utils.SphereCfg(
                        radius=0.008,
                        visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)),
                    ),
                },
            ))
            print(f"Frame visualization: {len(frame_vis_specs)} frames across objects")

    # --- Gripper depth rays ---
    ray_length = 0.5
    gripper_ray_markers = VisualizationMarkers(VisualizationMarkersCfg(
        prim_path="/World/Visuals/gripper_rays",
        markers={
            "ray": sim_utils.CylinderCfg(
                radius=0.008,
                height=ray_length,
                axis="Z",  # Align cylinder along Z axis
                visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0), opacity=0.8),
            ),
        },
    ))

    def update_gripper_rays():
        """Draw rays extending from gripper in gripper's forward direction."""
        left_panda = unwrapped_env.scene["left_panda"]
        right_panda = unwrapped_env.scene["right_panda"]

        left_idx = left_panda.body_names.index("panda_link8")
        right_idx = right_panda.body_names.index("panda_link8")

        # Get end-effector positions and orientations
        left_ee_pos = left_panda.data.body_pos_w[:, left_idx]
        right_ee_pos = right_panda.data.body_pos_w[:, right_idx]
        left_ee_quat = left_panda.data.body_quat_w[:, left_idx]
        right_ee_quat = right_panda.data.body_quat_w[:, right_idx]

        # Gripper forward is -Z in link8 frame, offset ray center along that direction
        forward_local = torch.tensor([0.0, 0.0, -1.0], device=unwrapped_env.device)
        left_forward = math_utils.quat_apply(left_ee_quat, forward_local.unsqueeze(0).expand(unwrapped_env.num_envs, -1))
        right_forward = math_utils.quat_apply(right_ee_quat, forward_local.unsqueeze(0).expand(unwrapped_env.num_envs, -1))

        # Ray center is offset by half ray length along forward direction, plus extra offset to reach past gripper fingers
        gripper_offset = -0.3
        left_ray_pos = left_ee_pos + left_forward * (ray_length / 2 + gripper_offset)
        right_ray_pos = right_ee_pos + right_forward * (ray_length / 2 + gripper_offset)

        ray_positions = torch.cat([left_ray_pos, right_ray_pos], dim=0)
        ray_orientations = torch.cat([left_ee_quat, right_ee_quat], dim=0)
        gripper_ray_markers.visualize(translations=ray_positions, orientations=ray_orientations)

    def update_frame_markers():
        """Compute world positions of all named frames and update markers."""
        if frame_markers is None:
            return
        positions = []
        for obj_name, local_offset in frame_vis_specs:
            try:
                obj: RigidObject = unwrapped_env.scene[obj_name]
            except KeyError:
                positions.append(torch.zeros(unwrapped_env.num_envs, 3, device=unwrapped_env.device))
                continue
            root_pos = obj.data.root_pos_w    # (num_envs, 3)
            root_quat = obj.data.root_quat_w  # (num_envs, 4)
            offset = local_offset.to(device=unwrapped_env.device)
            world_offset = math_utils.quat_apply(root_quat, offset.unsqueeze(0).expand(unwrapped_env.num_envs, -1))
            positions.append(root_pos + world_offset)
        # Stack: (num_frames * num_envs, 3)
        all_pos = torch.cat(positions, dim=0)
        frame_markers.visualize(translations=all_pos)

    # --- Camera viewports (2 wrist + 1 custom external) ---
    if simulation_app.is_running() and unwrapped_env.sim.has_gui():
        from isaacsim.core.utils.viewports import create_viewport_for_camera, get_viewport_names
        from pxr import UsdGeom, Gf

        stage = unwrapped_env.sim.stage

        def create_camera(path: str, pos: tuple, quat: tuple):
            """Create a camera at the given pose if it doesn't exist. quat is (w, x, y, z)."""
            if not stage.GetPrimAtPath(path):
                cam = UsdGeom.Camera.Define(stage, path)
                cam.AddTranslateOp().Set(Gf.Vec3d(*pos))
                cam.AddOrientOp().Set(Gf.Quatf(quat[0], quat[1], quat[2], quat[3]))
                # Match scenario_rollout_cfg.py camera settings
                cam.GetClippingRangeAttr().Set(Gf.Vec2f(0.05, 3.0))
                cam.GetFocalLengthAttr().Set(0.5953)
                cam.GetHorizontalApertureAttr().Set(1.0)
                cam.GetVerticalApertureAttr().Set(0.75)
                print(f"[Viewport] Created camera at {path}")

        # Create external cameras at poses from scenario_rollout_cfg.py (opengl convention)
        create_camera("/World/envs/env_0/external_cam_left",
                      pos=(0.49714, -0.36177, 0.77877),
                      quat=(0.72133, 0.33411, 0.07049, 0.60257))
        create_camera("/World/envs/env_0/external_cam_right",
                      pos=(0.54755, 0.18739, 0.83597),
                      quat=(0.63083, 0.15475, 0.27309, 0.70960))
        # Teleop camera - fixed position
        create_camera("/World/envs/env_0/teleop_camera",
                      pos=(0.19894, -0.37113, 0.34373),
                      quat=(0.46727, 0.42831, 0.52262, 0.57016))

        # Define camera viewport configs: (viewport_name, camera_prim_path, size, position)
        camera_viewports = [
            ("left_wrist", "/World/envs/env_0/left_panda/panda_link8/flir_wrist_left_plus", (480, 300), (0, 0)),
            ("right_wrist", "/World/envs/env_0/right_panda/panda_link8/flir_wrist_right_minus", (480, 300), (490, 0)),
            ("external_left", "/World/envs/env_0/external_cam_left", (480, 360), (0, 310)),
            ("external_right", "/World/envs/env_0/external_cam_right", (480, 360), (490, 310)),
            ("teleop_view", "/World/envs/env_0/teleop_camera", (640, 480), (980, 0)),
        ]

        for vp_name, cam_path, vp_size, vp_pos in camera_viewports:
            if vp_name not in get_viewport_names():
                try:
                    create_viewport_for_camera(
                        viewport_name=vp_name,
                        camera_prim_path=cam_path,
                        width=vp_size[0],
                        height=vp_size[1],
                        position_x=vp_pos[0],
                        position_y=vp_pos[1],
                    )
                    print(f"[Viewport] Created {vp_name} for {cam_path}")
                except Exception as e:
                    print(f"[Viewport] Could not create {vp_name}: {e}")

    # Teleop camera offset from mug_holder
    teleop_cam_offset = (0.11416, 0.09196, 0.34373)
    teleop_cam_quat = (0.46727, 0.42831, 0.52262, 0.57016)

    def update_teleop_camera():
        """Update teleop camera position to follow mug_holder."""
        try:
            from pxr import UsdGeom, Gf
            stage = unwrapped_env.sim.stage
            cam_prim = stage.GetPrimAtPath("/World/envs/env_0/teleop_camera")
            if not cam_prim.IsValid():
                return
            mug_holder = unwrapped_env.scene["mug_holder"]
            mug_pos = mug_holder.data.root_pos_w[0].cpu().numpy()
            new_pos = Gf.Vec3d(
                mug_pos[0] + teleop_cam_offset[0],
                mug_pos[1] + teleop_cam_offset[1],
                mug_pos[2] + teleop_cam_offset[2],
            )
            xform = UsdGeom.Xformable(cam_prim)
            xform.GetOrderedXformOps()[0].Set(new_pos)
        except Exception:
            pass

    # --- Dual spacemouse setup ---
    left_mouse = Se3SpaceMouseIndexed(
        Se3SpaceMouseCfg(pos_sensitivity=args_cli.pos_sensitivity, rot_sensitivity=args_cli.rot_sensitivity),
        device_index=0,
    )
    right_mouse = Se3SpaceMouseIndexed(
        Se3SpaceMouseCfg(pos_sensitivity=args_cli.pos_sensitivity, rot_sensitivity=args_cli.rot_sensitivity),
        device_index=1,
    )

    should_reset = False
    should_quit = False

    def sigint_handler(sig, frame):
        nonlocal should_quit
        if not should_quit:
            should_quit = True
            print("\nCtrl+C received - finishing up...")

    signal.signal(signal.SIGINT, sigint_handler)

    # --- Keyboard input for reset (R key) ---
    appwindow = omni.appwindow.get_default_app_window()
    input_interface = carb.input.acquire_input_interface()
    keyboard = appwindow.get_keyboard()

    def on_keyboard_event(event, *args):
        nonlocal should_reset
        if event.type == carb.input.KeyboardEventType.KEY_PRESS:
            if event.input.name == "R":
                should_reset = True
                print("Reset triggered via 'R' key - Environment will reset on next step")
        return True

    # Use weak reference to allow cleanup
    keyboard_sub = input_interface.subscribe_to_keyboard_events(
        keyboard,
        on_keyboard_event,
    )

    # --- Main loop ---
    env.reset()
    left_mouse.reset()
    right_mouse.reset()
    update_teleop_camera()

    # Track demo counts for logging
    recorder_mgr = unwrapped_env.recorder_manager
    initial_success = recorder_mgr.exported_successful_episode_count
    initial_failed = recorder_mgr.exported_failed_episode_count

    def count_demos_in_file(file_path: str) -> int:
        """Count demos in an HDF5 file."""
        import h5py
        full_path = file_path if file_path.endswith(".hdf5") else file_path + ".hdf5"
        if not Path(full_path).exists():
            return 0
        with h5py.File(full_path, "r") as f:
            return len([k for k in f["data"].keys() if k.startswith("demo_")])

    def log_demo_counts():
        session_success = recorder_mgr.exported_successful_episode_count - initial_success
        session_failed = recorder_mgr.exported_failed_episode_count - initial_failed
        file_success = count_demos_in_file(str(output_folder / "teleop"))
        file_failed = count_demos_in_file(str(output_folder / "teleop_failed"))
        print(f"[Demos] Session: {session_success} success, {session_failed} failed | In file: {file_success} success, {file_failed} failed")

    action_dim = env.action_space.shape[-1]
    print(f"Teleoperation started. Action dim: {action_dim}")
    print("  Press 'R': reset episode | Ctrl+C to quit safely")
    log_demo_counts()

    while simulation_app.is_running() and not should_quit:
        with torch.inference_mode():
            # Each spacemouse returns 7D: [dx, dy, dz, droll, dpitch, dyaw, gripper]
            left_action = left_mouse.advance()
            right_action = right_mouse.advance()

            # Build 14D action: [left_ik(6), right_ik(6), left_gripper(1), right_gripper(1)]
            action = torch.cat([
                left_action[:6],   # left arm pose
                right_action[:6],  # right arm pose
                left_action[6:],   # left gripper
                right_action[6:],  # right gripper
            ])
            actions = action.unsqueeze(0).repeat(env.num_envs, 1)

            print(f"Actions: {actions}")

            obs, reward, terminated, truncated, info = env.step(actions)
            update_frame_markers()
            update_gripper_rays()

            if terminated.any() or truncated.any():
                print(f"Terminated: {terminated}, Truncated: {truncated}")
                log_demo_counts()
                update_teleop_camera()

            if should_reset:
                print("Forcing reset")
                env.reset()
                left_mouse.reset()
                right_mouse.reset()
                should_reset = False
                print("Environment reset complete")
                log_demo_counts()
                update_teleop_camera()

    if should_quit:
        log_demo_counts()

    # Cleanup keyboard subscription
    input_interface.unsubscribe_to_keyboard_events(keyboard, keyboard_sub)

    env.close()
    print("Environment closed")


if __name__ == "__main__":
    main()
    simulation_app.close()
