"""Autoregressive transformer policy with high-res patches + relative-to-EEF positional embeddings.

Architecture (per Cameron's spec, 2026-05-16):
  Per-frame patches (frozen DINOv3 ViT-S/16, 28x28 = 784 patches, D=384)
  + 1 learnable EEF token per timestep
  History context H frames (default 8)
  Three summed positional embeddings on every token:
    (1) temporal PE — which timestep (0..H-1)
    (2) absolute 2D PE — patch position (or EEF position) in [0,1]^2
    (3) relative-to-EEF PE — (xy - EEF(t-1).xy), the anchor inductive bias for free-space motion
  Causal mask along time; within a frame all tokens are bidirectional.
  Readout: take the EEF_query token at t = H-1, MLP -> logits over a 56x56 spatial grid
  Loss: cross-entropy on grid cell containing next-EEF GT.

First-pass scope:
  xy-only output (no height / rotation / gripper — bolt on later as parallel readout queries).
  Teacher-force past EEF coords from data; predict next-step (single step) EEF cell.
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Match the existing model.py env-var convention so the same DINO weights work.
DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE   = 16

# Defaults; train_ar.py may override at construction time.
HISTORY_LEN     = 8
GRID_SIZE       = 56      # 56x56 output classification grid (8-pixel cells at 448 input)
TRANSFORMER_D   = 384     # matches DINOv3 ViT-S/16 embed_dim — no projection needed
TRANSFORMER_H   = 6
TRANSFORMER_L   = 4
IMAGE_SIZE      = 448


def _sincos_pe_1d(positions, dim, device, dtype):
    """positions: (...,) float in any range. dim must be even. returns (..., dim)."""
    half = dim // 2
    freqs = torch.exp(torch.arange(half, device=device, dtype=dtype) * -(math.log(10000.0) / half))
    angles = positions.unsqueeze(-1) * freqs  # (..., half)
    return torch.cat([angles.sin(), angles.cos()], dim=-1)


def _sincos_pe_2d(xy, dim, device, dtype):
    """xy: (..., 2). dim must be divisible by 4. returns (..., dim)."""
    half = dim // 2
    px = _sincos_pe_1d(xy[..., 0], half, device, dtype)
    py = _sincos_pe_1d(xy[..., 1], half, device, dtype)
    return torch.cat([px, py], dim=-1)


class ARTransformerPolicy(nn.Module):
    def __init__(
        self,
        target_size=IMAGE_SIZE,
        history_len=HISTORY_LEN,
        grid_size=GRID_SIZE,
        d_model=TRANSFORMER_D,
        n_heads=TRANSFORMER_H,
        n_layers=TRANSFORMER_L,
        freeze_backbone=True,
    ):
        super().__init__()
        self.target_size  = target_size
        self.history_len  = history_len
        self.grid_size    = grid_size
        self.d_model      = d_model
        self.patch_size   = DINO_PATCH_SIZE
        self.patches_per_side = target_size // self.patch_size  # 28 at 448 input
        self.n_patches    = self.patches_per_side ** 2          # 784

        print("Loading DINOv3 backbone (frozen={})...".format(freeze_backbone))
        self.dino = torch.hub.load(
            DINO_REPO_DIR,
            'dinov3_vits16plus',
            source='local',
            weights=DINO_WEIGHTS_PATH,
        )
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
            self.dino.eval()
        self.embed_dim = self.dino.embed_dim
        assert self.embed_dim == d_model, (
            f"d_model ({d_model}) must equal DINO embed_dim ({self.embed_dim}) — "
            "otherwise add a linear projection before the transformer."
        )
        self.freeze_backbone = freeze_backbone

        # Learnable EEF token (one prototype, broadcast across time).
        # Temporal PE + EEF-relative PE distinguish timesteps.
        self.eef_token = nn.Parameter(torch.randn(d_model) * 0.02)

        # Token-type embeddings (patch vs eef) — small disambiguator.
        self.type_embed_patch = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_eef   = nn.Parameter(torch.randn(d_model) * 0.02)

        # Relative-to-EEF PE: project a 2D relative coord through a small MLP into d_model.
        # Coord is normalized to [-1, 1] (image-size scale, then center & rescale).
        self.rel_pe_mlp = nn.Sequential(
            nn.Linear(2, d_model), nn.GELU(), nn.Linear(d_model, d_model),
        )

        # Transformer encoder layers (with custom causal-time mask supplied at forward).
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
            dropout=0.0, activation="gelu", batch_first=True, norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)

        # Output head: take EEF query token at t=H-1, project to grid_size^2 logits.
        self.readout = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, grid_size * grid_size),
        )

        # Precompute patch coordinate grid in [0,1] (centers).
        ys = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        xs = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')
        self.register_buffer(
            "patch_xy_01",
            torch.stack([grid_x, grid_y], dim=-1).reshape(self.n_patches, 2),  # (Np, 2)
            persistent=False,
        )

    # ---------------- DINO feature extraction ---------------- #

    def _dino_patches(self, x):
        """x: (N, 3, H, W). returns patch tokens (N, n_patches, D)."""
        if self.freeze_backbone:
            with torch.no_grad():
                tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
                for blk in self.dino.blocks:
                    rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                    tokens = blk(tokens, rope_sincos)
                if self.dino.untie_cls_and_patch_norms:
                    cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                    pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                    tokens = torch.cat([cls_n, pat_n], dim=1)
                else:
                    tokens = self.dino.norm(tokens)
                patches = tokens[:, self.dino.n_storage_tokens + 1 :]  # (N, n_patches, D)
            return patches.detach()
        else:
            tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
            for blk in self.dino.blocks:
                rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                tokens = blk(tokens, rope_sincos)
            if self.dino.untie_cls_and_patch_norms:
                cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                tokens = torch.cat([cls_n, pat_n], dim=1)
            else:
                tokens = self.dino.norm(tokens)
            return tokens[:, self.dino.n_storage_tokens + 1 :]

    # ---------------- positional embeddings ---------------- #

    def _positional_embeddings(self, B, H, history_eef_xy_01):
        """Build the summed positional embedding for every token in the sequence.

        Sequence layout per batch element:
          for t = 0..H-1:
            patch tokens (Np)        — at frame t
            EEF token (1)            — at frame t (query when t == H-1)
        Total tokens per frame: Np + 1; total length: H * (Np + 1).

        history_eef_xy_01: (B, H, 2) in [0,1] image coords. The EEF token at the LAST
        frame is the prediction query; we pass the previous frame's EEF as that token's
        positional anchor (since the current EEF is unknown at prediction time, by spec
        "Relative-to-current-EEF uses EEF(t-1) as anchor").
        """
        device = history_eef_xy_01.device
        dtype  = history_eef_xy_01.dtype
        D = self.d_model
        Np = self.n_patches

        # (1) Temporal PE — one vec per timestep, broadcast over tokens in that frame.
        time_idx = torch.arange(H, device=device, dtype=dtype)        # (H,)
        time_pe = _sincos_pe_1d(time_idx, D, device, dtype)           # (H, D)

        # (2) Absolute 2D PE
        # Patches: precomputed grid; EEF: its own xy in [0,1].
        patch_abs_pe = _sincos_pe_2d(self.patch_xy_01.to(dtype), D, device, dtype)  # (Np, D)

        # (3) Relative-to-EEF(t-1) PE
        # Anchor = EEF at the SECOND-TO-LAST frame (the most recent KNOWN EEF that the
        # query token has access to). For frames earlier than t=H-1 it's the same anchor
        # — we keep the inductive bias uniform across the sequence.
        anchor_xy_01 = history_eef_xy_01[:, -1, :]  # (B, 2)  ← the "current" frame's EEF; at training this is teacher-forced state EEF
        # During teacher forcing we know EEF at t=H-1. For the EEF_query token (predicting t=H)
        # we anchor on EEF(H-1) per the spec; for past tokens we anchor on the same EEF
        # so the relative coord is interpretable.
        # Build per-batch patch relative coords:
        # patch_rel_xy[b, p] = patch_xy_01[p] - anchor_xy_01[b]   then scale to roughly [-1, 1]
        patch_xy = self.patch_xy_01.to(dtype).unsqueeze(0).expand(B, Np, 2)  # (B, Np, 2)
        patch_rel_xy = (patch_xy - anchor_xy_01.unsqueeze(1)) * 2.0          # (B, Np, 2) approx in [-1, 1]
        patch_rel_pe = self.rel_pe_mlp(patch_rel_xy)                          # (B, Np, D)

        # EEF tokens (per batch, per timestep) — their absolute 2D PE and relative PE
        # use that token's OWN xy (history_eef_xy_01[b, t]), except the query at t=H-1
        # which uses the anchor (== itself, so relative coord is zero — fine).
        eef_abs_pe = _sincos_pe_2d(history_eef_xy_01, D, device, dtype)       # (B, H, D)
        eef_rel_xy = (history_eef_xy_01 - anchor_xy_01.unsqueeze(1)) * 2.0    # (B, H, 2)
        eef_rel_pe = self.rel_pe_mlp(eef_rel_xy)                              # (B, H, D)

        # Patch tokens' temporal PE: time_pe[t] broadcast over Np patches.
        # We build the full sequence-shaped PE and return it.
        # Sequence ordering: for each t: [patches(Np), eef(1)]
        # Patch PE for frame t:  patch_abs_pe + patch_rel_pe[b] + time_pe[t]
        # EEF PE for frame t:    eef_abs_pe[b, t] + eef_rel_pe[b, t] + time_pe[t]
        patch_pe_per_frame = patch_abs_pe.unsqueeze(0) + patch_rel_pe          # (B, Np, D)  [no t dim yet — same per frame]
        # Add time PE per frame
        full = []
        for t in range(H):
            # patches at frame t
            p_pe = patch_pe_per_frame + time_pe[t]                            # (B, Np, D)
            e_pe = (eef_abs_pe[:, t] + eef_rel_pe[:, t] + time_pe[t]).unsqueeze(1)  # (B, 1, D)
            full.append(torch.cat([p_pe, e_pe], dim=1))                       # (B, Np+1, D)
        return torch.cat(full, dim=1)                                          # (B, H*(Np+1), D)

    def _build_causal_time_mask(self, H, device):
        """Causal across time: token in frame t may NOT attend to tokens in frame > t.
        Within a frame, all tokens (patches + eef) attend bidirectionally.

        Returns additive mask (T, T) with 0 / -inf entries, suitable for nn.TransformerEncoder.
        """
        Np = self.n_patches
        per_frame = Np + 1
        T = H * per_frame
        # Build a (H, H) block mask, then expand to (T, T) by Kronecker w/ ones.
        block = torch.zeros(H, H, device=device)
        block.masked_fill_(torch.triu(torch.ones(H, H, device=device, dtype=torch.bool), diagonal=1), float("-inf"))
        # Expand: each block entry becomes a (per_frame, per_frame) sub-block.
        mask = block.repeat_interleave(per_frame, dim=0).repeat_interleave(per_frame, dim=1)
        return mask  # (T, T)

    # ---------------- forward ---------------- #

    def forward(self, history_imgs, history_eef_xy):
        """
        Args:
            history_imgs:   (B, H, 3, 448, 448) — past H frames (most recent at index H-1).
            history_eef_xy: (B, H, 2) — past H EEF pixel coords in image-pixel space
                                         (the last entry is the current frame's EEF; teacher-forced).

        Returns:
            logits: (B, grid_size*grid_size) — categorical over 56x56 spatial grid
                                                 for the predicted next-step EEF cell.
        """
        B, H, C, Hi, Wi = history_imgs.shape
        assert H == self.history_len, f"Expected H={self.history_len}, got {H}"
        assert Hi == self.target_size and Wi == self.target_size, \
            f"Expected {self.target_size}x{self.target_size}, got {Hi}x{Wi}"

        # 1. DINO patches per frame (flatten time into batch)
        x = history_imgs.view(B * H, C, Hi, Wi)
        patches = self._dino_patches(x)                # (B*H, Np, D)
        patches = patches.view(B, H, self.n_patches, self.d_model)

        # 2. Add token-type embeddings (patch vs eef)
        patches = patches + self.type_embed_patch
        eef_proto = self.eef_token + self.type_embed_eef    # (D,)
        eef_tokens = eef_proto.view(1, 1, self.d_model).expand(B, H, self.d_model)  # (B, H, D)

        # 3. Interleave per-frame: [patches(t), eef(t)] for t = 0..H-1
        frame_blocks = []
        for t in range(H):
            block = torch.cat([patches[:, t], eef_tokens[:, t:t+1]], dim=1)  # (B, Np+1, D)
            frame_blocks.append(block)
        seq = torch.cat(frame_blocks, dim=1)                                   # (B, T, D)

        # 4. Add positional embeddings
        history_eef_01 = history_eef_xy / float(self.target_size)              # to [0,1]
        pe = self._positional_embeddings(B, H, history_eef_01)                 # (B, T, D)
        seq = seq + pe

        # 5. Transformer with causal-time mask
        mask = self._build_causal_time_mask(H, seq.device)                     # (T, T) additive
        out = self.transformer(seq, mask=mask)                                  # (B, T, D)

        # 6. Readout: EEF_query token is the LAST eef token in the sequence.
        # Its index: end of frame (H-1) block, at position Np within that block.
        # Sequence layout per frame: Np patches + 1 eef. End-of-frame idx = (t+1)*(Np+1) - 1.
        query_idx = H * (self.n_patches + 1) - 1
        query_feat = out[:, query_idx, :]                                      # (B, D)
        logits = self.readout(query_feat)                                       # (B, grid_size^2)
        return logits


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ARTransformerPolicy(history_len=8, freeze_backbone=True).to(device)
    n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {n_train:,} / {n_total:,}")

    B, H = 2, 8
    imgs = torch.randn(B, H, 3, 448, 448).to(device)
    eef  = torch.rand(B, H, 2).to(device) * 448
    with torch.no_grad():
        logits = model(imgs, eef)
    print("logits", logits.shape)  # (B, 56*56=3136)
