import tyro
import mediapy

import tqdm 
import gymnasium as gym
import torch
import argparse

from dataclasses import dataclass
from pathlib import Path
from isaaclab.app import AppLauncher

from polaris.config import EvalArgs


@dataclass
class RolloutArgs(EvalArgs):
    scene_path: str | None = None  # Path to config JSON file (for ScenarioIKRolloutCfg)
    library_dir: str | None = None  # Path to shared USD model library (enables library mode)
    num_envs: int = 1  # Number of parallel environments
    rollouts: int = 10  
    overwrite: bool = False



def main(eval_args: RolloutArgs):
    # This must be done before importing anything from IsaacLab
    # Inside main function to avoid launching IsaacLab in global scope
    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = eval_args.headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab_tasks.utils import parse_env_cfg  # noqa: E402
    from isaaclab.envs import ManagerBasedRLEnv  # noqa: E402
    import polaris.environments  # noqa: E402
    from sim_improvement.inference.droid_jointpos import DroidJointPosClient  # noqa: E402
    from sim_improvement.inference.lbm_grpc_client import LbmGrpcClient  # noqa: E402
    from sim_improvement.inference.lbm_openpi_client import LbmOpenpiClient  # noqa: E402
    import sim_improvement.environments  # noqa: E402
    from sim_improvement.environments.mdp.recorders.recorders_config import DatasetRecorderManagerCfg  # noqa: E402
    from isaaclab.managers.recorder_manager import DatasetExportMode  # noqa: E402
    from sim_improvement.inference.zero_action_client import ZeroActionClient  # noqa: E402

    CLIENT_REGISTRY = {
        "DroidJointPos": DroidJointPosClient,
        "LbmGrpc": LbmGrpcClient,
        "LbmOpenpi": LbmOpenpiClient,
        "Zero": ZeroActionClient,
    }

    # if eval_args.scene_path is not None:
    #     from sim_improvement.environments.lbm.scenario_rollout_cfg import ScenarioIKRolloutCfg  # noqa: E402

    #     env_cfg = ScenarioIKRolloutCfg()
    #     env_cfg.scene.num_envs = eval_args.num_envs
    #     env_cfg.dynamic_setup(
    #         scene_path=eval_args.scene_path,
    #         library_dir=eval_args.library_dir,
    #     )
    #     if eval_args.no_collision_detection:
    #         env_cfg.events.reset_objects.params["collision_detection"] = False
    #     env = ManagerBasedRLEnv(cfg=env_cfg)
    # else:
    env_cfg = parse_env_cfg(
        eval_args.environment,
        device="cuda",
        num_envs=eval_args.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup( # type: ignore
        scene_path=eval_args.scene_path,
        library_dir=eval_args.library_dir,
    )

    run_folder = Path(eval_args.run_folder)
    run_folder.mkdir(parents=True, exist_ok=eval_args.overwrite)
    # Record trajectory
    env_cfg.recorders = DatasetRecorderManagerCfg()
    env_cfg.recorders.dataset_export_dir_path = str(eval_args.run_folder)
    env_cfg.recorders.dataset_filename = "rollouts"
    env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_SUCCEEDED_FAILED_IN_SEPARATE_FILES

    env = gym.make(eval_args.environment, cfg=env_cfg)  # type: ignore

    client_name = eval_args.policy.client
    client_cls = CLIENT_REGISTRY.get(client_name)
    if client_cls is None:
        raise ValueError(f"Unknown client {client_name!r}. Available: {list(CLIENT_REGISTRY.keys())}")

    # Clients that support action chunking take open_loop_horizon
    policy_client = client_cls(
        host=eval_args.policy.host,
        port=eval_args.policy.port,
        open_loop_horizon=eval_args.policy.open_loop_horizon,
        action_shape=env.action_space.shape,
    )
    instruction = eval_args.instruction or "put the red bell pepper in the bin"

    horizon = env.max_episode_length
    # bar = tqdm.tqdm(range(horizon))
    obs, info = env.reset()
    policy_client.reset()
    # episode = len(list(run_folder.glob("episode_*.mp4")))
    successful_episodes = 0
    while True:
        action, viz = policy_client.infer(obs, instruction=instruction, return_viz=False)
        # if viz is not None:
        #     for idx, v in enumerate(viz):
        #         videos[idx].append(v)
        action_tensor = torch.tensor(action, dtype=torch.float32)
        obs, rew, term, trunc, info = env.step(action_tensor)

        # bar.update(1)
        # if term[0] or trunc[0] or bar.n >= horizon:
        if term.any() or trunc.any():
            # get indices of environments that need to be reset
            needs_reset = (term | trunc).nonzero().flatten().detach().cpu().numpy()
            policy_client.reset(env_ids=needs_reset, obs=obs)
            success_values = env.termination_manager.get_term("success")[needs_reset]
            successful_episodes += success_values.sum().item()
            print(f"Successful episodes: {successful_episodes}")
            # for idx, success in zip(needs_reset.tolist(), success_values.tolist()):
                # Save video and metadata
                # video = videos[idx]
                # filename = run_folder / f"episode_{episode}_{success}.mp4"
                # mediapy.write_video(filename, video, fps=10)
                # episode += 1
                # videos[idx] = []

            if successful_episodes >= eval_args.rollouts:
                break



    env.close()
    simulation_app.close()


if __name__ == "__main__":
    args: RolloutArgs = tyro.cli(RolloutArgs)
    main(args)
