from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange


@dataclass
class UVAConfig:
    """Config for the simplified UVA-style video predictor (video-only).

    This model:
      - Encodes each frame into latent tokens via a MAR-style VAE (outside).
      - Uses a transformer over tokens to predict all future frame latents
        in one shot.
      - During training, always masks out *all* future frames: only frame 0
        tokens are given as input; all future tokens are replaced with a
        learned mask token and supervised with GT latents.
    """

    latent_channels: int
    latent_height: int
    latent_width: int
    num_frames: int = 8  # full clip length, including first frame
    d_model: int = 512
    n_heads: int = 8
    n_layers: int = 8
    dropout: float = 0.0


class UVAVideoTransformer(nn.Module):
    def __init__(self, cfg: UVAConfig):
        super().__init__()
        self.cfg = cfg

        self.token_dim = cfg.latent_channels
        self.n_tokens_per_frame = cfg.latent_height * cfg.latent_width
        self.total_tokens = cfg.num_frames * self.n_tokens_per_frame

        self.in_proj = nn.Linear(self.token_dim, cfg.d_model)

        self.pos_embed = nn.Parameter(
            torch.randn(1, self.total_tokens, cfg.d_model) * 0.01
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg.d_model,
            nhead=cfg.n_heads,
            dim_feedforward=cfg.d_model * 4,
            dropout=cfg.dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.n_layers)

        self.out_proj = nn.Linear(cfg.d_model, self.token_dim)

        # Mask token for all future-frame tokens during training.
        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.token_dim))

    def _frame_tokens(self, z_frames: torch.Tensor) -> torch.Tensor:
        """
        z_frames: (B, T, C, H, W)
        Returns: (B, T*N, C) where N = H*W
        """
        b, t, c, h, w = z_frames.shape
        assert c == self.token_dim
        assert h == self.cfg.latent_height and w == self.cfg.latent_width
        tokens = rearrange(z_frames, "b t c h w -> b (t h w) c")
        return tokens

    def _tokens_to_frames(self, tokens: torch.Tensor, num_future: int) -> torch.Tensor:
        """
        tokens: (B, (T-1)*N, C) -> (B, T-1, C, H, W)
        """
        b, l, c = tokens.shape
        h, w = self.cfg.latent_height, self.cfg.latent_width
        n = h * w
        assert l == num_future * n
        z = rearrange(tokens, "b (t h w) c -> b t c h w", t=num_future, h=h, w=w)
        return z

    def forward(
        self,
        z_first: torch.Tensor,
        z_future: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            z_first:  (B, C, H_l, W_l)     latent for frame 0
            z_future: (B, T-1, C, H_l, W_l) latent GT for frames 1..T-1 (only used in training)

        Returns:
            pred_future: (B, T-1, C, H_l, W_l) predicted future latents
            target_future_tokens: (B, (T-1)*N, C) if z_future is provided, else None
        """
        b, c, h, w = z_first.shape
        assert c == self.token_dim
        assert h == self.cfg.latent_height and w == self.cfg.latent_width

        n = h * w
        t = self.cfg.num_frames
        num_future = t - 1

        # Tokens for first frame: (B, N, C)
        tokens_0 = rearrange(z_first, "b c h w -> b (h w) c")

        if self.training and z_future is not None:
            # z_future: (B, T-1, C, H, W)
            assert z_future.shape[1] == num_future
            # Ground-truth tokens for future frames (for supervision only)
            target_tokens = self._frame_tokens(z_future)  # (B, (T-1)*N, C)
            # Input to transformer: first frame tokens + MASK tokens for all future tokens
            mask_tokens = self.mask_token.expand(b, target_tokens.size(1), c)
            tokens_in = torch.cat([tokens_0, mask_tokens], dim=1)  # (B, T*N, C)
        else:
            # Inference-time: we only feed first-frame tokens; model must hallucinate all future tokens
            target_tokens = None
            # For simplicity we also feed mask tokens for future tokens at inference,
            # mirroring training; you could also feed zeros instead.
            dummy_future = self.mask_token.expand(b, num_future * n, c)
            tokens_in = torch.cat([tokens_0, dummy_future], dim=1)

        x = self.in_proj(tokens_in) + self.pos_embed[:, : tokens_in.size(1)]
        x = self.encoder(x)  # (B, T*N, d_model)
        out = self.out_proj(x)  # (B, T*N, C)

        # Slice out future-frame tokens and reshape back to frames
        pred_future_tokens = out[:, tokens_0.size(1) :, :]  # (B, (T-1)*N, C)
        pred_future = self._tokens_to_frames(pred_future_tokens, num_future=num_future)

        return pred_future, target_tokens

