"""DINO + per-timestep query MLP with AdaLN-Zero(t) conditioning.

Cameron's redesign (2026-05-20): keep the volume formulation but move *all* the
nonlinearity into the per-timestep query, so the spatial scoring stays a cheap
dot product. Architecturally this is cross-attention from per-timestep query
tokens to the (factored) volume features.

Computation graph
─────────────────
  rgb → DINO → patch tokens + cls
      → 1×1 conv refine → F (B, d_feat, H, W)            # spatial feature map
      → cls token        (B, embed_dim)

  eef_feat = F[b, :, y_eef, x_eef]                        # current EEF feature
  q_input  = Linear(concat(eef_feat, cls)) → (B, d_model)
  # 5-layer residual MLP with AdaLN-Zero on sin(t), applied per timestep
  # (B, T) copies of the same input; AdaLN(t) differentiates per-step
  penult = MLP_with_AdaLN_t(q_input)                       # (B, T, d_model)

  q_F, q_z, q_t = split(q_head(penult), [d_feat, d_sin_z, d_sin_t])
  gripper       = grip_head(penult)                        # (B, T, n_grip)
  rotation      = rot_head(penult)                         # (B, T, n_rot) — 1D PCA

  # Spatial scoring: dot product of q with the *implicit* volume
  # V[b, t, z, y, x] = concat(F[y,x], sin_z[z], sin_t[t]) — never materialised.
  # The concat structure lets the dot product factor:
  score_yx = einsum('btc, bchw -> bthw', q_F, F)           # (B, T, H, W)
  score_z  = einsum('btc, zc -> btz',    q_z, z_sin)       # (B, T, Z)
  score_t  = einsum('btc, tc -> bt',     q_t, t_sin)       # (B, T)  — constant per (b,t)

  volume_logits = (score_yx[:, :, None] + score_z[..., None, None]
                                       + score_t[..., None, None, None])     # (B, T, Z, H, W)

The 6-D feature volume (B, T, Z, H, W, d) is never instantiated; only the
5-D scalar logit volume is, which is what the CE loss needs anyway.

Memory usage stays trunk-bound — the head adds <100 MB at any reasonable B,
because the MLP runs B*T times (not B*T*Z*H*W like the FiLM-per-voxel design).
"""
import os, math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

IMG_SIZE       = 504
N_WINDOW       = 50
N_HEIGHT_BINS  = 32
N_GRIPPER_BINS = 32
N_ROT_BINS     = 32        # per-axis euler bins (Cameron 2026-05-20: dropped 1D PCA)
PRED_SIZE      = 56

D_FEAT  = 32               # F's per-pixel dim (also q_F dim)
D_SINZ  = 16
D_SINT  = 16
D_MODEL = D_FEAT + D_SINZ + D_SINT      # = 64 — query/penultimate dim
D_COND  = 128              # AdaLN cond dim (sin(t) → Linear)
N_BLOCKS = 5               # 5-layer residual MLP per spec


def sinusoidal_features(n, dim, base=10000.0):
    pos = torch.arange(n, dtype=torch.float32)
    div = torch.exp(torch.arange(0, dim, 2, dtype=torch.float32) * -(math.log(base) / dim))
    pe = torch.zeros(n, dim)
    pe[:, 0::2] = torch.sin(pos.unsqueeze(1) * div)
    pe[:, 1::2] = torch.cos(pos.unsqueeze(1) * div)
    return pe


class AdaLNZeroMLPBlock(nn.Module):
    """DiT-style block on (N, d): LN → FiLM(γ,β) → MLP(d→4d→d) → +α·resid."""
    def __init__(self, d, d_cond, mlp_ratio=4):
        super().__init__()
        self.norm = nn.LayerNorm(d, elementwise_affine=False)
        self.cond_proj = nn.Linear(d_cond, 3 * d)
        nn.init.zeros_(self.cond_proj.weight)
        nn.init.zeros_(self.cond_proj.bias)
        self.mlp = nn.Sequential(
            nn.Linear(d, mlp_ratio * d),
            nn.GELU(),
            nn.Linear(mlp_ratio * d, d),
        )

    def forward(self, x, cond):
        g, b, a = self.cond_proj(cond).chunk(3, dim=-1)
        h = self.norm(x)
        h = h * (1.0 + g) + b
        h = self.mlp(h)
        return x + a * h


