import tyro
import argparse
import torch
import logging
import gymnasium as gym
from datetime import datetime
from pathlib import Path
from isaaclab.app import AppLauncher
from sim_improvement.rl.config import RLConfig


# import logger
logger = logging.getLogger(__name__)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False


def main(cfg: RLConfig):
    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = cfg.headless
    args_cli.distributed = cfg.distributed
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<


    from isaaclab.envs import ManagerBasedRLEnvCfg
    from isaaclab.utils.io import dump_yaml
    from isaaclab_tasks.utils import parse_env_cfg
    import isaaclab_tasks
    from rsl_rl.runners import DistillationRunner, OnPolicyRunner

    import sim_improvement.environments 
    import sim_improvement.rl.utils as rl_utils
    from sim_improvement.rl.test import RslRlVecEnvWrapper, FrankaReachPPORunnerCfg

    agent_cfg = FrankaReachPPORunnerCfg(experiment_name=cfg.run_name)
    env_cfg = parse_env_cfg(
        cfg.environment,
        device=cfg.device,
        num_envs=cfg.num_envs,
        use_fabric=True,
    )
    """Train with RSL-RL agent."""
    if cfg.distributed and cfg.device is not None and "cpu" in cfg.device:
        raise ValueError(
            "Distributed training is not supported when using CPU device. "
            "Please use GPU device (e.g., --device cuda) for distributed training."
        )
    # multi-gpu training configuration
    if cfg.distributed:
        print(f"[INFO] Running distributed training on device: {cfg.device}")
        env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
        agent_cfg.device = f"cuda:{app_launcher.local_rank}"
        print(f"[INFO] Environment device: {env_cfg.sim.device}")
        print(f"[INFO] Agent device: {agent_cfg.device}")

        # set seed to have diversity in different threads
        seed = agent_cfg.seed + app_launcher.local_rank
        env_cfg.seed = seed
        agent_cfg.seed = seed



    # specify directory for logging experiments
    log_root_path = Path("rl-runs") / agent_cfg.experiment_name
    # specify directory for logging runs: {time-stamp}_{run_name}
    log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = log_root_path / log_dir
    print(f"[INFO] Logging experiment in directory: {log_dir}")
    # set the log directory for the environment (works for all environment types)
    env_cfg.log_dir = str(log_dir) # unsure what this does
    env_cfg.io_descriptors_output_dir = str(log_dir / "io_descriptors")
    if not cfg.headless:
        env_cfg.sim.render_interval = 1

    # create isaac environment
    env = gym.make(
        cfg.environment, cfg=env_cfg, render_mode="rgb_array" if cfg.video else None
    )


    # wrap for video recording
    if cfg.video:
        video_kwargs = {
            "video_folder": log_dir / "videos" / "train",
            "step_trigger": lambda step: step % cfg.video_interval == 0,
            "video_length": cfg.video_length,
            "disable_logger": True,
        }
        print("[INFO] Recording videos during training.")
        env = gym.wrappers.RecordVideo(env, **video_kwargs)

    # wrap around environment for rsl-rl
    env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

    # create runner from rsl-rl
    runner = OnPolicyRunner(
        env, agent_cfg.to_dict(), log_dir=str(log_dir), device=agent_cfg.device
    )
    # write git state to logs
    runner.add_git_repo_to_log(__file__)


    # load the checkpoint
    if cfg.resume:
        # resume_path = rl_utils.get_checkpoint_path(str(log_root_path), cfg.resume)
        resume_path = cfg.resume
        print(f"[INFO]: Loading model checkpoint from: {resume_path}")
        # load previously trained model
        runner.load(resume_path)

    # dump the configuration into log-directory
    dump_yaml(str(log_dir / "params" / "env.yaml"), env_cfg)
    dump_yaml(str(log_dir / "params" / "agent.yaml"), agent_cfg)

    # run training
    runner.learn(
        num_learning_iterations=agent_cfg.max_iterations, init_at_random_ep_len=True
    )

    # close the simulator
    env.close()
    # close sim app
    simulation_app.close()


if __name__ == "__main__":
    # run the main function
    cfg = tyro.cli(RLConfig)
    main(cfg)
