"""Replay states from an HDF5 trajectory and re-record observations.

Loads a recorded HDF5 file (produced by AllStatesRecorder), replays each
demo by writing articulation joint positions and rigid-object poses into
the sim, and re-captures observations into a new HDF5 file.

Multiple demos are replayed in parallel across ``num_envs`` environments.

Usage::

    python scripts/reset_states/replay_and_record.py \
        --scene_path ./configs/3_cabot_breakfast/spawn_config.json \
        --library_dir /path/to/usd_library \
        --input_file demonstrations/trajectory.hdf5 \
        --output_file demonstrations/trajectory_rerendered.hdf5

    python scripts/reset_states/replay_and_record.py \
        --scene_path ./configs/3_cabot_breakfast/spawn_config.json \
        --library_dir /path/to/usd_library \
        --input_file demonstrations/trajectory.hdf5 \
        --output_file demonstrations/trajectory_rerendered.hdf5 \
        --num_envs 8 --headless
"""

import tyro
import argparse
import h5py
import numpy as np
import torch
import tqdm

from dataclasses import dataclass
from pathlib import Path
from isaaclab.app import AppLauncher


@dataclass
class Args:
    scene_path: str  # Path to config JSON file (spawn_config.json)
    library_dir: str  # Path to shared USD model library
    input_file: str  # Path to the input HDF5 trajectory file
    output_file: str  # Path to write the output HDF5 file
    num_envs: int = 1  # Number of parallel environments
    headless: bool = True  # Whether to run in headless mode


