"""Standalone SAC algorithm.

Adapted from holosoma FastSAC (https://github.com/amazon-far/holosoma).
No rsl_rl dependency — works with the OffPolicyRunner directly.
"""

from __future__ import annotations

import math
from contextlib import contextmanager

import torch
import torch.nn as nn
import torch.nn.functional as F

from sim_improvement.rl.algs.sac.networks import Actor, Critic
from sim_improvement.rl.algs.sac.replay_buffer import DemoReplayBuffer, EmpiricalNormalization, SimpleReplayBuffer


class SAC:
    """Soft Actor-Critic with distributional critics.

    Constructed directly with observation/action dimensions — no rsl_rl
    ActorCritic or RolloutStorage required.
    """

    def __init__(
        self,
        actor_obs_dim: int,
        critic_obs_dim: int,
        n_act: int,
        num_envs: int,
        device: str = "cpu",
        # Network architecture
        actor_hidden_dim: int = 512,
        critic_hidden_dim: int = 768,
        use_tanh: bool = True,
        use_layer_norm: bool = True,
        log_std_max: float = 0.0,
        log_std_min: float = -5.0,
        num_atoms: int = 101,
        v_min: float = -20.0,
        v_max: float = 20.0,
        num_q_networks: int = 2,
        # Optimizers
        actor_learning_rate: float = 3e-4,
        critic_learning_rate: float = 3e-4,
        alpha_learning_rate: float = 3e-4,
        weight_decay: float = 0.001,
        # SAC core
        alpha_init: float = 0.001,
        target_entropy_ratio: float = 0.0,
        use_autotune: bool = True,
        gamma: float = 0.97,
        tau: float = 0.125,
        # Replay / training schedule
        buffer_size: int = 1024,
        batch_size: int = 8192,
        num_updates: int = 8,
        policy_frequency: int = 4,
        learning_starts: int = 10,
        num_steps: int = 1,
        max_grad_norm: float = 0.0,
        obs_normalization: bool = True,
        # AMP / compile
        amp: bool = True,
        amp_dtype: str = "bf16",
        compile: bool = True,
        # Multi-GPU
        multi_gpu_cfg: dict | None = None,
        # RLPD demo buffer
        demo_buffer: DemoReplayBuffer | None = None,
        demo_ratio: float = 0.0,
    ):
        self.device = device
        self.actor_obs_dim = actor_obs_dim
        self.critic_obs_dim = critic_obs_dim
        self.n_act = n_act
        self.num_envs = num_envs

        # Store config
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.num_updates = num_updates
        self.policy_frequency = policy_frequency
        self.learning_starts = learning_starts
        self.max_grad_norm = max_grad_norm
        self.obs_normalization = obs_normalization
        self.use_autotune = use_autotune
        self.amp_enabled = amp
        self.amp_dtype = amp_dtype
        self.global_step = 0

        # Multi-GPU
        self.is_multi_gpu = multi_gpu_cfg is not None
        if self.is_multi_gpu:
            self.gpu_global_rank = multi_gpu_cfg.get("global_rank", 0)
            self.gpu_world_size = multi_gpu_cfg.get("world_size", 1)
        else:
            self.gpu_global_rank = 0
            self.gpu_world_size = 1

        # Observation normalization
        if obs_normalization:
            self.obs_normalizer: nn.Module = EmpiricalNormalization(shape=actor_obs_dim, device=device)
            self.critic_obs_normalizer: nn.Module = EmpiricalNormalization(shape=critic_obs_dim, device=device)
        else:
            self.obs_normalizer = nn.Identity()
            self.critic_obs_normalizer = nn.Identity()

        # Actor
        self.actor = Actor(
            n_obs=actor_obs_dim, n_act=n_act, hidden_dim=actor_hidden_dim,
            log_std_max=log_std_max, log_std_min=log_std_min,
            use_tanh=use_tanh, use_layer_norm=use_layer_norm, device=device,
        )

        # Critic (Q-networks)
        self.qnet = Critic(
            n_obs=critic_obs_dim, n_act=n_act, num_atoms=num_atoms,
            v_min=v_min, v_max=v_max, hidden_dim=critic_hidden_dim,
            use_layer_norm=use_layer_norm, num_q_networks=num_q_networks, device=device,
        )

        # Target Q-network
        self.qnet_target = Critic(
            n_obs=critic_obs_dim, n_act=n_act, num_atoms=num_atoms,
            v_min=v_min, v_max=v_max, hidden_dim=critic_hidden_dim,
            use_layer_norm=use_layer_norm, num_q_networks=num_q_networks, device=device,
        )
        self.qnet_target.load_state_dict(self.qnet.state_dict())

        # Entropy temperature
        self.log_alpha = torch.tensor([math.log(alpha_init)], requires_grad=True, device=device)
        self.target_entropy = -n_act * target_entropy_ratio

        # Optimizers
        self.actor_optimizer = torch.optim.AdamW(
            self.actor.parameters(), lr=actor_learning_rate,
            weight_decay=weight_decay, betas=(0.9, 0.95),
        )
        self.q_optimizer = torch.optim.AdamW(
            self.qnet.parameters(), lr=critic_learning_rate,
            weight_decay=weight_decay, betas=(0.9, 0.95),
        )
        self.alpha_optimizer = torch.optim.AdamW(
            [self.log_alpha], lr=alpha_learning_rate, betas=(0.9, 0.95),
        )

        # AMP scaler
        self.scaler = torch.amp.GradScaler(enabled=amp)

        # Replay buffer
        self.rb = SimpleReplayBuffer(
            n_env=num_envs, buffer_size=buffer_size,
            n_obs=actor_obs_dim, n_act=n_act,
            n_critic_obs=critic_obs_dim if critic_obs_dim != actor_obs_dim else None,
            n_steps=num_steps, gamma=gamma, device=device,
        )

        # RLPD demo buffer
        self.demo_buffer = demo_buffer
        self.demo_ratio = demo_ratio

        # Compile
        if compile:
            self._update_critic_fn = torch.compile(self._update_critic)
            self._update_actor_fn = torch.compile(self._update_actor)
        else:
            self._update_critic_fn = self._update_critic
            self._update_actor_fn = self._update_actor

    # ------------------------------------------------------------------
    # Public interface (called by OffPolicyRunner)
    # ------------------------------------------------------------------

    @torch.no_grad()
    def act(self, obs: torch.Tensor, critic_obs: torch.Tensor | None = None) -> torch.Tensor:
        """Sample exploration actions. Updates normalizer stats."""
        if critic_obs is None:
            critic_obs = obs

        if self.obs_normalization:
            self.obs_normalizer.train()
            self.critic_obs_normalizer.train()
            norm_obs = self.obs_normalizer(obs)
            self.critic_obs_normalizer(critic_obs)
        else:
            norm_obs = obs

        return self.actor.explore(norm_obs)

    def store_transition(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        next_obs: torch.Tensor,
        dones: torch.Tensor,
        truncations: torch.Tensor,
        critic_obs: torch.Tensor | None = None,
        next_critic_obs: torch.Tensor | None = None,
    ):
        """Add a transition to the replay buffer."""
        self.rb.extend(
            observations=obs, actions=actions, rewards=rewards,
            next_observations=next_obs, dones=dones.long(), truncations=truncations.long(),
            critic_observations=critic_obs, next_critic_observations=next_critic_obs,
        )
        self.global_step += 1

    def update(self) -> dict[str, float]:
        """Run SAC gradient steps. Returns loss dict for logging."""
        if self.rb.size < self.learning_starts:
            return {
                "actor_loss": 0.0,
                "qf_loss": 0.0,
                "alpha_loss": 0.0,
                "alpha_value": self.log_alpha.exp().item(),
                "policy_entropy": 0.0,
            }

        total_qf_loss = 0.0
        total_actor_loss = 0.0
        total_alpha_loss = 0.0
        total_entropy = 0.0
        actor_updates = 0

        per_env_batch = max(self.batch_size // self.num_envs // self.gpu_world_size, 1)
        online_batch_size = self.batch_size
        demo_batch_size = 0
        if self.demo_buffer is not None and self.demo_ratio > 0:
            demo_batch_size = int(self.batch_size * self.demo_ratio)
            online_batch_size = self.batch_size - demo_batch_size
            per_env_batch = max(online_batch_size // self.num_envs // self.gpu_world_size, 1)

        for i in range(self.num_updates):
            data = self.rb.sample(per_env_batch)

            # Mix in demo data (RLPD)
            if demo_batch_size > 0:
                demo_data = self.demo_buffer.sample(demo_batch_size)
                # For keys present in both, concatenate. For online-only keys
                # (e.g. critic_observations in asymmetric setups), use the demo's
                # observations as a fallback.
                merged = {}
                for k in data:
                    if k in demo_data:
                        merged[k] = torch.cat([data[k], demo_data[k]], dim=0)
                    elif k in ("critic_observations", "next_critic_observations"):
                        obs_key = "observations" if k == "critic_observations" else "next_observations"
                        merged[k] = torch.cat([data[k], demo_data[obs_key]], dim=0)
                data = merged

            # Normalize
            if self.obs_normalization:
                self.obs_normalizer.eval()
                self.critic_obs_normalizer.eval()
                data["observations"] = self.obs_normalizer(data["observations"], update=False)
                data["next_observations"] = self.obs_normalizer(data["next_observations"], update=False)
                if "critic_observations" in data:
                    data["critic_observations"] = self.critic_obs_normalizer(
                        data["critic_observations"], update=False
                    )
                    data["next_critic_observations"] = self.critic_obs_normalizer(
                        data["next_critic_observations"], update=False
                    )

            critic_obs = data.get("critic_observations", data["observations"])
            next_critic_obs = data.get("next_critic_observations", data["next_observations"])

            qf_loss = self._update_critic_fn(data, critic_obs, next_critic_obs)
            total_qf_loss += qf_loss.item()

            # Actor (less frequent)
            should_update_actor = (
                (self.num_updates > 1 and i % self.policy_frequency == 1)
                or (self.num_updates == 1 and self.global_step % self.policy_frequency == 0)
            )
            if should_update_actor:
                actor_loss, entropy = self._update_actor_fn(data, critic_obs)
                total_actor_loss += actor_loss.item()
                total_entropy += entropy.item()
                actor_updates += 1

            # Soft target update
            with torch.no_grad():
                src_ps = [p.data for p in self.qnet.parameters()]
                tgt_ps = [p.data for p in self.qnet_target.parameters()]
                torch._foreach_mul_(tgt_ps, 1.0 - self.tau)
                torch._foreach_add_(tgt_ps, src_ps, alpha=self.tau)

        n = max(self.num_updates, 1)
        na = max(actor_updates, 1)
        return {
            "actor_loss": total_actor_loss / na,
            "qf_loss": total_qf_loss / n,
            "alpha_loss": total_alpha_loss / n,
            "alpha_value": self.log_alpha.exp().detach().item(),
            "policy_entropy": total_entropy / na,
        }

    # ------------------------------------------------------------------
    # Internal update methods
    # ------------------------------------------------------------------

    @contextmanager
    def _maybe_amp(self):
        amp_dtype = torch.bfloat16 if self.amp_dtype == "bf16" else torch.float16
        with torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=self.amp_enabled):
            yield

    def _update_critic(self, data, critic_obs, next_critic_obs):
        with self._maybe_amp():
            actions = data["actions"]
            rewards = data["rewards"]
            dones = data["dones"].bool()
            truncations = data["truncations"].bool()
            bootstrap = (truncations | ~dones).float()

            with torch.no_grad():
                next_actions, next_log_probs = self.actor.get_actions_and_log_probs(data["next_observations"])
                discount = self.gamma ** data["effective_n_steps"]
                target_distributions = self.qnet_target.projection(
                    next_critic_obs, next_actions,
                    rewards - discount * bootstrap * self.log_alpha.exp() * next_log_probs,
                    bootstrap, discount,
                )

            q_outputs = self.qnet(critic_obs, actions)
            critic_log_probs = F.log_softmax(q_outputs, dim=-1)
            qf_loss = (-torch.sum(target_distributions * critic_log_probs, dim=-1)).mean(dim=1).sum(dim=0)

        self.q_optimizer.zero_grad(set_to_none=True)
        self.scaler.scale(qf_loss).backward()
        if self.is_multi_gpu:
            self._all_reduce_grads(self.qnet)
        self.scaler.unscale_(self.q_optimizer)
        if self.max_grad_norm > 0:
            nn.utils.clip_grad_norm_(self.qnet.parameters(), max_norm=self.max_grad_norm)
        self.scaler.step(self.q_optimizer)
        self.scaler.update()

        # Alpha
        if self.use_autotune:
            self.alpha_optimizer.zero_grad(set_to_none=True)
            with self._maybe_amp():
                alpha_loss = (-self.log_alpha.exp() * (next_log_probs.detach() + self.target_entropy)).mean()
            self.scaler.scale(alpha_loss).backward()
            if self.is_multi_gpu and self.log_alpha.grad is not None:
                torch.distributed.all_reduce(self.log_alpha.grad.data, op=torch.distributed.ReduceOp.SUM)
                self.log_alpha.grad.data /= self.gpu_world_size
            self.scaler.unscale_(self.alpha_optimizer)
            self.scaler.step(self.alpha_optimizer)
            self.scaler.update()

        return qf_loss.detach()

    def _update_actor(self, data, critic_obs):
        with self._maybe_amp():
            actions, log_probs = self.actor.get_actions_and_log_probs(data["observations"])
            with torch.no_grad():
                policy_entropy = -log_probs.mean()
            q_outputs = self.qnet(critic_obs, actions)
            q_probs = F.softmax(q_outputs, dim=-1)
            qf_value = self.qnet.get_value(q_probs).mean(dim=0)
            actor_loss = (self.log_alpha.exp().detach() * log_probs - qf_value).mean()

        self.actor_optimizer.zero_grad(set_to_none=True)
        self.scaler.scale(actor_loss).backward()
        if self.is_multi_gpu:
            self._all_reduce_grads(self.actor)
        self.scaler.unscale_(self.actor_optimizer)
        if self.max_grad_norm > 0:
            nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.max_grad_norm)
        self.scaler.step(self.actor_optimizer)
        self.scaler.update()
        return actor_loss.detach(), policy_entropy.detach()

    def _all_reduce_grads(self, model: nn.Module):
        if not self.is_multi_gpu:
            return
        grads = [p.grad.view(-1) for p in model.parameters() if p.grad is not None]
        if not grads:
            return
        flat = torch.cat(grads)
        torch.distributed.all_reduce(flat, op=torch.distributed.ReduceOp.SUM)
        flat /= self.gpu_world_size
        offset = 0
        for p in model.parameters():
            if p.grad is not None:
                n = p.numel()
                p.grad.copy_(flat[offset : offset + n].view_as(p.grad))
                offset += n

    # ------------------------------------------------------------------
    # Distributed helpers
    # ------------------------------------------------------------------

    def broadcast_parameters(self):
        if not self.is_multi_gpu:
            return
        for param in self.actor.parameters():
            torch.distributed.broadcast(param.data, src=0)
        for param in self.qnet.parameters():
            torch.distributed.broadcast(param.data, src=0)
        torch.distributed.broadcast(self.log_alpha.data, src=0)
        self.qnet_target.load_state_dict(self.qnet.state_dict())

    # ------------------------------------------------------------------
    # Checkpointing
    # ------------------------------------------------------------------

    def state_dict(self):
        return {
            "actor": self.actor.state_dict(),
            "qnet": self.qnet.state_dict(),
            "qnet_target": self.qnet_target.state_dict(),
            "log_alpha": self.log_alpha.detach().cpu(),
            "obs_normalizer": self.obs_normalizer.state_dict() if self.obs_normalization else None,
            "critic_obs_normalizer": (
                self.critic_obs_normalizer.state_dict() if self.obs_normalization else None
            ),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "q_optimizer": self.q_optimizer.state_dict(),
            "alpha_optimizer": self.alpha_optimizer.state_dict(),
            "scaler": self.scaler.state_dict(),
            "global_step": self.global_step,
        }

    def load_state_dict(self, sd):
        self.actor.load_state_dict(sd["actor"])
        self.qnet.load_state_dict(sd["qnet"])
        self.qnet_target.load_state_dict(sd["qnet_target"])
        self.log_alpha.data.copy_(sd["log_alpha"].to(self.device))
        if sd.get("obs_normalizer") and self.obs_normalization:
            self.obs_normalizer.load_state_dict(sd["obs_normalizer"])
        if sd.get("critic_obs_normalizer") and self.obs_normalization:
            self.critic_obs_normalizer.load_state_dict(sd["critic_obs_normalizer"])
        self.actor_optimizer.load_state_dict(sd["actor_optimizer"])
        self.q_optimizer.load_state_dict(sd["q_optimizer"])
        self.alpha_optimizer.load_state_dict(sd["alpha_optimizer"])
        if sd.get("scaler"):
            self.scaler.load_state_dict(sd["scaler"])
        self.global_step = sd.get("global_step", 0)

    def train(self):
        self.actor.train()
        self.qnet.train()

    def eval(self):
        self.actor.eval()
        self.qnet.eval()

    @torch.no_grad()
    def get_inference_policy(self, device=None):
        """Return a deterministic policy callable for evaluation."""
        device = device or self.device
        actor = self.actor.to(device)
        obs_norm = self.obs_normalizer.to(device) if self.obs_normalization else None
        actor.eval()
        if obs_norm is not None:
            obs_norm.eval()

        def policy_fn(obs: torch.Tensor) -> torch.Tensor:
            if obs_norm is not None:
                obs = obs_norm(obs, update=False)
            return actor.explore(obs, deterministic=True)

        return policy_fn
