"""Teacher: frozen DINO on all frames -> (B, 8, D, 32, 32). Student: same -> (B, 8, D, 32, 32). 512x512 input -> 32x32 patches."""

from pathlib import Path

import torch
import torch.nn as nn
from einops import rearrange

from .dino_loader import load_dino, extract_patch_features, DinoNormalize

NUM_FRAMES = 8
DINO_SIZE = 512
PATCH_H = PATCH_W = 32  # 512/16

# -----------------------------------------------------------------------------
# Previous model (video self-attention + per-frame decoder). Commented out for revert.
# -----------------------------------------------------------------------------
# class VideoSelfAttentionBlock(nn.Module):
#     """Single transformer block: LayerNorm -> self-attention -> residual -> LayerNorm -> FFN -> residual."""
#
#     def __init__(self, embed_dim: int, num_heads: int = 8, mlp_ratio: float = 4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
#         self.norm2 = nn.LayerNorm(embed_dim)
#         self.mlp = nn.Sequential(
#             nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
#             nn.GELU(),
#             nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
#         )
#
#     def forward(self, x):
#         x = x + self._attn_block(self.norm1(x))
#         x = x + self._mlp_block(self.norm2(x))
#         return x
#
#     def _attn_block(self, x):
#         return self.attn(x, x, x, need_weights=False)[0]
#
#     def _mlp_block(self, x):
#         return self.mlp(x)
#
#
# class StudentDinoVideo(nn.Module):
#     """DINO on first frame (trainable) -> broadcast over 8 frames + spatial/temporal embeddings + video self-attention."""
#
#     def __init__(self, keygrip_root: Path, num_attn_layers: int = 4, num_heads: int = 8):
#         super().__init__()
#         self.dino = load_dino(keygrip_root, freeze=False)
#         self.embed_dim = self.dino.embed_dim
#         self.norm = DinoNormalize(dino_size=DINO_SIZE)
#         self.spatial_embed = nn.Parameter(torch.randn(1, self.embed_dim, 1, PATCH_H, PATCH_W) * 0.02)
#         self.time_embed = nn.Parameter(torch.randn(1, self.embed_dim, NUM_FRAMES, 1, 1) * 0.02)
#         self.attn_blocks = nn.ModuleList([
#             VideoSelfAttentionBlock(self.embed_dim, num_heads=num_heads)
#             for _ in range(num_attn_layers)
#         ])
#         dec_channels1 = self.embed_dim // 2
#         dec_channels2 = max(self.embed_dim // 4, 64)
#         dec_channels3 = max(self.embed_dim // 8, 32)
#         self.decoder = nn.Sequential(
#             nn.Conv2d(self.embed_dim, dec_channels1, kernel_size=3, padding=1),
#             nn.GELU(),
#             nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
#             nn.Conv2d(dec_channels1, dec_channels2, kernel_size=3, padding=1),
#             nn.GELU(),
#             nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
#             nn.Conv2d(dec_channels2, dec_channels3, kernel_size=3, padding=1),
#             nn.GELU(),
#             nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
#             nn.Conv2d(dec_channels3, 3, kernel_size=1),
#             nn.Tanh(),
#         )
#
#     def forward(self, x):
#         B = x.shape[0]
#         first = self.norm(x[:, :, 0])
#         first_feat = extract_patch_features(self.dino, first, None, None)
#         z = first_feat.unsqueeze(2).expand(B, self.embed_dim, NUM_FRAMES, PATCH_H, PATCH_W)
#         z = z + self.spatial_embed + self.time_embed
#         z = rearrange(z, "b d t h w -> b (t h w) d")
#         for blk in self.attn_blocks:
#             z = blk(z)
#         z = rearrange(z, "b (t h w) d -> b d t h w", t=NUM_FRAMES, h=PATCH_H, w=PATCH_W)
#         feats = z.permute(0, 2, 1, 3, 4)
#         z_2d = feats.reshape(B * NUM_FRAMES, self.embed_dim, PATCH_H, PATCH_W)
#         rgb_2d = self.decoder(z_2d)
#         rgb = rgb_2d.view(B, NUM_FRAMES, 3, 256, 256)
#         return feats, rgb
# -----------------------------------------------------------------------------
class TeacherDinoVideo(nn.Module):
    """Frozen DINO run on each of 8 frames (resized to 512x512 -> 32x32 patches). Returns (B, 8, D, 32, 32)."""

    def __init__(self, keygrip_root: Path):
        super().__init__()
        self.dino = load_dino(keygrip_root, freeze=True)
        self.embed_dim = self.dino.embed_dim
        self.norm = DinoNormalize(dino_size=DINO_SIZE)

    def to(self, device):
        super().to(device)
        self.dino = self.dino.to(device)
        return self

    def forward(self, x):
        """x: (B, 3, 8, 256, 256). Returns (B, 8, D, 32, 32). Batched over time for speed."""
        B, C, T, H, W = x.shape
        x_flat = rearrange(x, "b c t h w -> (b t) c h w")
        x_flat = self.norm(x_flat)
        with torch.no_grad():
            feat = extract_patch_features(self.dino, x_flat, None, None)
        return rearrange(feat, "(b t) d h w -> b t d h w", b=B, t=T)


