"""SAC training script.

Uses the OffPolicyRunner with the standalone SAC algorithm.
Supports distributed training via torchrun.

Usage:
    # Single GPU
    python scripts/rl/train_sac.py --config kiwimanip_sac

    # Multi-GPU
    torchrun --nproc_per_node=4 scripts/rl/train_sac.py --config kiwimanip_sac
"""

import argparse
import dataclasses
import os
import re
import shutil
import torch
import gymnasium as gym
from pathlib import Path

from isaaclab.app import AppLauncher
import sim_improvement.rl.sac_configs as sac_configs

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False


def resolve_checkpoint(resume_path: str) -> str:
    """Resolve a checkpoint path, supporting wandb run paths."""
    if Path(resume_path).exists():
        return resume_path

    parts = resume_path.split("/")
    if len(parts) != 3:
        raise FileNotFoundError(
            f"Resume path '{resume_path}' is not a local file and doesn't "
            f"look like a wandb run path (entity/project/run_id)."
        )

    entity, project, run_and_tag = parts
    run_id = run_and_tag.split(":")[0]

    import wandb
    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")
    files = [f for f in run.files() if f.name.startswith("model_") and f.name.endswith(".pt")]
    if not files:
        raise FileNotFoundError(f"No model_*.pt files found in wandb run {resume_path}")

    def _iter_num(f):
        m = re.search(r"model_(\d+)\.pt", f.name)
        return int(m.group(1)) if m else -1

    files.sort(key=_iter_num)
    target = files[-1]

    cache_dir = Path.home() / ".cache" / "wandb_checkpoints" / run_id
    cache_dir.mkdir(parents=True, exist_ok=True)
    local_path = str(cache_dir / target.name)

    if not Path(local_path).exists():
        print(f"[wandb] Downloading {target.name} from {entity}/{project}/{run_id}...")
        target.download(root=str(cache_dir), replace=True)
    else:
        print(f"[wandb] Using cached checkpoint: {local_path}")

    return local_path


def evaluate_policy(runner, num_episodes: int = 1):
    """Run evaluation episodes and return success rate."""
    policy_fn = runner.get_inference_policy()
    env = runner.env

    obs_td = env.get_observations().to(runner.device)
    successes = 0
    total = 0

    max_steps = int(env.max_episode_length) * 2
    for _ in range(max_steps):
        actions = policy_fn(obs_td)
        obs_td, rewards, dones, extras = env.step(actions.to(env.device))
        obs_td = obs_td.to(runner.device)

        if dones.any():
            done_mask = dones.bool()
            ep_info = extras.get("episode", {})
            if "success" in ep_info:
                successes += ep_info["success"][done_mask].sum().item()
            total += done_mask.sum().item()

        if total >= num_episodes:
            break

    runner.alg.train()
    return successes / max(total, 1)