def main(args: Args):
    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = args.headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab.envs import ManagerBasedRLEnv  # noqa: E402
    from isaaclab.assets.articulation import Articulation  # noqa: E402
    from isaaclab.assets.rigid_object import RigidObject  # noqa: E402
    import sim_improvement.environments  # noqa: E402
    from sim_improvement.environments.lbm.scenario_rollout_cfg import ScenarioIKStateCfg  # noqa: E402

    # Setup environment
    env_cfg = ScenarioIKStateCfg()
    env_cfg.scene.num_envs = args.num_envs
    env_cfg.dynamic_setup(scene_path=args.scene_path, library_dir=args.library_dir)
    env_cfg.episode_length_s = 50000.0
    env_cfg.sim.render_interval = 1
    # policy stays concatenated (flat), policy_dict is already set up as non-concatenated

    env: ManagerBasedRLEnv = ManagerBasedRLEnv(cfg=env_cfg)
    num_envs = args.num_envs
    device = env.device

    # Load input trajectory file
    input_path = Path(args.input_file)
    if not input_path.exists():
        print(f"Error: Input trajectory file not found at {input_path}")
        env.close()
        simulation_app.close()
        return

    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"Loading trajectories from {input_path}")
    print(f"Writing new trajectories to {output_path}")

    with h5py.File(input_path, "r") as input_file, h5py.File(output_path, "w") as output_file:
        data_group = input_file["data"]  # type: ignore
        demo_keys = sorted(list(data_group.keys()))  # type: ignore
        print(f"Found {len(demo_keys)} demos, processing in batches of {num_envs}")

        output_data_group = output_file.create_group("data")

        # Process demos in batches of num_envs
        for batch_start in range(0, len(demo_keys), num_envs):
            batch_keys = demo_keys[batch_start : batch_start + num_envs]
            batch_size = len(batch_keys)
            print(f"\nBatch {batch_start // num_envs + 1}: demos {batch_start}–{batch_start + batch_size - 1}")

            # Collect per-demo metadata: entities and trajectory lengths
            batch_entities: list[list[tuple[str, str]]] = []
            batch_traj_lengths: list[int] = []
            skip_flags: list[bool] = []

            for demo_key in batch_keys:
                demo_path = f"data/{demo_key}/states"
                if demo_path not in input_file:
                    print(f"  Warning: No states found for {demo_key}, skipping")
                    batch_entities.append([])
                    batch_traj_lengths.append(0)
                    skip_flags.append(True)
                    continue

                states_group = input_file[demo_path]  # type: ignore
                entities: list[tuple[str, str]] = []
                artic_group = states_group.get("articulation")
                rigid_group = states_group.get("rigid_object")
                if artic_group:
                    for name in artic_group.keys():  # type: ignore
                        entities.append((name, "articulation"))
                if rigid_group:
                    for name in rigid_group.keys():  # type: ignore
                        entities.append((name, "rigid_object"))

                if not entities:
                    print(f"  Warning: No articulation/rigid_object states for {demo_key}, skipping")
                    batch_entities.append([])
                    batch_traj_lengths.append(0)
                    skip_flags.append(True)
                    continue

                first_name, first_type = entities[0]
                first_group = states_group[first_type][first_name]  # type: ignore
                first_key = list(first_group.keys())[0]  # type: ignore
                traj_length = first_group[first_key].shape[0]  # type: ignore

                batch_entities.append(entities)
                batch_traj_lengths.append(traj_length)
                skip_flags.append(False)

            max_traj_length = max(batch_traj_lengths) if batch_traj_lengths else 0
            if max_traj_length == 0:
                continue
            print(f"  Trajectory lengths: {batch_traj_lengths} (max={max_traj_length})")

            env.reset()

            # Storage for collected observations per env
            obs_lists: list[dict[str, dict[str, list]]] = [{} for _ in range(batch_size)]

            # Collect T+1 observations: obs[0..T-1] from states, obs[T] from next_states
            # so that obs = collected[0:T], next_obs = collected[1:T+1]
            def write_state_to_sim(env_idx, demo_key, step_idx, state_key="states"):
                """Write a single timestep's state into the sim for one env."""
                states_group = input_file[f"data/{demo_key}/{state_key}"]  # type: ignore
                for entity_name, entity_type in batch_entities[env_idx]:
                    try:
                        scene_obj = env.scene[entity_name]
                    except KeyError:
                        continue
                    entity_h5 = states_group[entity_type][entity_name]  # type: ignore
                    if entity_type == "articulation" and isinstance(scene_obj, Articulation):
                        joint_pos = torch.tensor(
                            entity_h5["joint_position"][step_idx], dtype=torch.float32, device=device  # type: ignore
                        )[None]
                        scene_obj.write_joint_position_to_sim(joint_pos, env_ids=torch.tensor([env_idx], device=device))
                    elif entity_type == "rigid_object" and isinstance(scene_obj, RigidObject):
                        pose = torch.tensor(
                            entity_h5["root_pose"][step_idx], dtype=torch.float32, device=device  # type: ignore
                        )[None]
                        scene_obj.write_root_pose_to_sim(pose, env_ids=torch.tensor([env_idx], device=device))

            def capture_obs(env_idx, obs_dict):
                """Append one timestep of observations for an env."""
                for group_name, group_data in obs_dict.items():
                    if group_name not in obs_lists[env_idx]:
                        obs_lists[env_idx][group_name] = {} if isinstance(group_data, dict) else []
                    if isinstance(group_data, dict):
                        for key, value in group_data.items():
                            if key not in obs_lists[env_idx][group_name]:
                                obs_lists[env_idx][group_name][key] = []
                            obs_lists[env_idx][group_name][key].append(
                                value[env_idx].detach().cpu().numpy()
                            )
                    else:
                        obs_lists[env_idx][group_name].append(
                            group_data[env_idx].detach().cpu().numpy()
                        )

            # Replay states and collect T observations per demo
            for step_idx in tqdm.tqdm(range(max_traj_length), desc="  Replaying batch"):
                for env_idx in range(batch_size):
                    if skip_flags[env_idx] or step_idx >= batch_traj_lengths[env_idx]:
                        continue
                    write_state_to_sim(env_idx, batch_keys[env_idx], step_idx, "states")

                env.sim.forward()
                env.scene.update(0)
                env.sim.render()

                obs_dict = env.obs_buf
                for env_idx in range(batch_size):
                    if skip_flags[env_idx] or step_idx >= batch_traj_lengths[env_idx]:
                        continue
                    capture_obs(env_idx, obs_dict)

            # Write next_states for the last timestep to get the T+1'th observation
            for env_idx in range(batch_size):
                if skip_flags[env_idx]:
                    continue
                last_step = batch_traj_lengths[env_idx] - 1
                write_state_to_sim(env_idx, batch_keys[env_idx], last_step, "next_states")

            env.sim.forward()
            env.scene.update(0)
            env.sim.render()

            obs_dict = env.obs_buf
            for env_idx in range(batch_size):
                if not skip_flags[env_idx]:
                    capture_obs(env_idx, obs_dict)

            # Write results for each demo in the batch
            for env_idx in range(batch_size):
                if skip_flags[env_idx]:
                    continue

                demo_key = batch_keys[env_idx]
                states_group = input_file[f"data/{demo_key}/states"]  # type: ignore

                demo_group = output_data_group.create_group(demo_key)
                obs_out_group = demo_group.create_group("obs")

                # Copy existing data groups from input
                for key in ("action", "states", "next_states", "reward"):
                    src = input_file.get(f"data/{demo_key}/{key}")
                    if src is not None:
                        input_file.copy(src, demo_group, name=key)

                # We collected T+1 observations: obs = [0:T], next_obs = [1:T+1]
                traj_len = batch_traj_lengths[env_idx]
                next_obs_group = demo_group.create_group("next_obs")
                for group_name, group_data in obs_lists[env_idx].items():
                    if isinstance(group_data, list):
                        if group_data:
                            all_obs = np.array(group_data)  # (T+1, dim)
                            obs_out_group.create_dataset(group_name, data=all_obs[:traj_len])
                            next_obs_group.create_dataset(group_name, data=all_obs[1:traj_len + 1])
                    else:
                        group = obs_out_group.create_group(group_name)
                        next_group = next_obs_group.create_group(group_name)
                        for key, values in group_data.items():
                            if values:
                                all_obs = np.array(values)  # (T+1, dim)
                                group.create_dataset(key, data=all_obs[:traj_len])
                                next_group.create_dataset(key, data=all_obs[1:traj_len + 1])

    print(f"\nDone! New trajectory file saved to {output_path}")
    env.close()
    simulation_app.close()


if __name__ == "__main__":
    args = tyro.cli(Args)
    main(args)