class StudentDinoVideo(nn.Module):
    """Single-frame DINO + Minimal Iterative Policy (MIP)-style 2-step RGB video regression.

    We treat the "action" as the whole video in pixel space. The same network is invoked twice:
    - step 0: pi(o, 0, 0)
    - step 1: pi(o, I_{t*}, t*)  (training uses GT noisy interpolant; inference uses t* * a0_hat)

    Outputs are in [-1, 1]. The forward() path is deterministic inference (returns step-1 output).
    Use `mip_train_preds()` to get both step predictions for supervision.
    """

    def __init__(self, keygrip_root: Path):
        super().__init__()
        self.dino = load_dino(keygrip_root, freeze=False)
        self.embed_dim = self.dino.embed_dim
        self.norm = DinoNormalize(dino_size=DINO_SIZE)
        # Per-pixel regression head: 32x32 -> 256x256, output 3 * window_length channels
        self.out_channels = 3 * NUM_FRAMES  # 24 for 8 frames
        dec_c1 = self.embed_dim // 2
        dec_c2 = max(self.embed_dim // 4, 64)
        dec_c3 = max(self.embed_dim // 8, 32)

        # Condition on I_t by downsampling it to 32x32 and projecting to embed_dim
        self.it_proj = nn.Conv2d(self.out_channels, self.embed_dim, kernel_size=1)
        self.t_embed = nn.Sequential(
            nn.Linear(1, self.embed_dim),
            nn.GELU(),
            nn.Linear(self.embed_dim, self.embed_dim),
        )
        self.head = nn.Sequential(
            nn.Conv2d(self.embed_dim, dec_c1, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(dec_c1, dec_c2, kernel_size=4, stride=2, padding=1),  # 32 -> 64
            nn.GELU(),
            nn.ConvTranspose2d(dec_c2, dec_c3, kernel_size=4, stride=2, padding=1),  # 64 -> 128
            nn.GELU(),
            nn.ConvTranspose2d(dec_c3, dec_c3, kernel_size=4, stride=2, padding=1),  # 128 -> 256
            nn.GELU(),
            nn.Conv2d(dec_c3, self.out_channels, kernel_size=1),
            nn.Tanh(),
        )

    def _encode_first(self, x):
        first = self.norm(x[:, :, 0])  # (B, 3, H, W)
        return extract_patch_features(self.dino, first, None, None)  # (B, D, 32, 32)

    def pi(self, first_feat, I_t, t):
        """pi_theta(o_feat, I_t, t) -> (B, 3*NUM_FRAMES, 256, 256)."""
        B = first_feat.shape[0]
        if I_t is None:
            I_t = torch.zeros(B, self.out_channels, 256, 256, device=first_feat.device, dtype=first_feat.dtype)
        # Downsample I_t to 32x32 and project to embed_dim
        it32 = torch.nn.functional.interpolate(I_t, size=(PATCH_H, PATCH_W), mode="bilinear", align_corners=False)
        cond = self.it_proj(it32)
        # Embed scalar t and add as a bias
        if not torch.is_tensor(t):
            t = torch.tensor(t, device=first_feat.device, dtype=first_feat.dtype)
        t = t.reshape(1, 1).expand(B, 1)
        tb = self.t_embed(t.float()).to(dtype=first_feat.dtype).view(B, self.embed_dim, 1, 1)
        z = first_feat + cond + tb
        return self.head(z)

    def mip_train_preds(self, x, target_video, t_star: float = 0.9):
        """Return (a0_hat, a1_hat) in (B, T, 3, 256, 256) for MIP training supervision.

        target_video should be (B, T, 3, 256, 256) in [-1, 1].
        """
        B = x.shape[0]
        first_feat = self._encode_first(x)
        # Step 0: I_t = 0, t = 0
        a0 = self.pi(first_feat, I_t=None, t=0.0).view(B, NUM_FRAMES, 3, 256, 256)
        # Step 1: I_{t*} = t* a + (1-t*) z  (z ~ N(0, I))
        a_flat = target_video.reshape(B, self.out_channels, 256, 256)
        z = torch.randn_like(a_flat)
        I_t = (t_star * a_flat + (1.0 - t_star) * z).to(dtype=first_feat.dtype)
        a1 = self.pi(first_feat, I_t=I_t, t=t_star).view(B, NUM_FRAMES, 3, 256, 256)
        return a0, a1

    def mip_infer_steps(self, x, t_star: float = 0.9):
        """Deterministic 2-step inference returning both steps (a0_hat, a1_hat).

        Returns:
            a0_hat: (B, T, 3, 256, 256)
            a1_hat: (B, T, 3, 256, 256)
        """
        B = x.shape[0]
        first_feat = self._encode_first(x)
        a0_flat = self.pi(first_feat, I_t=None, t=0.0)  # (B, 3*NUM_FRAMES, 256, 256)
        a1_flat = self.pi(first_feat, I_t=t_star * a0_flat, t=t_star)
        a0 = a0_flat.view(B, NUM_FRAMES, 3, 256, 256)
        a1 = a1_flat.view(B, NUM_FRAMES, 3, 256, 256)
        return a0, a1

    def forward(self, x, t_star: float = 0.9):
        """Deterministic 2-step inference. Returns (feats, rgb) where rgb is step-1 prediction."""
        B = x.shape[0]
        first_feat = self._encode_first(x)
        # Step 0
        a0_flat = self.pi(first_feat, I_t=None, t=0.0)  # (B, 3*NUM_FRAMES, 256, 256)
        # Step 1 (deterministic): I_t = t* a0_hat
        a1_flat = self.pi(first_feat, I_t=t_star * a0_flat, t=t_star)
        rgb = a1_flat.view(B, NUM_FRAMES, 3, 256, 256)
        feats = first_feat.unsqueeze(1).expand(B, NUM_FRAMES, self.embed_dim, PATCH_H, PATCH_W)
        return feats, rgb

