"""Two-stage autoregressive transformer policy (refactor of model_autoregressive.py).

Cameron's design (2026-05-16):
  Stage A — PatchEncoder: per-frame DINO + small projection. Cacheable.
            Runs ONCE per frame ever (eval) or once per window (train).
  Stage B — ARHead: cross-frame transformer with rel-PE (target-specific).
            Runs per prediction step. Small attention budget.

Wins vs. v1:
  Eval cost: H DINO calls/step → 1 DINO call/step (H× cheaper).
  Train density: one DINO call → K supervision signals (per window of W>H, K = W-H targets).

Defaults: W=20 window, H=8 attention context, grid=56, D=384 (DINOv3 ViT-S/16+).
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Reuse env-var convention from existing models.
DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE   = 16

HISTORY_LEN     = 8        # H — attention context for ARHead
WINDOW_LEN      = 20       # W — frames per dataset window (W - H = #targets per DINO call)
GRID_SIZE       = 56       # 56×56 output classification grid
TRANSFORMER_D   = 384      # matches DINOv3 embed_dim — no projection inside ARHead needed
TRANSFORMER_H   = 6
TRANSFORMER_L   = 4
IMAGE_SIZE      = 448
# 7-DoF heads (match existing model.py / model_act.py constants)
N_HEIGHT_BINS = 32
N_GRIPPER_BINS = 32  # unused — gripper uses BCE on a single logit (matches existing PARA)
N_ROT_BINS = 32


# ───────── helpers ─────────

def _sincos_pe_1d(positions, dim, device, dtype):
    half = dim // 2
    freqs = torch.exp(torch.arange(half, device=device, dtype=dtype) * -(math.log(10000.0) / half))
    angles = positions.unsqueeze(-1) * freqs
    return torch.cat([angles.sin(), angles.cos()], dim=-1)


def _sincos_pe_2d(xy, dim, device, dtype):
    half = dim // 2
    px = _sincos_pe_1d(xy[..., 0], half, device, dtype)
    py = _sincos_pe_1d(xy[..., 1], half, device, dtype)
    return torch.cat([px, py], dim=-1)


# ───────── Stage A ─────────

class PatchEncoder(nn.Module):
    """Per-frame DINO patches + optional learned projection. No temporal/positional info added here.

    The output is the cache primitive: (B, W, Np, D). At eval the cache is a ring buffer of size H.
    """

    def __init__(self, target_size=IMAGE_SIZE, d_model=TRANSFORMER_D, freeze_backbone=True,
                 add_projection=True):
        super().__init__()
        self.target_size      = target_size
        self.patch_size       = DINO_PATCH_SIZE
        self.patches_per_side = target_size // self.patch_size
        self.n_patches        = self.patches_per_side ** 2

        print(f"PatchEncoder: loading DINOv3 (frozen={freeze_backbone})...")
        self.dino = torch.hub.load(
            DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH,
        )
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
            self.dino.eval()
        self.embed_dim = self.dino.embed_dim
        assert self.embed_dim == d_model, f"d_model {d_model} != DINO {self.embed_dim}"
        self.freeze_backbone = freeze_backbone

        # Learned projection so the ARHead sees a slightly task-tuned representation
        # without paying full DINO compute. Optional; small and cheap.
        if add_projection:
            self.proj = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model))
        else:
            self.proj = nn.Identity()

    def _dino_patches(self, x):
        """x: (N, 3, H, W) → (N, Np, D)"""
        if self.freeze_backbone:
            with torch.no_grad():
                tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
                for blk in self.dino.blocks:
                    rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                    tokens = blk(tokens, rope_sincos)
                if self.dino.untie_cls_and_patch_norms:
                    cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                    pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                    tokens = torch.cat([cls_n, pat_n], dim=1)
                else:
                    tokens = self.dino.norm(tokens)
                patches = tokens[:, self.dino.n_storage_tokens + 1 :]
            return patches.detach()
        else:
            tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
            for blk in self.dino.blocks:
                rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                tokens = blk(tokens, rope_sincos)
            if self.dino.untie_cls_and_patch_norms:
                cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                tokens = torch.cat([cls_n, pat_n], dim=1)
            else:
                tokens = self.dino.norm(tokens)
            return tokens[:, self.dino.n_storage_tokens + 1 :]

    def forward(self, frames):
        """frames: (B, W, 3, H, W) → patches: (B, W, Np, D)"""
        B, W = frames.shape[:2]
        x = frames.view(B * W, *frames.shape[2:])
        patches = self._dino_patches(x)                    # (B*W, Np, D)
        patches = self.proj(patches)                        # (B*W, Np, D)
        return patches.view(B, W, self.n_patches, self.embed_dim)


# ───────── Stage B ─────────

class ARHead(nn.Module):
    """Cross-frame AR transformer + readout.

    Inputs (per prediction call):
      patch_tokens:    (B, H, Np, D)  — sliced cache of past H frames
      eef_history_xy:  (B, H, 2)      — pixel coords at those H frames (state EEF)
      anchor_xy:       (B, 2)         — anchor for the relative-PE; usually = eef_history_xy[:, -1]
      target_size:     int (image size, for normalization)

    The "current" frame is index H-1 and the readout is the EEF query token there.
    """

    def __init__(self, history_len=HISTORY_LEN, grid_size=GRID_SIZE,
                 d_model=TRANSFORMER_D, n_heads=TRANSFORMER_H, n_layers=TRANSFORMER_L,
                 patches_per_side=IMAGE_SIZE // DINO_PATCH_SIZE):
        super().__init__()
        self.history_len = history_len
        self.grid_size   = grid_size
        self.d_model     = d_model
        self.patches_per_side = patches_per_side
        self.n_patches   = patches_per_side ** 2

        self.eef_token = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_patch = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_eef   = nn.Parameter(torch.randn(d_model) * 0.02)

        # Target-specific relative PE; recomputed per Stage-B call.
        self.rel_pe_mlp = nn.Sequential(
            nn.Linear(2, d_model), nn.GELU(), nn.Linear(d_model, d_model),
        )

        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
            dropout=0.0, activation="gelu", batch_first=True, norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)

        self.readout = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, grid_size * grid_size),
        )
        # 7-DoF parallel heads read from the same EEF query token hidden state.
        # All share a LayerNorm on the feature first; small MLPs per output.
        self.feat_norm = nn.LayerNorm(d_model)
        self.height_head = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, N_HEIGHT_BINS),
        )
        self.gripper_head = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, 1),
        )
        self.rotation_head = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, 3 * N_ROT_BINS),
        )

        ys = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        xs = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')
        self.register_buffer(
            "patch_xy_01",
            torch.stack([grid_x, grid_y], dim=-1).reshape(self.n_patches, 2),
            persistent=False,
        )

    def _causal_time_mask(self, H, device):
        Np = self.n_patches
        per_frame = Np + 1
        block = torch.zeros(H, H, device=device)
        block.masked_fill_(torch.triu(torch.ones(H, H, device=device, dtype=torch.bool), diagonal=1), float("-inf"))
        return block.repeat_interleave(per_frame, dim=0).repeat_interleave(per_frame, dim=1)

    def forward(self, patch_tokens, eef_history_xy, anchor_xy, target_size=IMAGE_SIZE):
        """Returns: logits (B, grid^2)."""
        B, H, Np, D = patch_tokens.shape
        assert H == self.history_len, f"expected H={self.history_len}, got {H}"
        assert Np == self.n_patches
        dtype = patch_tokens.dtype
        device = patch_tokens.device

        # Token-type embeddings
        patches = patch_tokens + self.type_embed_patch                          # (B, H, Np, D)
        eef_tok = (self.eef_token + self.type_embed_eef).view(1, 1, D).expand(B, H, D)  # (B, H, D)

        # Positional embeddings
        eef_01    = eef_history_xy / float(target_size)                        # (B, H, 2)
        anchor_01 = anchor_xy / float(target_size)                              # (B, 2)

        time_idx  = torch.arange(H, device=device, dtype=dtype)
        time_pe   = _sincos_pe_1d(time_idx, D, device, dtype)                  # (H, D)
        patch_abs = _sincos_pe_2d(self.patch_xy_01.to(dtype), D, device, dtype)  # (Np, D)
        patch_xy_b = self.patch_xy_01.to(dtype).unsqueeze(0).expand(B, Np, 2)
        patch_rel = self.rel_pe_mlp((patch_xy_b - anchor_01.unsqueeze(1)) * 2.0)  # (B, Np, D)
        eef_abs   = _sincos_pe_2d(eef_01, D, device, dtype)                    # (B, H, D)
        eef_rel   = self.rel_pe_mlp((eef_01 - anchor_01.unsqueeze(1)) * 2.0)   # (B, H, D)

        # Build sequence: per frame t = [patches, eef]
        frame_seqs = []
        patches_pos_t = patch_abs.unsqueeze(0) + patch_rel                      # (B, Np, D)
        for t in range(H):
            p = patches[:, t] + patches_pos_t + time_pe[t]                      # (B, Np, D)
            e = (eef_tok[:, t] + eef_abs[:, t] + eef_rel[:, t] + time_pe[t]).unsqueeze(1)  # (B, 1, D)
            frame_seqs.append(torch.cat([p, e], dim=1))
        seq = torch.cat(frame_seqs, dim=1)                                      # (B, H*(Np+1), D)

        mask = self._causal_time_mask(H, device)
        out  = self.transformer(seq, mask=mask)                                 # (B, T, D)

        query_idx = H * (self.n_patches + 1) - 1                                # EEF token at last frame
        q = out[:, query_idx, :]                                                # (B, D)
        xy_logits = self.readout(q)                                              # (B, grid^2)
        f = self.feat_norm(q)
        height_logits  = self.height_head(f)                                     # (B, N_HEIGHT_BINS)
        gripper_logit  = self.gripper_head(f).squeeze(-1)                        # (B,)
        rotation_logits = self.rotation_head(f).view(-1, 3, N_ROT_BINS)          # (B, 3, N_ROT_BINS)
        return {
            "xy_logits": xy_logits,
            "height_logits": height_logits,
            "gripper_logit": gripper_logit,
            "rotation_logits": rotation_logits,
        }


# ───────── Convenience wrapper ─────────

class ARTransformerPolicyV2(nn.Module):
    """End-to-end convenience wrapper. For multi-target training, use the two stages directly:
        patches = self.patch_encoder(window_frames)      # (B, W, Np, D), one DINO call
        for t in target_steps:
            history = patches[:, t-H:t]
            logits  = self.ar_head(history, eef_xy[:, t-H:t], eef_xy[:, t-1])
    """

    def __init__(self, target_size=IMAGE_SIZE, history_len=HISTORY_LEN, grid_size=GRID_SIZE,
                 d_model=TRANSFORMER_D, n_heads=TRANSFORMER_H, n_layers=TRANSFORMER_L,
                 freeze_backbone=True):
        super().__init__()
        self.target_size = target_size
        self.history_len = history_len
        self.grid_size   = grid_size
        self.patch_encoder = PatchEncoder(target_size, d_model, freeze_backbone)
        self.ar_head = ARHead(history_len, grid_size, d_model, n_heads, n_layers,
                               patches_per_side=target_size // DINO_PATCH_SIZE)

    def forward(self, frames, eef_history_xy, anchor_xy=None):
        """Single-target forward (predicts EEF after the last frame).

        Args:
            frames: (B, H, 3, target_size, target_size)
            eef_history_xy: (B, H, 2)
            anchor_xy: (B, 2) — defaults to eef_history_xy[:, -1]
        """
        patches = self.patch_encoder(frames)                                    # (B, H, Np, D)
        if anchor_xy is None:
            anchor_xy = eef_history_xy[:, -1]
        return self.ar_head(patches, eef_history_xy, anchor_xy, self.target_size)


# ───────── eval-time rolling cache ─────────

class RolloutCache:
    """Ring buffer of patch tokens for closed-loop AR rollout.

    Usage:
        cache = RolloutCache(history_len=8, n_patches=784, d_model=384, device=device)
        # at each rollout step:
        new_patches = patch_encoder(frame.unsqueeze(0).unsqueeze(0))[:, 0]  # (1, Np, D)
        cache.push(new_patches, eef_xy)
        history_patches, history_eef = cache.window()
        logits = ar_head(history_patches, history_eef, history_eef[:, -1])
        next_eef = grid_idx_to_pixel(logits.argmax())
    """

    def __init__(self, history_len, n_patches, d_model, device):
        self.H = history_len
        self.patches = torch.zeros(1, history_len, n_patches, d_model, device=device)
        self.eef_xy  = torch.zeros(1, history_len, 2, device=device)
        self.fill = 0

    def push(self, new_patches, new_eef_xy):
        """new_patches: (1, Np, D)  new_eef_xy: (2,) or (1, 2)"""
        new_eef_xy = new_eef_xy.view(1, 2) if new_eef_xy.dim() == 1 else new_eef_xy
        if self.fill < self.H:
            self.patches[0, self.fill] = new_patches[0]
            self.eef_xy[0, self.fill]  = new_eef_xy[0]
            self.fill += 1
        else:
            self.patches[:, :-1] = self.patches[:, 1:].clone()
            self.eef_xy[:, :-1]  = self.eef_xy[:, 1:].clone()
            self.patches[0, -1]  = new_patches[0]
            self.eef_xy[0, -1]   = new_eef_xy[0]

    def window(self):
        """Return the current H-window, left-padded with the earliest frame if not yet full."""
        if self.fill < self.H:
            patches = self.patches.clone()
            eef     = self.eef_xy.clone()
            for i in range(self.fill, self.H):
                patches[0, i] = patches[0, max(0, self.fill - 1)]
                eef[0, i]     = eef[0, max(0, self.fill - 1)]
            return patches, eef
        return self.patches, self.eef_xy


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Smoke test 1: convenience wrapper
    print("\n== Smoke test: ARTransformerPolicyV2 single-target ==")
    model = ARTransformerPolicyV2(history_len=8, freeze_backbone=True).to(device)
    n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {n_train:,} / {n_total:,}")
    B, H = 2, 8
    frames = torch.randn(B, H, 3, 448, 448).to(device)
    eef = torch.rand(B, H, 2).to(device) * 448
    with torch.no_grad():
        out = model(frames, eef)
    print("single-target xy:", out['xy_logits'].shape, "height:", out['height_logits'].shape)

    # Smoke test 2: multi-target training-style call
    print("\n== Smoke test: PatchEncoder once, ARHead × K ==")
    W = 20
    frames_w = torch.randn(B, W, 3, 448, 448).to(device)
    eef_w    = torch.rand(B, W, 2).to(device) * 448
    with torch.no_grad():
        patches = model.patch_encoder(frames_w)          # (B, W, Np, D)
    print("patches:", patches.shape)
    target_steps = list(range(H, W))                      # 12 targets
    for t in target_steps[:3]:
        hist_p = patches[:, t - H : t]
        hist_e = eef_w[:, t - H : t]
        with torch.no_grad():
            out = model.ar_head(hist_p, hist_e, hist_e[:, -1])
        print(f"  target_step={t}: xy {out['xy_logits'].shape}  height {out['height_logits'].shape}  "
              f"grip {out['gripper_logit'].shape}  rot {out['rotation_logits'].shape}")
    print(f"(would run {len(target_steps)} ARHead calls per DINO call at train)")
