"""DINO + per-pixel 5-layer residual MLP that regresses all T-step bins.

Baseline for the query-MLP design. Same DINO trunk, same F refinement to d_feat,
but the head is a per-pixel 5-layer residual MLP (1×1 conv stack) that at each
pixel outputs the entire flattened-in-time prediction:
  - volume:   T * Z         (height bin per t)
  - gripper:  T * n_grip
  - rotation: T * 3 * n_rot

Volume is materialised densely (B, T, Z, H, W) — needed for the argmax heatmap.
For gripper/rotation: the dense (T * n_grip, H, W) etc. tensors would be huge,
so we run the conv head only at the per-(b, t) GT pixel (sample penult first,
then small linear). The trained weights are equivalent to a dense 1×1 conv —
this is just a memory-efficient evaluation.

Per Cameron 2026-05-20: "produce the same F feature map, and regress the height
bins rotation bins and gripper bins, flattened in time, as a 5 layer residual
mlp on top of each pixel feature (no volume)".
"""
import os, math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

IMG_SIZE       = 504
N_WINDOW       = 50
N_HEIGHT_BINS  = 32
N_GRIPPER_BINS = 32
N_ROT_BINS     = 32
PRED_SIZE      = 56

D_FEAT   = 32
MLP_RATIO = 4
N_BLOCKS = 5


class ResMLPBlock(nn.Module):
    """Per-pixel residual MLP block, expressed as 1×1 conv stack.
       Mathematically: at each pixel, x → LN(x) → Linear(d → 4d) → GELU → Linear(4d → d) → +x."""
    def __init__(self, d, mlp_ratio=MLP_RATIO):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=1, num_channels=d, affine=True)
        self.fc1 = nn.Conv2d(d, mlp_ratio * d, kernel_size=1)
        self.fc2 = nn.Conv2d(mlp_ratio * d, d, kernel_size=1)

    def forward(self, x):
        h = self.norm(x)
        h = self.fc1(h)
        h = F.gelu(h)
        h = self.fc2(h)
        return x + h


class DinoPerPixelMLP(nn.Module):
    def __init__(self,
                 n_window=N_WINDOW, n_height_bins=N_HEIGHT_BINS,
                 n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
                 d_feat=D_FEAT, n_blocks=N_BLOCKS, mlp_ratio=MLP_RATIO,
                 image_size=IMG_SIZE, pred_size=PRED_SIZE,
                 freeze_backbone=False):
        super().__init__()
        self.n_window       = n_window
        self.n_height_bins  = n_height_bins
        self.n_gripper_bins = n_gripper_bins
        self.n_rot_bins     = n_rot_bins
        self.d_feat         = d_feat
        self.image_size     = image_size
        self.pred_size      = pred_size

        # DINO backbone
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus',
                                    source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
        self.embed_dim = self.dino.embed_dim

        # Refine DINO patch tokens to F (d_feat at pred_size)
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(self.embed_dim, d_feat, kernel_size=1),
        )

        # 5-layer residual MLP applied per pixel (as 1×1 conv stack)
        self.blocks = nn.ModuleList([ResMLPBlock(d_feat, mlp_ratio) for _ in range(n_blocks)])
        self.final_norm = nn.GroupNorm(num_groups=1, num_channels=d_feat, affine=True)

        # Per-pixel heads, each producing all T-step bins (flattened in time):
        #   volume:   T * Z
        #   gripper:  T * n_grip
        #   rotation: T * 3 * n_rot
        # Implemented as 1×1 conv for the volume (dense) and Linear for grip/rot
        # (evaluated only at sampled query pixels for memory).
        self.volume_head   = nn.Conv2d(d_feat, n_window * n_height_bins, kernel_size=1)
        self.gripper_head  = nn.Linear(d_feat, n_window * n_gripper_bins)
        self.rotation_head = nn.Linear(d_feat, n_window * 3 * n_rot_bins)

    def _extract_dino_features(self, x):
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope)
        if self.dino.untie_cls_and_patch_norms:
            x_cls = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
            x_pat = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
            x_tokens = torch.cat([x_cls, x_pat], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        patch = x_tokens[:, self.dino.n_storage_tokens + 1:]
        return patch.reshape(B, H_p, W_p, self.embed_dim).permute(0, 3, 1, 2).contiguous()

    def forward(self, rgb, start_pix=None, query_pixels=None, **kwargs):
        """
        rgb: (B, 3, IMG, IMG)
        start_pix: unused here (kept for interface parity with query model)
        query_pixels: (B, T, 2) of (y_grid, x_grid) in pred_size coords — per-timestep
                       GT pixels at training, volume argmax at inference. If None, grip/rot
                       not returned.
        """
        B = rgb.shape[0]
        T = self.n_window
        Z = self.n_height_bins

        patch = self._extract_dino_features(rgb)
        feat_up = F.interpolate(patch, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        F_feat = self.refine(feat_up)                                        # (B, d, H, W)

        # 5-layer per-pixel residual MLP
        h = F_feat
        for blk in self.blocks:
            h = blk(h)
        h = self.final_norm(h)                                                # (B, d, H, W)

        # Volume head — dense per-pixel
        vol = self.volume_head(h)                                             # (B, T*Z, H, W)
        H, W = vol.shape[-2:]
        volume_logits = vol.view(B, T, Z, H, W)

        out = {"volume_logits": volume_logits, "pixel_feats": F_feat}

        if query_pixels is not None:
            # Sample penult at each per-(b, t) query pixel
            qy = query_pixels[..., 0].long().clamp(0, H - 1)                  # (B, T)
            qx = query_pixels[..., 1].long().clamp(0, W - 1)
            b_idx = torch.arange(B, device=h.device).view(B, 1).expand(B, T)
            sampled = h[b_idx, :, qy, qx]                                     # (B, T, d)

            # Heads produce all T-step predictions per query
            grip_all = self.gripper_head(sampled).view(B, T, T, self.n_gripper_bins)
            rot_all  = self.rotation_head(sampled).view(B, T, T, 3, self.n_rot_bins)

            # Each query at step t corresponds to predicting timestep t — take diagonal
            t_idx = torch.arange(T, device=h.device).view(1, T).expand(B, T)
            out["gripper_logits"]  = grip_all[b_idx, t_idx, t_idx]            # (B, T, n_grip)
            out["rotation_logits"] = rot_all [b_idx, t_idx, t_idx]            # (B, T, 3, n_rot)
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoPerPixelMLP(n_window=50).to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    qp  = torch.randint(0, PRED_SIZE, (2, 50, 2)).to(device)
    with torch.no_grad():
        out = m(rgb, query_pixels=qp)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
