"""Evaluate a checkpoint and report success rate by episode progress %.

Supports both BC+PPO and SAC checkpoints.

Usage (bcppo):
    python scripts/rl/eval_by_timestep.py bcppo:kiwimanip_relative_ik \
        --bcppo.eval-only --bcppo.resume <checkpoint_or_wandb_path> \
        --bcppo.num-envs 2048 --num-rounds 1 --num-bins 10

Usage (sac):
    python scripts/rl/eval_by_timestep.py sac:kiwimanip_sac \
        --sac.eval-only --sac.resume <checkpoint_or_wandb_path> \
        --sac.num-envs 2048 --num-rounds 1 --num-bins 10

Each round resets all num_envs in parallel and runs them to completion,
giving num_envs data points per round. Total episodes = num_rounds * num_envs.

Reset progress is stored as timestep / episode_length (0.0 = start, 1.0 = end).
"""

import argparse
import csv
import dataclasses
import torch
import numpy as np
import gymnasium as gym
from pathlib import Path
from typing import Union

from isaaclab.app import AppLauncher
import sim_improvement.rl.bcppo_configs as bcppo_configs
import sim_improvement.rl.sac_configs as sac_configs


def resolve_checkpoint(resume_path: str) -> str:
    """Resolve a checkpoint path, supporting wandb run paths."""
    import re as _re

    if Path(resume_path).exists():
        return resume_path

    parts = resume_path.split("/")
    if len(parts) != 3:
        raise FileNotFoundError(
            f"Resume path '{resume_path}' is not a local file and doesn't "
            f"look like a wandb run path (entity/project/run_id)."
        )

    entity, project, run_and_tag = parts
    run_id = run_and_tag.split(":")[0]

    import wandb
    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")
    files = [f for f in run.files() if f.name.startswith("model_") and f.name.endswith(".pt")]
    if not files:
        raise FileNotFoundError(f"No model_*.pt files found in wandb run {resume_path}")

    def _iter_num(f):
        m = _re.search(r"model_(\d+)\.pt", f.name)
        return int(m.group(1)) if m else -1

    files.sort(key=_iter_num)
    target = files[-1]

    cache_dir = Path.home() / ".cache" / "wandb_checkpoints" / run_id
    cache_dir.mkdir(parents=True, exist_ok=True)
    local_path = str(cache_dir / target.name)

    if not Path(local_path).exists():
        print(f"[wandb] Downloading {target.name} from {entity}/{project}/{run_id}...")
        target.download(root=str(cache_dir), replace=True)
    else:
        print(f"[wandb] Using cached checkpoint: {local_path}")

    return local_path

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False


def plot_results(progress: np.ndarray, successes: np.ndarray, values: np.ndarray,
                 num_bins: int, out_path: str):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    bin_edges = np.linspace(0.0, 1.0, num_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    bin_width = bin_edges[1] - bin_edges[0]
    total_counts, _ = np.histogram(progress, bins=bin_edges)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

    # Top: average value function at reset state per bin
    value_sums, _ = np.histogram(progress, bins=bin_edges, weights=values)
    avg_values = np.where(total_counts > 0, value_sums / total_counts, 0)
    bars1 = ax1.bar(bin_centers, avg_values, width=bin_width * 0.9, color="#f0ad4e",
                    edgecolor="black", linewidth=0.5)
    for bar, v, n in zip(bars1, avg_values, total_counts):
        if n > 0:
            ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                     f"n={n}", ha="center", va="bottom", fontsize=7)
    ax1.set_ylabel("Avg Value (V)")
    ax1.set_title(f"Value Function & Success Rate by Reset Progress (n={len(progress)})")

    # Bottom: success rate per bin
    prog_succ = progress[successes]
    success_counts, _ = np.histogram(prog_succ, bins=bin_edges)
    rates = np.where(total_counts > 0, success_counts / total_counts * 100, 0)
    bars2 = ax2.bar(bin_centers, rates, width=bin_width * 0.9, color="steelblue",
                    edgecolor="black", linewidth=0.5)
    for bar, rate, n in zip(bars2, rates, total_counts):
        if n > 0:
            ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1,
                     f"n={n}", ha="center", va="bottom", fontsize=7)
    ax2.set_ylabel("Success Rate (%)")
    ax2.set_xlabel("Reset Progress (fraction of episode)")
    ax2.set_xlim(0, 1)
    ax2.set_ylim(0, min(100, rates.max() + 15) if rates.max() > 0 else 10)

    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()
    print(f"[INFO] Plot saved to {out_path}")


