"""DINOv3 + EEF-patch attention as heatmap.

Per Cameron 2026-05-18: instead of vanilla heatmap prediction via 1×1 conv, have the
EEF-projecting patch ATTEND to all other patches and interpret the attention output
as the heatmap response. Architectural shift: the heatmap IS the attention map, not
a separate dense prediction.

Forward outline:
  1. DINOv3 forward → patch tokens P ∈ R^(B × N × D), N = grid × grid.
  2. Compute current EEF patch index from `start_pixel` (passed in or = first GT pixel
     at training time). Extract that token p_eef ∈ R^(B × D).
  3. Per timestep t, the query is q_t = q_proj(p_eef) + t_query[t], where t_query is a
     learnable embedding ∈ R^(T × D). Keys/values are the full patch token stack.
  4. Attention scores s_t = (q_t · K^T) / √D ∈ R^(B × N). These are the per-timestep
     heatmap logits, reshaped to (h, w).
  5. For the volume formulation we ALSO need a height distribution per pixel. We add a
     small per-pixel MLP that, conditioned on the patch token AND timestep t, outputs
     N_HEIGHT_BINS logits. So the joint volume logit is:
        vol[b, t, z, h, w] = score_t[b, h, w] + height_t_pixel[b, t, z, h, w]
     This separates the 2D attention from the height; each timestep still gets its own
     z distribution, conditioned on the visited patch.

Inputs:  rgb (B, 3, IMG, IMG) in [0, 1], start_pixel (B, 2) in 504-space (training: GT
         current EEF; inference: predicted current pixel from prior step or a heuristic).
Outputs: volume_logits (B, T, Z, h_out, w_out)
"""
import os, sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH",
                                    "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMG_SIZE        = 448
N_WINDOW        = 8
N_HEIGHT_BINS   = 32
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


