"""Voxel-token AR policy (Cameron's variants B and C).

Architecture:
  Stage A (PatchEncoder, reused from v2): DINO patches per frame.
  Voxel grid (per current frame): (G_xy × G_xy × G_z) voxels. Each voxel feature =
      Linear(PE(xyz)) + dino_patch[x_pix, y_pix]
    where xyz is either:
      - variant B: absolute world xyz of the voxel center
      - variant C: (xyz - eef_start_xyz), the EEF-anchored delta
    The image-aligned formulation means voxel (x, y, z) projects to the same pixel as
    (x, y, 0), so dino_patch indexing is trivial.

  Stage B (cross-attention only — Perceiver-IO style):
    Query tokens: past H-1 EEF tokens (causal) + 1 EEF query at current frame = H tokens
    KV tokens:    H × N_patches past patches (cached, small) + V current-frame voxels (large)
    Self-attention among the H query tokens, cross-attention from queries to KV.
    No voxel↔voxel attention. Total compute O(K × V) per layer.

  Output heads: same 7-DoF (xy, height, gripper, rotation) read off the last EEF query.

Two flavors share this skeleton:
  - VoxelARPolicyAbs: PE input = world xyz
  - VoxelARPolicyRel: PE input = world xyz - eef_start_xyz_world

eef_start_xyz_world: the EEF position at the FIRST frame of the current attention context
(t = current_step - H + 1). Documented choice, see inbox spec.
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE   = 16

# Voxel grid resolution (volume cells, image-aligned).
# 56×56×32 = ~100k voxels per frame. Drop to 28×28×16 = 12.5k if OOM (see spec).
VOXEL_XY = 56
VOXEL_Z  = 32

# Reuse the 7-DoF constants from model_autoregressive_v2.
HISTORY_LEN     = 8
GRID_SIZE       = 56
TRANSFORMER_D   = 384
TRANSFORMER_H   = 6
TRANSFORMER_L   = 4
IMAGE_SIZE      = 448
N_HEIGHT_BINS   = 32
N_ROT_BINS      = 32


def _sincos_pe_1d(positions, dim, device, dtype):
    half = dim // 2
    freqs = torch.exp(torch.arange(half, device=device, dtype=dtype) * -(math.log(10000.0) / half))
    angles = positions.unsqueeze(-1) * freqs
    return torch.cat([angles.sin(), angles.cos()], dim=-1)


def _sincos_pe_3d(xyz, dim, device, dtype):
    """xyz: (..., 3); returns (..., dim) summing PEs for each axis (dim must be divisible by 6)."""
    per = dim // 3
    px = _sincos_pe_1d(xyz[..., 0], per, device, dtype)
    py = _sincos_pe_1d(xyz[..., 1], per, device, dtype)
    pz = _sincos_pe_1d(xyz[..., 2], per, device, dtype)
    out = torch.cat([px, py, pz], dim=-1)
    if out.shape[-1] != dim:
        pad = dim - out.shape[-1]
        out = F.pad(out, (0, pad))
    return out


# ───────── PatchEncoder (slim duplicate of v2; same DINO + projection) ───────── #

class _PatchEncoder(nn.Module):
    def __init__(self, target_size=IMAGE_SIZE, d_model=TRANSFORMER_D, freeze_backbone=True):
        super().__init__()
        self.target_size = target_size
        self.patches_per_side = target_size // DINO_PATCH_SIZE
        self.n_patches = self.patches_per_side ** 2
        print(f"PatchEncoder: loading DINOv3 (frozen={freeze_backbone})...")
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus', source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
            self.dino.eval()
        self.embed_dim = self.dino.embed_dim
        assert self.embed_dim == d_model
        self.freeze_backbone = freeze_backbone
        self.proj = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model))

    def _dino_patches(self, x):
        if self.freeze_backbone:
            with torch.no_grad():
                tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
                for blk in self.dino.blocks:
                    rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                    tokens = blk(tokens, rope)
                if self.dino.untie_cls_and_patch_norms:
                    cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                    pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                    tokens = torch.cat([cls_n, pat_n], dim=1)
                else:
                    tokens = self.dino.norm(tokens)
                return tokens[:, self.dino.n_storage_tokens + 1 :].detach()
        else:
            tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
            for blk in self.dino.blocks:
                rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
                tokens = blk(tokens, rope)
            if self.dino.untie_cls_and_patch_norms:
                cls_n = self.dino.cls_norm(tokens[:, : self.dino.n_storage_tokens + 1])
                pat_n = self.dino.norm(tokens[:, self.dino.n_storage_tokens + 1 :])
                tokens = torch.cat([cls_n, pat_n], dim=1)
            else:
                tokens = self.dino.norm(tokens)
            return tokens[:, self.dino.n_storage_tokens + 1 :]

    def forward(self, frames):
        B, W = frames.shape[:2]
        x = frames.view(B * W, *frames.shape[2:])
        patches = self._dino_patches(x)
        patches = self.proj(patches)
        return patches.view(B, W, self.n_patches, self.embed_dim)


# ───────── Cross-attention building block ───────── #

class CrossAttnBlock(nn.Module):
    """Pre-LN cross-attention: q ← q + Attn(LN(q), LN(kv), LN(kv)); q ← q + FFN(LN(q))."""

    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.0):
        super().__init__()
        self.ln_q = nn.LayerNorm(d_model)
        self.ln_kv = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_mult * d_model), nn.GELU(),
            nn.Linear(ffn_mult * d_model, d_model),
        )

    def forward(self, q, kv, attn_mask=None):
        q_n = self.ln_q(q); kv_n = self.ln_kv(kv)
        a, _ = self.attn(q_n, kv_n, kv_n, attn_mask=attn_mask, need_weights=False)
        q = q + a
        q = q + self.ffn(self.ln2(q))
        return q


class SelfAttnBlock(nn.Module):
    """Pre-LN causal self-attention over EEF tokens."""

    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.0):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_mult * d_model), nn.GELU(),
            nn.Linear(ffn_mult * d_model, d_model),
        )

    def forward(self, q, attn_mask=None):
        q_n = self.ln(q)
        a, _ = self.attn(q_n, q_n, q_n, attn_mask=attn_mask, need_weights=False)
        q = q + a
        q = q + self.ffn(self.ln2(q))
        return q


# ───────── Voxel AR Head ───────── #

class VoxelARHead(nn.Module):
    """Perceiver-IO style temporal+voxel attention with H EEF query tokens.

    Inputs to forward:
      patch_tokens:   (B, H, Np, D)   — past H frames' DINO patches (cached)
      eef_history_xy: (B, H, 2)        — past H EEF pixel coords (state EEF, teacher-forced)
      voxel_feats:    (B, V, D)        — voxel features (Linear(PE(xyz)) + image_feat[x_pix,y_pix])
                                          ONLY the current (last) frame's voxels. KV-only.
    Output: dict of 7-DoF logits.
    """

    def __init__(self, history_len=HISTORY_LEN, grid_size=GRID_SIZE,
                 d_model=TRANSFORMER_D, n_heads=TRANSFORMER_H, n_layers=TRANSFORMER_L,
                 patches_per_side=IMAGE_SIZE // DINO_PATCH_SIZE):
        super().__init__()
        self.history_len = history_len
        self.grid_size   = grid_size
        self.d_model     = d_model
        self.patches_per_side = patches_per_side
        self.n_patches   = patches_per_side ** 2

        self.eef_token = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_eef    = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_patch  = nn.Parameter(torch.randn(d_model) * 0.02)
        self.type_embed_voxel  = nn.Parameter(torch.randn(d_model) * 0.02)

        # Layers: interleave self-attn (EEF tokens) + cross-attn (EEF → patches+voxels).
        self.self_blocks  = nn.ModuleList([SelfAttnBlock(d_model, n_heads) for _ in range(n_layers)])
        self.cross_blocks = nn.ModuleList([CrossAttnBlock(d_model, n_heads) for _ in range(n_layers)])

        # Readouts on the last EEF query token
        self.readout_xy = nn.Sequential(
            nn.LayerNorm(d_model), nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, grid_size * grid_size),
        )
        self.feat_norm = nn.LayerNorm(d_model)
        self.height_head   = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, N_HEIGHT_BINS))
        self.gripper_head  = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, 1))
        self.rotation_head = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, 3 * N_ROT_BINS))

        ys = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        xs = (torch.arange(self.patches_per_side) + 0.5) / self.patches_per_side
        gy, gx = torch.meshgrid(ys, xs, indexing='ij')
        self.register_buffer("patch_xy_01",
                             torch.stack([gx, gy], dim=-1).reshape(self.n_patches, 2),
                             persistent=False)

    def forward(self, patch_tokens, eef_history_xy, voxel_feats, target_size=IMAGE_SIZE):
        B, H, Np, D = patch_tokens.shape
        V = voxel_feats.shape[1]
        assert H == self.history_len, f"expected H={self.history_len}, got {H}"
        device = patch_tokens.device
        dtype  = patch_tokens.dtype

        # EEF query tokens: H tokens, one per timestep, with abs-2D PE on their EEF coord
        # + temporal PE. (Causal self-attention will see the past; the last token is the query.)
        eef_01 = eef_history_xy / float(target_size)
        time_idx = torch.arange(H, device=device, dtype=dtype)
        time_pe = _sincos_pe_1d(time_idx, D, device, dtype)            # (H, D)
        eef_abs = torch.cat([
            _sincos_pe_1d(eef_01[..., 0], D // 2, device, dtype),
            _sincos_pe_1d(eef_01[..., 1], D // 2, device, dtype),
        ], dim=-1)                                                       # (B, H, D)
        eef_proto = (self.eef_token + self.type_embed_eef).view(1, 1, D).expand(B, H, D)
        eef_q = eef_proto + eef_abs + time_pe                            # (B, H, D)

        # Patch tokens get abs-2D PE only (no temporal frame info needed beyond cache layout —
        # we keep all H frames' patches as KV with no per-frame temporal PE here, to keep the
        # voxel + patch KV pool simple; if needed, add per-frame PE in a future ablation).
        patch_abs = torch.cat([
            _sincos_pe_1d(self.patch_xy_01.to(dtype)[..., 0], D // 2, device, dtype),
            _sincos_pe_1d(self.patch_xy_01.to(dtype)[..., 1], D // 2, device, dtype),
        ], dim=-1).unsqueeze(0).expand(B, Np, D)                         # (B, Np, D)
        # Add same patch PE to every frame, then flatten H × Np
        patches_kv = (patch_tokens + self.type_embed_patch + patch_abs.unsqueeze(1)).reshape(B, H * Np, D)

        # Voxel tokens get a different type embedding; they already carry PE(xyz) from caller.
        voxels_kv = voxel_feats + self.type_embed_voxel                   # (B, V, D)

        kv = torch.cat([patches_kv, voxels_kv], dim=1)                    # (B, H*Np + V, D)

        # Causal mask for EEF self-attention: token t may attend to tokens ≤ t.
        causal = torch.zeros(H, H, device=device)
        causal.masked_fill_(torch.triu(torch.ones(H, H, device=device, dtype=torch.bool), diagonal=1),
                            float("-inf"))

        for sa, ca in zip(self.self_blocks, self.cross_blocks):
            eef_q = sa(eef_q, attn_mask=causal)
            eef_q = ca(eef_q, kv)

        # Readout from the last (current) EEF query token
        q = eef_q[:, -1, :]
        xy_logits = self.readout_xy(q)
        f = self.feat_norm(q)
        height_logits  = self.height_head(f)
        gripper_logit  = self.gripper_head(f).squeeze(-1)
        rotation_logits = self.rotation_head(f).view(-1, 3, N_ROT_BINS)
        return {
            "xy_logits": xy_logits,
            "height_logits": height_logits,
            "gripper_logit": gripper_logit,
            "rotation_logits": rotation_logits,
        }


# ───────── Voxel feature builder ───────── #

class VoxelFeatureBuilder(nn.Module):
    """Build voxel features for the current frame.

    Given:
      patch_tokens:   (B, Np, D)            — current frame's DINO patches (already projected)
      cam_K:          (B, 3, 3)             — image-pixel intrinsics
      cam_extrinsic:  (B, 4, 4)             — camera→world (used to unproject voxel centers to world)
      eef_start_xyz_world: (B, 3) or None   — if not None, subtract from xyz before PE (variant C)

    The voxel grid is image-aligned: x,y span the image pixel space, z spans
    [MIN_HEIGHT, MAX_HEIGHT] in world coords. For each voxel:
      - find its (x_pix, y_pix) → look up the corresponding DINO patch via bilinear/nearest sample
      - unproject (x_pix, y_pix, world_z) → world xyz using the camera matrices
      - feature = Linear(PE(world_xyz [- eef_start_xyz])) + patch_feature

    Returns:
      voxel_feats: (B, V, D) where V = G_xy * G_xy * G_z
      voxel_xyz_world: (B, V, 3) — for inspection / debug
    """

    def __init__(self, image_size=IMAGE_SIZE, grid_xy=VOXEL_XY, grid_z=VOXEL_Z,
                 min_height=0.85, max_height=1.55, d_model=TRANSFORMER_D,
                 patches_per_side=IMAGE_SIZE // DINO_PATCH_SIZE):
        super().__init__()
        self.image_size = image_size
        self.grid_xy = grid_xy
        self.grid_z = grid_z
        self.min_h = min_height
        self.max_h = max_height
        self.d_model = d_model
        self.patches_per_side = patches_per_side
        self.pe_to_d = nn.Linear(d_model, d_model)  # Linear over the sincos PE → D
        # Precompute voxel pixel centers (grid_xy × grid_xy) and height bin centers (grid_z)
        cell = image_size / grid_xy
        xs = (torch.arange(grid_xy) + 0.5) * cell                          # (Gxy,) pixel x
        ys = (torch.arange(grid_xy) + 0.5) * cell                          # (Gxy,) pixel y
        zs = torch.linspace(min_height, max_height, grid_z)                 # (Gz,) world z
        self.register_buffer("vox_px", xs, persistent=False)
        self.register_buffer("vox_py", ys, persistent=False)
        self.register_buffer("vox_pz", zs, persistent=False)

    def _patch_lookup(self, patch_tokens):
        """patch_tokens: (B, Np, D). Returns (B, grid_xy, grid_xy, D) — bilinear-sampled at voxel xy."""
        B, Np, D = patch_tokens.shape
        Ps = self.patches_per_side
        feats = patch_tokens.view(B, Ps, Ps, D).permute(0, 3, 1, 2)         # (B, D, Ps, Ps)
        # Build sample coords in normalized [-1, 1]
        nx = (self.vox_px / self.image_size) * 2 - 1                        # (Gxy,)
        ny = (self.vox_py / self.image_size) * 2 - 1
        gy, gx = torch.meshgrid(ny, nx, indexing='ij')                       # (Gxy, Gxy)
        grid = torch.stack([gx, gy], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)  # (B, Gxy, Gxy, 2)
        sampled = F.grid_sample(feats, grid, mode='bilinear', align_corners=False)  # (B, D, Gxy, Gxy)
        return sampled.permute(0, 2, 3, 1)                                   # (B, Gxy, Gxy, D)

    def forward(self, patch_tokens, cam_K, cam_extrinsic, eef_start_xyz_world=None):
        """Build voxel feats. patch_tokens: (B, Np, D). cam_K: (B,3,3). cam_extrinsic: (B,4,4)."""
        B = patch_tokens.shape[0]
        D = self.d_model
        device = patch_tokens.device
        dtype  = patch_tokens.dtype

        # 1. Per-(x_pix, y_pix) patch features (Gxy × Gxy × D)
        patch_feats_xy = self._patch_lookup(patch_tokens)                    # (B, Gxy, Gxy, D)
        # Broadcast across the Z axis to build (B, Gxy, Gxy, Gz, D)
        feats = patch_feats_xy.unsqueeze(3).expand(-1, -1, -1, self.grid_z, -1)

        # 2. Compute world xyz for each voxel center
        # vox_px, vox_py are pixel coords; need to invert cam_K and cam_extrinsic to get world xyz
        # given a depth (we have world_z, not depth — but image-aligned model uses height as z directly).
        # For PE purposes we just need a 3D coord; using (vox_px, vox_py, vox_pz) directly as the
        # "geometry channel" is also reasonable (image-pixel-aligned x,y + world-z).
        # The original spec was world xyz via unprojection; here we'll do the cheap version first
        # (pixel-xy + world-z), then upgrade to full unprojection only if results are weak.
        gx, gy = self.grid_xy, self.grid_xy
        # Build xyz tensor (B, Gxy, Gxy, Gz, 3) cheaply: pixel x/y normalized to [0,1] × world z
        px_n = (self.vox_px.to(dtype) / self.image_size).view(1, 1, gx, 1, 1).expand(B, gy, gx, self.grid_z, 1)
        py_n = (self.vox_py.to(dtype) / self.image_size).view(1, gy, 1, 1, 1).expand(B, gy, gx, self.grid_z, 1)
        pz   = self.vox_pz.to(dtype).view(1, 1, 1, self.grid_z, 1).expand(B, gy, gx, self.grid_z, 1)
        xyz = torch.cat([px_n, py_n, pz], dim=-1)                           # (B, Gxy, Gxy, Gz, 3)

        if eef_start_xyz_world is not None:
            xyz = xyz - eef_start_xyz_world.view(B, 1, 1, 1, 3)

        pe = _sincos_pe_3d(xyz, D, device, dtype)                            # (B, Gxy, Gxy, Gz, D)
        pe = self.pe_to_d(pe)

        feats = feats + pe                                                   # (B, Gxy, Gxy, Gz, D)
        return feats.reshape(B, gy * gx * self.grid_z, D), xyz.reshape(B, -1, 3)


# ───────── Convenience policies ───────── #

class _VoxelARPolicyBase(nn.Module):
    def __init__(self, target_size=IMAGE_SIZE, history_len=HISTORY_LEN,
                 grid_size=GRID_SIZE, voxel_xy=VOXEL_XY, voxel_z=VOXEL_Z,
                 d_model=TRANSFORMER_D, n_heads=TRANSFORMER_H, n_layers=TRANSFORMER_L,
                 freeze_backbone=True, use_eef_relative=False,
                 min_height=0.85, max_height=1.55):
        super().__init__()
        self.target_size = target_size
        self.history_len = history_len
        self.grid_size   = grid_size
        self.use_eef_relative = use_eef_relative
        self.patch_encoder = _PatchEncoder(target_size, d_model, freeze_backbone)
        self.voxel_builder = VoxelFeatureBuilder(
            image_size=target_size, grid_xy=voxel_xy, grid_z=voxel_z,
            min_height=min_height, max_height=max_height, d_model=d_model,
            patches_per_side=target_size // DINO_PATCH_SIZE,
        )
        self.ar_head = VoxelARHead(
            history_len=history_len, grid_size=grid_size,
            d_model=d_model, n_heads=n_heads, n_layers=n_layers,
            patches_per_side=target_size // DINO_PATCH_SIZE,
        )

    def forward(self, frames, eef_history_xy, cam_K, cam_extrinsic,
                eef_start_xyz_world=None):
        """
        frames:              (B, H, 3, target_size, target_size)
        eef_history_xy:      (B, H, 2)
        cam_K:               (B, 3, 3)
        cam_extrinsic:       (B, 4, 4)
        eef_start_xyz_world: (B, 3)  required for variant C
        """
        patches = self.patch_encoder(frames)                                 # (B, H, Np, D)
        current = patches[:, -1]                                              # (B, Np, D)
        anchor = eef_start_xyz_world if self.use_eef_relative else None
        voxel_feats, _ = self.voxel_builder(current, cam_K, cam_extrinsic, anchor)
        return self.ar_head(patches, eef_history_xy, voxel_feats, self.target_size)


class VoxelARPolicyAbs(_VoxelARPolicyBase):
    def __init__(self, **kwargs):
        kwargs["use_eef_relative"] = False
        super().__init__(**kwargs)


class VoxelARPolicyRel(_VoxelARPolicyBase):
    def __init__(self, **kwargs):
        kwargs["use_eef_relative"] = True
        super().__init__(**kwargs)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for cls, label in [(VoxelARPolicyAbs, "abs"), (VoxelARPolicyRel, "rel")]:
        print(f"\n== smoke: {label} ==")
        model = cls(history_len=8, voxel_xy=28, voxel_z=16, freeze_backbone=True).to(device)
        n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable: {n_train:,}")
        B, H = 1, 8
        frames = torch.randn(B, H, 3, 448, 448).to(device)
        eef = torch.rand(B, H, 2).to(device) * 448
        K = torch.eye(3).unsqueeze(0).expand(B, 3, 3).to(device)
        E = torch.eye(4).unsqueeze(0).expand(B, 4, 4).to(device)
        anchor = torch.tensor([[0.4, -0.1, 1.0]]).to(device)
        with torch.no_grad():
            out = model(frames, eef, K, E, eef_start_xyz_world=anchor)
        for k, v in out.items():
            print(f"  {k}: {tuple(v.shape)}")
