"""Off-policy runner for SAC and similar algorithms.

Mirrors rsl_rl's OnPolicyRunner interface but handles off-policy specifics:
- Replay buffer instead of rollout storage
- No compute_returns
- Per-step updates (not per-epoch)
- SAC-specific save/load

Supports distributed training, wandb/tensorboard logging, and checkpointing.
No dependency on rsl_rl internals (Logger, RolloutStorage, etc.).
"""

from __future__ import annotations

import os
import statistics
import time
from collections import deque

import torch
import tqdm
from tensordict import TensorDict

from sim_improvement.rl.algs.sac import SAC


# ---------------------------------------------------------------------------
# Lightweight logger (replaces rsl_rl.utils.logger.Logger)
# ---------------------------------------------------------------------------

class _Logger:
    """Minimal training logger with wandb/tensorboard/console support."""

    def __init__(
        self,
        log_dir: str | None,
        cfg: dict,
        env_cfg,
        num_envs: int,
        is_distributed: bool,
        gpu_world_size: int,
        gpu_global_rank: int,
        device: str,
    ):
        self.log_dir = log_dir
        self.cfg = cfg
        self.num_envs = num_envs
        self.gpu_world_size = gpu_world_size
        self.device = device
        self.tot_timesteps = 0
        self.tot_time = 0.0

        self.ep_extras: list[dict] = []
        self.rewbuffer: deque = deque(maxlen=100)
        self.lenbuffer: deque = deque(maxlen=100)
        self.cur_reward_sum = torch.zeros(num_envs, dtype=torch.float, device=device)
        self.cur_episode_length = torch.zeros(num_envs, dtype=torch.float, device=device)

        self.disable_logs = is_distributed and gpu_global_rank != 0
        self.writer = None
        self.logger_type = "tensorboard"
        self._init_writer()

        if self.writer and not self.disable_logs and self.logger_type in ("wandb", "neptune"):
            try:
                self.writer.store_config(env_cfg, cfg, cfg.get("algorithm", {}), {})
            except Exception:
                pass

    # --- Writer setup ---

    def _init_writer(self):
        if self.log_dir is None or self.disable_logs:
            return
        self.logger_type = self.cfg.get("logger", "tensorboard").lower()
        if self.logger_type == "wandb":
            from rsl_rl.utils.wandb_utils import WandbSummaryWriter
            self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
        elif self.logger_type == "tensorboard":
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
        else:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)

    # --- Per-step tracking ---

    def process_env_step(self, rewards, dones, extras, _intrinsic=None):
        if self.log_dir is None:
            return
        if "episode" in extras:
            self.ep_extras.append(extras["episode"])
        elif "log" in extras:
            self.ep_extras.append(extras["log"])

        self.cur_reward_sum += rewards
        self.cur_episode_length += 1

        new_ids = (dones > 0).nonzero(as_tuple=False)
        self.rewbuffer.extend(self.cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
        self.lenbuffer.extend(self.cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
        self.cur_reward_sum[new_ids] = 0
        self.cur_episode_length[new_ids] = 0

    # --- Periodic log ---

    def log(self, *, it, start_it, total_it, collect_time, learn_time, loss_dict, learning_rate, **_kw):
        if self.log_dir is None or self.disable_logs:
            return

        collection_size = self.num_envs * self.gpu_world_size
        iteration_time = collect_time + learn_time
        self.tot_timesteps += collection_size
        self.tot_time += iteration_time

        # Episode extras
        if self.ep_extras and self.writer:
            for key in self.ep_extras[0]:
                vals = [ep[key] for ep in self.ep_extras if key in ep]
                if not vals:
                    continue
                t = torch.tensor(vals, device=self.device) if not isinstance(vals[0], torch.Tensor) else torch.cat(
                    [v.unsqueeze(0) if v.dim() == 0 else v for v in [v.to(self.device) for v in vals]]
                )
                tag = key if "/" in key else f"Episode/{key}"
                self.writer.add_scalar(tag, t.float().mean().item(), it)
        self.ep_extras.clear()

        # Losses
        if self.writer:
            for key, value in loss_dict.items():
                self.writer.add_scalar(f"Loss/{key}", value, it)
            self.writer.add_scalar("Loss/learning_rate", learning_rate, it)
            fps = int(collection_size / max(iteration_time, 1e-6))
            self.writer.add_scalar("Perf/total_fps", fps, it)
            self.writer.add_scalar("Perf/collection_time", collect_time, it)
            self.writer.add_scalar("Perf/learning_time", learn_time, it)
            if self.rewbuffer:
                self.writer.add_scalar("Train/mean_reward", statistics.mean(self.rewbuffer), it)
                self.writer.add_scalar("Train/mean_episode_length", statistics.mean(self.lenbuffer), it)

        # Console
        fps = int(collection_size / max(iteration_time, 1e-6))
        pad = 35
        s = f"{'=' * 70}\n"
        s += f"  Iteration {it}/{total_it}\n"
        s += f"{'Total steps:':>{pad}} {self.tot_timesteps}\n"
        s += f"{'FPS:':>{pad}} {fps}\n"
        for k, v in loss_dict.items():
            s += f"{f'{k}:':>{pad}} {v:.4f}\n"
        if self.rewbuffer:
            s += f"{'Mean reward:':>{pad}} {statistics.mean(self.rewbuffer):.2f}\n"
            s += f"{'Mean ep length:':>{pad}} {statistics.mean(self.lenbuffer):.1f}\n"
        done_it = max(it + 1 - start_it, 1)
        eta = self.tot_time / done_it * (total_it - start_it - done_it)
        s += f"{'ETA:':>{pad}} {time.strftime('%H:%M:%S', time.gmtime(eta))}\n"
        print(s)

    def save_model(self, path, it):
        if self.writer and not self.disable_logs and self.logger_type in ("wandb", "neptune"):
            self.writer.save_model(path, it)


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------

class OffPolicyRunner:
    """Off-policy training runner compatible with rsl_rl VecEnv wrappers."""

    def __init__(
        self,
        env,
        train_cfg: dict,
        log_dir: str | None = None,
        device: str = "cpu",
    ):
        self.cfg = train_cfg
        self.alg_cfg = train_cfg["algorithm"]
        self.device = device
        self.env = env

        # Multi-GPU
        self._configure_multi_gpu()

        # Query observations to get dimensions
        obs = self.env.get_observations()
        actor_obs, critic_obs = self._resolve_obs(obs)
        actor_obs_dim = actor_obs.shape[-1]
        critic_obs_dim = critic_obs.shape[-1]

        # Construct algorithm
        self.alg = self._construct_algorithm(
            actor_obs_dim=actor_obs_dim,
            critic_obs_dim=critic_obs_dim,
            n_act=self.env.num_actions,
            num_envs=self.env.num_envs,
        )

        # Logger
        self.logger = _Logger(
            log_dir=log_dir,
            cfg=self.cfg,
            env_cfg=self.env.cfg,
            num_envs=self.env.num_envs,
            is_distributed=self.is_distributed,
            gpu_world_size=self.gpu_world_size,
            gpu_global_rank=self.gpu_global_rank,
            device=self.device,
        )

        self.current_learning_iteration = 0

    # ------------------------------------------------------------------
    # Training
    # ------------------------------------------------------------------

    def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False):
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(
                self.env.episode_length_buf, high=int(self.env.max_episode_length)
            )

        obs_td = self.env.get_observations().to(self.device)
        actor_obs, critic_obs = self._resolve_obs(obs_td)
        self.alg.train()

        if self.is_distributed:
            print(f"Synchronizing parameters for rank {self.gpu_global_rank}...")
            self.alg.broadcast_parameters()

        start_it = self.current_learning_iteration
        total_it = start_it + num_learning_iterations
        pbar = tqdm.tqdm(total=total_it, initial=start_it, disable=(self.gpu_global_rank != 0))
        save_interval = self.cfg.get("save_interval", 500)
        log_interval = self.cfg.get("log_interval", 10)

        for it in range(start_it, total_it):
            t0 = time.time()

            # --- Collect one transition ---
            with torch.no_grad():
                actions = self.alg.act(actor_obs, critic_obs)
            next_obs_td, rewards, dones, extras = self.env.step(actions.to(self.env.device))
            next_obs_td = next_obs_td.to(self.device)
            rewards = rewards.to(self.device)
            dones = dones.to(self.device)
            next_actor_obs, next_critic_obs = self._resolve_obs(next_obs_td)

            truncations = extras.get("time_outs", torch.zeros_like(dones))

            self.alg.store_transition(
                obs=actor_obs,
                actions=actions,
                rewards=rewards,
                next_obs=next_actor_obs,
                dones=dones,
                truncations=truncations,
                critic_obs=critic_obs if critic_obs is not actor_obs else None,
                next_critic_obs=next_critic_obs if next_critic_obs is not next_actor_obs else None,
            )

            self.logger.process_env_step(rewards, dones, extras)

            actor_obs = next_actor_obs
            critic_obs = next_critic_obs

            collect_time = time.time() - t0
            t1 = time.time()

            # --- Update ---
            loss_dict = self.alg.update()
            learn_time = time.time() - t1

            self.current_learning_iteration = it

            # --- Logging ---
            if it % log_interval == 0:
                self.logger.log(
                    it=it,
                    start_it=start_it,
                    total_it=total_it,
                    collect_time=collect_time,
                    learn_time=learn_time,
                    loss_dict=loss_dict,
                    learning_rate=self.alg.actor_optimizer.param_groups[0]["lr"],
                )

            # --- Checkpointing ---
            if save_interval > 0 and it > 0 and it % save_interval == 0:
                if self.logger.log_dir is not None:
                    self.save(os.path.join(self.logger.log_dir, f"model_{it}.pt"))

            pbar.update(1)

        pbar.close()
        if self.logger.log_dir is not None and not self.logger.disable_logs:
            self.save(os.path.join(self.logger.log_dir, f"model_{self.current_learning_iteration}.pt"))

    # ------------------------------------------------------------------
    # Save / Load
    # ------------------------------------------------------------------

    def save(self, path: str):
        sd = self.alg.state_dict()
        sd["iter"] = self.current_learning_iteration
        torch.save(sd, path)
        self.logger.save_model(path, self.current_learning_iteration)

    def load(self, path: str, load_optimizer: bool = True):
        sd = torch.load(path, weights_only=False, map_location=self.device)
        self.alg.load_state_dict(sd)
        if "iter" in sd:
            self.current_learning_iteration = sd["iter"]
        return sd

    # ------------------------------------------------------------------
    # Inference
    # ------------------------------------------------------------------

    def get_inference_policy(self, device: str | None = None):
        self.alg.eval()

        device = device or self.device
        policy_fn = self.alg.get_inference_policy(device)
        obs_groups = self.cfg.get("obs_groups", {})
        actor_keys = obs_groups.get("policy", ["policy"])

        def wrapped_policy(obs_td):
            if isinstance(obs_td, (TensorDict, dict)):
                flat = torch.cat([obs_td[k] for k in actor_keys if k in obs_td], dim=-1)
            else:
                flat = obs_td
            return policy_fn(flat)

        return wrapped_policy

    # ------------------------------------------------------------------
    # Multi-GPU setup
    # ------------------------------------------------------------------

    def _configure_multi_gpu(self):
        self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1"))
        self.is_distributed = self.gpu_world_size > 1

        if not self.is_distributed:
            self.gpu_local_rank = 0
            self.gpu_global_rank = 0
            self.multi_gpu_cfg = None
            return

        self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.gpu_global_rank = int(os.getenv("RANK", "0"))
        self.multi_gpu_cfg = {
            "global_rank": self.gpu_global_rank,
            "local_rank": self.gpu_local_rank,
            "world_size": self.gpu_world_size,
        }

        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group(
                backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size
            )
        torch.cuda.set_device(self.gpu_local_rank)

    # ------------------------------------------------------------------
    # Algorithm construction
    # ------------------------------------------------------------------

    def _construct_algorithm(
        self, actor_obs_dim: int, critic_obs_dim: int, n_act: int, num_envs: int
    ) -> SAC:
        alg_kwargs = {k: v for k, v in self.alg_cfg.items() if k != "class_name"}
        return SAC(
            actor_obs_dim=actor_obs_dim,
            critic_obs_dim=critic_obs_dim,
            n_act=n_act,
            num_envs=num_envs,
            device=self.device,
            multi_gpu_cfg=self.multi_gpu_cfg,
            **alg_kwargs,
        )

    # ------------------------------------------------------------------
    # Observation helpers
    # ------------------------------------------------------------------

    def _resolve_obs(self, obs: TensorDict | dict) -> tuple[torch.Tensor, torch.Tensor]:
        """Extract flat actor and critic observation tensors from env obs dict."""
        obs_groups = self.cfg.get("obs_groups", {})
        actor_keys = obs_groups.get("policy", ["policy"])
        critic_keys = obs_groups.get("critic", actor_keys)

        actor_obs = torch.cat([obs[k] for k in actor_keys if k in obs], dim=-1)
        critic_obs = torch.cat([obs[k] for k in critic_keys if k in obs], dim=-1)
        return actor_obs, critic_obs