def main(cfg: sac_configs.SACRunConfig):
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    distributed = int(os.getenv("WORLD_SIZE", "1")) > 1

    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = cfg.video
    args_cli.headless = cfg.headless
    args_cli.distributed = distributed
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab.utils.io import dump_yaml
    from isaaclab_tasks.utils import parse_env_cfg  # type: ignore
    import isaaclab_tasks  # noqa: F401

    import sim_improvement.environments  # noqa: F401
    from sim_improvement.rl.test import RslRlVecEnvWrapper
    from sim_improvement.rl.runners import OffPolicyRunner

    # --- Env config ---
    env_cfg = parse_env_cfg(
        cfg.environment,
        device=cfg.device,
        num_envs=cfg.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(
        scene_path=cfg.scene_path,
        library_dir=cfg.library_dir,
    )

    # Reset distribution
    if cfg.curriculum:
        from isaaclab.managers import EventTermCfg as EventTerm
        from sim_improvement.environments.lbm.scenario_helper import ResetCurriculum
        env_cfg.events.reset_objects = EventTerm(
            func=ResetCurriculum,
            mode="reset",
            params={
                "dataset_path": cfg.reset_dataset_path,
                "scene_path": cfg.scene_path,
                "library_dir": cfg.library_dir,
                "initial_frontier": cfg.curriculum_initial_frontier,
                "advance_threshold": cfg.curriculum_advance_threshold,
                "retreat_threshold": cfg.curriculum_retreat_threshold,
                "step_size": cfg.curriculum_step_size,
            },
        )
        print(f"[INFO] Curriculum reset (frontier={cfg.curriculum_initial_frontier}) from {cfg.reset_dataset_path}")
    elif cfg.reset_from_dataset:
        from isaaclab.managers import EventTermCfg as EventTerm
        from sim_improvement.environments.lbm.scenario_helper import ResetFromDataset
        env_cfg.events.reset_objects = EventTerm(
            func=ResetFromDataset,
            mode="reset",
            params={"dataset_path": cfg.reset_dataset_path, "initial_only": cfg.reset_initial_only},
        )
        mode = "initial states only" if cfg.reset_initial_only else "all timesteps"
        print(f"[INFO] Dataset-based reset ({mode}) from {cfg.reset_dataset_path}")

    device = cfg.device
    if distributed:
        device = f"cuda:{local_rank}"
        env_cfg.sim.device = device

    # --- Logging ---
    exp_name = cfg.experiment_name or cfg.name
    log_dir = Path(cfg.log_dir)
    rl_dir = log_dir / "rl" / cfg.name
    _rank0 = int(os.getenv("RANK", "0")) == 0

    if _rank0 and not cfg.eval_only and not cfg.resume:
        if rl_dir.exists():
            if cfg.overwrite:
                # Only delete the RL output dir, never the parent (which may contain data)
                dataset = Path(cfg.reset_dataset_path).resolve()
                if dataset.is_relative_to(rl_dir.resolve()):
                    raise ValueError(
                        f"Refusing to overwrite: dataset {cfg.reset_dataset_path} is inside rl_dir {rl_dir}. "
                        "Move the dataset or change log_dir."
                    )
                shutil.rmtree(rl_dir)
            else:
                raise ValueError(f"Log directory {rl_dir} already exists. Use --overwrite to overwrite.")
    rl_dir.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] Logging experiment in directory: {rl_dir}")
    env_cfg.log_dir = str(rl_dir)
    env_cfg.io_descriptors_output_dir = str(rl_dir / "io_descriptors")
    if not cfg.headless:
        env_cfg.sim.render_interval = 1

    # --- Create environment ---
    env = gym.make(
        cfg.environment, cfg=env_cfg, render_mode="rgb_array" if cfg.video else None
    )
    if cfg.video:
        video_kwargs = {
            "video_folder": rl_dir / "videos" / "train",
            "step_trigger": lambda step: step % cfg.video_interval == 0,
            "video_length": cfg.video_length,
            "disable_logger": True,
        }
        print("[INFO] Recording videos during training.")
        env = gym.wrappers.RecordVideo(env, **video_kwargs)
    env = RslRlVecEnvWrapper(env)

    # --- Build runner config dict ---
    sac_alg_dict = dataclasses.asdict(cfg.sac)
    train_cfg = {
        "seed": cfg.seed + (local_rank if distributed else 0),
        "device": device,
        "algorithm": sac_alg_dict,
        "obs_groups": {
            "policy": cfg.actor_obs_groups,
            "critic": cfg.critic_obs_groups,
        },
        "save_interval": cfg.save_interval,
        "log_interval": cfg.log_interval,
        "max_iterations": cfg.max_iterations,
        "experiment_name": exp_name,
        "logger": cfg.logger,
        "wandb_project": cfg.wandb_project,
        "num_steps_per_env": 1,
    }

    # --- RLPD demo buffer ---
    demo_buffer = None
    if cfg.demo_dataset_paths and cfg.demo_ratio > 0:
        from sim_improvement.rl.algs.sac.replay_buffer import DemoReplayBuffer
        demo_buffer = DemoReplayBuffer(
            dataset_paths=cfg.demo_dataset_paths,
            action_keys=cfg.action_keys,
            obs_key=cfg.actor_obs_groups[0] if cfg.actor_obs_groups else "policy",
            device=device,
            success_only=cfg.demo_success_only,
        )

    # --- Create runner ---
    runner = OffPolicyRunner(
        env, train_cfg,
        log_dir=None if cfg.eval_only else str(rl_dir),
        device=device,
    )

    # Attach demo buffer to SAC algorithm
    if demo_buffer is not None:
        runner.alg.demo_buffer = demo_buffer
        runner.alg.demo_ratio = cfg.demo_ratio

    # Wandb note
    if cfg.note and _rank0:
        try:
            import wandb
            if wandb.run is not None:
                wandb.run.notes = cfg.note
                wandb.run.update()
        except Exception:
            pass

    if cfg.eval_only:
        if not cfg.resume:
            raise ValueError("--eval-only requires --resume.")
        ckpt_path = resolve_checkpoint(cfg.resume)
        print(f"[INFO] Eval-only — loading: {ckpt_path}")
        runner.load(ckpt_path)
        sr = evaluate_policy(runner, num_episodes=cfg.num_envs)
        print(f"[Eval] Success rate: {sr:.1%}")
    else:
        if cfg.resume:
            ckpt_path = resolve_checkpoint(cfg.resume)
            print(f"[INFO] Resuming from: {ckpt_path}")
            runner.load(ckpt_path)

        # Dump configs
        if _rank0:
            dump_yaml(str(rl_dir / "params" / "env.yaml"), env_cfg)
            dump_yaml(str(rl_dir / "params" / "sac.yaml"), cfg)

        # Train
        runner.learn(
            num_learning_iterations=cfg.max_iterations,
            init_at_random_ep_len=True,
        )

    env.close()
    simulation_app.close()


if __name__ == "__main__":
    cfg = sac_configs.cli()
    main(cfg)
