"""Minimal volume AR model per Cameron's spec (2026-05-17).

Pipeline:
  rgb 448²  → DINO (frozen) 28² patches D=384
            → bilinear up 64² → 1×1 conv MLP → 64² × 32D image features
  voxels (32³ = 32,768) + 20 past EEFs + current EEF, all projected to image pixel via cam,
            grid_sample → per-token 32D image feature
  per-token PE: sincos(xyz - current_eef) → 2-layer MLP → 32D
  token feature = image_feature + PE_feature + type_embed
  KV pool = 20 past + 1 current EEF tokens (21 tokens)
  Q       = 32k voxel tokens
  4× cross-attention layers (Q ← KV) with 1×1 conv between (no FFN, "really cheap")
  final 1×1 conv: 32 → 8 timestep logits per voxel
  per-timestep argmax voxel → MLP → grip logit + (3,32) rot bins
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from robot_volume import (
    voxel_centers_world, world_to_pixel_torch, pixel_to_normalized_grid,
    sincos_pe_3d, PE_DIM, N_PAST_EEF, T_FUTURE, N_ROT_BINS, IMAGE_SIZE, N_VOX,
)

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE   = 16
UPSAMPLE_RES      = 64
TOKEN_D           = 32
N_HEADS           = 4
N_LAYERS          = 4


class CrossAttn(nn.Module):
    """One cross-attention layer (Q ← KV), pre-LN, no FFN. Cameron's 'really cheap'."""
    def __init__(self, d=TOKEN_D, heads=N_HEADS):
        super().__init__()
        self.ln_q  = nn.LayerNorm(d)
        self.ln_kv = nn.LayerNorm(d)
        self.attn  = nn.MultiheadAttention(d, heads, batch_first=True)

    def forward(self, q, kv):
        a, _ = self.attn(self.ln_q(q), self.ln_kv(kv), self.ln_kv(kv), need_weights=False)
        return q + a


