"""
DINO-conditioned diffusion in VAE latent space.
- Tokenizer & diffusion from simple_uva (UVA).
- Single-frame DINO encoder at 32x32 from dino_vid_model.
- Lightweight self-attention denoising net conditioned on DINO features.
"""

import math
import sys
from pathlib import Path

import torch
import torch.nn as nn

# Add unified_video_action for simple_uva (VAE, diffusion)
_uva_root = Path(__file__).resolve().parents[1] / "unified_video_action"
if _uva_root.exists() and str(_uva_root) not in sys.path:
    sys.path.insert(0, str(_uva_root))

from simple_uva.diffusion import create_diffusion
from .dino_encoder import load_dino_encoder, encode_frame

# VAE tokenizer: 16x16 latent, patch_size=1 -> 256 tokens of dim 16 per frame
VAE_LATENT_SCALE = 0.2325
NUM_FRAMES = 8
TOKENS_PER_FRAME = 256
SEQ_LEN = NUM_FRAMES * TOKENS_PER_FRAME
TOKEN_DIM = 16
SEQ_H = SEQ_W = 16


def patchify_latent(latent):
    """latent: (B, C, H, W) with H=W=16, C=16. Returns (B, 256, 16)."""
    B, C, H, W = latent.shape
    return latent.reshape(B, C, -1).permute(0, 2, 1)


def unpatchify_latent(tokens):
    """tokens: (B, 256, 16). Returns (B, 16, 16, 16)."""
    B, S, C = tokens.shape
    h = w = 16
    return tokens.permute(0, 2, 1).reshape(B, C, h, w)


def frame_to_tokens(frame, vae, scale=VAE_LATENT_SCALE):
    """Encode frame to VAE latent, scale, and patchify. frame: (B, 3, H, W). Returns (B, 256, 16)."""
    with torch.no_grad():
        posterior = vae.encode(frame)
        z = posterior.mode()
    z = z * scale
    return patchify_latent(z)

def video_to_tokens(video, vae, scale=VAE_LATENT_SCALE):
    """
    video: (B, 3, T, H, W) in [-1,1] (or compatible). Returns (B, T*256, 16).
    Uses a single VAE encode over flattened frames.
    """
    B, C, T, H, W = video.shape
    frames = video.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
    with torch.no_grad():
        posterior = vae.encode(frames)
        z = posterior.mode() * scale  # (B*T, 16, 16, 16)
    tokens = patchify_latent(z)  # (B*T, 256, 16)
    tokens = tokens.reshape(B, T * TOKENS_PER_FRAME, TOKEN_DIM)
    return tokens


def tokens_to_frame(tokens, vae, scale=VAE_LATENT_SCALE):
    """Unpatchify, unscale, decode. tokens: (B, 256, 16). Returns (B, 3, H, W)."""
    z = unpatchify_latent(tokens) / scale
    with torch.no_grad():
        return vae.decode(z)


def tokens_to_video(tokens, vae, scale=VAE_LATENT_SCALE, t=NUM_FRAMES):
    """
    tokens: (B, T*256, 16). Returns (B, 3, T, 256, 256) (VAE output range depends on VAE; typically [-1,1]).
    """
    B, S, C = tokens.shape
    if S != t * TOKENS_PER_FRAME:
        raise ValueError(f"Expected tokens second dim {t * TOKENS_PER_FRAME}, got {S}")
    tokens_bt = tokens.reshape(B * t, TOKENS_PER_FRAME, C)
    z = unpatchify_latent(tokens_bt) / scale  # (B*T, 16, 16, 16)
    with torch.no_grad():
        frames = vae.decode(z)  # (B*T, 3, 256, 256)
    frames = frames.reshape(B, t, 3, frames.shape[-2], frames.shape[-1]).permute(0, 2, 1, 3, 4)
    return frames


def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
            / half
        )
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        return self.mlp(t_freq)


class ResBlock(nn.Module):
    """Per-token MLP block with AdaLN modulation (same pattern as UVA SimpleMLPAdaLN)."""

    def __init__(self, channels: int):
        super().__init__()
        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift, scale)
        h = self.mlp(h)
        return x + gate * h


class FinalLayer(nn.Module):
    def __init__(self, model_channels: int, out_channels: int):
        super().__init__()
        self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)
        )

    def forward(self, x, y):
        shift, scale = self.adaLN_modulation(y).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        return self.linear(x)


