"""Two-view (BEV + wrist) DinoVolumeQuery — same architecture skeleton as the
single-view query-MLP model, but with:

  - Shared DINO trunk applied to both views → patch tokens for each
  - 1-2 cross-attention layers at the last spatial layer mixing BEV ↔ wrist
  - Two refine heads → F_bev, F_wrist ∈ (B, d_feat, H, W)
  - Query input: Linear(concat(eef_feat_bev, cls_bev, cls_wrist)) → d_model
  - 5-layer AdaLN-Zero MLP per timestep with sin(t) conditioning
  - Per-step query split into (q_F_bev, q_F_wrist, q_z, q_t)
  - Volume scoring sums contributions from BEV (direct lookup) + WRIST
    (project each voxel's world XYZ through the wrist camera, grid_sample
    F_wrist at the projected uv with padding_mode='zeros' so out-of-frustum
    voxels contribute 0). Plus the height/time terms.
"""
import os, math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

IMG_SIZE       = 448
N_WINDOW       = 8
N_HEIGHT_BINS  = 32
N_GRIPPER_BINS = 32
N_ROT_BINS     = 32
PRED_SIZE      = 56

D_FEAT  = 32
D_SINZ  = 16
D_SINT  = 16
D_MODEL = D_FEAT + D_SINZ + D_SINT      # 64
D_COND  = 128
N_BLOCKS = 5
N_CROSS_ATTN_LAYERS = 2                  # cross-view mixing at the last spatial layer


def sinusoidal_features(n, dim, base=10000.0):
    pos = torch.arange(n, dtype=torch.float32)
    div = torch.exp(torch.arange(0, dim, 2, dtype=torch.float32) * -(math.log(base) / dim))
    pe = torch.zeros(n, dim)
    pe[:, 0::2] = torch.sin(pos.unsqueeze(1) * div)
    pe[:, 1::2] = torch.cos(pos.unsqueeze(1) * div)
    return pe


class AdaLNZeroMLPBlock(nn.Module):
    """DiT-style block on (N, d): LN → FiLM(γ,β) → MLP(d→4d→d) → +α·resid."""
    def __init__(self, d, d_cond, mlp_ratio=4):
        super().__init__()
        self.norm = nn.LayerNorm(d, elementwise_affine=False)
        self.cond_proj = nn.Linear(d_cond, 3 * d)
        nn.init.zeros_(self.cond_proj.weight)
        nn.init.zeros_(self.cond_proj.bias)
        self.mlp = nn.Sequential(
            nn.Linear(d, mlp_ratio * d), nn.GELU(),
            nn.Linear(mlp_ratio * d, d),
        )

    def forward(self, x, cond):
        g, b, a = self.cond_proj(cond).chunk(3, dim=-1)
        h = self.norm(x)
        h = h * (1.0 + g) + b
        h = self.mlp(h)
        return x + a * h


