"""DinoVolumeKV + gripper/rotation MLP heads (libero-style teacher forcing).

Inputs:  rgb (B, 3, IMG, IMG), query_pixels (B, T, 2) in 504-space (GT at train, argmax at eval)
Outputs:
  volume_logits:   (B, T, Z, h, w)         — unchanged from DinoVolumeKV
  gripper_logits:  (B, T, N_GRIPPER_BINS)  — CE over discretised gripper
  rotation_logits: (B, T, 3, N_ROT_BINS)   — CE per euler axis
  pixel_feats:     (B, key_dim, h, w)      — for viz/debug

Gripper/rotation MLPs index `pixel_feats` at `query_pixels` (teacher-forced GT at train).
We DETACH the indexed features so MLP grads don't destabilise the volume objective —
mirrors the original libero pattern where the volume head is the load-bearing one.
"""
import os, sys, math
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))
from model_dino_volume_kv import (DinoVolumeKV, DINO_REPO_DIR, DINO_WEIGHTS_PATH, sinusoidal_features,
                                   IMG_SIZE, N_WINDOW, N_HEIGHT_BINS, KEY_DIM)

N_ROT_BINS     = 48   # bumped from 32 (Cameron 2026-05-20: finer 1D PCA bins)
N_GRIPPER_BINS = 32
DA3_INPUT      = 504


D_MODEL_HEAD = 128
TF_LAYERS    = 5      # bumped from 2 (Cameron 2026-05-20 spec)
TF_HEADS     = 4
PRED_GRID    = 56     # 56×56 spatial grid (image_size=448 / patch=16 × 2 upsample)