class DinoCondDiffusionNet(nn.Module):
    """
    Denoising net: noisy VAE tokens (N, 16) + timestep t + DINO cond (N, D).
    Per-token MLP (UVA SimpleMLPAdaLN-style) conditioned on t + DINO via AdaLN.
    Output (N, out_channels) for epsilon (+ variance).
    """

    def __init__(
        self,
        in_channels=TOKEN_DIM,
        model_channels=512,
        out_channels=TOKEN_DIM * 2,
        num_res_blocks=6,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.model_channels = model_channels
        self.num_res_blocks = num_res_blocks

        self.time_embed = TimestepEmbedder(model_channels)
        # `c` is already in model_channels space (from cond_proj + video_pos_embed)
        self.cond_embed = nn.Linear(model_channels, model_channels)
        self.input_proj = nn.Linear(in_channels, model_channels)
        self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
        self.final_layer = FinalLayer(model_channels, out_channels)
        self._init_weights()

    def _init_weights(self):
        def _basic_init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.apply(_basic_init)
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def forward(self, x, t, c):
        # x: (N, in_channels), t: (N,), c: (N, dino_dim)
        x = self.input_proj(x)
        y = self.time_embed(t) + self.cond_embed(c)
        for block in self.res_blocks:
            x = block(x, y)
        return self.final_layer(x, y)


class Conv3DBlock(nn.Module):
    """3D conv block: GroupNorm, 3×3×3 conv, AdaLN from global conditioning (B, cond_dim)."""

    def __init__(self, channels: int, cond_dim: int, num_groups: int = 8):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups, channels)
        self.conv = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, 2 * channels, bias=True),
        )

    def forward(self, x, y):
        # x: (B, C, T, H, W), y: (B, cond_dim)
        shift, scale = self.adaLN_modulation(y).chunk(2, dim=-1)
        shift = shift[:, :, None, None, None]
        scale = scale[:, :, None, None, None]
        h = modulate(self.norm(x), shift, scale)
        h = self.conv(h)
        return x + h