def build_bev_world_xyz_table(K_norm_bev, bev_extrinsic, n_height_bins,
                              min_height, max_height, H, W, image_size, device):
    """Compute world XYZ for every voxel (z_bin, y_grid, x_grid) — BEV camera is static.

    K_norm_bev:    (3, 3) intrinsics normalised by image dims.
    bev_extrinsic: (4, 4) camera→world transform (from get_camera_extrinsic_matrix).
    Returns: (Z, H, W, 3) world XYZ table.
    """
    K = K_norm_bev.clone()
    K[0] *= float(image_size); K[1] *= float(image_size)
    scale = float(image_size) / float(H)
    ys = (torch.arange(H, device=device).float() + 0.5) * scale
    xs = (torch.arange(W, device=device).float() + 0.5) * scale
    grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy')
    K_inv = torch.inverse(K)
    ones = torch.ones_like(grid_x)
    uv1 = torch.stack([grid_x, grid_y, ones], dim=-1).reshape(-1, 3).T          # (3, H*W)
    rays_cam = (K_inv @ uv1).T.reshape(H, W, 3)                                   # (H, W, 3) in cam frame
    # cam→world: rotate ray, translation = camera position
    R_cw = bev_extrinsic[:3, :3]
    t_cw = bev_extrinsic[:3, 3]
    rays_world = (R_cw @ rays_cam.reshape(-1, 3).T).T.reshape(H, W, 3)

    # For each height bin, solve for the scalar s such that
    #   world_z(pix) = t_cw[2] + s * rays_world[..., 2] == target_z
    # → s = (target_z - t_cw[2]) / rays_world[..., 2]
    heights = torch.linspace(min_height, max_height, n_height_bins, device=device)   # (Z,)
    # rays_world[..., 2] is negative for downward-looking cameras — DO NOT clamp_min.
    # Guard divide-by-zero by adding a tiny eps with matching sign.
    rwz = rays_world[..., 2].unsqueeze(0)
    rwz_safe = torch.where(rwz.abs() < 1e-6, torch.full_like(rwz, -1e-6), rwz)
    s = (heights.view(n_height_bins, 1, 1) - t_cw[2]) / rwz_safe
    # world_xyz = t_cw + s * rays_world
    xyz = t_cw.view(1, 1, 1, 3) + s.unsqueeze(-1) * rays_world.unsqueeze(0)       # (Z, H, W, 3)
    return xyz                                                                    # (Z, H, W, 3)


def build_bev_world_xyz_table_batched(K_norm_bev, bev_extrinsic, n_height_bins,
                                        min_height, max_height, H, W, image_size):
    """Batched version of build_bev_world_xyz_table.

    K_norm_bev:    (B, 3, 3) intrinsics normalised by image dims.
    bev_extrinsic: (B, 4, 4) camera→world transform per sample.
    Returns: (B, Z, H, W, 3) world XYZ table per sample.
    """
    device = K_norm_bev.device
    B = K_norm_bev.shape[0]
    K = K_norm_bev.clone()
    K[:, 0] *= float(image_size); K[:, 1] *= float(image_size)                    # (B, 3, 3)
    scale = float(image_size) / float(H)
    ys = (torch.arange(H, device=device).float() + 0.5) * scale
    xs = (torch.arange(W, device=device).float() + 0.5) * scale
    grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy')
    ones = torch.ones_like(grid_x)
    uv1 = torch.stack([grid_x, grid_y, ones], dim=-1).reshape(-1, 3).T            # (3, H*W)
    K_inv = torch.inverse(K)                                                       # (B, 3, 3)
    rays_cam = (K_inv @ uv1.unsqueeze(0).expand(B, -1, -1)).transpose(1, 2)        # (B, H*W, 3)
    rays_cam = rays_cam.reshape(B, H, W, 3)
    R_cw = bev_extrinsic[:, :3, :3]                                                # (B, 3, 3)
    t_cw = bev_extrinsic[:, :3, 3]                                                 # (B, 3)
    rays_world = torch.einsum('bij,bhwj->bhwi', R_cw, rays_cam)                    # (B, H, W, 3)

    heights = torch.linspace(min_height, max_height, n_height_bins, device=device) # (Z,)
    rwz = rays_world[..., 2].unsqueeze(1)                                          # (B, 1, H, W)
    rwz_safe = torch.where(rwz.abs() < 1e-6, torch.full_like(rwz, -1e-6), rwz)
    s = (heights.view(1, n_height_bins, 1, 1) - t_cw[:, 2].view(B, 1, 1, 1)) / rwz_safe  # (B, Z, H, W)
    xyz = t_cw.view(B, 1, 1, 1, 3) + s.unsqueeze(-1) * rays_world.unsqueeze(1)     # (B, Z, H, W, 3)
    return xyz


