"""Replay buffer and normalization utilities for SAC.

Adapted from holosoma FastSAC (https://github.com/amazon-far/holosoma).
"""

from __future__ import annotations

import h5py
import numpy as np
import torch
import torch.distributed as dist
from torch import nn


class SimpleReplayBuffer(nn.Module):
    """Circular replay buffer with n-step return support.

    Stores transitions as (obs, action, reward, next_obs, done, truncation)
    across ``n_env`` parallel environments.
    """

    def __init__(
        self,
        n_env: int,
        buffer_size: int,
        n_obs: int,
        n_act: int,
        n_critic_obs: int | None = None,
        n_steps: int = 1,
        gamma: float = 0.99,
        device=None,
    ):
        super().__init__()
        self.n_env = n_env
        self.buffer_size = buffer_size
        self.n_obs = n_obs
        self.n_act = n_act
        self.n_critic_obs = n_critic_obs or n_obs
        self.gamma = gamma
        self.n_steps = n_steps
        self.device = device

        self.observations = torch.zeros((n_env, buffer_size, n_obs), device=device)
        self.actions = torch.zeros((n_env, buffer_size, n_act), device=device)
        self.rewards = torch.zeros((n_env, buffer_size), device=device)
        self.dones = torch.zeros((n_env, buffer_size), device=device, dtype=torch.long)
        self.truncations = torch.zeros((n_env, buffer_size), device=device, dtype=torch.long)
        self.next_observations = torch.zeros((n_env, buffer_size, n_obs), device=device)

        if n_critic_obs and n_critic_obs != n_obs:
            self.critic_observations = torch.zeros((n_env, buffer_size, self.n_critic_obs), device=device)
            self.next_critic_observations = torch.zeros((n_env, buffer_size, self.n_critic_obs), device=device)
            self._asymmetric = True
        else:
            self._asymmetric = False

        self.ptr = 0

    def extend(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        next_observations: torch.Tensor,
        dones: torch.Tensor,
        truncations: torch.Tensor,
        critic_observations: torch.Tensor | None = None,
        next_critic_observations: torch.Tensor | None = None,
    ):
        """Add one timestep of transitions from all envs."""
        ptr = self.ptr % self.buffer_size
        self.observations[:, ptr] = observations
        self.actions[:, ptr] = actions
        self.rewards[:, ptr] = rewards
        self.dones[:, ptr] = dones
        self.truncations[:, ptr] = truncations
        self.next_observations[:, ptr] = next_observations

        if self._asymmetric and critic_observations is not None:
            self.critic_observations[:, ptr] = critic_observations
            self.next_critic_observations[:, ptr] = next_critic_observations

        self.ptr += 1

    @property
    def size(self) -> int:
        return min(self.ptr, self.buffer_size)

    @torch.no_grad()
    def sample(self, batch_size: int) -> dict[str, torch.Tensor]:
        """Sample a flat batch of transitions.

        Returns dict with keys: observations, actions, rewards, next_observations,
        dones, truncations, effective_n_steps. Also critic_observations and
        next_critic_observations if asymmetric.
        """
        valid = min(self.buffer_size, self.ptr)

        if self.n_steps == 1:
            indices = torch.randint(0, valid, (self.n_env, batch_size), device=self.device)
            obs_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
            act_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_act)

            flat = self.n_env * batch_size
            observations = torch.gather(self.observations, 1, obs_idx).reshape(flat, self.n_obs)
            next_observations = torch.gather(self.next_observations, 1, obs_idx).reshape(flat, self.n_obs)
            actions = torch.gather(self.actions, 1, act_idx).reshape(flat, self.n_act)
            rewards = torch.gather(self.rewards, 1, indices).reshape(flat)
            dones = torch.gather(self.dones, 1, indices).reshape(flat)
            truncations = torch.gather(self.truncations, 1, indices).reshape(flat)
            effective_n_steps = torch.ones_like(dones)

            result = {
                "observations": observations,
                "actions": actions,
                "rewards": rewards,
                "next_observations": next_observations,
                "dones": dones,
                "truncations": truncations,
                "effective_n_steps": effective_n_steps,
            }

            if self._asymmetric:
                crit_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_critic_obs)
                result["critic_observations"] = torch.gather(
                    self.critic_observations, 1, crit_idx
                ).reshape(flat, self.n_critic_obs)
                result["next_critic_observations"] = torch.gather(
                    self.next_critic_observations, 1, crit_idx
                ).reshape(flat, self.n_critic_obs)

            return result

        # N-step returns
        if self.ptr >= self.buffer_size:
            current_pos = self.ptr % self.buffer_size
            curr_truncations = self.truncations[:, current_pos - 1].clone()
            self.truncations[:, current_pos - 1] = torch.logical_not(self.dones[:, current_pos - 1])
            indices = torch.randint(0, self.buffer_size, (self.n_env, batch_size), device=self.device)
        else:
            max_start_idx = max(1, self.ptr - self.n_steps + 1)
            indices = torch.randint(0, max_start_idx, (self.n_env, batch_size), device=self.device)

        obs_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
        act_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
        flat = self.n_env * batch_size

        observations = torch.gather(self.observations, 1, obs_idx).reshape(flat, self.n_obs)
        actions = torch.gather(self.actions, 1, act_idx).reshape(flat, self.n_act)

        seq_offsets = torch.arange(self.n_steps, device=self.device).view(1, 1, -1)
        all_indices = (indices.unsqueeze(-1) + seq_offsets) % self.buffer_size

        all_rewards = torch.gather(self.rewards.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices)
        all_dones = torch.gather(self.dones.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices)
        all_truncations = torch.gather(self.truncations.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices)

        all_dones_shifted = torch.cat([torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2)
        done_masks = torch.cumprod(1.0 - all_dones_shifted, dim=2)
        effective_n_steps = done_masks.sum(2)

        discounts = torch.pow(self.gamma, torch.arange(self.n_steps, device=self.device))
        n_step_rewards = (all_rewards * done_masks * discounts.view(1, 1, -1)).sum(dim=2)

        first_done = torch.argmax((all_dones > 0).float(), dim=2)
        first_trunc = torch.argmax((all_truncations > 0).float(), dim=2)
        no_dones = all_dones.sum(dim=2) == 0
        no_truncs = all_truncations.sum(dim=2) == 0
        first_done = torch.where(no_dones, self.n_steps - 1, first_done)
        first_trunc = torch.where(no_truncs, self.n_steps - 1, first_trunc)
        final_indices = torch.minimum(first_done, first_trunc)

        final_next_obs_indices = torch.gather(all_indices, 2, final_indices.unsqueeze(-1)).squeeze(-1)
        next_observations = self.next_observations.gather(
            1, final_next_obs_indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
        ).reshape(flat, self.n_obs)
        dones = self.dones.gather(1, final_next_obs_indices).reshape(flat)
        truncations = self.truncations.gather(1, final_next_obs_indices).reshape(flat)

        result = {
            "observations": observations,
            "actions": actions,
            "rewards": n_step_rewards.reshape(flat),
            "next_observations": next_observations,
            "dones": dones,
            "truncations": truncations,
            "effective_n_steps": effective_n_steps.reshape(flat),
        }

        if self._asymmetric:
            crit_idx = indices.unsqueeze(-1).expand(-1, -1, self.n_critic_obs)
            result["critic_observations"] = torch.gather(
                self.critic_observations, 1, crit_idx
            ).reshape(flat, self.n_critic_obs)
            result["next_critic_observations"] = self.next_critic_observations.gather(
                1, final_next_obs_indices.unsqueeze(-1).expand(-1, -1, self.n_critic_obs)
            ).reshape(flat, self.n_critic_obs)

        if self.n_steps > 1 and self.ptr >= self.buffer_size:
            self.truncations[:, current_pos - 1] = curr_truncations

        return result


class EmpiricalNormalization(nn.Module):
    """Running mean/variance normalization."""

    def __init__(self, shape, device, eps=1e-2):
        super().__init__()
        self.eps = eps
        self.device = device
        self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0).to(device))
        self.register_buffer("_var", torch.ones(shape).unsqueeze(0).to(device))
        self.register_buffer("_std", torch.ones(shape).unsqueeze(0).to(device))
        self.register_buffer("count", torch.tensor(0, dtype=torch.long).to(device))

    @property
    def mean(self):
        return self._mean.squeeze(0).clone()

    @property
    def std(self):
        return self._std.squeeze(0).clone()

    @torch.no_grad()
    def forward(self, x: torch.Tensor, center: bool = True, update: bool = True) -> torch.Tensor:
        if self.training and update:
            self.update(x)
        if center:
            return (x - self._mean) / (self._std + self.eps)
        return x / (self._std + self.eps)

    @torch.jit.unused
    def update(self, x):
        if dist.is_available() and dist.is_initialized():
            local_batch_size = x.shape[0]
            world_size = dist.get_world_size()
            global_batch_size = world_size * local_batch_size

            x_shifted = x - self._mean
            local_sum_shifted = torch.sum(x_shifted, dim=0, keepdim=True)
            local_sum_sq_shifted = torch.sum(x_shifted.pow(2), dim=0, keepdim=True)

            stats_to_sync = torch.cat([local_sum_shifted, local_sum_sq_shifted], dim=0)
            dist.all_reduce(stats_to_sync, op=dist.ReduceOp.SUM)
            global_sum_shifted, global_sum_sq_shifted = stats_to_sync

            batch_mean_shifted = global_sum_shifted / global_batch_size
            batch_var = global_sum_sq_shifted / global_batch_size - batch_mean_shifted.pow(2)
            batch_mean = batch_mean_shifted + self._mean
        else:
            global_batch_size = x.shape[0]
            batch_mean = torch.mean(x, dim=0, keepdim=True)
            batch_var = torch.var(x, dim=0, keepdim=True, unbiased=False)

        new_count = self.count + global_batch_size
        delta = batch_mean - self._mean
        self._mean.copy_(self._mean + delta * (global_batch_size / new_count))
        delta2 = batch_mean - self._mean
        m_a = self._var * self.count
        m_b = batch_var * global_batch_size
        M2 = m_a + m_b + delta2.pow(2) * (self.count * global_batch_size / new_count)
        self._var.copy_(M2 / new_count)
        self._std.copy_(self._var.sqrt())
        self.count.copy_(new_count)

    @torch.jit.unused
    def inverse(self, y):
        return y * (self._std + self.eps) + self._mean


