"""Video latent model: DINO first frame -> 16x16 patches; learnable patches for other frames; self-attention; conv up to 32x32."""

from pathlib import Path

import torch
import torch.nn as nn
from einops import rearrange

from .dino_featurizer import DINOFeaturizer

LATENT_H = LATENT_W = 32
PATCH_H = PATCH_W = 16
NUM_FRAMES = 8
TIME_DOWN = 2
LATENT_T = NUM_FRAMES // TIME_DOWN
Z_CHANNELS = 8


class VideoLatentModel(nn.Module):
    """Predict VidTok-style continuous latents (B, C, T, 32, 32) from 8-frame 256x256 video.
    First frame -> DINO patch tokens 16x16; frames 1..7 -> learnable 16x16 + time embed; self-attn; conv up to 32x32.
    """

    def __init__(self, keygrip_root: Path, hidden_dim: int = 384, num_heads: int = 6, num_layers: int = 4):
        super().__init__()
        self.keygrip_root = Path(keygrip_root)
        self.dino = DINOFeaturizer(keygrip_root)
        dino_dim = self.dino.embed_dim

        self.dino_proj = nn.Conv2d(dino_dim, hidden_dim, kernel_size=1)
        self.other_frames_embed = nn.Parameter(torch.randn(1, NUM_FRAMES - 1, hidden_dim, PATCH_H, PATCH_W) * 0.02)
        self.time_embed = nn.Embedding(NUM_FRAMES, hidden_dim)
        self.time_embed.weight.data.normal_(0, 0.02)

        self.attn_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim * 4,
                dropout=0.0,
                activation="gelu",
                batch_first=True,
                norm_first=True,
            )
            for _ in range(num_layers)
        ])
        self.attn_norm = nn.LayerNorm(hidden_dim)

        # (B, 8, 16, 16, H) -> (B, 8, 4, 32, 32)
        self.temporal_pool = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(TIME_DOWN, 1, 1), stride=(TIME_DOWN, 1, 1))
        self.spatial_upsample = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(hidden_dim, Z_CHANNELS, 4, stride=2, padding=1),
        )
        self._init_weights()

    def _init_weights(self):
        for m in [self.dino_proj, self.temporal_pool]:
            if hasattr(m, "weight"):
                nn.init.xavier_uniform_(m.weight, gain=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        for m in self.spatial_upsample:
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight, gain=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        """
        x: (B, 3, 8, 256, 256) in [-1, 1]
        out: (B, Z_CHANNELS, LATENT_T, LATENT_H, LATENT_W) = (B, 8, 4, 32, 32)
        """
        B = x.shape[0]
        first = x[:, :, 0]  # (B, 3, 256, 256)
        with torch.no_grad():
            dino_feat = self.dino(first)  # (B, D, 16, 16)
        first_tokens = self.dino_proj(dino_feat)  # (B, H, 16, 16)
        other_tokens = self.other_frames_embed.expand(B, -1, -1, -1, -1)  # (B, 7, H, 16, 16)
        time_idx = torch.arange(NUM_FRAMES - 1, device=x.device)
        other_tokens = other_tokens + self.time_embed(time_idx).view(1, NUM_FRAMES - 1, -1, 1, 1)
        tokens = torch.cat([first_tokens.unsqueeze(1), other_tokens], dim=1)  # (B, 8, H, 16, 16)
        tokens = rearrange(tokens, "b t c h w -> b (t h w) c")
        for layer in self.attn_layers:
            tokens = layer(tokens)
        tokens = self.attn_norm(tokens)
        tokens = rearrange(tokens, "b (t h w) c -> b c t h w", t=NUM_FRAMES, h=PATCH_H, w=PATCH_W)
        tokens = self.temporal_pool(tokens)  # (B, H, 4, 16, 16)
        b, c, t, h, w = tokens.shape
        tokens = rearrange(tokens, "b c t h w -> (b t) c h w")
        out = self.spatial_upsample(tokens)  # (B*4, 8, 32, 32)
        out = rearrange(out, "(b t) c h w -> b c t h w", b=B, t=LATENT_T)
        return out