class DinoEefAttnModel(nn.Module):
    def __init__(self, n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 image_size: int = IMG_SIZE, n_heads: int = 4,
                 attn_pred_upsample: int = 2):
        super().__init__()
        self.n_window      = n_window
        self.n_height_bins = n_height_bins
        self.image_size    = image_size
        self.patch_size    = DINO_PATCH_SIZE
        self.grid          = image_size // DINO_PATCH_SIZE     # 28
        self.pred_size     = self.grid * attn_pred_upsample    # 56

        if DINO_REPO_DIR not in sys.path:
            sys.path.insert(0, DINO_REPO_DIR)
        self.dino = torch.hub.load(DINO_REPO_DIR, "dinov3_vits16plus",
                                    source="local", weights=DINO_WEIGHTS_PATH)
        self.embed_dim = getattr(self.dino, "embed_dim", 384)
        D = self.embed_dim

        # Per-timestep learnable query bias (added to the EEF-patch token before projection).
        self.t_query = nn.Parameter(torch.zeros(n_window, D))
        nn.init.trunc_normal_(self.t_query, std=0.02)

        # Q/K projections (small, low-rank-ish). Multi-head attention to get richer scores.
        assert D % n_heads == 0
        self.n_heads  = n_heads
        self.head_dim = D // n_heads
        self.q_proj = nn.Linear(D, D, bias=False)
        self.k_proj = nn.Linear(D, D, bias=False)
        # We reduce per-head scores to per-pixel logits by summing across heads (each head
        # learns a different "aspect" of the attention; sum is the unnormalized score).

        # Per-pixel height head: small MLP applied to each patch token, conditioned on t.
        # Output (T, Z) channels per pixel. Implemented as a linear over (D + T-onehot).
        self.height_head = nn.Sequential(
            nn.Linear(D + n_window, 128), nn.GELU(),
            nn.Linear(128, n_height_bins),
        )

        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer("std",  torch.tensor(IMAGENET_STD ).view(1, 3, 1, 1), persistent=False)

    def _normalize(self, rgb01):
        return (rgb01 - self.mean) / self.std

    def _gather_eef_token(self, patch_tokens, start_pixel):
        """patch_tokens: (B, N, D); start_pixel: (B, 2) in image_size coords; returns (B, D)."""
        B, N, D = patch_tokens.shape
        g = self.grid
        u = (start_pixel[:, 0] / self.image_size * g).long().clamp(0, g - 1)
        v = (start_pixel[:, 1] / self.image_size * g).long().clamp(0, g - 1)
        idx = (v * g + u).view(B, 1, 1).expand(B, 1, D)                       # (B, 1, D)
        return patch_tokens.gather(1, idx).squeeze(1)                         # (B, D)

    def forward(self, rgb, start_pixel_504):
        """rgb: (B, 3, *, *) in [0, 1]. start_pixel_504: (B, 2) GT current EEF pixel in 504-space."""
        B = rgb.shape[0]
        in_size = rgb.shape[-1]
        if in_size != self.image_size:
            rgb = F.interpolate(rgb, size=(self.image_size, self.image_size),
                                mode='bilinear', align_corners=False)
        # Rescale start pixel from input-image size to model-image size.
        start_pixel = start_pixel_504.float() * (self.image_size / in_size)
        x = self._normalize(rgb)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats = self.dino.forward_features(x)
        if isinstance(feats, dict):
            patch_tokens = feats.get("x_norm_patchtokens", feats.get("x_prenorm"))
        else:
            patch_tokens = feats
        patch_tokens = patch_tokens.to(torch.float32)
        N = patch_tokens.shape[1]
        D = self.embed_dim
        h = w = self.grid

        # EEF patch token (B, D)
        p_eef = self._gather_eef_token(patch_tokens, start_pixel)             # (B, D)

        # Per-timestep query: (B, T, D)
        q_t = self.q_proj(p_eef).unsqueeze(1) + self.t_query.unsqueeze(0)     # (B, T, D)
        k_t = self.k_proj(patch_tokens)                                       # (B, N, D)

        # Multi-head split
        Bn, T, _ = q_t.shape
        q_t = q_t.view(B, T, self.n_heads, self.head_dim).permute(0, 2, 1, 3)     # (B, H, T, d)
        k_t = k_t.view(B, N, self.n_heads, self.head_dim).permute(0, 2, 3, 1)     # (B, H, d, N)
        # Scores per head, summed across heads → (B, T, N)
        scores = torch.einsum("bhtd, bhdn -> bhtn", q_t, k_t) / math.sqrt(self.head_dim)
        scores = scores.sum(dim=1)                                            # (B, T, N)

        # Reshape to (B, T, h, w); upsample to pred_size
        scores_2d = scores.view(B, T, h, w)
        scores_2d = F.interpolate(scores_2d, size=(self.pred_size, self.pred_size),
                                   mode='bilinear', align_corners=False)        # (B, T, ph, pw)

        # Per-pixel height head: per-patch token + t-onehot → (T, Z) per pixel
        # Build pixel features at pred_size via bilinear upsample of patch tokens.
        feat_2d = patch_tokens.permute(0, 2, 1).reshape(B, D, h, w)
        feat_2d = F.interpolate(feat_2d, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)         # (B, D, ph, pw)
        # Expand to (B, T, ph, pw, D + T-onehot) lazily via broadcasting in MLP.
        # Flatten spatial: (B, ph*pw, D)
        feat_flat = feat_2d.permute(0, 2, 3, 1).reshape(B, -1, D)              # (B, ph*pw, D)
        # Per-timestep loop is slow; do it vectorised with t-onehot tiling.
        t_onehot = torch.eye(T, device=rgb.device, dtype=feat_flat.dtype)      # (T, T)
        # For each timestep concat t_onehot[t] to every pixel feature.
        # f_tiled: (B, T, ph*pw, D + T)
        feat_tiled = feat_flat.unsqueeze(1).expand(B, T, -1, D)
        toh_tiled  = t_onehot.view(1, T, 1, T).expand(B, T, feat_flat.shape[1], T)
        joint = torch.cat([feat_tiled, toh_tiled], dim=-1)                     # (B, T, ph*pw, D+T)
        z_logits = self.height_head(joint)                                     # (B, T, ph*pw, Z)
        z_logits = z_logits.view(B, T, self.pred_size, self.pred_size, self.n_height_bins)
        z_logits = z_logits.permute(0, 1, 4, 2, 3)                             # (B, T, Z, ph, pw)

        # Joint volume logits: scores_2d acts on (h, w); z_logits adds the per-z component
        # vol[b, t, z, h, w] = scores_2d[b, t, h, w] + z_logits[b, t, z, h, w]
        vol = scores_2d.unsqueeze(2) + z_logits                                # (B, T, Z, ph, pw)

        return {
            "volume_logits":     vol,
            "attn_scores_2d":    scores_2d,    # (B, T, ph, pw) — the "heatmap = attention" output
            "pred_depth":        None,
            "pixel_feats":       feat_2d,
            "dino_feats":        [patch_tokens],
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoEefAttnModel().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    sp  = torch.tensor([[200., 200.], [300., 300.]]).to(device)
    with torch.no_grad():
        out = m(rgb, sp)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
        elif isinstance(v, (list, tuple)):
            print(f"  {k}: list({len(v)})")
            if v and hasattr(v[0], 'shape'): print(f"    first: {tuple(v[0].shape)}")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
