"""Open-loop replay of demonstration actions to verify environment success."""

import argparse
import torch
import h5py
import numpy as np
import gymnasium as gym
from pathlib import Path
from isaaclab.app import AppLauncher
import sim_improvement.rl.bcppo_configs as bcppo_configs


def load_episodes(
    trajectory_path: str,
    action_key_order: list[str] | None = None,
) -> tuple[list[np.ndarray], list[dict]]:
    """Load per-episode actions and initial states from HDF5.

    Returns:
        (episode_actions, episode_initial_states) where each initial_states
        entry matches the scene.reset_to format:
        {entity_type: {entity_name: {component: tensor_for_one_env}}}
    """
    all_actions = []
    all_states = []
    with h5py.File(trajectory_path, "r") as f:
        demo_keys = sorted(f["data"].keys())
        print(f"[Replay] Found {len(demo_keys)} demonstrations in {trajectory_path}")
        for demo_key in demo_keys:
            # Actions
            action_node = f[f"data/{demo_key}/action"]
            if isinstance(action_node, h5py.Dataset):
                actions = action_node[()]
            elif action_key_order is not None:
                parts = [action_node[k][()] for k in action_key_order]
                actions = np.concatenate(parts, axis=-1)
            else:
                raise ValueError(
                    f"action in {demo_key} is a Group with keys "
                    f"{sorted(action_node.keys())} but no action_key_order provided."
                )
            all_actions.append(actions)

            # Initial states (timestep 0), structured as {type: {name: {component: array}}}
            states_path = f"data/{demo_key}/states"
            initial_state = {}
            if states_path in f:
                states_group = f[states_path]
                for entity_type in states_group.keys():
                    initial_state[entity_type] = {}
                    for entity_name in states_group[entity_type].keys():
                        initial_state[entity_type][entity_name] = {}
                        for component in states_group[entity_type][entity_name].keys():
                            # Take first timestep
                            initial_state[entity_type][entity_name][component] = (
                                states_group[entity_type][entity_name][component][0]
                            )
            all_states.append(initial_state)

    return all_actions, all_states


def main(cfg: bcppo_configs.BCPPORunConfig):
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = True
    args_cli.headless = cfg.headless
    args_cli.distributed = False
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app

    import mediapy
    from isaaclab_tasks.utils import parse_env_cfg  # type: ignore
    import isaaclab_tasks  # noqa: F401
    import sim_improvement.environments  # noqa: F401

    # Load demo actions and initial states
    action_key_order = cfg.action_keys or None
    episode_actions, episode_states = load_episodes(cfg.dataset_path, action_key_order)
    num_episodes = len(episode_actions)
    act_dim = episode_actions[0].shape[1]
    max_T = max(len(a) for a in episode_actions)
    has_states = any(len(s) > 0 for s in episode_states)
    print(f"[Replay] {num_episodes} episodes, action dim: {act_dim}, max length: {max_T}, "
          f"has initial states: {has_states}")

    # Pad all episodes to the same length with zero actions
    padded = np.zeros((num_episodes, max_T, act_dim), dtype=np.float32)
    ep_lengths = []
    for i, actions in enumerate(episode_actions):
        T = len(actions)
        padded[i, :T] = actions
        ep_lengths.append(T)

    # One env per episode
    env_cfg = parse_env_cfg(
        cfg.environment,
        device=cfg.device,
        num_envs=num_episodes,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(  # type: ignore
        scene_path=cfg.scene_path,
        library_dir=cfg.library_dir,
    )
    if not cfg.headless:
        env_cfg.sim.render_interval = 1

    env = gym.make(cfg.environment, cfg=env_cfg, render_mode="rgb_array")
    env.reset()

    # Apply initial states from the recorded demos using scene.reset_to
    if has_states:
        scene = env.unwrapped.scene
        device = env.unwrapped.device
        # Build the batched state dict: {type: {name: {component: (N, dim) tensor}}}
        # by stacking timestep-0 from each episode
        batched_state: dict[str, dict[str, dict[str, torch.Tensor]]] = {}
        # Use first episode's structure as template
        for entity_type in episode_states[0]:
            batched_state[entity_type] = {}
            for entity_name in episode_states[0][entity_type]:
                batched_state[entity_type][entity_name] = {}
                for component in episode_states[0][entity_type][entity_name]:
                    arrays = [
                        episode_states[i][entity_type][entity_name][component]
                        for i in range(num_episodes)
                    ]
                    batched_state[entity_type][entity_name][component] = (
                        torch.from_numpy(np.stack(arrays)).float().to(device)
                    )

        scene.reset_to(batched_state, env_ids=None, is_relative=True)
        print(f"[Replay] Applied initial states via scene.reset_to()")

    frames: list[np.ndarray] = []
    successes = [False] * num_episodes
    done_flags = [False] * num_episodes

    for t in range(max_T):
        action_tensor = torch.from_numpy(padded[:, t]).to(env.unwrapped.device)
        _, _, terminated, truncated, _ = env.step(action_tensor)
        dones = terminated | truncated

        # Capture viewer frame
        try:
            frame = env.render()
            if frame is not None:
                frames.append(frame)
        except Exception:
            pass

        # Check newly done envs
        if dones.any():
            for idx in dones.nonzero(as_tuple=False).squeeze(-1).tolist():
                if done_flags[idx]:
                    continue
                done_flags[idx] = True
                try:
                    successes[idx] = env.unwrapped.termination_manager.get_term("success")[idx].item()
                except Exception:
                    successes[idx] = False

    # Print per-episode results
    for i in range(num_episodes):
        status = "SUCCESS" if successes[i] else "FAIL"
        print(f"[Replay] Episode {i}: {status} (length {ep_lengths[i]})")

    num_success = sum(successes)
    print(f"\n[Replay] Results: {num_success}/{num_episodes} ({num_success / num_episodes:.1%})")

    # Save video
    if frames:
        video_dir = Path(cfg.log_dir) / "replay"
        video_dir.mkdir(parents=True, exist_ok=True)
        video_path = video_dir / "openloop_replay.mp4"
        mediapy.write_video(str(video_path), frames, fps=10)
        print(f"[Replay] Video saved to {video_path}")

    env.close()
    simulation_app.close()


if __name__ == "__main__":
    cfg = bcppo_configs.cli()
    main(cfg)