def project_world_to_wrist_uv_grid(xyz_table, wrist_K_norm, wrist_extrinsic, image_size, return_mask=False):
    """xyz_table:        (Z, H, W, 3) OR (B, Z, H, W, 3) in world.
       wrist_K_norm:     (B, 3, 3) normalised intrinsics.
       wrist_extrinsic:  (B, 4, 4) camera→world (from get_camera_extrinsic_matrix).
       Returns (B, Z, H, W, 2) of normalised [-1, 1] UV grid for grid_sample.
       If return_mask=True, also returns in_frustum_mask (B, Z, H, W) bool.
    """
    B = wrist_K_norm.shape[0]
    K = wrist_K_norm.clone()
    K[:, 0] *= float(image_size); K[:, 1] *= float(image_size)                    # (B, 3, 3)
    # world→cam = inv(cam→world)
    world_to_cam = torch.inverse(wrist_extrinsic)                                 # (B, 4, 4)
    if xyz_table.dim() == 4:
        # (Z, H, W, 3) — broadcast to all batch
        Z, H, W, _ = xyz_table.shape
        xyz_flat = xyz_table.reshape(-1, 3)                                        # (Z*H*W, 3)
        pts_h = torch.cat([xyz_flat, torch.ones_like(xyz_flat[:, :1])], dim=-1)   # (Z*H*W, 4)
        pts_cam = world_to_cam @ pts_h.T.unsqueeze(0).expand(B, -1, -1)           # (B, 4, Z*H*W)
    else:
        # (B, Z, H, W, 3) — per-sample
        _, Z, H, W, _ = xyz_table.shape
        xyz_flat = xyz_table.reshape(B, -1, 3)                                     # (B, Z*H*W, 3)
        pts_h = torch.cat([xyz_flat, torch.ones_like(xyz_flat[..., :1])], dim=-1) # (B, Z*H*W, 4)
        pts_cam = torch.einsum('bij,bnj->bin', world_to_cam, pts_h)               # (B, 4, Z*H*W)
    z_cam = pts_cam[:, 2]                                                          # (B, Z*H*W)
    behind = z_cam <= 1e-3
    pts_norm = pts_cam[:, :3] / z_cam.clamp_min(1e-6).unsqueeze(1)                # (B, 3, Z*H*W)
    pix = K @ pts_norm                                                             # (B, 3, Z*H*W)
    pix_uv = pix[:, :2]                                                            # (B, 2, Z*H*W)
    # Normalise to [-1, 1] for grid_sample (u ↔ width, v ↔ height)
    u = pix_uv[:, 0] / float(image_size - 1) * 2 - 1                              # (B, Z*H*W)
    v = pix_uv[:, 1] / float(image_size - 1) * 2 - 1
    # Mark behind-cam or out-of-bounds as far outside [-1, 1] so grid_sample returns 0
    invalid = behind | (u.abs() > 1.0) | (v.abs() > 1.0)
    u = torch.where(invalid, torch.full_like(u, 2.0), u)
    v = torch.where(invalid, torch.full_like(v, 2.0), v)
    grid = torch.stack([u, v], dim=-1).view(B, Z, H, W, 2)                        # (B, Z, H, W, 2)
    if return_mask:
        in_frustum = (~invalid).view(B, Z, H, W)
        return grid, in_frustum
    return grid


