"""DINO + per-voxel AdaLN-Zero MLP volume head (Peebles DiT-style conditioning).

Replaces the rank-1 bilinear F·key scoring of the volume KV head with a small
(t,z)-conditioned MLP per voxel — strictly more expressive, fixes the "per-t
heatmap collapse" diagnosed across volume v1/v2/v3.

Per Cameron 2026-05-20 ("ultrathink" turn):
  - 1×1 conv layout treats (B, T, Z) as the batch dim → no 6D materialisation.
  - Bottleneck d=D_FEAT → D_BOT → D_FEAT per block (4× cheaper than full d).
  - Shared first projection: refine DINO once to F ∈ (B, D_FEAT, H, W),
    then broadcast across (T, Z) and apply N_BLOCKS AdaLN-Zero blocks.
  - α (residual scale) and FiLM γ,β derived from cond = sin(t) + sin(z),
    final cond_proj zero-initialised → block starts as identity at init.
  - Gripper/rotation read from the MLP penultimate (post-blocks, pre-heatmap)
    at the per-(b,t) argmax voxel (GT during training, teacher-forced).
"""
import os, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as ckpt

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

IMG_SIZE       = 448
N_WINDOW       = 48           # near-full episode at stride 2
N_HEIGHT_BINS  = 32
N_GRIPPER_BINS = 32
N_ROT_BINS     = 32           # libero — keep per-axis 3D euler bins (no PCA collapse here)
PRED_SIZE      = 56           # pixel-aligned heatmap grid (448 / 8 ≈ 56)

D_FEAT  = 32                  # per-voxel feature dim (refined from DINO 384 → 32)
D_BOT   = 8                   # bottleneck inside each FiLM block
D_COND  = 128                 # sinusoidal cond dim
N_BLOCKS = 2                  # AdaLN-Zero residual blocks


def sinusoidal_features(n, dim, base=10000.0):
    pos = torch.arange(n, dtype=torch.float32)
    div = torch.exp(torch.arange(0, dim, 2, dtype=torch.float32) * -(math.log(base) / dim))
    pe = torch.zeros(n, dim)
    pe[:, 0::2] = torch.sin(pos.unsqueeze(1) * div)
    pe[:, 1::2] = torch.cos(pos.unsqueeze(1) * div)
    return pe


class AdaLNZeroBlock(nn.Module):
    """AdaLN-Zero residual block over a (N, d, H, W) tensor.
       x → GN(x) → FiLM(γ,β) → 1×1conv(d → d_bot) → GELU → 1×1conv(d_bot → d)
       → + α · ...  (α zero-init → block is identity at start)
    """
    def __init__(self, d, d_bot, d_cond):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=1, num_channels=d, affine=False)
        # FiLM params γ, β, and residual scale α from cond
        self.cond_proj = nn.Linear(d_cond, 3 * d)
        nn.init.zeros_(self.cond_proj.weight)
        nn.init.zeros_(self.cond_proj.bias)
        # Bottleneck convs
        self.down = nn.Conv2d(d, d_bot, kernel_size=1)
        self.up   = nn.Conv2d(d_bot, d, kernel_size=1)

    def forward(self, x, cond):
        # x: (N, d, H, W), cond: (N, d_cond)
        gba = self.cond_proj(cond)                          # (N, 3*d)
        g, b, a = gba.chunk(3, dim=-1)
        g = g.view(-1, x.shape[1], 1, 1)
        b = b.view(-1, x.shape[1], 1, 1)
        a = a.view(-1, x.shape[1], 1, 1)
        h = self.norm(x)
        h = h * (1.0 + g) + b
        h = self.down(h)
        h = F.gelu(h)
        h = self.up(h)
        return x + a * h