class DinoVolumeQuery(nn.Module):
    def __init__(self,
                 n_window=N_WINDOW, n_height_bins=N_HEIGHT_BINS,
                 n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
                 d_feat=D_FEAT, d_sin_z=D_SINZ, d_sin_t=D_SINT,
                 d_cond=D_COND, n_blocks=N_BLOCKS,
                 image_size=IMG_SIZE, pred_size=PRED_SIZE,
                 freeze_backbone=False, use_eef=True, rotation_mode='per_axis',
                 kmeans_n_clusters=0):
        super().__init__()
        self.n_window        = n_window
        self.n_height_bins   = n_height_bins
        self.n_gripper_bins  = n_gripper_bins
        self.n_rot_bins      = n_rot_bins
        self.d_feat          = d_feat
        self.d_sin_z         = d_sin_z
        self.d_sin_t         = d_sin_t
        self.d_model         = d_feat + d_sin_z + d_sin_t
        self.image_size      = image_size
        self.pred_size       = pred_size
        self.use_eef         = use_eef
        assert rotation_mode in ('per_axis', '1d_pca', 'kmeans')
        self.rotation_mode   = rotation_mode
        self.kmeans_n_clusters = kmeans_n_clusters

        # DINO backbone
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus',
                                    source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
        self.embed_dim = self.dino.embed_dim

        # Refine DINO patch tokens to F (d_feat at pred_size)
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(self.embed_dim, d_feat, kernel_size=1),
        )

        # Sinusoidal PE buffers used both inside V (volume key) and as t-cond signal
        self.register_buffer("z_sin", sinusoidal_features(n_height_bins, d_sin_z), persistent=False)
        self.register_buffer("t_sin", sinusoidal_features(n_window,      d_sin_t), persistent=False)
        # Project t_sin → d_cond for AdaLN conditioning
        self.t_cond_proj = nn.Linear(d_sin_t, d_cond)

        # Input projection: (eef_feat ⊕ cls) → d_model, or just cls if use_eef=False
        in_dim = (d_feat + self.embed_dim) if use_eef else self.embed_dim
        self.input_proj = nn.Linear(in_dim, self.d_model)

        # 5-layer AdaLN-Zero residual MLP
        self.blocks = nn.ModuleList([
            AdaLNZeroMLPBlock(self.d_model, d_cond) for _ in range(n_blocks)
        ])
        self.final_norm = nn.LayerNorm(self.d_model)

        # Three heads on the penultimate per-timestep representation
        self.q_head    = nn.Linear(self.d_model, self.d_model)       # spatial query
        self.grip_head = nn.Linear(self.d_model, n_gripper_bins)
        if rotation_mode == 'per_axis':
            rot_out_dim = 3 * n_rot_bins
        elif rotation_mode == '1d_pca':
            rot_out_dim = n_rot_bins
        else:  # kmeans
            assert kmeans_n_clusters > 0, "rotation_mode='kmeans' requires kmeans_n_clusters>0"
            rot_out_dim = kmeans_n_clusters
        self.rot_head  = nn.Linear(self.d_model, rot_out_dim)

    def _extract_dino_features(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope)
        if self.dino.untie_cls_and_patch_norms:
            x_cls = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
            x_pat = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
            x_tokens = torch.cat([x_cls, x_pat], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        cls = x_tokens[:, 0]                                          # (B, embed)
        patch = x_tokens[:, self.dino.n_storage_tokens + 1:]
        patch = patch.reshape(B, H_p, W_p, self.embed_dim).permute(0, 3, 1, 2).contiguous()
        return patch, cls

    def forward(self, rgb, start_pix, kp_zyx=None):
        """rgb: (B, 3, IMG, IMG). start_pix: (B, 2) — current EEF pixel in IMG-coords.
           kp_zyx: unused (kept so the train loop's call signature stays uniform)."""
        B = rgb.shape[0]
        T = self.n_window
        Z = self.n_height_bins
        d = self.d_model

        patch, cls = self._extract_dino_features(rgb)                         # (B, embed, H_p, W_p), (B, embed)
        feat_up = F.interpolate(patch, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        F_feat = self.refine(feat_up)                                          # (B, d_feat, H, W)
        H, W = F_feat.shape[-2:]

        # Query input: (eef_feat ⊕ cls) or cls only → d_model, broadcast across T
        if self.use_eef:
            sx = (start_pix[..., 0] * (W / self.image_size)).long().clamp(0, W - 1)
            sy = (start_pix[..., 1] * (H / self.image_size)).long().clamp(0, H - 1)
            b_idx = torch.arange(B, device=rgb.device)
            eef_feat = F_feat[b_idx, :, sy, sx]                                # (B, d_feat)
            q_in = self.input_proj(torch.cat([eef_feat, cls], dim=-1))         # (B, d_model)
        else:
            q_in = self.input_proj(cls)                                        # (B, d_model)
        q_in_bt = q_in.unsqueeze(1).expand(B, T, d).reshape(B * T, d)

        # AdaLN conditioning: sin(t) projected, broadcast across batch
        cond_t  = self.t_cond_proj(self.t_sin)                                  # (T, d_cond)
        cond_bt = cond_t.unsqueeze(0).expand(B, T, -1).reshape(B * T, -1)

        # 5-block residual MLP with AdaLN-Zero(t)
        h = q_in_bt
        for blk in self.blocks:
            h = blk(h, cond_bt)
        h = self.final_norm(h)
        penult = h.view(B, T, d)                                                # (B, T, d_model)

        # Heads
        q_spatial = self.q_head(penult)                                         # (B, T, d_model)
        gripper   = self.grip_head(penult)                                      # (B, T, n_grip)
        if self.rotation_mode == 'per_axis':
            rotation = self.rot_head(penult).view(B, T, 3, self.n_rot_bins)     # (B, T, 3, n_rot)
        else:  # 1d_pca or kmeans — both flat (B, T, K)
            rotation = self.rot_head(penult)

        # Spatial scoring (factored, no volume materialisation)
        q_F = q_spatial[..., :self.d_feat]                                      # (B, T, d_feat)
        q_z = q_spatial[..., self.d_feat:self.d_feat + self.d_sin_z]            # (B, T, d_sin_z)
        q_t = q_spatial[..., self.d_feat + self.d_sin_z:]                       # (B, T, d_sin_t)

        score_yx = torch.einsum('btc, bchw -> bthw', q_F, F_feat)               # (B, T, H, W)
        score_z  = torch.einsum('btc, zc   -> btz',  q_z, self.z_sin)           # (B, T, Z)
        score_t  = torch.einsum('btc, tc   -> bt',   q_t, self.t_sin)           # (B, T)

        volume_logits = (
            score_yx.unsqueeze(2)                                                # (B, T, 1, H, W)
            + score_z.unsqueeze(-1).unsqueeze(-1)                                # (B, T, Z, 1, 1)
            + score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)                  # (B, T, 1, 1, 1)
        )

        return {
            "volume_logits":   volume_logits,
            "gripper_logits":  gripper,
            "rotation_logits": rotation,
            "pixel_feats":     F_feat,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQuery(n_window=50).to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    sp  = torch.rand(2, 2).to(device) * IMG_SIZE
    with torch.no_grad():
        out = m(rgb, sp)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
