"""DA3-based volume model with factored KV-attention.

Architecture (Cameron 2026-05-18 spec):
  - DA3 backbone + DPT head; aux head's final 1×1 conv replaced to output KEY_DIM (+1 conf).
  - Per-pixel feature F ∈ R^(B × KEY_DIM × H × W) — "value/query" stream.
  - Learnable height embeddings h_emb ∈ R^(N_HEIGHT_BINS × H_DIM).
  - Learnable time   embeddings t_emb ∈ R^(N_WINDOW × T_DIM).
  - Key per (t, z): key(t, z) = concat(t_emb[t], h_emb[z]) ∈ R^(T_DIM + H_DIM = KEY_DIM).
  - Volume logits via bilinear scoring: l(b, t, z, u, v) = F(b, :, u, v) · key(t, z) / sqrt(KEY_DIM).

We replace the original libero (B, N_WINDOW * N_HEIGHT_BINS, H, W) dense head with the
factored KV decomposition: parameter-efficient, structural inductive bias that height/time
are categorical attributes that share representation across spatial locations.

forward(rgb) returns:
  volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, h_out, w_out)
  pred_depth:    (B, H, W)  — for distillation against frozen DA3 depth
  dino_feats:    list of intermediate DINO features (for PCA viz)
  pixel_feats:   (B, KEY_DIM, h_out, w_out) — for debugging / viz
"""
import sys, types, os, math
for n in ['depth_anything_3.utils.export', 'depth_anything_3.utils.pose_align']:
    m = types.ModuleType(n); sys.modules[n] = m
sys.modules['depth_anything_3.utils.export'].export = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].align_poses_umeyama = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].batch_align_poses_umeyama = lambda *a, **k: None

import torch
import torch.nn as nn

from depth_anything_3.api import DepthAnything3

DA3_WEIGHTS_DEFAULT = "/data/cameron/da3_large_weights"
DA3_INPUT     = 504
N_WINDOW      = 8
N_HEIGHT_BINS = 32
KEY_DIM       = 48   # split into TIME_DIM + HEIGHT_DIM
TIME_DIM      = 24
HEIGHT_DIM    = 24
assert TIME_DIM + HEIGHT_DIM == KEY_DIM, "TIME_DIM + HEIGHT_DIM must equal KEY_DIM"


class DA3VolumeModel(nn.Module):
    def __init__(self, weights_path: str = DA3_WEIGHTS_DEFAULT,
                 n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 key_dim: int = KEY_DIM, time_dim: int = TIME_DIM,
                 height_dim: int = HEIGHT_DIM, dino_feat_layers=None):
        super().__init__()
        assert time_dim + height_dim == key_dim
        self.n_window      = n_window
        self.n_height_bins = n_height_bins
        self.key_dim       = key_dim
        self.time_dim      = time_dim
        self.height_dim    = height_dim

        full = DepthAnything3.from_pretrained(weights_path)
        self.backbone = full.model.backbone
        self.head     = full.model.head
        del full

        if dino_feat_layers is None:
            self.dino_feat_layers = list(getattr(self.backbone, 'out_layers', [5, 7, 9, 11]))
        else:
            self.dino_feat_layers = list(dino_feat_layers)

        # Swap final aux conv to output (key_dim + 1) channels — split into KEY_DIM features
        # and 1 conf channel (DualDPT splits aux into [..., :-1] + [..., -1]).
        last_aux_seq = self.head.scratch.output_conv2_aux[-1]
        old_conv = last_aux_seq[-1]
        new_conv = nn.Conv2d(old_conv.in_channels, key_dim + 1, kernel_size=1, stride=1, padding=0)
        nn.init.zeros_(new_conv.bias)
        # Small but non-zero init so the bilinear scores are not stuck at 0 at start.
        nn.init.normal_(new_conv.weight, std=0.05)
        last_aux_seq[-1] = new_conv

        # Learnable embeddings (init: truncated normal, std=0.1 → keys have ~unit L2 at start)
        self.h_emb = nn.Embedding(n_height_bins, height_dim)
        self.t_emb = nn.Embedding(n_window, time_dim)
        nn.init.trunc_normal_(self.h_emb.weight, std=0.1, a=-0.2, b=0.2)
        nn.init.trunc_normal_(self.t_emb.weight, std=0.1, a=-0.2, b=0.2)

        # Optional LayerNorm on the pixel features before scoring — helps numeric stability.
        self.pixel_norm = nn.LayerNorm(key_dim)

    def _build_keys(self):
        """Returns keys: (n_window, n_height_bins, key_dim).
        key(t, z) = concat(t_emb[t], h_emb[z]).
        """
        T = self.n_window
        Z = self.n_height_bins
        t_e = self.t_emb.weight.unsqueeze(1).expand(T, Z, self.time_dim)   # (T, Z, t_dim)
        h_e = self.h_emb.weight.unsqueeze(0).expand(T, Z, self.height_dim) # (T, Z, h_dim)
        return torch.cat([t_e, h_e], dim=-1)                                # (T, Z, key_dim)

    def forward(self, rgb):
        """rgb: (B, 3, 504, 504) in [0, 1]."""
        x = rgb.unsqueeze(1)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats, _aux = self.backbone(
                x, cam_token=None,
                export_feat_layers=self.dino_feat_layers,
                ref_view_strategy="saddle_balanced",
            )
        H, W = x.shape[-2], x.shape[-1]
        with torch.autocast(device_type=rgb.device.type, enabled=False):
            head_out = self.head(feats, H, W, patch_start_idx=0)

        # depth (B, S, H, W) → drop S
        depth = head_out['depth']
        if depth.dim() == 5:
            pred_depth = depth[:, 0, 0]
        elif depth.dim() == 4:
            pred_depth = depth[:, 0]
        else:
            raise RuntimeError(f"unexpected depth shape {tuple(depth.shape)}")

        # ray/aux: (B, S, H, W, KEY_DIM) — channel dim already permuted to last
        ray = head_out['ray']
        if ray.dim() == 5:
            pixel_feats = ray[:, 0].permute(0, 3, 1, 2)   # (B, KEY_DIM, h_out, w_out)
        elif ray.dim() == 4:
            pixel_feats = ray.permute(0, 3, 1, 2)
        else:
            raise RuntimeError(f"unexpected ray shape {tuple(ray.shape)}")

        B, Cf, Hf, Wf = pixel_feats.shape
        # LayerNorm over the channel dim.
        f_norm = self.pixel_norm(pixel_feats.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (B,C,H,W)

        keys = self._build_keys()                                          # (T, Z, key_dim)
        scale = 1.0 / math.sqrt(self.key_dim)
        # logits[b, t, z, h, w] = sum_c f_norm[b, c, h, w] * keys[t, z, c] * scale
        volume_logits = torch.einsum("bchw, tzc -> btzhw", f_norm, keys) * scale

        dino_feats = []
        for layer_feats in feats:
            if isinstance(layer_feats, (list, tuple)):
                dino_feats.append(layer_feats[0])
            else:
                dino_feats.append(layer_feats)

        return {
            "volume_logits": volume_logits,    # (B, T, Z, H_out, W_out)
            "pred_depth":    pred_depth,
            "pixel_feats":   pixel_feats,      # for viz/debug
            "dino_feats":    dino_feats,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DA3VolumeModel().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, DA3_INPUT, DA3_INPUT).to(device)
    with torch.no_grad():
        out = m(rgb)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
        elif isinstance(v, list):
            print(f"  {k}: list({len(v)})")
            if v and hasattr(v[0], 'shape'): print(f"    first: {tuple(v[0].shape)}")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
