from __future__ import annotations

import torch
from tensordict import TensorDict

from rsl_rl.storage import RolloutStorage


class ReplayBuffer(RolloutStorage):
    """Circular replay buffer extending RolloutStorage.

    Differences from the base class:
    - Ring-buffer pointer (wraps instead of raising OverflowError)
    - ``next_observations`` storage for off-policy (s, a, r, s', done)
    - ``sample()`` for random mini-batch sampling
    - ``add_bulk()`` for efficient offline data loading (e.g. BC demos)

    Pass any ``training_type`` other than ``"rl"`` / ``"distillation"``
    (e.g. ``"replay"``) to avoid allocating values / advantages / etc.
    """

    def __init__(
        self,
        num_envs: int,
        capacity: int,
        obs,
        actions_shape,
        device: str = "cpu",
    ):
        super().__init__(
            training_type="replay",
            num_envs=num_envs,
            num_transitions_per_env=capacity,
            obs=obs,
            actions_shape=actions_shape,
            device=device,
        )
        self.capacity = capacity
        self._size = 0

        # Allocate next_observations with the same structure as observations
        self.next_observations = TensorDict(
            {
                key: torch.zeros(capacity, *value.shape, device=device)
                for key, value in obs.items()
            },
            batch_size=[capacity, num_envs],
            device=device,
        )

    # ------------------------------------------------------------------
    # Writing
    # ------------------------------------------------------------------

    def add_transitions(self, transition):
        """Add one timestep (all envs) with wrapping."""
        idx = self.step % self.capacity

        self.observations[idx].copy_(transition.observations)
        self.actions[idx].copy_(transition.actions)
        self.rewards[idx].copy_(transition.rewards.view(-1, 1))
        self.dones[idx].copy_(transition.dones.view(-1, 1))
        if getattr(transition, "next_observations", None) is not None:
            self.next_observations[idx].copy_(transition.next_observations)

        self.step += 1
        self._size = min(self._size + 1, self.capacity)

    def add_bulk(
        self,
        obs: dict[str, torch.Tensor],
        actions: torch.Tensor,
        next_obs: dict[str, torch.Tensor] | None = None,
        rewards: torch.Tensor | None = None,
        dones: torch.Tensor | None = None,
    ):
        """Efficiently load a batch of transitions.

        Tensors should have shape ``(N, num_envs, dim)`` to match the
        buffer layout.  For convenience, if a 2-D tensor ``(N, dim)`` is
        passed and ``num_envs == 1``, it is automatically unsqueezed.

        Args:
            obs: Dict of ``{key: (N, [num_envs,] dim)}`` tensors matching
                 the buffer's observation keys.
            actions: ``(N, [num_envs,] act_dim)`` tensor.
            next_obs: Optional dict, same structure as *obs*.
            rewards: Optional ``(N, [num_envs,] 1)`` tensor.
            dones: Optional ``(N, [num_envs,] 1)`` tensor.
        """
        n = actions.shape[0]
        if n > self.capacity:
            raise ValueError(
                f"Data size ({n}) exceeds buffer capacity ({self.capacity})"
            )

        def _maybe_unsqueeze(t: torch.Tensor) -> torch.Tensor:
            """Insert a num_envs dim if the tensor is 2-D and num_envs==1."""
            if t.dim() == 2 and self.num_envs == 1:
                return t.unsqueeze(1)
            return t

        for key in self.observations.keys():
            self.observations[key][:n] = _maybe_unsqueeze(obs[key].to(self.device))
        self.actions[:n] = _maybe_unsqueeze(actions.to(self.device))

        if next_obs is not None:
            for key in self.next_observations.keys():
                self.next_observations[key][:n] = _maybe_unsqueeze(next_obs[key].to(self.device))
        if rewards is not None:
            self.rewards[:n] = _maybe_unsqueeze(rewards.to(self.device).view(-1, 1))
        if dones is not None:
            self.dones[:n] = _maybe_unsqueeze(dones.to(self.device).view(-1, 1))

        self.step = n
        self._size = n

    # ------------------------------------------------------------------
    # Reading
    # ------------------------------------------------------------------

    @torch.no_grad()
    def sample(self, batch_size: int):
        """Sample a random batch.

        Returns:
            ``(obs, actions, rewards, next_obs, dones)`` where *obs* and
            *next_obs* are TensorDicts and the rest are plain tensors.
        """
        total = self._size * self.num_envs
        idx = torch.randint(0, total, (batch_size,), device=self.device)

        obs_flat = self.observations[: self._size].flatten(0, 1)
        act_flat = self.actions[: self._size].flatten(0, 1)
        rew_flat = self.rewards[: self._size].flatten(0, 1)
        done_flat = self.dones[: self._size].flatten(0, 1)
        next_obs_flat = self.next_observations[: self._size].flatten(0, 1)

        return (
            obs_flat[idx],
            act_flat[idx],
            rew_flat[idx],
            next_obs_flat[idx],
            done_flat[idx],
        )

    @property
    def size(self) -> int:
        """Number of valid timesteps stored."""
        return self._size
