"""DinoVideoModel: predict future VAE latent tokens from a single DINO-encoded frame.

Uses a 3-step diffusion process: during training, a random noise level is sampled
and the 3D conv blocks predict the noise residual. During inference, 3 reverse
diffusion steps denoise from pure Gaussian noise to predicted VAE tokens.

Architecture:
  1. Frozen DINO ViT-S/16 extracts (B, D, 16, 16) patch features from first frame
  2. Temporal expand + temporal embeddings → (B, D, N_FRAMES, 16, 16)
  3. Concatenate with noised VAE tokens z_t → (B, D + vae_embed_dim, N_FRAMES, 16, 16)
  4. Add timestep embedding (broadcast over spatial/temporal dims)
  5. 3D conv blocks predict noise ε → (B, vae_embed_dim, N_FRAMES, 16, 16)
"""

import os
import math

import torch
import torch.nn as nn

DINO_REPO_DIR = os.environ.get("DINO_REPO_DIR", "/data/cameron/dinov3")
DINO_WEIGHTS_PATH = os.environ.get(
    "DINO_WEIGHTS_PATH",
    "/data/cameron/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth",
)
DINO_PATCH_SIZE = 16
N_FRAMES = 6
VAE_EMBED_DIM = 16
N_DIFFUSION_STEPS = 3

# Cosine noise schedule: alpha_bar values for 3 steps (t=0 cleanest, t=2 noisiest)
# alpha_bar_t = cos²(π/2 · (t+1)/(T+1)) gives a smooth decay
def _cosine_alpha_bar_schedule(n_steps):
    """Cosine schedule for alpha_bar: alpha_bar[t] = cos²(π/2 · (t+1)/(T+1))."""
    return [math.cos(math.pi / 2 * (t + 1) / (n_steps + 1)) ** 2 for t in range(n_steps)]

ALPHA_BAR = _cosine_alpha_bar_schedule(N_DIFFUSION_STEPS)
# e.g. for 3 steps: [0.75, 0.25, 0.0245] — clean → medium → near-pure noise


