from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from rsl_rl.algorithms import PPO

from sim_improvement.rl.storage import ReplayBuffer


class BC_PPO(PPO):
    """PPO with an auxiliary behavioral cloning loss.

    Maintains a :class:`ReplayBuffer` of demonstration (obs, action) pairs.
    During each PPO update step, a BC loss (MSE between the policy's
    mean output and the demo actions) is added to the PPO objective,
    scaled by ``bc_coefficient``.
    """

    def __init__(
        self,
        *args,
        bc_coefficient: float = 0.1,
        bc_batch_size: int = 256,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.bc_coefficient = bc_coefficient
        self.bc_batch_size = bc_batch_size

        self._bc_buffer: ReplayBuffer | None = None

    # ------------------------------------------------------------------
    # BC buffer management
    # ------------------------------------------------------------------

    def add_bc_data(self, obs: torch.Tensor, actions: torch.Tensor):
        """Load demonstration transitions into the BC replay buffer.

        Args:
            obs: ``(N, obs_dim)`` flat observation tensor, concatenated in
                 the same order as the policy's actor input.
            actions: ``(N, act_dim)`` action tensor.
        """
        obs = obs.float()
        actions = actions.float()
        n, obs_dim = obs.shape
        act_dim = actions.shape[-1]

        if self._bc_buffer is None:
            # Create a buffer sized to fit the data exactly.
            # Use a single "policy" key so we can extract flat tensors
            # easily at sample time.
            dummy_obs = {"policy": torch.zeros(1, obs_dim, device=self.device)}
            self._bc_buffer = ReplayBuffer(
                num_envs=1,
                capacity=n,
                obs=dummy_obs,
                actions_shape=(act_dim,),
                device=self.device,
            )

        self._bc_buffer.add_bulk(
            obs={"policy": obs},
            actions=actions,
        )

    @property
    def has_bc_data(self) -> bool:
        return self._bc_buffer is not None and self._bc_buffer.size > 0

    def _compute_bc_loss(self) -> torch.Tensor:
        """MSE between the actor's deterministic output and demo actions."""
        obs_td, actions, _, _, _ = self._bc_buffer.sample(self.bc_batch_size)
        obs_flat = obs_td["policy"]
        pred = self.policy.actor(self.policy.actor_obs_normalizer(obs_flat))
        return F.mse_loss(pred, actions)

    def bc_warmstart(
        self,
        num_epochs: int,
        learning_rate: float | None = None,
        log_interval: int = 10,
        writer=None,
        start_step: int = 0,
        val_fraction: float = 0.1,
        eval_freq: int = 0,
        eval_callback=None,
    ) -> tuple[float, int]:
        """Run pure behavioral cloning on the replay buffer before RL.

        Args:
            num_epochs: Number of full passes over the BC data.
            learning_rate: LR for the BC optimizer. If ``None``, uses
                ``self.learning_rate`` (the PPO LR).
            log_interval: Log to writer every this many gradient steps.
            writer: Optional logger (wandb/tensorboard) with
                ``add_scalar(tag, value, step)`` interface.
            start_step: Starting step for logging (so BC steps don't
                overlap with RL steps).
            val_fraction: Fraction of BC data to hold out for validation.
            eval_freq: Run ``eval_callback`` every this many epochs.
                0 disables eval during warmstart.
            eval_callback: ``fn(step) -> dict`` called for evaluation.
                Should return a dict of ``{metric_name: value}`` to log.

        Returns:
            Tuple of (final epoch-average loss, final global step).
        """
        if not self.has_bc_data:
            raise RuntimeError("No BC data loaded. Call add_bc_data() first.")

        lr = learning_rate if learning_rate is not None else self.learning_rate
        bc_optimizer = torch.optim.Adam(self.policy.actor.parameters(), lr=lr)

        # Split BC data into train/val
        n = self._bc_buffer.size * self._bc_buffer.num_envs
        obs_all = self._bc_buffer.observations[: self._bc_buffer.size].flatten(0, 1)
        act_all = self._bc_buffer.actions[: self._bc_buffer.size].flatten(0, 1)

        perm = torch.randperm(n, device=self.device)
        n_val = max(1, int(n * val_fraction))
        n_train = n - n_val

        train_obs = obs_all[perm[:n_train]]["policy"]
        train_act = act_all[perm[:n_train]]
        val_obs = obs_all[perm[n_train:]]["policy"]
        val_act = act_all[perm[n_train:]]

        steps_per_epoch = max(1, n_train // self.bc_batch_size)

        # Seed the normalizer with BC data so the actor sees properly
        # normalized obs from the start. Without this, the normalizer
        # stays at mean=0/std=1 during warmstart (update() is never called),
        # then jumps when RL calls process_env_step → BC loss spikes.
        if self.policy.actor_obs_normalization:
            self.policy.actor_obs_normalizer.train()
            self.policy.actor_obs_normalizer.update(train_obs)
            print(f"[BC warmstart] Seeded normalizer with {n_train} demo observations")

        print(f"[BC warmstart] {num_epochs} epochs, {n_train} train / {n_val} val, "
              f"batch_size={self.bc_batch_size}, lr={lr}")

        global_step = start_step

        # Eval before any training
        if eval_callback is not None:
            metrics = eval_callback(global_step)
            if writer is not None and metrics:
                for k, v in metrics.items():
                    writer.add_scalar(f"bc/{k}", v, global_step)

        last_epoch_loss = 0.0
        for epoch in range(num_epochs):
            self.policy.train()
            epoch_loss = 0.0

            for _ in range(steps_per_epoch):
                # Sample random train batch
                idx = torch.randint(0, n_train, (self.bc_batch_size,), device=self.device)
                obs_batch = train_obs[idx]
                act_batch = train_act[idx]

                pred = self.policy.actor(self.policy.actor_obs_normalizer(obs_batch))
                loss = F.mse_loss(pred, act_batch)

                bc_optimizer.zero_grad()
                loss.backward()
                bc_optimizer.step()

                loss_val = loss.item()
                epoch_loss += loss_val
                global_step += 1

                if writer is not None and global_step % log_interval == 0:
                    writer.add_scalar("bc/train_loss", loss_val, global_step)

            avg_train_loss = epoch_loss / steps_per_epoch

            # Validation loss
            self.policy.eval()
            with torch.no_grad():
                val_steps = max(1, n_val // self.bc_batch_size)
                val_loss_sum = 0.0
                for _ in range(val_steps):
                    idx = torch.randint(0, n_val, (min(self.bc_batch_size, n_val),), device=self.device)
                    pred = self.policy.actor(self.policy.actor_obs_normalizer(val_obs[idx]))
                    val_loss_sum += F.mse_loss(pred, val_act[idx]).item()
                avg_val_loss = val_loss_sum / val_steps

            print(f"  [BC] epoch {epoch + 1}/{num_epochs}  "
                  f"train_loss={avg_train_loss:.6f}  val_loss={avg_val_loss:.6f}")

            if writer is not None:
                writer.add_scalar("bc/epoch_train_loss", avg_train_loss, global_step)
                writer.add_scalar("bc/epoch_val_loss", avg_val_loss, global_step)
                writer.add_scalar("bc/epoch", epoch + 1, global_step)

            last_epoch_loss = avg_train_loss

            # Periodic eval
            if eval_callback is not None and eval_freq > 0 and (epoch + 1) % eval_freq == 0:
                metrics = eval_callback(global_step)
                if writer is not None and metrics:
                    for k, v in metrics.items():
                        writer.add_scalar(f"bc/{k}", v, global_step)

        print(f"[BC warmstart] done — final train loss: {last_epoch_loss:.6f}")
        return last_epoch_loss, global_step

    # ------------------------------------------------------------------
    # PPO update with BC auxiliary loss
    # ------------------------------------------------------------------

    def update(self):  # noqa: C901
        mean_value_loss = 0
        mean_surrogate_loss = 0
        mean_entropy = 0
        mean_bc_loss = 0.0

        if self.rnd:
            mean_rnd_loss = 0
        else:
            mean_rnd_loss = None
        if self.symmetry:
            mean_symmetry_loss = 0
        else:
            mean_symmetry_loss = None

        # mini-batch generator
        if self.policy.is_recurrent:
            generator = self.storage.recurrent_mini_batch_generator(
                self.num_mini_batches, self.num_learning_epochs
            )
        else:
            generator = self.storage.mini_batch_generator(
                self.num_mini_batches, self.num_learning_epochs
            )

        for (
            obs_batch,
            actions_batch,
            target_values_batch,
            advantages_batch,
            returns_batch,
            old_actions_log_prob_batch,
            old_mu_batch,
            old_sigma_batch,
            hid_states_batch,
            masks_batch,
        ) in generator:
            num_aug = 1
            original_batch_size = obs_batch.batch_size[0]

            if self.normalize_advantage_per_mini_batch:
                with torch.no_grad():
                    advantages_batch = (advantages_batch - advantages_batch.mean()) / (
                        advantages_batch.std() + 1e-8
                    )

            # -- symmetric augmentation
            if self.symmetry and self.symmetry["use_data_augmentation"]:
                data_augmentation_func = self.symmetry["data_augmentation_func"]
                obs_batch, actions_batch = data_augmentation_func(
                    obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"],
                )
                num_aug = int(obs_batch.batch_size[0] / original_batch_size)
                old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
                target_values_batch = target_values_batch.repeat(num_aug, 1)
                advantages_batch = advantages_batch.repeat(num_aug, 1)
                returns_batch = returns_batch.repeat(num_aug, 1)

            # -- recompute log-probs / values / entropy
            self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
            actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
            value_batch = self.policy.evaluate(
                obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1],
            )
            mu_batch = self.policy.action_mean[:original_batch_size]
            sigma_batch = self.policy.action_std[:original_batch_size]
            entropy_batch = self.policy.entropy[:original_batch_size]

            # -- adaptive LR via KL
            if self.desired_kl is not None and self.schedule == "adaptive":
                with torch.inference_mode():
                    kl = torch.sum(
                        torch.log(sigma_batch / old_sigma_batch + 1e-5)
                        + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
                        / (2.0 * torch.square(sigma_batch))
                        - 0.5,
                        axis=-1,
                    )
                    kl_mean = torch.mean(kl)
                    if self.is_multi_gpu:
                        torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
                        kl_mean /= self.gpu_world_size
                    if self.gpu_global_rank == 0:
                        if kl_mean > self.desired_kl * 2.0:
                            self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                        elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                            self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                    if self.is_multi_gpu:
                        lr_tensor = torch.tensor(self.learning_rate, device=self.device)
                        torch.distributed.broadcast(lr_tensor, src=0)
                        self.learning_rate = lr_tensor.item()
                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = self.learning_rate

            # -- surrogate loss
            ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
            surrogate = -torch.squeeze(advantages_batch) * ratio
            surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
                ratio, 1.0 - self.clip_param, 1.0 + self.clip_param,
            )
            surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

            # -- value loss
            if self.use_clipped_value_loss:
                value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
                    -self.clip_param, self.clip_param,
                )
                value_losses = (value_batch - returns_batch).pow(2)
                value_losses_clipped = (value_clipped - returns_batch).pow(2)
                value_loss = torch.max(value_losses, value_losses_clipped).mean()
            else:
                value_loss = (returns_batch - value_batch).pow(2).mean()

            loss = (
                surrogate_loss
                + self.value_loss_coef * value_loss
                - self.entropy_coef * entropy_batch.mean()
            )

            # -- BC auxiliary loss
            if self.has_bc_data:
                bc_loss = self._compute_bc_loss()
                loss = loss + self.bc_coefficient * bc_loss
                mean_bc_loss += bc_loss.item()

            # -- symmetry loss
            if self.symmetry:
                if not self.symmetry["use_data_augmentation"]:
                    data_augmentation_func = self.symmetry["data_augmentation_func"]
                    obs_batch, _ = data_augmentation_func(
                        obs=obs_batch, actions=None, env=self.symmetry["_env"],
                    )
                    num_aug = int(obs_batch.batch_size[0] / original_batch_size)
                mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
                action_mean_orig = mean_actions_batch[:original_batch_size]
                _, actions_mean_symm_batch = data_augmentation_func(
                    obs=None, actions=action_mean_orig, env=self.symmetry["_env"],
                )
                mse_loss = torch.nn.MSELoss()
                symmetry_loss = mse_loss(
                    mean_actions_batch[original_batch_size:],
                    actions_mean_symm_batch.detach()[original_batch_size:],
                )
                if self.symmetry["use_mirror_loss"]:
                    loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
                else:
                    symmetry_loss = symmetry_loss.detach()

            # -- RND loss
            if self.rnd:
                with torch.no_grad():
                    rnd_state_batch = self.rnd.get_rnd_state(obs_batch[:original_batch_size])
                    rnd_state_batch = self.rnd.state_normalizer(rnd_state_batch)
                predicted_embedding = self.rnd.predictor(rnd_state_batch)
                target_embedding = self.rnd.target(rnd_state_batch).detach()
                rnd_loss = torch.nn.MSELoss()(predicted_embedding, target_embedding)

            # -- backward + step
            self.optimizer.zero_grad()
            loss.backward()
            if self.rnd:
                self.rnd_optimizer.zero_grad()
                rnd_loss.backward()
            if self.is_multi_gpu:
                self.reduce_parameters()
            nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.optimizer.step()
            if self.rnd_optimizer:
                self.rnd_optimizer.step()

            # -- accumulate stats
            mean_value_loss += value_loss.item()
            mean_surrogate_loss += surrogate_loss.item()
            mean_entropy += entropy_batch.mean().item()
            if mean_rnd_loss is not None:
                mean_rnd_loss += rnd_loss.item()
            if mean_symmetry_loss is not None:
                mean_symmetry_loss += symmetry_loss.item()

        # -- average over all updates
        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        mean_entropy /= num_updates
        mean_bc_loss /= num_updates
        if mean_rnd_loss is not None:
            mean_rnd_loss /= num_updates
        if mean_symmetry_loss is not None:
            mean_symmetry_loss /= num_updates

        self.storage.clear()

        loss_dict = {
            "value_function": mean_value_loss,
            "surrogate": mean_surrogate_loss,
            "entropy": mean_entropy,
            "bc": mean_bc_loss,
        }
        if self.rnd:
            loss_dict["rnd"] = mean_rnd_loss
        if self.symmetry:
            loss_dict["symmetry"] = mean_symmetry_loss

        return loss_dict