class VolumeARModel(nn.Module):
    def __init__(self, n_past=N_PAST_EEF, t_future=T_FUTURE, freeze_backbone=True):
        super().__init__()
        self.n_past   = n_past
        self.t_future = t_future

        # DINO backbone (frozen).
        print(f"Loading DINOv3 (frozen={freeze_backbone})...")
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
            self.dino.eval()
        self.dino_d = self.dino.embed_dim  # 384
        self.freeze_backbone = freeze_backbone

        # Image feature MLP: 384 → 32 → 32 (per-pixel; implemented as 1×1 conv).
        self.image_mlp = nn.Sequential(
            nn.Conv2d(self.dino_d, TOKEN_D, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(TOKEN_D, TOKEN_D, kernel_size=1),
        )

        # Positional-encoding MLP (2 layers → TOKEN_D).
        self.pe_mlp = nn.Sequential(
            nn.Linear(PE_DIM, TOKEN_D),
            nn.GELU(),
            nn.Linear(TOKEN_D, TOKEN_D),
        )

        # Type embeddings (0=voxel, 1=past EEF, 2=current EEF).
        self.type_embed = nn.Embedding(3, TOKEN_D)

        # 4 cross-attention layers with 1×1 conv between (no FFN).
        self.attn_layers   = nn.ModuleList([CrossAttn(TOKEN_D, N_HEADS) for _ in range(N_LAYERS)])
        self.between_conv  = nn.ModuleList([nn.Linear(TOKEN_D, TOKEN_D) for _ in range(N_LAYERS)])

        # Final per-voxel head: 32 → 8 timestep logits.
        self.final = nn.Linear(TOKEN_D, t_future)

        # Per-timestep grip + rot heads (operate on the argmax-voxel's feature).
        self.grip_head = nn.Linear(TOKEN_D, 1)
        self.rot_head  = nn.Linear(TOKEN_D, 3 * N_ROT_BINS)

        # Precompute voxel centers in world coords (static, V × 3).
        self.register_buffer("voxel_centers", voxel_centers_world(), persistent=False)

    # ── DINO patch extraction (frozen, no grad) ──────────────────────
    def _dino_patches(self, x):
        """x: (B, 3, H, W) → (B, 384, 28, 28)."""
        if self.freeze_backbone:
            with torch.no_grad():
                tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
                for blk in self.dino.blocks:
                    rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                    tokens = blk(tokens, rope)
                if self.dino.untie_cls_and_patch_norms:
                    cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                    pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                    tokens = torch.cat([cls_n, pat_n], dim=1)
                else:
                    tokens = self.dino.norm(tokens)
                p = tokens[:, self.dino.n_storage_tokens + 1 :].detach()
            B = p.shape[0]
            return p.reshape(B, H_p, W_p, self.dino_d).permute(0, 3, 1, 2)
        # trainable path elided for brevity (we always run frozen here)
        raise NotImplementedError

    def forward(self, rgb, past_eef_world, current_eef_world, world_to_camera,
                target_voxel_idx=None):
        """
        rgb:                 (B, 3, 448, 448)
        past_eef_world:      (B, N=20, 3) — world frame; index 19 = most recent
        current_eef_world:   (B, 3)       — world frame (= past_eef_world[:, -1])
        world_to_camera:     (B, 4, 4)    — robosuite-style world→pixel matrix
        target_voxel_idx:    (B, T=8) or None  — for teacher-forced grip/rot heads
        """
        B = rgb.shape[0]
        device = rgb.device

        # 1. DINO patches → upsample → per-pixel MLP → (B, 32, 64, 64)
        patches = self._dino_patches(rgb)                                              # (B, 384, 28, 28)
        feats   = F.interpolate(patches, size=(UPSAMPLE_RES, UPSAMPLE_RES),
                                mode='bilinear', align_corners=False)
        feats   = self.image_mlp(feats)                                                 # (B, 32, 64, 64)

        # 2. Project all voxels + past EEFs + current EEF to image pixels.
        vox_world = self.voxel_centers.unsqueeze(0).expand(B, -1, -1)                   # (B, V, 3)
        vox_pix  = world_to_pixel_torch(vox_world, world_to_camera)                     # (B, V, 2) (u, v)
        past_pix = world_to_pixel_torch(past_eef_world, world_to_camera)                # (B, N, 2)
        cur_pix  = world_to_pixel_torch(current_eef_world.unsqueeze(1), world_to_camera).squeeze(1)  # (B, 2)

        # 3. grid_sample to fetch image features at each projected point. Single vectorized call
        #    per group (voxels / past EEFs / current EEF).
        def sample(pix_uv):  # pix_uv: (B, M, 2) in pixels
            grid = pixel_to_normalized_grid(pix_uv, IMAGE_SIZE).unsqueeze(2)            # (B, M, 1, 2)
            s = F.grid_sample(feats, grid, mode='bilinear', align_corners=False,
                              padding_mode='zeros')                                     # (B, 32, M, 1)
            return s.squeeze(-1).permute(0, 2, 1)                                       # (B, M, 32)
        vox_img  = sample(vox_pix)
        past_img = sample(past_pix)
        cur_img  = sample(cur_pix.unsqueeze(1)).squeeze(1)                              # (B, 32)

        # 4. Relative positional encoding: subtract current EEF in world coords.
        ce = current_eef_world.unsqueeze(1)                                             # (B, 1, 3)
        vox_rel  = vox_world - ce                                                       # (B, V, 3)
        past_rel = past_eef_world - ce                                                  # (B, N, 3)
        cur_rel  = torch.zeros(B, 1, 3, device=device, dtype=rgb.dtype)                 # (B, 1, 3)

        vox_pe  = self.pe_mlp(sincos_pe_3d(vox_rel))                                    # (B, V, 32)
        past_pe = self.pe_mlp(sincos_pe_3d(past_rel))                                   # (B, N, 32)
        cur_pe  = self.pe_mlp(sincos_pe_3d(cur_rel))                                    # (B, 1, 32)

        # 5. Token features = image_feature + PE + type_embed
        type_vox = self.type_embed(torch.zeros(B, vox_world.size(1), dtype=torch.long, device=device))
        type_past = self.type_embed(torch.ones(B, past_eef_world.size(1), dtype=torch.long, device=device))
        type_cur  = self.type_embed(torch.full((B, 1), 2, dtype=torch.long, device=device))

        vox_tokens   = vox_img  + vox_pe  + type_vox                                    # (B, V, 32)
        past_tokens  = past_img + past_pe + type_past                                   # (B, N, 32)
        cur_token    = cur_img.unsqueeze(1) + cur_pe + type_cur                         # (B, 1, 32)

        kv = torch.cat([past_tokens, cur_token], dim=1)                                  # (B, N+1=21, 32)

        # 6. Cross-attention stack: Q = voxel tokens, KV = past+current EEF tokens.
        x = vox_tokens
        for attn, lin in zip(self.attn_layers, self.between_conv):
            x = attn(x, kv)                                                             # (B, V, 32)
            x = lin(x)                                                                  # 1×1 conv (per-token linear)

        # 7. Per-voxel per-timestep logits: (B, V, 32) → (B, V, 8)
        logits_v_t = self.final(x)                                                      # (B, V, T)

        # 8. Per-timestep voxel argmax (or teacher-forced index for grip/rot heads).
        if target_voxel_idx is None:
            pred_voxel_idx = logits_v_t.argmax(dim=1)                                   # (B, T)
        else:
            pred_voxel_idx = target_voxel_idx

        # 9. Gather features at the (teacher-forced or argmax) voxel per timestep.
        idx_expand = pred_voxel_idx.unsqueeze(-1).expand(-1, -1, TOKEN_D)               # (B, T, 32)
        timestep_feat = x.gather(1, idx_expand)                                         # (B, T, 32)
        grip_logit = self.grip_head(timestep_feat).squeeze(-1)                          # (B, T)
        rot_logits = self.rot_head(timestep_feat).reshape(B, self.t_future, 3, N_ROT_BINS)

        return {
            "voxel_logits":   logits_v_t,      # (B, V, T)
            "grip_logit":     grip_logit,       # (B, T)
            "rot_logits":     rot_logits,       # (B, T, 3, 32)
            "pred_voxel_idx": pred_voxel_idx,   # (B, T)
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = VolumeARModel().to(device)
    n_train = sum(p.numel() for p in m.parameters() if p.requires_grad)
    n_tot   = sum(p.numel() for p in m.parameters())
    print(f"Trainable: {n_train:,} / {n_tot:,}")
    B = 2
    rgb = torch.randn(B, 3, 448, 448).to(device)
    past_eef = torch.randn(B, N_PAST_EEF, 3).to(device) * 0.2 + torch.tensor([0., 0., 1.0]).to(device)
    cur_eef  = past_eef[:, -1]
    w2c = torch.eye(4).unsqueeze(0).expand(B, 4, 4).to(device)
    out = m(rgb, past_eef, cur_eef, w2c)
    print("voxel_logits:", out["voxel_logits"].shape,
          "grip:", out["grip_logit"].shape,
          "rot:", out["rot_logits"].shape)
    print("pred_voxel_idx:", out["pred_voxel_idx"].shape)
    print(f"peak memory: {torch.cuda.max_memory_allocated()/1e9:.2f} GB" if device.type == 'cuda' else '')
