"""Evaluate a SAC checkpoint and report success rate.

Downloads a checkpoint from wandb (or uses a local path), rolls out
episodes with the default stochastic reset, reports success rate, and
plots a scatter of initial object XY positions colored by success/failure.

Usage:
    python scripts/rl/eval_sac.py kiwimanip_rlpd \
        --wandb-run entity/project/run_id \
        --wandb-file model_50000.pt \
        --sac.num-envs 2048 --num-eval-episodes 4096
"""

import argparse
import csv
import dataclasses
from pathlib import Path

import numpy as np
import torch
import gymnasium as gym
import tyro

from isaaclab.app import AppLauncher
import sim_improvement.rl.sac_configs as sac_configs

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False


@dataclasses.dataclass(frozen=True)
class EvalConfig:
    """Evaluation configuration."""

    sac: sac_configs.SACRunConfig
    wandb_run: str = ""
    """Wandb run path (entity/project/run_id). Overrides sac.resume if set."""
    wandb_file: str = ""
    """Filename in wandb to download (e.g. model_50000.pt). If empty, uses latest."""
    checkpoint: str = ""
    """Local checkpoint path. Overrides wandb_run if set."""
    num_eval_episodes: int = 4096
    """Total number of episodes to evaluate."""


def download_checkpoint(wandb_run: str, wandb_file: str = "") -> str:
    """Download a checkpoint from wandb, returning a local path."""
    import re
    import wandb

    parts = wandb_run.split("/")
    if len(parts) != 3:
        raise ValueError(f"wandb_run must be entity/project/run_id, got: {wandb_run}")
    entity, project, run_id = parts

    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")

    if wandb_file:
        files = [f for f in run.files() if f.name == wandb_file]
        if not files:
            raise FileNotFoundError(f"File '{wandb_file}' not found in wandb run {wandb_run}")
        target = files[0]
    else:
        files = [f for f in run.files() if f.name.startswith("model_") and f.name.endswith(".pt")]
        if not files:
            raise FileNotFoundError(f"No model_*.pt files found in wandb run {wandb_run}")

        def _iter_num(f):
            m = re.search(r"model_(\d+)\.pt", f.name)
            return int(m.group(1)) if m else -1
        files.sort(key=_iter_num)
        target = files[-1]

    cache_dir = Path.home() / ".cache" / "wandb_checkpoints" / run_id
    cache_dir.mkdir(parents=True, exist_ok=True)
    local_path = str(cache_dir / target.name)

    if not Path(local_path).exists():
        print(f"[wandb] Downloading {target.name} from {entity}/{project}/{run_id}...")
        target.download(root=str(cache_dir), replace=True)
    else:
        print(f"[wandb] Using cached: {local_path}")

    return local_path


def get_object_xy(env) -> dict[str, torch.Tensor]:
    """Get env-local XY positions of all rigid objects in the scene.

    Returns:
        Dict mapping object name to (num_envs, 2) tensor of XY positions.
    """
    from isaaclab.assets.rigid_object import RigidObject

    scene = env.unwrapped.scene
    origins = scene.env_origins  # (num_envs, 3)
    positions = {}
    for name in scene.rigid_objects:
        obj = scene.rigid_objects[name]
        if isinstance(obj, RigidObject):
            pos_local = obj.data.root_pos_w - origins  # (num_envs, 3)
            positions[name] = pos_local[:, :2].clone()  # XY only
    return positions


def plot_xy_scatter(
    obj_positions: dict[str, np.ndarray],
    success: np.ndarray,
    out_path: str,
):
    """Scatter plot of initial object XY positions, colored by success/failure."""
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    n_objects = len(obj_positions)
    if n_objects == 0:
        print("[WARN] No rigid objects found in scene, skipping XY plot.")
        return

    fig, axes = plt.subplots(1, n_objects, figsize=(6 * n_objects, 5), squeeze=False)

    for ax, (name, xy) in zip(axes[0], obj_positions.items()):
        fail_mask = ~success
        ax.scatter(xy[fail_mask, 0], xy[fail_mask, 1],
                   c="tomato", s=4, alpha=0.3, label="Fail", rasterized=True)
        ax.scatter(xy[success, 0], xy[success, 1],
                   c="seagreen", s=4, alpha=0.3, label="Success", rasterized=True)
        ax.set_xlabel("X (m)")
        ax.set_ylabel("Y (m)")
        ax.set_title(name)
        ax.set_aspect("equal")
        ax.legend(markerscale=3)

    sr = success.mean()
    fig.suptitle(f"Initial Object Positions (N={len(success)}, SR={sr:.1%})", y=1.02)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"[INFO] XY scatter plot saved to {out_path}")