class DinoVolumeFiLM(nn.Module):
    def __init__(self,
                 n_window=N_WINDOW, n_height_bins=N_HEIGHT_BINS,
                 n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
                 d_feat=D_FEAT, d_bot=D_BOT, d_cond=D_COND, n_blocks=N_BLOCKS,
                 image_size=IMG_SIZE, pred_size=PRED_SIZE,
                 freeze_backbone=False, use_checkpoint=True):
        super().__init__()
        self.n_window       = n_window
        self.n_height_bins  = n_height_bins
        self.n_gripper_bins = n_gripper_bins
        self.n_rot_bins     = n_rot_bins
        self.d_feat         = d_feat
        self.image_size     = image_size
        self.pred_size      = pred_size
        self.use_checkpoint = use_checkpoint

        # DINO backbone
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus',
                                    source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
        self.embed_dim = self.dino.embed_dim

        # Refine + project DINO features down to d_feat=32
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(self.embed_dim, d_feat, kernel_size=1),
        )

        # Sinusoidal PE buffers
        self.register_buffer("t_sin", sinusoidal_features(n_window, d_cond), persistent=False)
        self.register_buffer("z_sin", sinusoidal_features(n_height_bins, d_cond), persistent=False)

        # AdaLN-Zero blocks
        self.blocks = nn.ModuleList([
            AdaLNZeroBlock(d_feat, d_bot, d_cond) for _ in range(n_blocks)
        ])

        # Final heatmap readout: per-voxel scalar logit
        self.heatmap_head = nn.Conv2d(d_feat, 1, kernel_size=1)
        # Gripper / rotation linears on the MLP penultimate (d_feat) at argmax voxel
        self.gripper_head  = nn.Linear(d_feat, n_gripper_bins)
        self.rotation_head = nn.Linear(d_feat, 3 * n_rot_bins)

    def _extract_dino_features(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope)
        if self.dino.untie_cls_and_patch_norms:
            x_cls = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
            x_pat = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
            x_tokens = torch.cat([x_cls, x_pat], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        patch = x_tokens[:, self.dino.n_storage_tokens + 1:]
        return patch.reshape(B, H_p, W_p, self.embed_dim).permute(0, 3, 1, 2).contiguous()

    def forward(self, rgb, kp_zyx=None):
        """rgb: (B, 3, IMG, IMG). kp_zyx: (B, T, 3) long — (z_bin, y_grid, x_grid)
        at GT during training (teacher forcing), volume argmax at inference.
        """
        B = rgb.shape[0]
        T, Z, d = self.n_window, self.n_height_bins, self.d_feat

        # Trunk: DINO → upsample to pred_size → refine to d_feat
        patch = self._extract_dino_features(rgb)                       # (B, embed, H_p, W_p)
        feat_up = F.interpolate(patch, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        F_feat = self.refine(feat_up)                                  # (B, d_feat, H, W)
        H, W = F_feat.shape[-2:]

        # Per-(t,z) conditioning: cond_tz = sin(t) + sin(z), broadcast to (B,T,Z,d_cond) → flat
        cond_tz = self.t_sin.view(T, 1, -1) + self.z_sin.view(1, Z, -1)    # (T, Z, d_cond)
        cond_flat = cond_tz.view(1, T, Z, -1).expand(B, T, Z, -1).reshape(B * T * Z, -1).contiguous()

        # Broadcast F across (T, Z): expand is a view, .contiguous() materialises for conv
        # Memory ≈ B*T*Z * d_feat * H * W * 4B
        F_exp = F_feat.view(B, 1, 1, d, H, W).expand(B, T, Z, d, H, W)
        x = F_exp.reshape(B * T * Z, d, H, W).contiguous()

        # Apply AdaLN-Zero blocks (with gradient checkpointing on each block to fit memory)
        for blk in self.blocks:
            if self.use_checkpoint and self.training:
                x = ckpt(blk, x, cond_flat, use_reentrant=False)
            else:
                x = blk(x, cond_flat)

        penult = x                                                      # (N, d, H, W) — for grip/rot head

        # Heatmap readout
        logit = self.heatmap_head(x)                                    # (N, 1, H, W)
        volume_logits = logit.view(B, T, Z, H, W)

        out = {"volume_logits": volume_logits, "pixel_feats": F_feat}

        # Gripper/rotation from penultimate at GT/argmax voxel per (b, t)
        if kp_zyx is not None:
            z_idx = kp_zyx[..., 0].clamp(0, Z - 1)                      # (B, T)
            y_idx = kp_zyx[..., 1].clamp(0, H - 1)
            x_idx = kp_zyx[..., 2].clamp(0, W - 1)
            # penult flat index: b*T*Z + t*Z + z_idx
            t_idx = torch.arange(T, device=x.device).view(1, T).expand(B, T)
            b_idx = torch.arange(B, device=x.device).view(B, 1).expand(B, T)
            flat  = (b_idx * (T * Z) + t_idx * Z + z_idx).flatten()     # (B*T,)
            yf, xf = y_idx.flatten(), x_idx.flatten()
            sampled = penult[flat, :, yf, xf].reshape(B, T, d)          # (B, T, d_feat)
            out["gripper_logits"]  = self.gripper_head(sampled)         # (B, T, n_grip)
            rot = self.rotation_head(sampled).view(B, T, 3, self.n_rot_bins)
            out["rotation_logits"] = rot
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeFiLM(n_window=48).to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    kp = torch.stack([
        torch.randint(0, 32, (2, 48)),
        torch.randint(0, 56, (2, 48)),
        torch.randint(0, 56, (2, 48)),
    ], dim=-1).to(device)
    with torch.no_grad():
        out = m(rgb, kp)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
