"""DINOv3 volume with REGRESSED abstract keys (Cameron 2026-05-19 spec).

Architecture:
  - DINOv3 backbone → per-pixel features F ∈ R^(B × key_dim × H × W)  (same as kv model)
  - Sample F at start-EEF pixel → f_start ∈ R^(B × key_dim)
  - Shared MLP trunk on f_start → 3 outputs:
        keys           (B, T, D_key)         — abstract per-timestep keys
        gripper_logits (B, T, n_grip)        — direct gripper-bin predictions
        rotation_logits (B, T, 3, n_rot)     — direct rotation-bin predictions
  - Volume values: V[b, t, z, u, v] = Linear_F(F[b,u,v]) + Linear_t(sin_t[t]) + Linear_z(sin_z[z])
                   (all D_key-dim)
  - Volume logits: keys[b, t] · V[b, t, z, u, v]  (computed efficiently via 3 einsums)

The gripper/rotation predictions DO NOT use the volume decoding path — they're regressed
directly from f_start. The volume decoding uses the same upstream f_start to produce keys.
Single shared representation; all three heads pull on it.
"""
import os, sys, math
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))
from model_dino_volume_kv import (DINO_REPO_DIR, DINO_WEIGHTS_PATH, IMG_SIZE,
                                   N_WINDOW, N_HEIGHT_BINS, KEY_DIM, IMAGENET_MEAN, IMAGENET_STD,
                                   DINO_PATCH_SIZE, sinusoidal_features)

D_KEY        = 32
PRED_GRID    = 56
N_ROT_BINS   = 32
N_GRIP_BINS  = 32
DA3_INPUT    = 504