class DinoCondDiffusionNetConv3D(nn.Module):
    """
    Denoising net over latent volume: reshape (B*SEQ_LEN, 16) -> (B, 16, 8, 16, 16),
    in_proj -> 5× Conv3DBlock(cond) -> out_proj -> (B*SEQ_LEN, 32).
    Conditioning c (N, model_channels) is aggregated to (B, model_channels) for AdaLN.
    """

    def __init__(
        self,
        in_channels=TOKEN_DIM,
        model_channels=512,
        out_channels=TOKEN_DIM * 2,
        num_blocks=5,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.model_channels = model_channels
        self.num_blocks = num_blocks

        self.time_embed = TimestepEmbedder(model_channels)
        self.cond_embed = nn.Linear(model_channels, model_channels)
        self.in_proj = nn.Conv3d(in_channels, model_channels, kernel_size=1)
        self.blocks = nn.ModuleList([
            Conv3DBlock(model_channels, model_channels) for _ in range(num_blocks)
        ])
        self.out_proj = nn.Conv3d(model_channels, out_channels, kernel_size=1)
        self._init_weights()

    def _init_weights(self):
        def _basic_init(m):
            if isinstance(m, (nn.Linear, nn.Conv3d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.apply(_basic_init)
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.out_proj.weight, 0)
        nn.init.constant_(self.out_proj.bias, 0)

    def forward(self, x, t, c):
        # x: (N, in_channels), t: (N,), c: (N, model_channels)
        N = x.shape[0]
        B = N // SEQ_LEN
        assert B * SEQ_LEN == N
        # Reshape to (B, C, T, H, W)
        x = x.view(B, SEQ_LEN, self.in_channels)
        x = x.permute(0, 2, 1).reshape(B, self.in_channels, NUM_FRAMES, SEQ_H, SEQ_W)
        # Global conditioning: one vector per batch
        t_batch = t.view(B, SEQ_LEN)[:, 0]  # (B,)
        c_global = c.view(B, SEQ_LEN, -1).mean(dim=1)  # (B, model_channels)
        y = self.time_embed(t_batch) + self.cond_embed(c_global)  # (B, model_channels)

        x = self.in_proj(x)
        for block in self.blocks:
            x = block(x, y)
        x = self.out_proj(x)
        # (B, out_channels, 8, 16, 16) -> (N, out_channels)
        x = x.permute(0, 2, 3, 4, 1).reshape(B, -1, self.out_channels)
        x = x.reshape(N, self.out_channels)
        return x


class TransformerLayer(nn.Module):
    """Simple self-attention transformer block with AdaLN modulation from per-token conditioning."""

    def __init__(self, channels: int, num_heads: int = 8):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.ln1 = nn.LayerNorm(channels, eps=1e-6)
        self.ln2 = nn.LayerNorm(channels, eps=1e-6)
        self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(channels, 4 * channels, bias=True),
            nn.SiLU(),
            nn.Linear(4 * channels, channels, bias=True),
        )
        self.adaLN_modulation1 = nn.Sequential(
            nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
        )
        self.adaLN_modulation2 = nn.Sequential(
            nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        """
        x: (B, S, C) token features
        y: (B, S, C) per-token conditioning (time + DINO)
        """
        # Self-attention block
        shift1, scale1, gate1 = self.adaLN_modulation1(y).chunk(3, dim=-1)
        h = modulate(self.ln1(x), shift1, scale1)
        h_attn, _ = self.attn(h, h, h, need_weights=False)
        x = x + gate1 * h_attn

        # MLP block
        shift2, scale2, gate2 = self.adaLN_modulation2(y).chunk(3, dim=-1)
        h2 = modulate(self.ln2(x), shift2, scale2)
        h2 = self.mlp(h2)
        x = x + gate2 * h2
        return x


class DinoCondDiffusionNetTransformer(nn.Module):
    """
    Self-attention transformer denoiser over the token sequence.
    Reshapes (N, C) -> (B, SEQ_LEN, C), runs several TransformerLayer blocks,
    then maps back to (N, out_channels). Uses the same (time + cond) AdaLN
    conditioning pattern as the MLP denoiser.
    """

    def __init__(
        self,
        in_channels=TOKEN_DIM,
        model_channels=512,
        out_channels=TOKEN_DIM * 2,
        num_layers=4,
        num_heads=8,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.model_channels = model_channels
        self.num_layers = num_layers
        self.num_heads = num_heads

        self.time_embed = TimestepEmbedder(model_channels)
        self.cond_embed = nn.Linear(model_channels, model_channels)
        self.input_proj = nn.Linear(in_channels, model_channels)
        self.layers = nn.ModuleList(
            [TransformerLayer(model_channels, num_heads=num_heads) for _ in range(num_layers)]
        )
        self.final_layer = FinalLayer(model_channels, out_channels)
        self._init_weights()

    def _init_weights(self):
        def _basic_init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.apply(_basic_init)
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
        for layer in self.layers:
            nn.init.constant_(layer.adaLN_modulation1[-1].weight, 0)
            nn.init.constant_(layer.adaLN_modulation1[-1].bias, 0)
            nn.init.constant_(layer.adaLN_modulation2[-1].weight, 0)
            nn.init.constant_(layer.adaLN_modulation2[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def forward(self, x, t, c):
        # x: (N, in_channels), t: (N,), c: (N, model_channels)
        N = x.shape[0]
        B = N // SEQ_LEN
        assert B * SEQ_LEN == N, "N must be divisible by SEQ_LEN"

        x = self.input_proj(x)  # (N, model_channels)
        y = self.time_embed(t) + self.cond_embed(c)  # (N, model_channels)

        # (N, C) -> (B, S, C)
        x = x.view(B, SEQ_LEN, self.model_channels)
        y_seq = y.view(B, SEQ_LEN, self.model_channels)

        for layer in self.layers:
            x = layer(x, y_seq)

        x = x.view(N, self.model_channels)
        y = y.view(N, self.model_channels)
        return self.final_layer(x, y)


class DinoDiffusion(nn.Module):
    """
    Full model: DINO encoder (single frame -> 32x32 features) + diffusion on residuals in VAE token space.
    - Tokenizer & diffusion schedule from UVA.
    - Conditioning: DINO tokens + first-frame VAE tokens, broadcast to all 8 frames with video position embeddings.
    - Target: residual tokens r = x_true - x_ref, where x_ref is the first-frame tokens broadcast to all frames.
    """

    def __init__(
        self,
        keygrip_root: Path,
        dino_freeze=True,
        num_sampling_steps="1",
        hidden_size=512,
        num_res_blocks=6,
        denoiser="mlp",
    ):
        super().__init__()
        self.dino, self.dino_norm = load_dino_encoder(Path(keygrip_root), freeze=dino_freeze)
        self.dino_dim = getattr(self.dino, "embed_dim", 384)
        self.seq_len = SEQ_LEN
        self.token_dim = TOKEN_DIM
        self.num_frames = NUM_FRAMES
        self.tokens_per_frame = TOKENS_PER_FRAME
        self.model_channels = hidden_size
        # Number of iterative denoising steps (1 = pure one-step).
        self.num_steps = max(1, int(num_sampling_steps))
        # Noise levels sigma_k for k=0..K; sigma_0=0 (clean), sigma_K=max.
        noise_max = 1.0
        noise_min = 0.0
        noise_levels = torch.linspace(noise_min, noise_max, self.num_steps + 1)
        self.register_buffer("noise_levels", noise_levels)

        # Project [first_frame_tokens, dino_tokens] -> model_channels
        self.cond_proj = nn.Linear(self.token_dim + self.dino_dim, hidden_size)
        self.video_pos_embed = nn.Parameter(torch.zeros(1, SEQ_LEN, hidden_size))

        if denoiser == "conv3d":
            self.net = DinoCondDiffusionNetConv3D(
                in_channels=TOKEN_DIM,
                model_channels=hidden_size,
                out_channels=TOKEN_DIM * 2,
                num_blocks=5,
            )
        elif denoiser == "transformer":
            self.net = DinoCondDiffusionNetTransformer(
                in_channels=TOKEN_DIM,
                model_channels=hidden_size,
                out_channels=TOKEN_DIM * 2,
                num_layers=4,
                num_heads=8,
            )
        else:
            self.net = DinoCondDiffusionNet(
                in_channels=TOKEN_DIM,
                model_channels=hidden_size,
                out_channels=TOKEN_DIM * 2,
                num_res_blocks=num_res_blocks,
            )
        self.in_channels = TOKEN_DIM

        # Initialize video position embeddings
        nn.init.normal_(self.video_pos_embed, std=0.02)

    def get_dino_cond(self, frame):
        """frame: (B, 3, H, W). Returns (B, seq_len, dino_dim) by pooling 32x32 -> 16x16."""
        feat = encode_frame(self.dino, self.dino_norm, frame)
        B, D, H, W = feat.shape
        feat = torch.nn.functional.adaptive_avg_pool2d(feat, (16, 16))
        feat = feat.flatten(2).permute(0, 2, 1)
        return feat

    def compute_loss(self, target_tokens, dino_cond):
        """
        target_tokens: (B, T*256, token_dim) VAE tokens for all 8 frames (this is x0).
        dino_cond: (B, 256, dino_dim) from get_dino_cond(first_frame).

        We predict the residual latent r0 = x0 - x_ref at each diffusion timestep,
        where x_ref is the first-frame tokens broadcast across all frames. The
        full latent is reconstructed as x0 = x_ref + r0.
        """
        B, seq_len, _ = target_tokens.shape
        if seq_len != SEQ_LEN:
            raise ValueError(f"Expected target_tokens seq_len={SEQ_LEN}, got {seq_len}")

        # Conditioning: concat first-frame tokens and DINO tokens, then broadcast with video pos embed
        if dino_cond.shape != (B, TOKENS_PER_FRAME, self.dino_dim):
            raise ValueError(f"Expected dino_cond shape {(B, TOKENS_PER_FRAME, self.dino_dim)}, got {dino_cond.shape}")

        first_tokens = target_tokens[:, :TOKENS_PER_FRAME, :]  # (B, 256, token_dim) first frame
        cond_first = torch.cat([first_tokens, dino_cond], dim=-1)  # (B, 256, token_dim + dino_dim)
        cond_first = self.cond_proj(cond_first)  # (B, 256, model_channels)
        cond_b = cond_first.unsqueeze(1).expand(B, self.num_frames, self.tokens_per_frame, self.model_channels)
        cond_b = cond_b.reshape(B, SEQ_LEN, self.model_channels)
        cond_b = cond_b + self.video_pos_embed  # broadcast along batch
        cond_flat = cond_b.reshape(B * seq_len, -1)

        # Residual target r0 = x0 - x_ref, where x_ref is first-frame tokens broadcast to all frames.
        x_ref = first_tokens.unsqueeze(1).expand(
            B, self.num_frames, self.tokens_per_frame, self.token_dim
        ).reshape(B, SEQ_LEN, self.token_dim)
        r0 = target_tokens - x_ref  # (B, SEQ_LEN, token_dim)
        r0_flat = r0.reshape(B * seq_len, -1)

        # Unified training loss on residuals: sample a step index k in {1..K},
        # build r_t = r0 + sigma_k * eps, and train the model to predict r0.
        N = r0_flat.shape[0]
        noise = torch.randn_like(r0_flat)
        # Sample integer timesteps in [1, num_steps]
        k = torch.randint(1, self.num_steps + 1, (N,), device=target_tokens.device)
        sigma = self.noise_levels[k].unsqueeze(1)  # (N,1)
        r_t = r0_flat + sigma * noise
        t = k  # feed raw indices into the timestep embedder

        pred_all = self.net(r_t, t, cond_flat)  # (N, 2*C)
        pred_r0, _ = torch.split(pred_all, self.token_dim, dim=1)
        return torch.mean((pred_r0 - r0_flat) ** 2)

    def sample(self, first_tokens, dino_cond, temperature=1.0, device=None):
        """
        Sample full latent tokens x0 for 8 frames, conditioned on first-frame tokens + DINO.

        first_tokens: (B, 256, token_dim) first-frame VAE tokens.
        dino_cond: (B, 256, dino_dim) DINO tokens from the same frame.
        Returns: (B, T*256, token_dim) tokens for all frames.
        """
        device = device or next(self.parameters()).device
        B = first_tokens.shape[0]
        N = B * SEQ_LEN

        if dino_cond.shape != (B, TOKENS_PER_FRAME, self.dino_dim):
            raise ValueError(f"Expected dino_cond shape {(B, TOKENS_PER_FRAME, self.dino_dim)}, got {dino_cond.shape}")

        # Build conditioning tokens exactly as in compute_loss
        cond_first = torch.cat([first_tokens, dino_cond], dim=-1)  # (B, 256, token_dim + dino_dim)
        cond_first = self.cond_proj(cond_first)  # (B, 256, model_channels)
        cond_b = cond_first.unsqueeze(1).expand(
            B, self.num_frames, self.tokens_per_frame, self.model_channels
        )
        cond_b = cond_b.reshape(B, SEQ_LEN, self.model_channels)
        cond_b = cond_b + self.video_pos_embed
        cond_flat = cond_b.reshape(N, -1)

        # Reference latent: first-frame tokens broadcast across all frames
        x_ref = first_tokens.unsqueeze(1).expand(
            B, self.num_frames, self.tokens_per_frame, self.token_dim
        ).reshape(B, SEQ_LEN, self.token_dim)
        x_ref_flat = x_ref.reshape(N, self.token_dim)

        # Iterative refinement: start from pure noise at the highest noise level,
        # then for k = K..1 predict residual r0 and re-noise according to sigma_{k-1}.
        K = self.num_steps
        noise = torch.randn(N, self.token_dim, device=device) * self.noise_levels[K]
        r = noise  # residual at current step
        for k in range(K, 0, -1):
            t = torch.full((N,), k, dtype=torch.long, device=device)
            pred_all = self.net(r, t, cond_flat)
            r0_flat, _ = torch.split(pred_all, self.token_dim, dim=1)
            if k > 1:
                eps = torch.randn_like(r0_flat) * temperature
                sigma_prev = self.noise_levels[k - 1]
                r = r0_flat + sigma_prev * eps
            else:
                r = r0_flat

        x0 = (x_ref_flat + r).view(B, SEQ_LEN, self.token_dim)
        return x0


def build_dino_diffusion(keygrip_root, num_sampling_steps="1", denoiser="mlp", **kwargs):
    return DinoDiffusion(
        keygrip_root=Path(keygrip_root),
        num_sampling_steps=num_sampling_steps,
        denoiser=denoiser,
        **kwargs
    )
