"""DinoVolumeQuery with DA3 geometric features fused in.

Architecture:
  1. DINOv3 backbone → (B, D_dino, 28, 28) patches
  2. DA3-LARGE backbone (frozen) → last-layer patch tokens at (B, D_da3=2048, 28, 28)
  3. Project DA3 patches: 1×1 conv 2048 → 256
  4. Concat: (D_dino + 256, 28, 28)
  5. Fusion: 2 or 3 conv layers with GELU, 3×3 kernels → (D_dino, 28, 28)
  6. Upsample to (D_dino, 56, 56)
  7. Refine to (d_feat=32, 56, 56) — F_feat
  8. Existing query-MLP volume scoring

DA3 features carry geometric/depth priors. The fusion module learns to combine
DINO's semantic features with DA3's geometric features.
"""
import os
import sys
import types
import torch
import torch.nn as nn
import torch.nn.functional as F

# Stub DA3 utilities we don't need (avoid import errors)
for n in ['depth_anything_3.utils.export', 'depth_anything_3.utils.pose_align']:
    if n not in sys.modules:
        sys.modules[n] = types.ModuleType(n)
sys.modules['depth_anything_3.utils.export'].export = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].align_poses_umeyama = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].batch_align_poses_umeyama = lambda *a, **k: None
sys.path.insert(0, "/data/cameron/da3_repo/src")
from depth_anything_3.api import DepthAnything3

from model_dino_volume_query import (
    DinoVolumeQuery,
    N_WINDOW, N_HEIGHT_BINS, N_GRIPPER_BINS, N_ROT_BINS,
    D_FEAT, D_SINZ, D_SINT, IMG_SIZE, PRED_SIZE,
)

DA3_WEIGHTS_DEFAULT = "/data/cameron/da3_large_weights"
DA3_INPUT_SIZE = 504  # DA3-LARGE was trained at 504