class DemoReplayBuffer(nn.Module):
    """Read-only replay buffer pre-loaded from an HDF5 demonstration dataset.

    Stores flat (obs, action, reward, next_obs, done) tuples from successful
    demos and supports random sampling identical to SimpleReplayBuffer output.
    """

    def __init__(
        self,
        dataset_paths: str | list[str],
        action_keys: list[str],
        obs_key: str = "policy",
        success_only: bool = True,
        device: str | torch.device = "cpu",
    ):
        super().__init__()
        self.device = device

        if isinstance(dataset_paths, str):
            dataset_paths = [dataset_paths]

        observations, actions, rewards, next_observations, dones = self._load(
            dataset_paths, action_keys, obs_key, success_only,
        )
        n = observations.shape[0]
        self.n_obs = observations.shape[1]
        self.n_act = actions.shape[1]

        self.register_buffer("observations", observations.to(device))
        self.register_buffer("actions", actions.to(device))
        self.register_buffer("rewards", rewards.to(device))
        self.register_buffer("next_observations", next_observations.to(device))
        self.register_buffer("dones", dones.to(device))

        print(f"[DemoReplayBuffer] Loaded {n} transitions ({self.n_obs}D obs, {self.n_act}D act) from {len(dataset_paths)} file(s)")

    @staticmethod
    def _load(
        paths: list[str],
        action_keys: list[str],
        obs_key: str,
        success_only: bool,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        all_obs, all_act, all_rew, all_next_obs, all_done = [], [], [], [], []

        for path in paths:
            with h5py.File(path, "r") as f:
                for demo_key in sorted(f["data"].keys()):
                    demo = f[f"data/{demo_key}"]
                    if success_only and not demo.attrs.get("success", False):
                        continue

                    obs = np.array(demo[f"obs/{obs_key}"])  # (T, obs_dim)
                    reward = np.array(demo["reward"])  # (T,)
                    act_parts = [np.array(demo[f"action/{k}"]) for k in action_keys]
                    act = np.concatenate(act_parts, axis=1)  # (T, act_dim)

                    # Use stored next_obs if available, otherwise shift obs
                    if f"data/{demo_key}/next_obs/{obs_key}" in f:
                        next_obs = np.array(demo[f"next_obs/{obs_key}"])  # (T, obs_dim)
                    else:
                        next_obs = None

                    T = obs.shape[0]
                    if next_obs is not None:
                        # Full transitions available including terminal
                        all_obs.append(obs)
                        all_act.append(act)
                        all_rew.append(reward)
                        all_next_obs.append(next_obs)
                        done = np.zeros(T, dtype=np.float32)
                        done[-1] = 1.0
                        all_done.append(done)
                    else:
                        # Fallback: shift obs, drops terminal transition
                        all_obs.append(obs[:-1])
                        all_act.append(act[:-1])
                        all_rew.append(reward[:-1])
                        all_next_obs.append(obs[1:])
                        done = np.zeros(T - 1, dtype=np.float32)
                        done[-1] = 1.0
                        all_done.append(done)

        return (
            torch.from_numpy(np.concatenate(all_obs, axis=0)).float(),
            torch.from_numpy(np.concatenate(all_act, axis=0)).float(),
            torch.from_numpy(np.concatenate(all_rew, axis=0)).float(),
            torch.from_numpy(np.concatenate(all_next_obs, axis=0)).float(),
            torch.from_numpy(np.concatenate(all_done, axis=0)).long(),
        )

    @property
    def size(self) -> int:
        return self.observations.shape[0]

    @torch.no_grad()
    def sample(self, batch_size: int) -> dict[str, torch.Tensor]:
        indices = torch.randint(0, self.size, (batch_size,), device=self.device)
        return {
            "observations": self.observations[indices],
            "actions": self.actions[indices],
            "rewards": self.rewards[indices],
            "next_observations": self.next_observations[indices],
            "dones": self.dones[indices],
            "truncations": torch.zeros_like(self.dones[indices]),
            "effective_n_steps": torch.ones_like(self.dones[indices]),
        }