def _setup_env(cfg, simulation_app):
    """Create the evaluation environment (shared between bcppo and sac)."""
    from isaaclab_tasks.utils import parse_env_cfg
    from isaaclab.managers import EventTermCfg as EventTerm

    import sim_improvement.environments  # noqa: F401
    from sim_improvement.rl.test import RslRlVecEnvWrapper
    from sim_improvement.environments.lbm.scenario_helper import ResetFromDataset

    env_cfg = parse_env_cfg(
        cfg.environment,
        device=cfg.device,
        num_envs=cfg.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(
        scene_path=cfg.scene_path,
        library_dir=cfg.library_dir,
    )

    # Configure reset based on config values
    if cfg.reset_from_dataset:
        initial_only = cfg.reset_initial_only
        env_cfg.events.reset_objects = EventTerm(
            func=ResetFromDataset,
            mode="reset",
            params={"dataset_path": cfg.dataset_path, "initial_only": initial_only},
        )
        mode = "initial states only" if initial_only else "all timesteps"
        print(f"[INFO] Using dataset-based reset ({mode}) from {cfg.dataset_path}")
    else:
        print("[INFO] Using default stochastic reset (reset_from_dataset=False)")

    log_dir = Path(cfg.log_dir)
    rl_dir = log_dir / "rl" / cfg.name
    rl_dir.mkdir(parents=True, exist_ok=True)
    env_cfg.log_dir = str(rl_dir)
    env_cfg.io_descriptors_output_dir = str(rl_dir / "io_descriptors")
    if cfg.video:
        env_cfg.sim.render_interval = 1

    env = gym.make(cfg.environment, cfg=env_cfg, render_mode="rgb_array" if cfg.video else None)
    if cfg.video:
        video_kwargs = {
            "video_folder": str(rl_dir / "videos" / "eval"),
            "step_trigger": lambda step: step == 0,
            "video_length": getattr(cfg, "video_length", 200),
            "disable_logger": True,
        }
        print("[INFO] Recording video.")
        env = gym.wrappers.RecordVideo(env, **video_kwargs)
    env = RslRlVecEnvWrapper(env, clip_actions=getattr(cfg, 'clip_actions', 1.0))

    return env, rl_dir


def _build_bcppo_runner(cfg: bcppo_configs.BCPPORunConfig, env):
    """Build OnPolicyRunner with BC_PPO for bcppo configs."""
    from rsl_rl.runners import OnPolicyRunner
    import rsl_rl.runners.on_policy_runner as runner_module
    from sim_improvement.rl.test import FrankaReachPPORunnerCfg
    from sim_improvement.rl.algs.bcppo import BC_PPO

    runner_module.BC_PPO = BC_PPO

    agent_cfg = FrankaReachPPORunnerCfg(
        experiment_name=cfg.name,
        max_iterations=cfg.max_iterations,
        seed=cfg.seed,
    )
    agent_cfg_dict = agent_cfg.to_dict()
    agent_cfg_dict["algorithm"]["class_name"] = "BC_PPO"
    agent_cfg_dict["algorithm"]["bc_coefficient"] = 0.1
    agent_cfg_dict["algorithm"]["bc_batch_size"] = cfg.bc_batch_size

    runner = OnPolicyRunner(
        env, agent_cfg_dict,
        log_dir=None,
        device=agent_cfg_dict.get("device", agent_cfg.device),
    )
    return runner


def _build_sac_runner(cfg: sac_configs.SACRunConfig, env):
    """Build OffPolicyRunner for SAC configs."""
    from sim_improvement.rl.runners import OffPolicyRunner

    sac_alg_dict = dataclasses.asdict(cfg.sac)
    train_cfg = {
        "seed": cfg.seed,
        "device": cfg.device,
        "algorithm": sac_alg_dict,
        "obs_groups": {
            "policy": cfg.actor_obs_groups,
            "critic": cfg.critic_obs_groups,
        },
        "save_interval": cfg.save_interval,
        "log_interval": cfg.log_interval,
        "max_iterations": cfg.max_iterations,
        "experiment_name": cfg.experiment_name or cfg.name,
        "logger": cfg.logger,
        "wandb_project": cfg.wandb_project,
        "num_steps_per_env": 1,
    }

    runner = OffPolicyRunner(
        env, train_cfg,
        log_dir=None,
        device=cfg.device,
    )
    return runner


def _get_bcppo_policy_fns(runner):
    """Return (action_fn, value_fn) for bcppo."""
    policy = runner.alg.policy
    policy.eval()

    def action_fn(obs_td):
        actor_obs = policy.get_actor_obs(obs_td)
        actor_obs = policy.actor_obs_normalizer(actor_obs)
        return policy.actor(actor_obs)

    def value_fn(obs_td):
        return policy.evaluate(obs_td).squeeze(-1)

    return action_fn, value_fn


def _get_sac_policy_fns(runner):
    """Return (action_fn, value_fn) for SAC.

    Value is estimated as min Q(s, a) where a is the deterministic policy action.
    """
    inference_fn = runner.get_inference_policy()
    alg = runner.alg
    alg.eval()

    obs_groups = runner.cfg.get("obs_groups", {})
    actor_keys = obs_groups.get("policy", ["policy"])
    critic_keys = obs_groups.get("critic", actor_keys)

    def _flat_obs(obs_td, keys):
        from tensordict import TensorDict
        if isinstance(obs_td, (TensorDict, dict)):
            return torch.cat([obs_td[k] for k in keys if k in obs_td], dim=-1)
        return obs_td

    def action_fn(obs_td):
        return inference_fn(obs_td)

    def value_fn(obs_td):
        actor_obs = _flat_obs(obs_td, actor_keys)
        critic_obs = _flat_obs(obs_td, critic_keys)
        # Get deterministic action
        if alg.obs_normalization:
            norm_actor_obs = alg.obs_normalizer(actor_obs, update=False)
        else:
            norm_actor_obs = actor_obs
        actions = alg.actor.explore(norm_actor_obs, deterministic=True)
        # Compute Q-values and take min across ensemble
        if alg.obs_normalization:
            norm_critic_obs = alg.critic_obs_normalizer(critic_obs, update=False)
        else:
            norm_critic_obs = critic_obs
        q_logits = alg.qnet_target(norm_critic_obs, actions)  # (num_q, batch, num_atoms)
        q_probs = torch.softmax(q_logits, dim=-1)
        q_values = alg.qnet_target.get_value(q_probs)  # (num_q, batch)
        return q_values.min(dim=0).values  # (batch,)

    return action_fn, value_fn


def main(cfg: Union[bcppo_configs.BCPPORunConfig, sac_configs.SACRunConfig],
         num_rounds: int = 1, num_bins: int = 10):
    is_sac = isinstance(cfg, sac_configs.SACRunConfig)

    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = cfg.video
    args_cli.headless = cfg.headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from sim_improvement.environments.lbm.scenario_helper import ResetFromDataset

    env, rl_dir = _setup_env(cfg, simulation_app)

    # --- Create runner & load checkpoint ---
    if is_sac:
        runner = _build_sac_runner(cfg, env)
    else:
        runner = _build_bcppo_runner(cfg, env)

    if not cfg.resume:
        import re
        ckpts = sorted(
            rl_dir.glob("model_*.pt"),
            key=lambda p: int(m.group(1)) if (m := re.search(r"model_(\d+)\.pt", p.name)) else -1,
        )
        if not ckpts:
            raise ValueError(f"No --resume specified and no model_*.pt found in {rl_dir}")
        ckpt_path = str(ckpts[-1])
        print(f"[INFO] No checkpoint specified, using latest: {ckpt_path}")
    else:
        ckpt_path = resolve_checkpoint(cfg.resume)
    print(f"[INFO] Loading checkpoint: {ckpt_path}")
    runner.load(ckpt_path, load_optimizer=False)

    # --- Find the ResetFromDataset instance (if dataset reset is enabled) ---
    reset_term = None
    if cfg.reset_from_dataset:
        for term_cfg in env.unwrapped.event_manager._mode_term_cfgs.get("reset", []):
            if isinstance(term_cfg.func, ResetFromDataset):
                reset_term = term_cfg.func
                break
        if reset_term is None:
            raise RuntimeError("reset_from_dataset=True but could not find ResetFromDataset in event manager")
        print(f"[INFO] Dataset has {reset_term._num_states} states, "
              f"progress range [{reset_term._progress.min().item():.2f}, "
              f"{reset_term._progress.max().item():.2f}]")
    else:
        print("[INFO] Stochastic reset — progress tracking disabled (all progress values will be 0)")

    # --- Get policy functions ---
    if is_sac:
        action_fn, value_fn = _get_sac_policy_fns(runner)
    else:
        action_fn, value_fn = _get_bcppo_policy_fns(runner)

    # --- Run eval rounds ---
    num_envs = env.num_envs
    max_steps = env.max_episode_length
    total_episodes = num_rounds * num_envs

    all_progress = torch.zeros(total_episodes, device=env.device, dtype=torch.float32)
    all_success = torch.zeros(total_episodes, device=env.device, dtype=torch.bool)
    all_values = torch.zeros(total_episodes, device=env.device, dtype=torch.float32)

    print(f"[INFO] Running {num_rounds} round(s) x {num_envs} envs = {total_episodes} episodes")

    for rnd in range(num_rounds):
        env.reset()
        obs_td = env.get_observations()

        if reset_term is not None:
            reset_prog = reset_term._last_sampled_progress.clone()  # (num_envs,)
        else:
            reset_prog = torch.zeros(num_envs, device=env.device)

        # Compute value function estimate at the reset state
        with torch.no_grad():
            reset_values = value_fn(obs_td)  # (num_envs,)

        env_done = torch.zeros(num_envs, device=env.device, dtype=torch.bool)
        env_success = torch.zeros(num_envs, device=env.device, dtype=torch.bool)

        for step in range(max_steps):
            with torch.no_grad():
                actions = action_fn(obs_td)

            obs_td, _, dones, _ = env.step(actions)
            last_success = env.unwrapped.termination_manager.get_term("success")
            dones = dones.bool()

            if dones.any():
                newly_done = dones & ~env_done
                env_done |= dones
                env_success |= newly_done & last_success

            if env_done.all():
                break

        start = rnd * num_envs
        all_progress[start:start + num_envs] = reset_prog
        all_success[start:start + num_envs] = env_success
        all_values[start:start + num_envs] = reset_values

        round_sr = env_success.float().mean().item()
        print(f"  Round {rnd + 1}/{num_rounds}: {env_done.sum().item()}/{num_envs} done, "
              f"success rate {round_sr:.1%}")

    # --- Analyze results ---
    progress = all_progress.cpu().numpy()
    successes = all_success.cpu().numpy()
    values = all_values.cpu().numpy()

    bin_edges = np.linspace(0.0, 1.0, num_bins + 1)

    print(f"\n{'='*60}")
    print(f"Success rate by reset progress ({total_episodes} episodes)")
    print(f"{'='*60}")
    print(f"{'Progress Range':>20s}  {'Success':>8s}  {'Total':>6s}  {'Rate':>8s}")
    print(f"{'-'*20}  {'-'*8}  {'-'*6}  {'-'*8}")

    for i in range(num_bins):
        lo, hi = bin_edges[i], bin_edges[i + 1]
        mask = (progress >= lo) & (progress < hi)
        total = mask.sum()
        if total == 0:
            continue
        succ = successes[mask].sum()
        rate = succ / total
        print(f"  [{lo:.0%}, {hi:.0%}){' ':>10s}  {succ:8d}  {total:6d}  {rate:8.1%}")

    overall = successes.mean()
    print(f"{'-'*20}  {'-'*8}  {'-'*6}  {'-'*8}")
    print(f"  {'Overall':>18s}  {successes.sum():8d}  {total_episodes:6d}  {overall:8.1%}")

    # Save CSV
    csv_path = rl_dir / "eval_by_timestep.csv"
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["reset_progress", "success", "value"])
        for p, s, v in zip(progress, successes, values):
            writer.writerow([f"{p:.4f}", int(s), f"{v:.4f}"])
    print(f"[INFO] Raw results saved to {csv_path}")

    # Plot
    plot_path = str(rl_dir / "eval_by_timestep.png")
    plot_results(progress, successes, values, num_bins, plot_path)

    env.close()
    simulation_app.close()