class DinoVolumeQuery2View(nn.Module):
    def __init__(self,
                 n_window=N_WINDOW, n_height_bins=N_HEIGHT_BINS,
                 n_gripper_bins=N_GRIPPER_BINS, n_rot_bins=N_ROT_BINS,
                 d_feat=D_FEAT, d_sin_z=D_SINZ, d_sin_t=D_SINT,
                 d_cond=D_COND, n_blocks=N_BLOCKS,
                 n_cross_layers=N_CROSS_ATTN_LAYERS,
                 image_size=IMG_SIZE, pred_size=PRED_SIZE,
                 rotation_mode='1d_pca', kmeans_n_clusters=0,
                 freeze_backbone=False,
                 fusion_mode='sum'):
        super().__init__()
        # fusion_mode:
        #   'sum'        — original (raw scores added, OOF voxels get 0 wrist score → implicit bias)
        #   'max'        — logsumexp over views per voxel
        #   'poe'        — product of experts (per-view softmax + sum log-probs)
        #   'oof_mask'   — OOF voxels get a learned per-timestep oof_logit
        #   'aug_token'  — append +1 learned-feature abstain token to the volume; OOF voxels get bev-only
        self.fusion_mode     = fusion_mode
        # Per-timestep OOF logit (used by 'oof_mask' mode)
        self.wrist_oof_logit = nn.Parameter(torch.zeros(n_window))
        # Learned feature for the abstain (+1) token (used by 'aug_token' mode)
        self.F_oof_token     = nn.Parameter(torch.randn(d_feat) * 0.02)
        self.n_window        = n_window
        self.n_height_bins   = n_height_bins
        self.n_gripper_bins  = n_gripper_bins
        self.n_rot_bins      = n_rot_bins
        self.d_feat          = d_feat
        self.d_sin_z         = d_sin_z
        self.d_sin_t         = d_sin_t
        self.d_model         = d_feat + d_sin_z + d_sin_t
        self.image_size      = image_size
        self.pred_size       = pred_size
        assert rotation_mode in ('1d_pca', 'kmeans', 'per_axis')
        self.rotation_mode   = rotation_mode
        self.kmeans_n_clusters = kmeans_n_clusters

        # Shared DINO trunk
        self.dino = torch.hub.load(DINO_REPO_DIR, 'dinov3_vits16plus',
                                    source='local', weights=DINO_WEIGHTS_PATH)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad = False
        self.embed_dim = self.dino.embed_dim

        # Cross-view mixing transformer (1-2 layers).
        # Operates on (B, n_bev_tokens + n_wrist_tokens, embed_dim) — every token attends globally.
        # Tiny: ~1.5k tokens × 384-d × 2 layers.
        enc_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim, nhead=6, dim_feedforward=self.embed_dim * 2,
            dropout=0.0, activation='gelu', batch_first=True, norm_first=True,
        )
        self.cross_attn = nn.TransformerEncoder(enc_layer, num_layers=n_cross_layers)
        # View-type embedding so the cross-attn can distinguish BEV vs wrist tokens
        self.view_emb = nn.Embedding(2, self.embed_dim)

        # Two refine heads (separate weights per view — different appearance distributions)
        self.refine_bev = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, d_feat, 1),
        )
        self.refine_wrist = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.embed_dim, 3, padding=1), nn.GELU(),
            nn.Conv2d(self.embed_dim, d_feat, 1),
        )

        # Sinusoidal PE buffers
        self.register_buffer("z_sin", sinusoidal_features(n_height_bins, d_sin_z), persistent=False)
        self.register_buffer("t_sin", sinusoidal_features(n_window, d_sin_t), persistent=False)
        self.t_cond_proj = nn.Linear(d_sin_t, d_cond)

        # Query input: concat(eef_feat_bev (d_feat), cls_bev (embed), cls_wrist (embed)) → d_model
        self.input_proj = nn.Linear(d_feat + 2 * self.embed_dim, self.d_model)
        self.blocks = nn.ModuleList([
            AdaLNZeroMLPBlock(self.d_model, d_cond) for _ in range(n_blocks)
        ])
        self.final_norm = nn.LayerNorm(self.d_model)

        # Heads. The spatial query is split into (q_F_bev, q_F_wrist, q_z, q_t):
        # we need 2 * d_feat + d_sin_z + d_sin_t = 64 + 32 (+ z + t) = 96 total dims.
        # So the spatial-query output dim is d_feat + d_feat + d_sin_z + d_sin_t.
        self.q_head_dim = 2 * d_feat + d_sin_z + d_sin_t
        self.q_head    = nn.Linear(self.d_model, self.q_head_dim)
        self.grip_head = nn.Linear(self.d_model, n_gripper_bins)
        if rotation_mode == 'per_axis':
            rot_out_dim = 3 * n_rot_bins
        elif rotation_mode == '1d_pca':
            rot_out_dim = n_rot_bins
        else:
            assert kmeans_n_clusters > 0
            rot_out_dim = kmeans_n_clusters
        self.rot_head  = nn.Linear(self.d_model, rot_out_dim)

    def _extract_dino_tokens(self, x):
        """Returns (cls_token (B, embed), patch_tokens (B, n_patch, embed), (H_p, W_p))."""
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope)
        if self.dino.untie_cls_and_patch_norms:
            x_cls = self.dino.cls_norm(x_tokens[:, :self.dino.n_storage_tokens + 1])
            x_pat = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1:])
            x_tokens = torch.cat([x_cls, x_pat], dim=1)
        else:
            x_tokens = self.dino.norm(x_tokens)
        cls = x_tokens[:, 0]                                              # (B, embed)
        patches = x_tokens[:, self.dino.n_storage_tokens + 1:]            # (B, H_p*W_p, embed)
        return cls, patches, (H_p, W_p)

    def forward(self, rgb_bev, rgb_wrist, start_pix_bev, bev_xyz_table,
                wrist_K_norm, wrist_extrinsic):
        """rgb_bev / rgb_wrist: (B, 3, IMG, IMG)
        start_pix_bev: (B, 2) — current EEF pixel in IMG-coords on the BEV image
        bev_xyz_table: (Z, H, W, 3) — world XYZ at each (z_bin, y_grid, x_grid) voxel
                       (static for libero — same for all samples in batch).
        wrist_K_norm: (B, 3, 3) normalised wrist intrinsics
        wrist_world_to_cam: (B, 4, 4) wrist world→camera per sample
        """
        B = rgb_bev.shape[0]
        T, Z, d = self.n_window, self.n_height_bins, self.d_model

        # ── DINO forward on both views ──
        cls_bev,   patches_bev,   (Hp_b, Wp_b) = self._extract_dino_tokens(rgb_bev)
        cls_wrist, patches_wrist, (Hp_w, Wp_w) = self._extract_dino_tokens(rgb_wrist)
        n_p = Hp_b * Wp_b                                                 # patches per view
        assert (Hp_b, Wp_b) == (Hp_w, Wp_w), "Both views must produce the same patch grid"

        # ── Cross-view attention ──
        # Concat (BEV ; wrist) tokens + add a view-type embedding so they're distinguishable.
        view_ids = torch.cat([
            torch.zeros(n_p, dtype=torch.long, device=rgb_bev.device),
            torch.ones (n_p, dtype=torch.long, device=rgb_bev.device),
        ])                                                                # (2*n_p,)
        view_e = self.view_emb(view_ids).unsqueeze(0)                     # (1, 2*n_p, embed)
        joint = torch.cat([patches_bev, patches_wrist], dim=1) + view_e   # (B, 2*n_p, embed)
        joint = self.cross_attn(joint)                                     # (B, 2*n_p, embed)
        pat_bev_x   = joint[:, :n_p]                                       # (B, n_p, embed)
        pat_wrist_x = joint[:, n_p:]

        # Reshape patches → (B, embed, Hp, Wp), upsample to pred_size, refine
        def _to_grid(p):
            return p.reshape(B, Hp_b, Wp_b, self.embed_dim).permute(0, 3, 1, 2).contiguous()
        feat_bev_up   = F.interpolate(_to_grid(pat_bev_x),   size=(self.pred_size, self.pred_size),
                                       mode='bilinear', align_corners=False)
        feat_wrist_up = F.interpolate(_to_grid(pat_wrist_x), size=(self.pred_size, self.pred_size),
                                       mode='bilinear', align_corners=False)
        F_bev   = self.refine_bev  (feat_bev_up)                          # (B, d_feat, H, W)
        F_wrist = self.refine_wrist(feat_wrist_up)                        # (B, d_feat, H, W)
        H, W = F_bev.shape[-2:]

        # ── EEF feature from BEV ──
        sx = (start_pix_bev[..., 0] * (W / self.image_size)).long().clamp(0, W - 1)
        sy = (start_pix_bev[..., 1] * (H / self.image_size)).long().clamp(0, H - 1)
        b_idx = torch.arange(B, device=rgb_bev.device)
        eef_feat = F_bev[b_idx, :, sy, sx]                                 # (B, d_feat)

        # ── Query input: concat(eef_feat, cls_bev, cls_wrist) → d_model, broadcast across T ──
        q_in = self.input_proj(torch.cat([eef_feat, cls_bev, cls_wrist], dim=-1))
        q_in_bt = q_in.unsqueeze(1).expand(B, T, d).reshape(B * T, d)

        # AdaLN conditioning: sin(t) projected, broadcast across batch
        cond_t  = self.t_cond_proj(self.t_sin)                              # (T, d_cond)
        cond_bt = cond_t.unsqueeze(0).expand(B, T, -1).reshape(B * T, -1)

        h = q_in_bt
        for blk in self.blocks:
            h = blk(h, cond_bt)
        h = self.final_norm(h)
        penult = h.view(B, T, d)                                            # (B, T, d_model)

        q_spatial = self.q_head(penult)                                     # (B, T, 2d_feat + d_sin_z + d_sin_t)
        gripper   = self.grip_head(penult)                                  # (B, T, n_grip)
        if self.rotation_mode == 'per_axis':
            rotation = self.rot_head(penult).view(B, T, 3, self.n_rot_bins)
        else:
            rotation = self.rot_head(penult)

        # Split q_spatial into (q_F_bev, q_F_wrist, q_z, q_t)
        d_F, d_z, d_t = self.d_feat, self.d_sin_z, self.d_sin_t
        q_F_bev   = q_spatial[..., :d_F]
        q_F_wrist = q_spatial[..., d_F:2 * d_F]
        q_z       = q_spatial[..., 2 * d_F:2 * d_F + d_z]
        q_t       = q_spatial[..., 2 * d_F + d_z:]

        # ── Volume scoring ──
        # 1) BEV: direct einsum
        score_bev_yx = torch.einsum('btc, bchw -> bthw', q_F_bev, F_bev)    # (B, T, H, W)
        # 2) Wrist: project each voxel through wrist camera and grid-sample F_wrist
        with torch.no_grad():
            xyz = bev_xyz_table  # (Z, H, W, 3) — same world coords for every sample
            uv_grid, in_frustum = project_world_to_wrist_uv_grid(
                xyz, wrist_K_norm, wrist_extrinsic, self.image_size, return_mask=True,
            )                                                                # (B, Z, H, W, 2), (B, Z, H, W) bool
        Bv, Zv, Hv, Wv, _ = uv_grid.shape
        grid_flat = uv_grid.view(Bv, Zv * Hv, Wv, 2)
        F_w_sampled = F.grid_sample(F_wrist, grid_flat, mode='bilinear',
                                     padding_mode='zeros', align_corners=True)
        F_w_sampled = F_w_sampled.view(Bv, self.d_feat, Zv, Hv, Wv)
        # Wrist score from features (only meaningful for in-frustum voxels)
        score_wrist_from_feat = torch.einsum('btc, bczhw -> btzhw', q_F_wrist, F_w_sampled)  # (B, T, Z, H, W)
        in_frustum_BTZHW = in_frustum.unsqueeze(1).float()                     # (B, 1, Z, H, W)
        if self.fusion_mode == 'oof_mask':
            # OOF voxels get learned per-timestep global logit
            oof_logit_BT = self.wrist_oof_logit.view(1, -1, 1, 1, 1)
            score_wrist_zyx = (in_frustum_BTZHW * score_wrist_from_feat
                                + (1.0 - in_frustum_BTZHW) * oof_logit_BT)
        elif self.fusion_mode == 'aug_token':
            # OOF voxels get NO wrist contribution (just bev). The wrist's "abstain" energy
            # goes to the +1 augmented token below.
            score_wrist_zyx = in_frustum_BTZHW * score_wrist_from_feat
        else:
            # 'sum', 'max', 'poe' — keep original (OOF voxels = 0 wrist contribution from grid_sample's padding)
            score_wrist_zyx = score_wrist_from_feat

        # Z and T terms (additive, broadcast)
        score_z = torch.einsum('btc, zc -> btz', q_z, self.z_sin)            # (B, T, Z)
        score_t = torch.einsum('btc, tc -> bt',  q_t, self.t_sin)            # (B, T)

        # Build per-view full volumes BEFORE fusion. Each is (B, T, Z, H, W).
        z_term = score_z.unsqueeze(-1).unsqueeze(-1)                          # (B, T, Z, 1, 1)
        t_term = score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)            # (B, T, 1, 1, 1)
        vol_bev_full   = score_bev_yx.unsqueeze(2) + z_term + t_term          # (B, T, Z, H, W)
        vol_wrist_full = score_wrist_zyx           + z_term + t_term          # (B, T, Z, H, W)

        if self.fusion_mode in ('sum', 'oof_mask'):
            # Add raw scores from both views per voxel. 'oof_mask' replaces OOF wrist score with learned logit (handled above).
            volume_logits = (
                score_bev_yx.unsqueeze(2) + score_wrist_zyx + z_term + t_term
            )
        elif self.fusion_mode == 'aug_token':
            # Spatial volume: bev + (wrist * in_frustum_mask) — OOF voxels get bev only
            vol_spatial = score_bev_yx.unsqueeze(2) + score_wrist_zyx + z_term + t_term  # (B, T, Z, H, W)
            # +1 abstain token: wrist's score against the learned F_oof_token feature
            B_, T_, _, _, _ = vol_spatial.shape
            oof_score = torch.einsum('btc, c -> bt', q_F_wrist, self.F_oof_token).unsqueeze(-1)  # (B, T, 1)
            # Augment: flatten spatial, append +1 token → (B, T, Z*H*W + 1)
            flat_spatial = vol_spatial.reshape(B_, T_, -1)
            volume_logits = torch.cat([flat_spatial, oof_score], dim=-1)
        elif self.fusion_mode == 'max':
            # logsumexp over views per voxel — smooth "most confident view wins"
            stacked = torch.stack([vol_bev_full, vol_wrist_full], dim=0)      # (2, B, T, Z, H, W)
            volume_logits = torch.logsumexp(stacked, dim=0)                   # soft-max over views
        elif self.fusion_mode == 'poe':
            # Product of experts: per-view softmax over the volume → sum of log-probs.
            B_, T_, Z_, H_, W_ = vol_bev_full.shape
            log_p_bev   = vol_bev_full  .reshape(B_, T_, -1).log_softmax(dim=-1).reshape(B_, T_, Z_, H_, W_)
            log_p_wrist = vol_wrist_full.reshape(B_, T_, -1).log_softmax(dim=-1).reshape(B_, T_, Z_, H_, W_)
            volume_logits = log_p_bev + log_p_wrist
        else:
            raise ValueError(f"Unknown fusion_mode: {self.fusion_mode}")

        return {
            "volume_logits":   volume_logits,
            "vol_bev":         vol_bev_full,                  # (B, T, Z, H, W) — BEV-only score volume
            "vol_wrist":       vol_wrist_full,                # (B, T, Z, H, W) — wrist-only score volume (OOF voxels = oof_logit)
            "in_frustum":      in_frustum,                    # (B, Z, H, W) bool — wrist sees this voxel?
            "gripper_logits":  gripper,
            "rotation_logits": rotation,
            "pixel_feats":     F_bev,
            "pixel_feats_wrist": F_wrist,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQuery2View(n_window=8, rotation_mode='1d_pca').to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    # Build a dummy static BEV-XYZ table (would be derived from K + extrinsic in practice)
    bev_xyz = torch.rand(N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE, 3, device=device)
    rgb_bev   = torch.rand(2, 3, IMG_SIZE, IMG_SIZE, device=device)
    rgb_wrist = torch.rand(2, 3, IMG_SIZE, IMG_SIZE, device=device)
    sp        = torch.rand(2, 2, device=device) * IMG_SIZE
    wrist_K   = torch.eye(3, device=device).unsqueeze(0).expand(2, -1, -1).clone()
    wrist_wtc = torch.eye(4, device=device).unsqueeze(0).expand(2, -1, -1).clone()
    with torch.no_grad():
        out = m(rgb_bev, rgb_wrist, sp, bev_xyz, wrist_K, wrist_wtc)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
