"""
Example script for running 10 rollouts of a DROID policy on the example environment.

Usage:

First, make sure you download the simulation assets and unpack them into the root directory of this package.

Then, in a separate terminal, launch the policy server on localhost:8000 
-- make sure to set XLA_PYTHON_CLIENT_MEM_FRACTION to avoid JAX hogging all the GPU memory.

For example, to launch a pi0-FAST-DROID policy (with joint position control), 
run the command below in a separate terminal from the openpi "karl/droid_policies" branch:

XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid_jointpos --policy.dir=s3://openpi-assets-simeval/pi0_fast_droid_jointpos

Finally, run the evaluation script:

python run_eval.py --episodes 10 --headless
"""

import tyro
import argparse
import gymnasium as gym
import torch
import numpy as np
import cv2
import mediapy
from datetime import datetime
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from PIL import ImageDraw
from scipy.spatial.transform import Rotation as R

# from src.inference.droid_jointpos import Client as DroidJointPosClient
from src.inference.random_client import Client as RandomClient
    
def main(
        episodes = 10,
        headless: bool = True,
        scene: int = 1,
        ):
    # launch omniverse app with arguments (inside function to prevent overriding tyro)
    from isaaclab.app import AppLauncher
    parser = argparse.ArgumentParser(description="Tutorial on creating an empty stage.")
    AppLauncher.add_app_launcher_args(parser)
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app

    # All IsaacLab dependent modules should be imported after the app is launched
    import src.environments # noqa: F401
    from isaaclab_tasks.utils import parse_env_cfg
    from isaaclab.utils.math import matrix_from_quat
    from custom.utils import overwrite_camera_pose, overwrite_joint_positions, get_camera_pose
    import isaaclab.sim as sim_utils
    # Initialize the env
    env_cfg = parse_env_cfg(
        "DROID",
        device=args_cli.device,
        num_envs=1,
        use_fabric=True,
    )
    instruction = None
    match scene:
        case 1:
            instruction = "put the cube in the bowl"
        case 2:
            instruction = "put the can in the mug"
        case 3:
            instruction = "put banana in the bin"
        case _:
            raise ValueError(f"Scene {scene} not supported")

    camera_pose = np.array([[ 0.12289379, -0.25066368,  0.96024207, -0.12524252],
                            [-0.9860285,  -0.14047504,  0.08952408, -0.41773149],
                            [ 0.1124496,  -0.957828,   -0.26442504,  0.41110082],
                            [ 0.,          0.,          0.,          1.        ]])
    joint_positions = np.load("droid/proprio_states.npy")
    # np.array([[0.07692957, -0.62007982, -0.14516954, -2.71003532,  0.07004238,  2.26837683, -0.2618871,   0.0]])
    intrinsics = np.array([[525.31878662,   0.,         648.12060547],
                            [  0.,         525.31878662, 374.60479736],
                            [  0.,           0.,           1.        ]])

    overwrite_camera_pose(env_cfg, camera_pose)

    overwrite_joint_positions(env_cfg, joint_positions[0].tolist())
    env_cfg.scene.external_cam.spawn = sim_utils.PinholeCameraCfg.from_intrinsic_matrix(intrinsics.flatten(), 1280, 720, focal_length=2.1, focus_distance=28.0)

    # env_cfg.set_scene(scene)
    env = gym.make("DROID", cfg=env_cfg)
    obs, _ = env.reset()
    obs, _ = env.reset() # need second render cycle to get correctly loaded materials
    client = RandomClient() # DroidJointPosClient()

    cam = env.unwrapped.scene["external_cam"]
    
    T_C2W = get_camera_pose(np.array(cam.cfg.offset.pos).reshape(-1), np.array(cam.cfg.offset.rot).reshape(-1))
    T_W2C = torch.linalg.inv(torch.from_numpy(T_C2W)).to(torch.float32)

    # cam.set_intrinsic_matrices(torch.tensor(intrinsics).unsqueeze(0))
    cam_intrinsics = cam.data.intrinsic_matrices[0]

    video_dir = Path("runs") / datetime.now().strftime("%Y-%m-%d") / datetime.now().strftime("%H-%M-%S")
    video_dir.mkdir(parents=True, exist_ok=True)
    video = []
    ep = 0
    try:
        with torch.no_grad():
            for ind in tqdm(range(len(joint_positions))):
                # instead of doing env.step, we set the joint positions directly
                robot = env.unwrapped.scene["robot"]
                env_joint_position = robot.data.joint_pos[0]
                qpos = torch.concatenate([
                    torch.from_numpy(joint_positions[ind][:-1]).to(torch.float32).to(robot.device),
                    env_joint_position[7:]
                ])
                robot.write_joint_position_to_sim(qpos)

                obs, _ = env.unwrapped.reset_to(env.unwrapped.scene.get_state(), None)

                # # write, step, and update buffers/sensors
                # env.unwrapped.scene.write_data_to_sim()
                # env.unwrapped.sim.step()  # render=True by default; ensures RTX cameras render
                # env.unwrapped.scene.update(env.unwrapped.physics_dt)

                # now compute fresh observations
                # obs = env.unwrapped.observation_manager.compute()
                # obs = env.unwrapped.obs_buf # {"policy": {"arm_joint_pos", "gripper_pos", "external_cam", "wrist_cam"}}

                video.append(obs['policy']['external_cam'].cpu().numpy()[0])

                # cube_pos = (T_W2C @ torch.tensor([*obs['policy']['end_effector_pose'][:3], 1.0]).to(T_W2C.device))[:3]
                # cube_pos_img = cam_intrinsics @ cube_pos.to(cam_intrinsics.device)
                # cube_pos_img = cube_pos_img / cube_pos_img[2]

                # robot_base_pos = (T_W2C @ torch.tensor([0.0, 0.0, 0.0, 1.0]).to(T_W2C.device))[:3]
                # robot_base_pos_img = cam_intrinsics @ robot_base_pos.to(cam_intrinsics.device)
                # robot_base_pos_img = robot_base_pos_img / robot_base_pos_img[2]

                # img = obs['policy']['external_cam'].cpu().numpy()[0]
                # img = Image.fromarray(img)
                
                # # Draw circle at projected cube position
                # draw = ImageDraw.Draw(img)
                # radius = 5
                # x, y = cube_pos_img[0].item(), cube_pos_img[1].item()
                # draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill='red')
                # x, y = robot_base_pos_img[0].item(), robot_base_pos_img[1].item()
                # draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill='blue')
                # img.save("external_cam.png")
                # print("saved external_cam.png")

            mediapy.write_video(
                video_dir / f"replay.mp4",
                video,
                fps=15,
            )

    except Exception as e:
        print(e)
        pass
    finally:
        env.close()
        simulation_app.close()

if __name__ == "__main__":
    args = tyro.cli(main)
