"""SmoothVolumeARModel v2 — Cameron's two tweaks on the smooth arch (2026-05-18):

(1) cur_img feature is added to the timestep query embeddings BEFORE cross-attn.
    This anchors each timestep query with visual context at the starting EEF.

(2) Gripper + rotation regression moved to a separate stage:
      - argmax voxel per timestep (or GT during training)
      - gather VOXEL TOKEN features at those indices  →  (B, T, D)
      - 2 rounds of self-attention among the T gathered tokens
      - MLP per token → grip logit + 3-axis rot bins
    Rationale: the voxel feature carries "what's at that spot visually" — better
    for grasp-timing + orientation than the timestep query feature (which is
    more about navigation).
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from robot_volume import (
    voxel_centers_world, world_to_pixel_torch, pixel_to_normalized_grid,
    sincos_pe_3d, PE_DIM, N_PAST_EEF, T_FUTURE, N_ROT_BINS, IMAGE_SIZE, N_VOX,
)

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE   = 16
UPSAMPLE_RES      = 64
TOKEN_D           = 32
N_HEADS           = 4
N_LAYERS          = 4
N_RG_SELFATTN     = 2  # rounds of self-attn among the T gathered voxel features


class SelfAttn(nn.Module):
    def __init__(self, d=TOKEN_D, heads=N_HEADS):
        super().__init__()
        self.ln = nn.LayerNorm(d)
        self.attn = nn.MultiheadAttention(d, heads, batch_first=True)

    def forward(self, x):
        n = self.ln(x)
        a, _ = self.attn(n, n, n, need_weights=False)
        return x + a


class CrossAttn(nn.Module):
    def __init__(self, d=TOKEN_D, heads=N_HEADS):
        super().__init__()
        self.ln_q  = nn.LayerNorm(d)
        self.ln_kv = nn.LayerNorm(d)
        self.attn  = nn.MultiheadAttention(d, heads, batch_first=True)

    def forward(self, q, kv):
        a, _ = self.attn(self.ln_q(q), self.ln_kv(kv), self.ln_kv(kv), need_weights=False)
        return q + a


class SmoothVolumeARModelV2(nn.Module):
    def __init__(self, n_past=N_PAST_EEF, t_future=T_FUTURE, freeze_backbone=True):
        super().__init__()
        self.n_past = n_past
        self.t_future = t_future

        print(f"Loading DINOv3 (frozen={freeze_backbone})...")
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
            self.dino.eval()
        self.dino_d = self.dino.embed_dim
        self.freeze_backbone = freeze_backbone

        self.image_mlp = nn.Sequential(
            nn.Conv2d(self.dino_d, TOKEN_D, kernel_size=1), nn.GELU(),
            nn.Conv2d(TOKEN_D, TOKEN_D, kernel_size=1),
        )
        self.pe_mlp = nn.Sequential(
            nn.Linear(PE_DIM, TOKEN_D), nn.GELU(), nn.Linear(TOKEN_D, TOKEN_D),
        )
        self.type_embed = nn.Embedding(3, TOKEN_D)

        # Timestep query tokens + temporal PE.
        self.timestep_token = nn.Parameter(torch.randn(t_future, TOKEN_D) * 0.02)
        self.timestep_pe = nn.Parameter(torch.randn(t_future, TOKEN_D) * 0.02)

        self.self_blocks  = nn.ModuleList([SelfAttn(TOKEN_D, N_HEADS) for _ in range(N_LAYERS)])
        self.cross_blocks = nn.ModuleList([CrossAttn(TOKEN_D, N_HEADS) for _ in range(N_LAYERS)])

        # Attention-score projection for voxel logits.
        self.q_proj = nn.Linear(TOKEN_D, TOKEN_D)
        self.k_proj = nn.Linear(TOKEN_D, TOKEN_D)

        # NEW: rot/grip self-attn stack (operates on T=8 gathered voxel features).
        self.rg_self_blocks = nn.ModuleList([SelfAttn(TOKEN_D, N_HEADS) for _ in range(N_RG_SELFATTN)])
        self.rg_norm = nn.LayerNorm(TOKEN_D)
        self.grip_head = nn.Linear(TOKEN_D, 1)
        self.rot_head  = nn.Linear(TOKEN_D, 3 * N_ROT_BINS)

        self.register_buffer("voxel_centers", voxel_centers_world(), persistent=False)

    def _dino_patches(self, x):
        if self.freeze_backbone:
            with torch.no_grad():
                tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
                for blk in self.dino.blocks:
                    rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                    tokens = blk(tokens, rope)
                if self.dino.untie_cls_and_patch_norms:
                    cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                    pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                    tokens = torch.cat([cls_n, pat_n], dim=1)
                else:
                    tokens = self.dino.norm(tokens)
                p = tokens[:, self.dino.n_storage_tokens + 1 :].detach()
            B = p.shape[0]
            return p.reshape(B, H_p, W_p, self.dino_d).permute(0, 3, 1, 2)
        raise NotImplementedError

    def forward(self, rgb, past_eef_world, current_eef_world, world_to_camera,
                target_voxel_idx=None):
        B = rgb.shape[0]
        device = rgb.device

        # 1. Image feature pyramid
        patches = self._dino_patches(rgb)
        feats   = F.interpolate(patches, size=(UPSAMPLE_RES, UPSAMPLE_RES),
                                mode='bilinear', align_corners=False)
        feats   = self.image_mlp(feats)                                               # (B, 32, 64, 64)

        # 2. Project all voxels + past EEFs + current EEF to image pixels
        vox_world = self.voxel_centers.unsqueeze(0).expand(B, -1, -1)                 # (B, V, 3)
        vox_pix   = world_to_pixel_torch(vox_world, world_to_camera)                  # (B, V, 2)
        past_pix  = world_to_pixel_torch(past_eef_world, world_to_camera)             # (B, N, 2)
        cur_pix   = world_to_pixel_torch(current_eef_world.unsqueeze(1), world_to_camera).squeeze(1)

        # 3. Vectorized grid_sample
        def sample(pix_uv):
            grid = pixel_to_normalized_grid(pix_uv, IMAGE_SIZE).unsqueeze(2)
            s = F.grid_sample(feats, grid, mode='bilinear', align_corners=False, padding_mode='zeros')
            return s.squeeze(-1).permute(0, 2, 1)
        vox_img  = sample(vox_pix)                                                    # (B, V, 32)
        past_img = sample(past_pix)
        cur_img  = sample(cur_pix.unsqueeze(1)).squeeze(1)                            # (B, 32)

        # 4. PE relative to current EEF
        ce = current_eef_world.unsqueeze(1)
        vox_rel  = vox_world - ce
        past_rel = past_eef_world - ce
        cur_rel  = torch.zeros(B, 1, 3, device=device, dtype=rgb.dtype)

        vox_pe  = self.pe_mlp(sincos_pe_3d(vox_rel))
        past_pe = self.pe_mlp(sincos_pe_3d(past_rel))
        cur_pe  = self.pe_mlp(sincos_pe_3d(cur_rel))

        # 5. Tokens with type embeddings
        type_vox  = self.type_embed(torch.zeros(B, vox_world.size(1), dtype=torch.long, device=device))
        type_past = self.type_embed(torch.ones (B, past_eef_world.size(1), dtype=torch.long, device=device))
        type_cur  = self.type_embed(torch.full ((B, 1), 2, dtype=torch.long, device=device))

        vox_tokens  = vox_img  + vox_pe  + type_vox                                    # (B, V, 32)
        past_tokens = past_img + past_pe + type_past                                   # (B, N, 32)
        cur_token   = cur_img.unsqueeze(1) + cur_pe + type_cur                         # (B, 1, 32)

        kv = torch.cat([vox_tokens, past_tokens, cur_token], dim=1)                    # (B, V+N+1, 32)

        # 6. Timestep query tokens — NEW: add cur_img as the visual "starting anchor"
        ts_q = (self.timestep_token + self.timestep_pe).unsqueeze(0).expand(B, -1, -1) # (B, T, 32)
        ts_q = ts_q + cur_img.unsqueeze(1)                                              # (B, T, 32)

        # 7. Cross-attention stack (timestep queries ← KV pool)
        x = ts_q
        for sa, ca in zip(self.self_blocks, self.cross_blocks):
            x = sa(x)
            x = ca(x, kv)

        # 8. Voxel logits via attention-score scoring
        q = self.q_proj(x)                                                             # (B, T, 32)
        k = self.k_proj(vox_tokens)                                                    # (B, V, 32)
        voxel_logits_tv = torch.einsum('btd, bvd -> btv', q, k) / (TOKEN_D ** 0.5)     # (B, T, V)
        voxel_logits    = voxel_logits_tv.permute(0, 2, 1)                              # (B, V, T)
        pred_voxel_idx  = voxel_logits.argmax(dim=1)                                    # (B, T)

        # 9. NEW: gather VOXEL TOKEN features at argmax/GT positions, then self-attn, regress
        ref_idx = target_voxel_idx if target_voxel_idx is not None else pred_voxel_idx   # (B, T)
        target_vox_feat = vox_tokens.gather(1, ref_idx.unsqueeze(-1).expand(-1, -1, TOKEN_D))  # (B, T, 32)
        rg = target_vox_feat
        for sa in self.rg_self_blocks:
            rg = sa(rg)
        rg = self.rg_norm(rg)
        grip_logit = self.grip_head(rg).squeeze(-1)                                     # (B, T)
        rot_logits = self.rot_head(rg).reshape(B, self.t_future, 3, N_ROT_BINS)

        return {
            "voxel_logits":   voxel_logits,
            "grip_logit":     grip_logit,
            "rot_logits":     rot_logits,
            "pred_voxel_idx": pred_voxel_idx,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = SmoothVolumeARModelV2().to(device)
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    n_a = sum(p.numel() for p in m.parameters())
    print(f"Trainable: {n_t:,} / {n_a:,}")
    B = 2
    rgb = torch.randn(B, 3, 448, 448).to(device)
    past = torch.randn(B, N_PAST_EEF, 3).to(device) * 0.2 + torch.tensor([0., 0., 1.0]).to(device)
    cur = past[:, -1]
    w2c = torch.eye(4).unsqueeze(0).expand(B, 4, 4).to(device)
    out = m(rgb, past, cur, w2c)
    for k, v in out.items():
        print(f"  {k}: {tuple(v.shape)}")
    print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB" if device.type == 'cuda' else '')