class DinoVolumeKVFull(DinoVolumeKV):
    """Volume KV + TEMPORAL TRANSFORMER head for gripper and rotation.

    Per Cameron 2026-05-19 (later in session): replace per-timestep MLPs with a tiny
    transformer encoder that processes all 8 future keypoints jointly. Each token =
       z_emb[z_t] + y_emb[y_t] + x_emb[x_t] + t_emb[t] + Linear(F[y_t, x_t])
    Self-attention over the 8 tokens lets gripper@t=5 see the keypoints at t=0..4 and t=6,7
    so it can reason about trajectory phase ("descending → grasping → ascending"). Each
    output token is mapped to gripper bins + rotation bins per axis via small linears.

    Inputs at forward time:
      rgb:     (B, 3, IMG, IMG)
      kp_zyx:  (B, T, 3) long — (z_bin, y_grid, x_grid). At training this is GT (teacher
               forcing); at inference it's the argmax over the volume head's output.
    """
    def __init__(self, n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 key_dim: int = KEY_DIM, image_size: int = IMG_SIZE,
                 height_enc: str = 'sin', time_enc: str = 'sin',
                 head_hidden: int = 192, dino_variant: str = 'dinov3_vits16plus',
                 n_rot_bins: int = N_ROT_BINS, n_gripper_bins: int = N_GRIPPER_BINS,
                 d_model: int = D_MODEL_HEAD, tf_layers: int = TF_LAYERS,
                 tf_heads: int = TF_HEADS, pred_grid: int = PRED_GRID,
                 detach_for_head: bool = False):
        super().__init__(n_window=n_window, n_height_bins=n_height_bins, key_dim=key_dim,
                         image_size=image_size, height_enc=height_enc, time_enc=time_enc,
                         head_hidden=head_hidden, dino_variant=dino_variant)
        self.n_rot_bins      = n_rot_bins
        self.n_gripper_bins  = n_gripper_bins
        self.d_model         = d_model
        self.pred_grid       = pred_grid
        self.detach_for_head = detach_for_head

        # Sinusoidal PE for z (height) and t (time) — matches parent's height_enc='sin',
        # time_enc='sin'. Fixed buffers, no learnable parameters.
        self.register_buffer("z_sin", sinusoidal_features(n_height_bins, d_model), persistent=False)
        self.register_buffer("t_sin", sinusoidal_features(n_window,     d_model), persistent=False)
        # y/x stay as learned embeddings (pixel features already carry y/x positional info
        # implicitly since they're indexed at (y, x), so y/x emb is a smaller residual).
        self.y_tok_emb = nn.Embedding(pred_grid, d_model)
        self.x_tok_emb = nn.Embedding(pred_grid, d_model)
        self.feat_to_token = nn.Linear(key_dim, d_model)

        # Tiny transformer encoder — self-attention across the 8 future tokens
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=tf_heads, dim_feedforward=d_model * 4,
            dropout=0.0, activation='gelu', batch_first=True, norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=tf_layers)
        self.out_norm = nn.LayerNorm(d_model)

        # Per-token output heads. Rotation is now 1D (PCA-collapsed) — single CE over
        # n_rot_bins, not 3-axis CE.
        self.gripper_out  = nn.Linear(d_model, n_gripper_bins)
        self.rotation_out = nn.Linear(d_model, n_rot_bins)
        # Default init throughout — proven to start near log(N) without choking gradients
        # (previous std=0.01 init plateaued for 400+ iters).

    def predict_from_keypoints(self, pixel_feats, kp_zyx):
        """pixel_feats: (B, C, ph, pw). kp_zyx: (B, T, 3) long — (z_bin, y_grid, x_grid).
        Returns gripper_logits (B, T, n_grip), rotation_logits (B, T, 3, n_rot)."""
        B, C, ph, pw = pixel_feats.shape
        T = kp_zyx.shape[1]
        feats = pixel_feats.detach() if self.detach_for_head else pixel_feats
        z = kp_zyx[..., 0].clamp(0, self.n_height_bins - 1)
        y = kp_zyx[..., 1].clamp(0, ph - 1)
        x = kp_zyx[..., 2].clamp(0, pw - 1)
        # Sample pixel features at each keypoint location
        batch_idx = torch.arange(B, device=feats.device).view(B, 1).expand(B, T)
        sampled = feats[batch_idx, :, y, x]                                    # (B, T, C)

        # Build per-token embedding: sin(z) + sin(t) + learned(y) + learned(x) + Linear(F)
        t_idx = torch.arange(T, device=feats.device).view(1, T).expand(B, T)
        z_pe  = self.z_sin[z]                                                   # (B, T, d_model)
        t_pe  = self.t_sin[t_idx]                                               # (B, T, d_model)
        tokens = (z_pe + t_pe + self.y_tok_emb(y) + self.x_tok_emb(x)
                  + self.feat_to_token(sampled))                                # (B, T, d_model)

        # Transformer self-attention across the T tokens (deepened to 5 layers)
        h = self.transformer(tokens)                                            # (B, T, d_model)
        h = self.out_norm(h)
        grip = self.gripper_out(h)                                              # (B, T, n_grip)
        rot  = self.rotation_out(h)                                             # (B, T, n_rot) — 1D PCA
        return grip, rot

    def forward(self, rgb, kp_zyx=None):
        """If kp_zyx given, run the transformer head and emit gripper/rotation logits.
        kp_zyx: (B, T, 3) long — (z_bin, y_grid, x_grid). GT at train, argmax at inference."""
        out = super().forward(rgb)
        if kp_zyx is not None:
            grip, rot = self.predict_from_keypoints(out["pixel_feats"], kp_zyx)
            out["gripper_logits"]  = grip
            out["rotation_logits"] = rot
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeKVFull().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    # Random (z, y, x) keypoints
    kp = torch.stack([
        torch.randint(0, 32, (2, N_WINDOW)),
        torch.randint(0, 56, (2, N_WINDOW)),
        torch.randint(0, 56, (2, N_WINDOW)),
    ], dim=-1).to(device)
    with torch.no_grad():
        out = m(rgb, kp)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