def main(cfg: EvalConfig):
    sac_cfg = cfg.sac

    # >>>> Isaac Sim App Launcher <<<<
    parser = argparse.ArgumentParser()
    args_cli, _ = parser.parse_known_args()
    args_cli.enable_cameras = sac_cfg.video
    args_cli.headless = sac_cfg.headless
    app_launcher = AppLauncher(args_cli)
    simulation_app = app_launcher.app
    # >>>> Isaac Sim App Launcher <<<<

    from isaaclab_tasks.utils import parse_env_cfg
    import sim_improvement.environments  # noqa: F401
    from sim_improvement.rl.test import RslRlVecEnvWrapper
    from sim_improvement.rl.runners import OffPolicyRunner

    # --- Resolve checkpoint ---
    if cfg.checkpoint:
        ckpt_path = cfg.checkpoint
    elif cfg.wandb_run:
        ckpt_path = download_checkpoint(cfg.wandb_run, cfg.wandb_file)
    elif sac_cfg.resume:
        ckpt_path = sac_cfg.resume
    else:
        raise ValueError("Must provide --checkpoint, --wandb-run, or --sac.resume")
    print(f"[INFO] Checkpoint: {ckpt_path}")

    # --- Env setup (mirrors train_sac, but no curriculum/dataset resets) ---
    env_cfg = parse_env_cfg(
        sac_cfg.environment,
        device=sac_cfg.device,
        num_envs=sac_cfg.num_envs,
        use_fabric=True,
    )
    env_cfg.dynamic_setup(
        scene_path=sac_cfg.scene_path,
        library_dir=sac_cfg.library_dir,
    )

    log_dir = Path(sac_cfg.log_dir) / "rl" / sac_cfg.name
    log_dir.mkdir(parents=True, exist_ok=True)
    env_cfg.log_dir = str(log_dir)
    env_cfg.io_descriptors_output_dir = str(log_dir / "io_descriptors")
    if sac_cfg.video:
        env_cfg.sim.render_interval = 1

    env = gym.make(sac_cfg.environment, cfg=env_cfg,
                   render_mode="rgb_array" if sac_cfg.video else None)
    if sac_cfg.video:
        video_kwargs = {
            "video_folder": str(log_dir / "videos" / "eval"),
            "step_trigger": lambda step: step == 0,
            "video_length": sac_cfg.video_length,
            "disable_logger": True,
        }
        print("[INFO] Recording video.")
        env = gym.wrappers.RecordVideo(env, **video_kwargs)
    env = RslRlVecEnvWrapper(env)

    # --- Build runner & load checkpoint ---
    sac_alg_dict = dataclasses.asdict(sac_cfg.sac)
    train_cfg = {
        "seed": sac_cfg.seed,
        "device": sac_cfg.device,
        "algorithm": sac_alg_dict,
        "obs_groups": {
            "policy": sac_cfg.actor_obs_groups,
            "critic": sac_cfg.critic_obs_groups,
        },
        "save_interval": sac_cfg.save_interval,
        "log_interval": sac_cfg.log_interval,
        "max_iterations": sac_cfg.max_iterations,
        "experiment_name": sac_cfg.experiment_name or sac_cfg.name,
        "logger": "tensorboard",
        "wandb_project": sac_cfg.wandb_project,
        "num_steps_per_env": 1,
    }

    runner = OffPolicyRunner(env, train_cfg, log_dir=None, device=sac_cfg.device)
    runner.load(ckpt_path, load_optimizer=False)
    policy_fn = runner.get_inference_policy()

    # --- Run eval ---
    num_envs = env.num_envs
    max_steps = env.max_episode_length
    num_rounds = (cfg.num_eval_episodes + num_envs - 1) // num_envs
    total_episodes = num_rounds * num_envs

    # Per-object: list of (num_envs, 2) arrays across rounds
    all_obj_xy: dict[str, list[np.ndarray]] = {}
    all_success = []

    print(f"[INFO] Running {num_rounds} round(s) x {num_envs} envs = {total_episodes} episodes")
    print(f"[INFO] Max steps per episode: {max_steps}")

    for rnd in range(num_rounds):
        env.reset()
        obs_td = env.get_observations()

        # Capture initial object XY positions
        init_xy = get_object_xy(env)
        for name, xy in init_xy.items():
            all_obj_xy.setdefault(name, []).append(xy.cpu().numpy())

        env_done = torch.zeros(num_envs, device=env.device, dtype=torch.bool)
        env_success = torch.zeros(num_envs, device=env.device, dtype=torch.bool)

        for step in range(max_steps):
            with torch.no_grad():
                actions = policy_fn(obs_td)
            obs_td, _, dones, _ = env.step(actions)

            dones = dones.bool()
            if not dones.any():
                continue

            newly_done = dones & ~env_done
            if newly_done.any():
                success = env.unwrapped.termination_manager.get_term("success")
                env_success |= newly_done & success
            env_done |= dones

            if env_done.all():
                break

        all_success.append(env_success.cpu())

        sr = env_success.float().mean().item()
        done_count = env_done.sum().item()
        print(f"  Round {rnd + 1}/{num_rounds}: {done_count}/{num_envs} done, "
              f"success rate {sr:.1%}")

    # --- Results ---
    success = torch.cat(all_success).numpy().astype(bool)

    overall_sr = success.mean()
    print(f"\n{'=' * 50}")
    print(f"Overall success rate: {overall_sr:.1%} ({success.sum()}/{len(success)})")
    print(f"{'=' * 50}")

    # Save CSV (per-episode success + initial object positions)
    obj_names = sorted(all_obj_xy.keys())
    obj_xy_np = {name: np.concatenate(all_obj_xy[name], axis=0) for name in obj_names}

    csv_path = log_dir / "eval_results.csv"
    header = ["success"] + [f"{name}_x" for name in obj_names] + [f"{name}_y" for name in obj_names]
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        for i in range(len(success)):
            row = [int(success[i])]
            row += [f"{obj_xy_np[name][i, 0]:.4f}" for name in obj_names]
            row += [f"{obj_xy_np[name][i, 1]:.4f}" for name in obj_names]
            writer.writerow(row)
    print(f"[INFO] Results saved to {csv_path}")

    # Plot XY scatter
    plot_path = str(log_dir / "eval_initial_xy.png")
    plot_xy_scatter(obj_xy_np, success, plot_path)

    env.close()
    simulation_app.close()


if __name__ == "__main__":
    defaults = {
        k: (k, EvalConfig(sac=v))
        for k, v in sac_configs.get_sac_configs().items()
    }
    eval_cfg = tyro.extras.overridable_config_cli(defaults)
    main(eval_cfg)
