"""BC+PPO training using the BC_PPO algorithm.

Unlike train_bcppo.py which does a standalone BC warmstart then switches to
pure PPO, this script loads demo data into BC_PPO's replay buffer so that
an auxiliary BC loss is applied throughout PPO training.
"""

import argparse
import os
import re
import torch
import logging
import gymnasium as gym
import shutil
import numpy as np
from pathlib import Path

from isaaclab.app import AppLauncher
import sim_improvement.rl.bcppo_configs as bcppo_configs

logger = logging.getLogger(__name__)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False


# Reuse the data loader from the sibling script (same directory)
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent))
from train_bcppo import load_bc_data, evaluate_policy


def resolve_checkpoint(resume_path: str, device: str = "cpu") -> str:
    """Resolve a checkpoint path, supporting wandb run paths.

    Accepts:
      - A local file path (returned as-is)
      - A wandb run path like ``entity/project/run_id`` or
        ``entity/project/run_id:latest`` / ``entity/project/run_id:best`` / ``entity/project/run_id:<iter>``

    For wandb paths, downloads the requested ``.pt`` checkpoint (preferring
    ``model_*.pt``) to a local cache and returns the local path.
    """
    # If it's an existing local file, use it directly
    if Path(resume_path).exists():
        return resume_path

    # Try parsing as wandb run path: entity/project/run_id[:tag]
    parts = resume_path.split("/")
    if len(parts) == 3:
        entity, project, run_and_tag = parts
        # Optional :tag suffix (e.g. :latest, :best, :v3)
        if ":" in run_and_tag:
            run_id, tag = run_and_tag.split(":", 1)
        else:
            run_id, tag = run_and_tag, None
    else:
        raise FileNotFoundError(
            f"Resume path '{resume_path}' is not a local file and doesn't "
            f"look like a wandb run path (entity/project/run_id). "
        )

    import wandb

    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")

    all_pt_files = [f for f in run.files() if f.name.endswith(".pt")]
    if not all_pt_files:
        raise FileNotFoundError(f"No .pt files found in wandb run {resume_path}")

    model_files = [f for f in all_pt_files if f.name.startswith("model_")]

    # Sort by iteration number to find the latest
    def _iter_num(f):
        m = re.search(r"model_(\d+)\.pt", f.name)
        return int(m.group(1)) if m else -1

    model_files.sort(key=_iter_num)
    if not model_files:
        raise FileNotFoundError(f"No model_*.pt files found in wandb run {resume_path}")

    # If a tag was provided in the wandb path, try to match it.
    # - "latest": use highest-numbered model_*.pt
    # - "best": prefer files containing "best" (fallback to latest)
    # - "<iter>": match model_<iter>.pt (or a file named exactly <iter>)
    if "tag" in locals() and tag not in (None, "", "latest"):
        if tag == "best":
            best_files = [f for f in all_pt_files if "best" in f.name.lower()]
            if best_files:
                best_files.sort(key=_iter_num)
                target = best_files[-1]
            else:
                print(f"[wandb] Warning: no 'best' checkpoint found, using latest model_*.pt")
                target = model_files[-1]
        else:
            matched = [
                f for f in all_pt_files
                if f.name == tag or f.name == f"model_{tag}.pt"
            ]
            if not matched:
                available = [f.name for f in all_pt_files]
                raise FileNotFoundError(
                    f"No checkpoint matching '{tag}' in wandb run {resume_path}. "
                    f"Available: {available}"
                )
            target = matched[0]
    else:
        target = model_files[-1]  # latest checkpoint

    # Download to local cache
    cache_dir = Path.home() / ".cache" / "wandb_checkpoints" / run_id
    cache_dir.mkdir(parents=True, exist_ok=True)
    local_path = str(cache_dir / target.name)

    if not Path(local_path).exists():
        print(f"[wandb] Downloading {target.name} from {entity}/{project}/{run_id}...")
        target.download(root=str(cache_dir), replace=True)
    else:
        print(f"[wandb] Using cached checkpoint: {local_path}")

    return local_path


