import sys
import tyro
import os
import mediapy
import tqdm
import gymnasium as gym
import torch
import argparse
import numpy as np


from typing import List
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, field
from isaaclab.app import AppLauncher

@dataclass
class Args:
    usd: str                                        # Path to the USD file
    demonstrations: str                             # Path to the demonstrations directory
    environment: str = "DROID-RoboSplat"            # Which IsaacLab environment to use
    force: bool = False                             # Whether to force rerendering of demonstrations


def main(args: Args):
    # This must be done before importing anything from IsaacLab 
    # Inside main function for compatibility with HPC cluster python launch scripts
    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    AppLauncher.add_app_launcher_args(parser)

    args_cli, other_args = parser.parse_known_args()
    sys.argv = [sys.argv[0]] + other_args  # clear out sys.argv for hydra
    args_cli.enable_cameras = True
    args_cli.headless = True

    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    import polaris.environments
    from polaris.environments.manager_based_rl_splat_environment import MangerBasedRLSplatEnv
    from polaris.utils import parse_env_cfg

    env_cfg = parse_env_cfg(
        args.environment,
        usd_file=args.usd,
        device=args_cli.device,
        num_envs=1,
        use_fabric=True,
    )
    env: MangerBasedRLSplatEnv = gym.make(args.environment, cfg=env_cfg)   # type: ignore
    env.reset()
    env.reset()
    env.reset()


    directory = Path(args.demonstrations)
    all_demonstrations = list(directory.glob("*.npz"))
    output_folder = directory.parent.parent / args.environment
    output_folder.mkdir(parents=True, exist_ok=True)
    for episode in tqdm.tqdm(all_demonstrations):
        output_demo_path = output_folder / episode.name
        try:
            data = dict(np.load(episode, allow_pickle=True))
            observations = data["observations"]
            if output_demo_path.exists() and not args.force:
                # already done
                print(f"skipping {episode}")
                continue
            if len(observations) < 10:
                # remove file
                print(f"removing {episode} because it has less than 10 observations")
                os.remove(episode)
                continue
            rgbs = []
            for obs in observations:
                # obs["policy"].pop("arm_joint_pos")
                # obs["policy"].pop("gripper_pos")
                # if "ee_pose" in obs["policy"]:
                #     obs["policy"].pop("ee_pose")
                all_joints = obs["policy"]["all_joints"]

                jps = torch.tensor(all_joints)
                env.scene["robot"].write_joint_position_to_sim(jps)

                # rigid object poses
                # for name, pose in obs["policy"].items():
                for name, pose in obs["policy"].items():
                    if name in ["arm_joint_pos", "gripper_pos", "ee_pose", "all_joints"]:
                        continue
                    pose = torch.tensor(pose)[None]
                    env.scene[name].write_root_pose_to_sim(pose)

                env.sim.forward() # updates all links and kinematics
                env.sim.render()
                env.scene.update(0)
                env.render(recompute=True)

                # rgb = env.custom_render(expensive=False)
                # rgb = env.custom_render(expensive=True)
                mask_and_rgb = env.get_robot_from_sim()
                env.transform_sim_to_splat()
                rgb = env.render_splat()
                obs["splat"] = rgb

                for cam in mask_and_rgb:
                    og_img = rgb[cam] if cam in rgb else np.zeros_like(mask_and_rgb[cam]["rgb"])
                    mask = mask_and_rgb[cam]["mask"]
                    sim_img = mask_and_rgb[cam]["rgb"]

                    new_img = np.where(mask, sim_img, og_img)

                    if hasattr(env, "model"):
                        new_img = env.diffuse_image(new_img)
                    rgb[cam] = new_img
                rgbs.append(rgb)

                
                rgbs.append(rgb)

            for i in range(len(data["observations"])):
                data["observations"][i]["renders"] = rgbs[i]

            np.savez(file=output_demo_path, **data)
            print("rendered", output_demo_path)
        except Exception as e:
            print(e)
            print(f"{episode} failed")

    env.close()
    simulation_app.close()
    

if __name__ == "__main__":
    args: Args = tyro.cli(Args)
    main(args)
