"""DinoVolumeQuery, image-concat 2view variant.

Inputs (rgb_bev, rgb_wrist) → horizontally concat to (3, IMG, 2*IMG) → DinoV3 → patches
over 28×56 → upsample to 56×112 → take first 56 columns (BEV half) for volume head.

DINO's self-attention naturally crosses both views via the extra patches; the wrist's
features influence the BEV-side patches through attention, but the world-space scoring
remains anchored on the BEV camera (which is the only one with a fixed world frame here).
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from model_dino_volume_query import (
    DINO_REPO_DIR, DINO_WEIGHTS_PATH,
    N_WINDOW, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS,
    D_FEAT, D_SINZ, D_SINT, D_COND, N_BLOCKS,
    IMG_SIZE, PRED_SIZE,
    sinusoidal_features, AdaLNZeroMLPBlock, DinoVolumeQuery,
)


class DinoVolumeQueryConcat(DinoVolumeQuery):
    """Image-concat variant: takes (rgb_bev, rgb_wrist) and concatenates side-by-side."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Concat width = 2 * IMG_SIZE; DINO accepts any shape that's a multiple of patch_size

    def forward(self, rgb_bev, rgb_wrist, start_pix, kp_zyx=None):
        """rgb_bev, rgb_wrist: each (B, 3, IMG, IMG)
           start_pix: (B, 2) — EEF pixel in BEV frame (IMG-coords).

        BEV is left half, wrist is right half. We crop the BEV-half from the upsampled
        feature map so volume_logits live in BEV pixel/world space (matched to bev_xyz table
        used at inference for 3D recovery).
        """
        B = rgb_bev.shape[0]
        T = self.n_window
        Z = self.n_height_bins
        d = self.d_model

        # Horizontal concat: (B, 3, IMG, 2*IMG) — BEV left, wrist right
        rgb_cat = torch.cat([rgb_bev, rgb_wrist], dim=-1)

        # DINO extracts patches over the full concat. With IMG=448, patch=16:
        # patches (28, 56) for the full image. We need to handle DINO's positional
        # embedding for non-square input — most DINOv3 hubs interpolate pos embed on demand.
        patch_cat, cls = self._extract_dino_features(rgb_cat)                     # (B, embed, H_p_full, W_p_full)
        # Upsample to (pred_size, 2*pred_size)
        feat_up = F.interpolate(patch_cat, size=(self.pred_size, 2 * self.pred_size),
                                 mode='bilinear', align_corners=False)
        F_full = self.refine(feat_up)                                              # (B, d_feat, H, 2W)
        # Take BEV half (columns 0..pred_size)
        F_feat = F_full[..., :self.pred_size]                                       # (B, d_feat, H, W)
        H, W = F_feat.shape[-2:]

        # Query input: (eef_feat ⊕ cls) — start_pix is in BEV frame so unchanged
        if self.use_eef:
            sx = (start_pix[..., 0] * (W / self.image_size)).long().clamp(0, W - 1)
            sy = (start_pix[..., 1] * (H / self.image_size)).long().clamp(0, H - 1)
            b_idx = torch.arange(B, device=rgb_bev.device)
            eef_feat = F_feat[b_idx, :, sy, sx]
            q_in = self.input_proj(torch.cat([eef_feat, cls], dim=-1))
        else:
            q_in = self.input_proj(cls)
        q_in_bt = q_in.unsqueeze(1).expand(B, T, d).reshape(B * T, d)

        cond_t  = self.t_cond_proj(self.t_sin)
        cond_bt = cond_t.unsqueeze(0).expand(B, T, -1).reshape(B * T, -1)

        h = q_in_bt
        for blk in self.blocks:
            h = blk(h, cond_bt)
        h = self.final_norm(h)
        penult = h.view(B, T, d)

        q_spatial = self.q_head(penult)
        gripper   = self.grip_head(penult)
        if self.rotation_mode == 'per_axis':
            rotation = self.rot_head(penult).view(B, T, 3, self.n_rot_bins)
        else:
            rotation = self.rot_head(penult)

        q_F = q_spatial[..., :self.d_feat]
        q_z = q_spatial[..., self.d_feat:self.d_feat + self.d_sin_z]
        q_t = q_spatial[..., self.d_feat + self.d_sin_z:]

        score_yx = torch.einsum('btc, bchw -> bthw', q_F, F_feat)
        score_z  = torch.einsum('btc, zc   -> btz',  q_z, self.z_sin)
        score_t  = torch.einsum('btc, tc   -> bt',   q_t, self.t_sin)

        volume_logits = (
            score_yx.unsqueeze(2)
            + score_z.unsqueeze(-1).unsqueeze(-1)
            + score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        )

        return {
            "volume_logits":   volume_logits,
            "gripper_logits":  gripper,
            "rotation_logits": rotation,
            "pixel_feats":     F_feat,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQueryConcat(n_window=8, rotation_mode='1d_pca').to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb_bev = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    rgb_wrist = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    sp = torch.rand(2, 2).to(device) * IMG_SIZE
    with torch.no_grad():
        out = m(rgb_bev, rgb_wrist, sp)
    for k, v in out.items():
        if v is not None: print(f"  {k}: {tuple(v.shape)}")