def main(cfg: bcppo_configs.BCPPORunConfig):
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    distributed = int(os.getenv("WORLD_SIZE", "1")) > 1

    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = cfg.video
    args_cli.headless = cfg.headless
    args_cli.distributed = distributed
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab.utils.io import dump_yaml
    from isaaclab_tasks.utils import parse_env_cfg  # type: ignore
    import isaaclab_tasks  # noqa: F401
    from rsl_rl.runners import OnPolicyRunner
    import rsl_rl.runners.on_policy_runner as runner_module

    import sim_improvement.environments  # noqa: F401
    from sim_improvement.rl.test import (
        RslRlVecEnvWrapper, FrankaReachPPORunnerCfg,
        RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg,
    )
    from sim_improvement.rl.algs.bcppo import BC_PPO

    # Inject BC_PPO into the runner module so eval("BC_PPO") resolves
    runner_module.BC_PPO = BC_PPO

    # --- Agent config ---
    rl = cfg.rl
    agent_cfg = FrankaReachPPORunnerCfg(
        experiment_name=cfg.name,
        max_iterations=cfg.max_iterations,
        seed=cfg.seed,
        num_steps_per_env=rl.num_steps_per_env,
        policy=RslRlPpoActorCriticCfg(
            init_noise_std=rl.init_noise_std,
            actor_obs_normalization=rl.actor_obs_normalization,
            critic_obs_normalization=rl.critic_obs_normalization,
            actor_hidden_dims=list(rl.actor_hidden_dims),
            critic_hidden_dims=list(rl.critic_hidden_dims),
            activation=rl.activation,
        ),
        algorithm=RslRlPpoAlgorithmCfg(
            value_loss_coef=rl.value_loss_coef,
            use_clipped_value_loss=True,
            clip_param=rl.clip_param,
            entropy_coef=rl.entropy_coef,
            num_learning_epochs=rl.num_learning_epochs,
            num_mini_batches=rl.num_mini_batches,
            learning_rate=rl.learning_rate,
            schedule="adaptive",
            gamma=rl.gamma,
            lam=rl.lam,
            desired_kl=rl.desired_kl,
            max_grad_norm=rl.max_grad_norm,
        ),
    )
    agent_cfg_dict = agent_cfg.to_dict()
    agent_cfg_dict["algorithm"]["class_name"] = "BC_PPO"
    agent_cfg_dict["algorithm"]["bc_coefficient"] = rl.bc_coefficient
    agent_cfg_dict["algorithm"]["bc_batch_size"] = cfg.bc_batch_size

    # --- Env config ---
    env_cfg = parse_env_cfg(
        cfg.environment,
        device=cfg.device,
        num_envs=cfg.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(
        scene_path=cfg.scene_path,
        library_dir=cfg.library_dir,
    )

    # Optionally replace stochastic reset with dataset-based reset
    if cfg.curriculum:
        from isaaclab.managers import EventTermCfg as EventTerm
        from sim_improvement.environments.lbm.scenario_helper import ResetCurriculum
        env_cfg.events.reset_objects = EventTerm(
            func=ResetCurriculum,
            mode="reset",
            params={
                "dataset_path": cfg.dataset_path,
                "initial_frontier": cfg.curriculum_initial_frontier,
                "advance_threshold": cfg.curriculum_advance_threshold,
                "retreat_threshold": cfg.curriculum_retreat_threshold,
                "step_size": cfg.curriculum_step_size,
            },
        )
        print(f"[INFO] Using curriculum reset (frontier={cfg.curriculum_initial_frontier}) from {cfg.dataset_path}")
    elif cfg.reset_from_dataset:
        from isaaclab.managers import EventTermCfg as EventTerm
        from sim_improvement.environments.lbm.scenario_helper import ResetFromDataset
        env_cfg.events.reset_objects = EventTerm(
            func=ResetFromDataset,
            mode="reset",
            params={"dataset_path": cfg.dataset_path, "initial_only": cfg.reset_initial_only},
        )
        mode = "initial states only" if cfg.reset_initial_only else "all timesteps"
        print(f"[INFO] Using dataset-based reset ({mode}) from {cfg.dataset_path}")

    if distributed:
        env_cfg.sim.device = f"cuda:{local_rank}"
        agent_cfg_dict["device"] = f"cuda:{local_rank}"
        agent_cfg_dict["seed"] = agent_cfg_dict["seed"] + local_rank
        env_cfg.seed = agent_cfg_dict["seed"]

    # --- Logging ---
    log_dir = Path(cfg.log_dir)
    rl_dir = log_dir / "rl" / cfg.name
    _rank0 = int(os.getenv("RANK", "0")) == 0
    if _rank0 and not cfg.eval_only and not cfg.resume:
        if rl_dir.exists():
            if cfg.overwrite:
                shutil.rmtree(rl_dir)
            else:
                raise ValueError(f"RL log directory {rl_dir} already exists. Use --overwrite to overwrite.")
    rl_dir.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] Logging experiment in directory: {rl_dir}")
    env_cfg.log_dir = str(rl_dir)
    env_cfg.io_descriptors_output_dir = str(rl_dir / "io_descriptors")
    if not cfg.headless:
        env_cfg.sim.render_interval = 1

    # --- Create environment ---
    env = gym.make(
        cfg.environment, cfg=env_cfg, render_mode="rgb_array" if cfg.video else None
    )
    if cfg.video:
        video_kwargs = {
            "video_folder": rl_dir / "videos" / "train",
            "step_trigger": lambda step: step % cfg.video_interval == 0,
            "video_length": cfg.video_length,
            "disable_logger": True,
        }
        print("[INFO] Recording videos during training.")
        env = gym.wrappers.RecordVideo(env, **video_kwargs)
    env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

    # --- Create runner (will construct BC_PPO via eval) ---
    runner = OnPolicyRunner(
        env, agent_cfg_dict,
        log_dir=None if cfg.eval_only else str(rl_dir),
        device=agent_cfg_dict.get("device", agent_cfg.device),
    )
    # Initialize logging writer so BC can log to wandb.
    # Also set logger_type (needed by runner.save) — _prepare_logging_writer
    # skips this on non-rank-0 when disable_logs=True.
    runner._prepare_logging_writer()
    if not hasattr(runner, "logger_type"):
        runner.logger_type = runner.cfg.get("logger", "tensorboard")

    # Add note to wandb run
    if cfg.note:
        try:
            import wandb
            if wandb.run is not None:
                wandb.run.notes = cfg.note
                wandb.run.update()
                print(f"[wandb] Note: {cfg.note}")
        except Exception:
            pass

    if cfg.eval_only:
        if not cfg.resume:
            raise ValueError("--eval-only requires --resume pointing to a checkpoint or wandb run path.")
        ckpt_path = resolve_checkpoint(cfg.resume)
        print(f"[INFO] Eval-only mode — loading checkpoint: {ckpt_path}")
        runner.load(ckpt_path, load_optimizer=False)
        video_path = str(rl_dir / "videos" / "eval.mp4") if cfg.video else None
        sr = evaluate_policy(runner, num_episodes=cfg.num_envs, video_path=video_path)
        print(f"[Eval] Success rate: {sr:.1%}")
    else:
        is_rank0 = not runner.is_distributed or runner.gpu_global_rank == 0

        if cfg.resume:
            ckpt_path = resolve_checkpoint(cfg.resume)
            print(f"[INFO] Loading model checkpoint from: {ckpt_path}")
            runner.load(ckpt_path)

        # --- Load BC data into the algorithm's replay buffer ---
        trajectory_path = cfg.dataset_path
        if trajectory_path and Path(trajectory_path).exists():
            obs_key_order = runner.env.unwrapped.observation_manager.active_terms["policy"]
            print(f"[BC] Obs key order from env: {obs_key_order}")

            episodes = load_bc_data(
                trajectory_path,
                obs_key_order=obs_key_order,
                action_key_order=cfg.action_keys or None,
            )

            # Concatenate all episodes into flat tensors
            all_obs = torch.from_numpy(
                np.concatenate([ep[0] for ep in episodes], axis=0)
            ).float()
            all_actions = torch.from_numpy(
                np.concatenate([ep[1] for ep in episodes], axis=0)
            ).float()

            print(f"[BC] Loading {len(all_obs)} transitions into BC_PPO replay buffer")
            runner.alg.add_bc_data(all_obs, all_actions)

            # Pure BC warmstart before RL (rank 0 only, then sync)
            if cfg.bc_epochs > 0 and not cfg.resume:

                if is_rank0:
                    video_dir = str(rl_dir / "videos" / "bc_eval") if cfg.video else None

                    def _eval_callback(step):
                        vpath = f"{video_dir}/step_{step:06d}.mp4" if video_dir else None
                        sr = evaluate_policy(runner, num_episodes=1, video_path=vpath)
                        return {"eval_success_rate": sr}

                    _, bc_final_step = runner.alg.bc_warmstart(
                        num_epochs=cfg.bc_epochs,
                        learning_rate=cfg.bc_learning_rate,
                        writer=runner.writer,
                        eval_freq=cfg.bc_eval_freq,
                        eval_callback=_eval_callback,
                    )
                    bc_ckpt = str(rl_dir / "bc.pt")
                    runner.save(bc_ckpt)
                    print(f"[INFO] Saved BC checkpoint: {bc_ckpt}")
                else:
                    bc_final_step = 0

                if runner.is_distributed:
                    torch.distributed.barrier()
                    step_tensor = torch.tensor([bc_final_step], device=runner.device)
                    torch.distributed.broadcast(step_tensor, src=0)
                    bc_final_step = int(step_tensor.item())
                    if not is_rank0:
                        runner.load(str(rl_dir / "bc.pt"), load_optimizer=False)

                # Offset RL logging so steps continue after BC
                runner.current_learning_iteration = bc_final_step
        else:
            print(f"[WARN] No trajectory file at {trajectory_path}, running PPO without BC loss")

        # --- Dump configs ---
        if is_rank0:
            dump_yaml(str(rl_dir / "params" / "env.yaml"), env_cfg)
            dump_yaml(str(rl_dir / "params" / "agent.yaml"), agent_cfg)

        # --- Train ---
        runner.learn(
            num_learning_iterations=cfg.max_iterations, init_at_random_ep_len=True
        )

    env.close()
    simulation_app.close()


if __name__ == "__main__":
    cfg = bcppo_configs.cli()
    main(cfg)