class DinoVolumeRegressed(nn.Module):
    def __init__(self, n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 key_dim_volume: int = KEY_DIM, image_size: int = IMG_SIZE,
                 d_key: int = D_KEY, trunk_hidden: int = 512,
                 dino_variant: str = 'dinov3_vits16plus',
                 n_rot_bins: int = N_ROT_BINS, n_gripper_bins: int = N_GRIP_BINS):
        super().__init__()
        self.n_window       = n_window
        self.n_height_bins  = n_height_bins
        self.key_dim_volume = key_dim_volume
        self.image_size     = image_size
        self.d_key          = d_key
        self.n_rot_bins     = n_rot_bins
        self.n_gripper_bins = n_gripper_bins
        self.patch_size     = DINO_PATCH_SIZE
        self.grid           = image_size // DINO_PATCH_SIZE
        self.pred_size      = self.grid * 2

        # DINO backbone
        if DINO_REPO_DIR not in sys.path: sys.path.insert(0, DINO_REPO_DIR)
        variant_weights = {
            "dinov3_vits16plus": DINO_WEIGHTS_PATH,
            "dinov3_vitl16":     "/data/cameron/keygrip/dinov3/weights/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
        }
        w = variant_weights.get(dino_variant, DINO_WEIGHTS_PATH)
        self.dino = torch.hub.load(DINO_REPO_DIR, dino_variant, source="local", weights=w)
        self.embed_dim = getattr(self.dino, "embed_dim", 384)

        # Refine head producing the per-pixel volume feature F (key_dim_volume = 48 by default).
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, 192, 3, padding=1), nn.GELU(),
            nn.Conv2d(192, 192, 3, padding=1), nn.GELU(),
            nn.Conv2d(192, key_dim_volume, 1),
        )

        # F → D_key projection used in volume value computation
        self.feat_to_value = nn.Linear(key_dim_volume, d_key)
        # Sinusoidal positional features for time and height (fixed buffers).
        self.register_buffer("t_sin", sinusoidal_features(n_window, d_key), persistent=False)
        self.register_buffer("h_sin", sinusoidal_features(n_height_bins, d_key), persistent=False)
        # Learnable linear projections of the sin features (per Cameron's spec).
        self.t_sin_proj = nn.Linear(d_key, d_key)
        self.z_sin_proj = nn.Linear(d_key, d_key)

        # Shared MLP trunk from f_start
        self.start_trunk = nn.Sequential(
            nn.LayerNorm(key_dim_volume),
            nn.Linear(key_dim_volume, trunk_hidden), nn.GELU(),
            nn.Linear(trunk_hidden, trunk_hidden), nn.GELU(),
        )
        # 3 output heads off the trunk
        self.keys_head     = nn.Linear(trunk_hidden, n_window * d_key)
        self.gripper_head  = nn.Linear(trunk_hidden, n_window * n_gripper_bins)
        self.rotation_head = nn.Linear(trunk_hidden, n_window * 3 * n_rot_bins)

        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer("std",  torch.tensor(IMAGENET_STD ).view(1, 3, 1, 1), persistent=False)

    def _normalize(self, rgb01):
        return (rgb01 - self.mean) / self.std

    def _pixel_features(self, rgb):
        """rgb (B, 3, *, *) in [0, 1] → F (B, key_dim_volume, pred_size, pred_size)."""
        B = rgb.shape[0]
        if rgb.shape[-1] != self.image_size:
            rgb = F.interpolate(rgb, size=(self.image_size, self.image_size),
                                mode='bilinear', align_corners=False)
        x = self._normalize(rgb)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats = self.dino.forward_features(x)
        if isinstance(feats, dict):
            patch_tokens = feats.get("x_norm_patchtokens", feats.get("x_prenorm"))
        else:
            patch_tokens = feats
        patch_tokens = patch_tokens.to(torch.float32)
        D = patch_tokens.shape[-1]
        h = w = self.grid
        feat_2d = patch_tokens.permute(0, 2, 1).reshape(B, D, h, w)
        feat_2d = F.interpolate(feat_2d, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        return self.refine(feat_2d), patch_tokens                              # (B, C, ph, pw)

    def forward(self, rgb, start_pix_504):
        """rgb (B, 3, *, *) in [0, 1]; start_pix_504 (B, 2) GT current EEF pixel in 504-space.
        Returns:
          volume_logits   (B, T, Z, ph, pw)
          gripper_logits  (B, T, n_grip)
          rotation_logits (B, T, 3, n_rot)
          pixel_feats     (B, key_dim_volume, ph, pw)  for compatibility
        """
        B = rgb.shape[0]
        pixel_feats, dino_tokens = self._pixel_features(rgb)                   # (B, C, ph, pw)
        C, ph, pw = pixel_feats.shape[1:]
        T = self.n_window
        Z = self.n_height_bins

        # 1. Sample f_start at start_eef pixel
        sx = ph / DA3_INPUT; sy = ph / DA3_INPUT  # uniform scale (image_size square)
        gx = (start_pix_504[:, 0] * sx).long().clamp(0, pw - 1)
        gy = (start_pix_504[:, 1] * sy).long().clamp(0, ph - 1)
        f_start = pixel_feats[torch.arange(B, device=rgb.device), :, gy, gx]   # (B, C)

        # 2. Trunk + 3 heads
        h = self.start_trunk(f_start)                                          # (B, trunk_hidden)
        keys = self.keys_head(h).view(B, T, self.d_key)                        # (B, T, D_key)
        grip = self.gripper_head(h).view(B, T, self.n_gripper_bins)
        rot  = self.rotation_head(h).view(B, T, 3, self.n_rot_bins)

        # 3. Volume value components
        # F_proj: (B, D_key, ph, pw) — project pixel features into the key-space
        f_proj = self.feat_to_value(pixel_feats.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        t_proj = self.t_sin_proj(self.t_sin)                                   # (T, D_key)
        z_proj = self.z_sin_proj(self.h_sin)                                   # (Z, D_key)

        # 4. Volume logits = keys · (f_proj + t_proj + z_proj) by linearity
        #    term1[b,t,h,w]   = sum_d keys[b,t,d] * f_proj[b,d,h,w]
        #    term2[b,t]       = sum_d keys[b,t,d] * t_proj[t,d]      (broadcast over z,h,w)
        #    term3[b,t,z]     = sum_d keys[b,t,d] * z_proj[z,d]      (broadcast over h,w)
        term1 = torch.einsum("btd, bdhw -> bthw", keys, f_proj)                # (B, T, ph, pw)
        term2 = torch.einsum("btd, td -> bt",     keys, t_proj)                # (B, T)
        term3 = torch.einsum("btd, zd -> btz",    keys, z_proj)                # (B, T, Z)
        vol = (term1.unsqueeze(2)                              # (B, T, 1, ph, pw)
               + term2.view(B, T, 1, 1, 1)
               + term3.view(B, T, Z, 1, 1))
        # vol: (B, T, Z, ph, pw)

        return {
            "volume_logits":   vol,
            "gripper_logits":  grip,
            "rotation_logits": rot,
            "pixel_feats":     pixel_feats,
            "dino_feats":      [dino_tokens],
            "pred_depth":      None,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeRegressed().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    sp = torch.tensor([[200., 200.], [300., 250.]]).to(device)
    with torch.no_grad():
        out = m(rgb, sp)
    for k, v in out.items():
        if hasattr(v, 'shape'): print(f"  {k}: {tuple(v.shape)}")
    if device.type == 'cuda':
        print(f"peak mem: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