class DinoVolumeQueryDA3(DinoVolumeQuery):
    """1view query-MLP with DA3 geometric features fused in via concat + conv layers."""

    def __init__(self, *args,
                 da3_weights=DA3_WEIGHTS_DEFAULT,
                 da3_proj_dim=256,
                 fusion_layers=2,
                 freeze_da3=True,
                 **kwargs):
        super().__init__(*args, **kwargs)
        # Load DA3 backbone (frozen)
        full = DepthAnything3.from_pretrained(da3_weights)
        self.da3_backbone = full.model.backbone
        del full
        if freeze_da3:
            for p in self.da3_backbone.parameters():
                p.requires_grad = False
            self.da3_backbone.eval()

        # DA3-LARGE patch token dim — query a dummy forward to verify
        self.da3_token_dim = 2048
        self.da3_proj = nn.Conv2d(self.da3_token_dim, da3_proj_dim, kernel_size=1)

        # Fusion: concat (D_dino, da3_proj_dim) → back to D_dino via 2 or 3 3×3 convs
        fused_in = self.embed_dim + da3_proj_dim
        layers = []
        for i in range(fusion_layers):
            in_ch = fused_in if i == 0 else self.embed_dim
            out_ch = self.embed_dim
            layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
            if i < fusion_layers - 1:
                layers.append(nn.GELU())
        self.fusion = nn.Sequential(*layers)
        self.da3_proj_dim = da3_proj_dim
        self.fusion_layers = fusion_layers

    def _extract_da3_patches(self, rgb):
        """Run DA3 backbone, return last-layer patch tokens reshaped to (B, 2048, H_p, W_p).

        DA3 expects (B, S=1, 3, H, W) input.
        """
        # Resize to DA3 native size (504) if needed, then re-resize patches back
        B, _, H_in, W_in = rgb.shape
        # Run at the input size — DA3's vision backbone handles arbitrary input sizes
        x = rgb.unsqueeze(1)  # (B, 1, 3, H, W)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.no_grad():
            with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
                feats, _aux = self.da3_backbone(
                    x, cam_token=None,
                    export_feat_layers=list(getattr(self.da3_backbone, 'out_layers', [5, 7, 9, 11])),
                    ref_view_strategy="saddle_balanced",
                )
        # feats is a list of layer outputs; each is [B, S, N_tokens, C]. Last layer is most semantic.
        last = feats[-1]
        if isinstance(last, (list, tuple)):
            last = last[0]
        # last shape: (B, S=1, N, C). Drop S, drop possible CLS/storage tokens to get patches only.
        last = last[:, 0]  # (B, N, C)
        # DA3 backbone uses patch_size=14 (DINOv2-L). N_tokens = H_p * W_p + n_storage + 1.
        # Compute patch grid: try to recover H_p, W_p from N and stride.
        # Easier path: assume the last-layer tokens follow N = 1 + 4 storage + Hp*Wp
        # For DINOv2-L: patch_size=14 → at H=448 → Hp=32 (≈32x32). Plus 1 CLS + 4 reg = 5 prefix.
        n_storage = getattr(self.da3_backbone, 'n_storage_tokens', 0)
        n_prefix = n_storage + 1  # CLS + storage
        n_patch = last.shape[1] - n_prefix
        Hp_da3 = int(n_patch ** 0.5)
        Wp_da3 = n_patch // Hp_da3
        patches = last[:, n_prefix:n_prefix + Hp_da3 * Wp_da3]                       # (B, Hp*Wp, C)
        patches = patches.reshape(B, Hp_da3, Wp_da3, last.shape[-1]).permute(0, 3, 1, 2).contiguous()
        return patches.float()                                                       # (B, C, Hp, Wp)

    def forward(self, rgb, start_pix, kp_zyx=None):
        B = rgb.shape[0]
        T = self.n_window
        d = self.d_model

        # DINO patches
        patch_dino, cls = self._extract_dino_features(rgb)                          # (B, embed, H_p, W_p), (B, embed)
        H_dino, W_dino = patch_dino.shape[-2:]

        # DA3 patches
        patch_da3 = self._extract_da3_patches(rgb)                                  # (B, 2048, H_da3, W_da3)
        # Resize DA3 to DINO grid size
        if patch_da3.shape[-2:] != (H_dino, W_dino):
            patch_da3 = F.interpolate(patch_da3, size=(H_dino, W_dino),
                                       mode='bilinear', align_corners=False)
        # Project DA3 channels
        da3_proj = self.da3_proj(patch_da3)                                          # (B, da3_proj_dim, H_p, W_p)

        # Concat + fusion
        fused = torch.cat([patch_dino, da3_proj], dim=1)                            # (B, embed + da3_proj, H_p, W_p)
        fused = self.fusion(fused)                                                   # (B, embed, H_p, W_p)

        # Upsample + refine (same as parent)
        feat_up = F.interpolate(fused, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        F_feat = self.refine(feat_up)                                                # (B, d_feat, H, W)
        H, W = F_feat.shape[-2:]

        # Query MLP (same as parent from here)
        if self.use_eef:
            sx = (start_pix[..., 0] * (W / self.image_size)).long().clamp(0, W - 1)
            sy = (start_pix[..., 1] * (H / self.image_size)).long().clamp(0, H - 1)
            b_idx = torch.arange(B, device=rgb.device)
            eef_feat = F_feat[b_idx, :, sy, sx]
            q_in = self.input_proj(torch.cat([eef_feat, cls], dim=-1))
        else:
            q_in = self.input_proj(cls)
        q_in_bt = q_in.unsqueeze(1).expand(B, T, d).reshape(B * T, d)

        cond_t  = self.t_cond_proj(self.t_sin)
        cond_bt = cond_t.unsqueeze(0).expand(B, T, -1).reshape(B * T, -1)

        h = q_in_bt
        for blk in self.blocks:
            h = blk(h, cond_bt)
        h = self.final_norm(h)
        penult = h.view(B, T, d)

        q_spatial = self.q_head(penult)
        gripper   = self.grip_head(penult)
        if self.rotation_mode == 'per_axis':
            rotation = self.rot_head(penult).view(B, T, 3, self.n_rot_bins)
        else:
            rotation = self.rot_head(penult)

        q_F = q_spatial[..., :self.d_feat]
        q_z = q_spatial[..., self.d_feat:self.d_feat + self.d_sin_z]
        q_t = q_spatial[..., self.d_feat + self.d_sin_z:]

        score_yx = torch.einsum('btc, bchw -> bthw', q_F, F_feat)
        score_z  = torch.einsum('btc, zc   -> btz',  q_z, self.z_sin)
        score_t  = torch.einsum('btc, tc   -> bt',   q_t, self.t_sin)

        volume_logits = (
            score_yx.unsqueeze(2)
            + score_z.unsqueeze(-1).unsqueeze(-1)
            + score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        )

        return {
            "volume_logits":   volume_logits,
            "gripper_logits":  gripper,
            "rotation_logits": rotation,
            "pixel_feats":     F_feat,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVolumeQueryDA3(n_window=8, image_size=448, rotation_mode='1d_pca',
                            fusion_layers=2).to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, 448, 448).to(device)
    sp = torch.rand(2, 2).to(device) * 448
    with torch.no_grad():
        out = m(rgb, sp)
    for k, v in out.items():
        if v is not None: print(f"  {k}: {tuple(v.shape)}")
