import tyro
import mediapy

import tqdm
import gymnasium as gym
import torch
import argparse


from pathlib import Path
from dataclasses import dataclass
from isaaclab.app import AppLauncher


@dataclass
class Args:
    environment: str 
    output: str = "demonstrations"
    headless: bool = False

def observation_to_droid_observation(obs):
    from isaaclab.utils import math
    ee_pose = obs["policy"]["ee_pose"]
    ee_pos, ee_quat = ee_pose[..., :3], ee_pose[..., 3:]
    ee_euler = math.euler_xyz_from_quat(ee_quat)
    ee_euler = torch.cat(ee_euler, dim=-1)
    ee_state = torch.cat([ee_pos, ee_euler[None]], dim=-1)
    droid_obs = { 
                 "robot_state": {
                     "cartesian_position": ee_state.cpu().numpy().squeeze(),
                     "gripper_position": 0.0,
                     } 
                 }
    return droid_obs

def close_isaac(simulation_app, env=None):
    if env is not None:
        env.close()
    simulation_app.close()

def main(args: Args):
    # This must be done before importing anything from IsaacLab
    # >>>> Isaac Sim App Launcher <<<<
    args_cli = argparse.Namespace(enable_cameras=True, headless=args.headless)
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab_tasks.utils import parse_env_cfg  # noqa: E402
    from sim_improvement.environments.mdp.recorders.recorders_config import AllStatesRecorderManagerCfg
    from sim_improvement.teleop.oculus import VRPolicy
    from isaaclab.envs import ManagerBasedRLEnvCfg, ManagerBasedRLEnv
    from isaaclab.managers.recorder_manager import DatasetExportMode
    from isaaclab.markers import VisualizationMarkersCfg, VisualizationMarkers
    import isaaclab.sim as sim_utils
    from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR
    from isaaclab.utils.math import quat_from_euler_xyz
    from isaaclab.utils import math

    # Make output folder
    output_folder = Path(args.output) / args.environment
    output_folder.mkdir(parents=True, exist_ok=True)

    env_cfg = parse_env_cfg(
            args.environment,
            device="cuda",
            num_envs=1,
            use_fabric=True,
        )
    env_cfg.observations.policy.concatenate_terms = False
    env_cfg.episode_length_s = 50000.0
    env_cfg.sim.render_interval = 1
    env_cfg.viewer.eye = (-1, 1, 1)
    env_cfg.viewer.lookat = (0, 0, 0)

    marker_cfg = VisualizationMarkersCfg(
        markers={
            "frame": sim_utils.UsdFileCfg(
                usd_path=f"{ISAAC_NUCLEUS_DIR}/Props/UIElements/frame_prim.usd",
                scale=(0.1, 0.1, 0.1),
            ),
        },
        prim_path="/Visuals/commanded",
    )
    marker = VisualizationMarkers(marker_cfg)

    env_cfg.recorders = AllStatesRecorderManagerCfg()
    env_cfg.recorders.dataset_export_dir_path = str(output_folder)
    env_cfg.recorders.dataset_filename = "trajectory"
    env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_ALL

    env: ManagerBasedRLEnv = gym.make(args.environment, cfg=env_cfg)  # type: ignore
    obs, info = env.reset()

    oculus = VRPolicy()
    oculus.reset_state()

    while True:
        droid_obs = observation_to_droid_observation(obs)
        action = oculus.forward(droid_obs)
        delta_pos = action[:3]
        delta_euler = action[3:-1]
        delta_quat = quat_from_euler_xyz(torch.tensor(delta_euler[0:1]), torch.tensor(delta_euler[1:2]), torch.tensor(delta_euler[2:3]))

        robot_pos = droid_obs["robot_state"]["cartesian_position"][:3]
        target_pos = torch.tensor(robot_pos + delta_pos)
        robot_euler = droid_obs["robot_state"]["cartesian_position"][3:]
        robot_quat = quat_from_euler_xyz(torch.tensor(robot_euler[0:1]), torch.tensor(robot_euler[1:2]), torch.tensor(robot_euler[2:3]))
        target_quat = math.quat_mul(delta_quat, robot_quat)

        # marker.visualize(target_pos.reshape(-1, 3), target_quat.reshape(-1, 4))
        obs, reward, terminated, truncated, info = env.step(torch.tensor(action).reshape(1, 7))
        print(f"Reward: {reward}")

        control_info = oculus.get_info()
        if control_info["success"] or control_info["failure"]:
            print("Resetting environment")
            obs, info = env.reset()
            if control_info["success"]:
                break

            oculus.reset_state()


    print("Done!")
    close_isaac(simulation_app, env)

if __name__ == "__main__":
    args = tyro.cli(Args)
    main(args)