if __name__ == "__main__":
    import tyro

    @dataclasses.dataclass(frozen=True)
    class EvalBCPPOConfig:
        """Eval with a BC+PPO checkpoint."""
        bcppo: bcppo_configs.BCPPORunConfig
        num_rounds: int = 1
        num_bins: int = 10

    @dataclasses.dataclass(frozen=True)
    class EvalSACConfig:
        """Eval with a SAC checkpoint."""
        sac: sac_configs.SACRunConfig
        num_rounds: int = 1
        num_bins: int = 10

    # Build defaults from both config registries
    bcppo_defaults = {
        f"bcppo:{k}": (f"bcppo:{k}", EvalBCPPOConfig(bcppo=v))
        for k, v in bcppo_configs.get_bcppo_configs().items()
    }
    sac_defaults = {
        f"sac:{k}": (f"sac:{k}", EvalSACConfig(sac=v))
        for k, v in sac_configs.get_sac_configs().items()
    }
    defaults = {**bcppo_defaults, **sac_defaults}

    eval_cfg = tyro.extras.overridable_config_cli(defaults)

    if isinstance(eval_cfg, EvalBCPPOConfig):
        main(eval_cfg.bcppo, num_rounds=eval_cfg.num_rounds, num_bins=eval_cfg.num_bins)
    else:
        main(eval_cfg.sac, num_rounds=eval_cfg.num_rounds, num_bins=eval_cfg.num_bins)