class DinoVideoModel(nn.Module):
    """Predict future VAE latent tokens via 3-step diffusion from a single DINO-encoded frame."""

    def __init__(self, n_frames=N_FRAMES, vae_embed_dim=VAE_EMBED_DIM,
                 n_diffusion_steps=N_DIFFUSION_STEPS):
        super().__init__()
        self.n_frames = n_frames
        self.vae_embed_dim = vae_embed_dim
        self.n_diffusion_steps = n_diffusion_steps
        self.patch_size = DINO_PATCH_SIZE

        # Noise schedule (registered as buffer so they move with .to(device))
        alpha_bar = _cosine_alpha_bar_schedule(n_diffusion_steps)
        self.register_buffer("alpha_bar", torch.tensor(alpha_bar, dtype=torch.float32))
        print(f"Diffusion schedule ({n_diffusion_steps} steps): alpha_bar = {[f'{a:.4f}' for a in alpha_bar]}")

        # Frozen DINO backbone
        print("Loading DINOv3 model...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            "dinov3_vits16plus",
            source="local",
            weights=DINO_WEIGHTS_PATH,
        )
        for param in self.dino.parameters():
            param.requires_grad = False
        self.dino.eval()
        self.embed_dim = self.dino.embed_dim
        print(f"Frozen DINOv3 backbone (embed_dim={self.embed_dim})")

        D = self.embed_dim  # 384

        # Learnable temporal embeddings: (1, D, N_FRAMES, 1, 1)
        self.temporal_embed = nn.Parameter(torch.randn(1, D, n_frames, 1, 1) * 0.02)

        # Timestep embeddings: one per diffusion step, projected to channel dim
        self.timestep_embed = nn.Embedding(n_diffusion_steps, D + vae_embed_dim)

        # 3D conv blocks: input is [DINO features, noised z_t] concatenated
        in_ch = D + vae_embed_dim  # 384 + 16 = 400
        self.conv3d_blocks = nn.Sequential(
            nn.Conv3d(in_ch, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv3d(D, D, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv3d(D, 256, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv3d(256, vae_embed_dim, kernel_size=1),
        )
        print(f"3D conv blocks: {in_ch} -> {D} -> {D} -> 256 -> {vae_embed_dim}")
        print(f"Output: (B, {vae_embed_dim}, {n_frames}, 16, 16) — noise prediction")

    def to(self, device):
        super().to(device)
        if hasattr(self, "dino"):
            self.dino = self.dino.to(device)
        return self

    def _extract_dino_features(self, x):
        """Extract patch features from frozen DINO backbone.

        Args:
            x: (B, 3, 256, 256) ImageNet-normalized input

        Returns:
            patch_features: (B, D, H_p, W_p)
        """
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_cls = self.dino.cls_norm(x_tokens[:, : self.dino.n_storage_tokens + 1])
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
            x_tokens = torch.cat([x_norm_cls, x_norm_patches], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)

        patch_tokens = x_tokens[:, self.dino.n_storage_tokens + 1 :]
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim)
        patch_features = patch_features.permute(0, 3, 1, 2).contiguous()  # (B, D, H_p, W_p)
        return patch_features

    def _get_dino_cond(self, x):
        """Extract DINO features and expand temporally.

        Args:
            x: (B, 3, 256, 256) ImageNet-normalized input

        Returns:
            cond: (B, D, N_FRAMES, 16, 16) — DINO features + temporal embeddings
        """
        with torch.no_grad():
            feats = self._extract_dino_features(x)  # (B, D, 16, 16)
        # Temporal expansion + embeddings
        cond = feats.unsqueeze(2).expand(-1, -1, self.n_frames, -1, -1)
        cond = cond + self.temporal_embed
        return cond

    def predict_noise(self, x, z_t, t):
        """Predict noise given conditioning image, noised tokens, and timestep.

        Args:
            x: (B, 3, 256, 256) ImageNet-normalized input frame
            z_t: (B, vae_embed_dim, N_FRAMES, 16, 16) noised VAE tokens
            t: (B,) integer timestep indices in [0, n_diffusion_steps)

        Returns:
            eps_pred: (B, vae_embed_dim, N_FRAMES, 16, 16) predicted noise
        """
        cond = self._get_dino_cond(x)  # (B, D, N_FRAMES, 16, 16)

        # Concatenate DINO conditioning with noised tokens
        h = torch.cat([cond, z_t], dim=1)  # (B, D + vae_embed_dim, N_FRAMES, 16, 16)

        # Add timestep embedding (broadcast over spatial and temporal dims)
        t_emb = self.timestep_embed(t)  # (B, D + vae_embed_dim)
        h = h + t_emb[:, :, None, None, None]

        # 3D conv → noise prediction
        eps_pred = self.conv3d_blocks(h)  # (B, vae_embed_dim, N_FRAMES, 16, 16)
        return eps_pred

    def forward_diffusion(self, z_0, t):
        """Forward diffusion: add noise to clean tokens.

        Args:
            z_0: (B, C, T, H, W) clean VAE tokens
            t: (B,) integer timestep indices

        Returns:
            z_t: (B, C, T, H, W) noised tokens
            eps: (B, C, T, H, W) the noise that was added
        """
        alpha_bar_t = self.alpha_bar[t]  # (B,)
        alpha_bar_t = alpha_bar_t[:, None, None, None, None]  # (B, 1, 1, 1, 1)

        eps = torch.randn_like(z_0)
        z_t = torch.sqrt(alpha_bar_t) * z_0 + torch.sqrt(1.0 - alpha_bar_t) * eps
        return z_t, eps

    def forward(self, x, z_0=None, t=None):
        """Training forward pass: sample noise level, predict clean x₀.

        Args:
            x: (B, 3, 256, 256) ImageNet-normalized input
            z_0: (B, vae_embed_dim, N_FRAMES, 16, 16) clean target VAE tokens
            t: (B,) optional timestep indices. If None, sampled randomly.

        Returns:
            x0_pred: (B, vae_embed_dim, N_FRAMES, 16, 16) predicted clean tokens
            z_0: (B, vae_embed_dim, N_FRAMES, 16, 16) actual clean tokens (target)
            t: (B,) sampled timestep indices
        """
        B = x.shape[0]
        device = x.device

        # Sample random timestep per batch element
        if t is None:
            t = torch.randint(0, self.n_diffusion_steps, (B,), device=device)

        # Forward diffusion: add noise to clean tokens
        z_t, eps = self.forward_diffusion(z_0, t)

        # Predict clean x₀ (not noise)
        x0_pred = self.predict_noise(x, z_t, t)  # reusing same network, just different target

        return x0_pred, z_0, t

    @torch.no_grad()
    def sample(self, x):
        """3-step reverse diffusion (x₀-prediction): denoise from pure noise to predicted VAE tokens.

        At each step the network directly predicts x₀. We derive ε from x₀_pred
        to compute the DDPM posterior for stepping to the next noise level.

        Args:
            x: (B, 3, 256, 256) ImageNet-normalized input

        Returns:
            z_0_pred: (B, vae_embed_dim, N_FRAMES, 16, 16) predicted clean VAE tokens
        """
        B = x.shape[0]
        device = x.device

        # Start from pure Gaussian noise
        z = torch.randn(B, self.vae_embed_dim, self.n_frames, 16, 16, device=device)

        # Reverse diffusion: t = T-1, T-2, ..., 0
        for step in reversed(range(self.n_diffusion_steps)):
            t = torch.full((B,), step, device=device, dtype=torch.long)
            x0_pred = self.predict_noise(x, z, t)  # network predicts x₀ directly

            alpha_bar_t = self.alpha_bar[step]

            if step > 0:
                # Derive ε from x₀_pred: ε = (z_t - √ᾱ·x₀) / √(1-ᾱ)
                eps_derived = (z - math.sqrt(alpha_bar_t) * x0_pred) / math.sqrt(1.0 - alpha_bar_t)

                # DDPM posterior step
                alpha_bar_prev = self.alpha_bar[step - 1]
                alpha_t = alpha_bar_t / alpha_bar_prev
                beta_t = 1.0 - alpha_t

                coef_z0 = math.sqrt(alpha_bar_prev) * beta_t / (1.0 - alpha_bar_t)
                coef_zt = math.sqrt(alpha_t) * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t)
                mu = coef_z0 * x0_pred + coef_zt * z

                posterior_var = (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t) * beta_t
                sigma = math.sqrt(posterior_var)
                z = mu + sigma * torch.randn_like(z)
            else:
                # Final step: just use x₀_pred directly
                z = x0_pred

        return z
