"""DA3-based volume model with factored KV-attention.

Architecture (Cameron 2026-05-18 spec):
  - DA3 backbone + DPT head; aux head's final 1×1 conv replaced to output KEY_DIM (+1 conf).
  - Per-pixel feature F ∈ R^(B × KEY_DIM × H × W) — "value/query" stream.
  - Learnable height embeddings h_emb ∈ R^(N_HEIGHT_BINS × H_DIM).
  - Learnable time   embeddings t_emb ∈ R^(N_WINDOW × T_DIM).
  - Key per (t, z): key(t, z) = concat(t_emb[t], h_emb[z]) ∈ R^(T_DIM + H_DIM = KEY_DIM).
  - Volume logits via bilinear scoring: l(b, t, z, u, v) = F(b, :, u, v) · key(t, z) / sqrt(KEY_DIM).

We replace the original libero (B, N_WINDOW * N_HEIGHT_BINS, H, W) dense head with the
factored KV decomposition: parameter-efficient, structural inductive bias that height/time
are categorical attributes that share representation across spatial locations.

forward(rgb) returns:
  volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, h_out, w_out)
  pred_depth:    (B, H, W)  — for distillation against frozen DA3 depth
  dino_feats:    list of intermediate DINO features (for PCA viz)
  pixel_feats:   (B, KEY_DIM, h_out, w_out) — for debugging / viz
"""
import sys, types, os, math
for n in ['depth_anything_3.utils.export', 'depth_anything_3.utils.pose_align']:
    m = types.ModuleType(n); sys.modules[n] = m
sys.modules['depth_anything_3.utils.export'].export = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].align_poses_umeyama = lambda *a, **k: None
sys.modules['depth_anything_3.utils.pose_align'].batch_align_poses_umeyama = lambda *a, **k: None

import torch
import torch.nn as nn

from depth_anything_3.api import DepthAnything3

DA3_WEIGHTS_DEFAULT = "/data/cameron/da3_large_weights"
DA3_INPUT     = 504
N_WINDOW      = 8
N_HEIGHT_BINS = 32
KEY_DIM       = 48
# v3: sinusoidal positional encoding ADDED to learnable embeddings + L2-normalize both
# F and keys before bilinear (cosine similarity), with a learnable temperature.
# Motivation:
#  - Ordinality prior: sin/cos features give adjacent t's / z's similar keys for free.
#  - Magnitude race: L2-norm prevents F (which dominates the raw dot product) from
#    drowning out the t/z signal. Per-t differentiation comes "for cheap".
#  - Temperature: cosine ∈ [-1,1] is too flat for CE over 2.6M classes; learnable
#    logit_scale (CLIP-style) inits at log(1/0.07) ≈ 2.66 → scale ≈ 14.
TIME_DIM      = 48
HEIGHT_DIM    = 48


class DA3VolumeModel(nn.Module):
    def __init__(self, weights_path: str = DA3_WEIGHTS_DEFAULT,
                 n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 key_dim: int = KEY_DIM, time_dim: int = TIME_DIM,
                 height_dim: int = HEIGHT_DIM, dino_feat_layers=None):
        super().__init__()
        # v2: time_dim == height_dim == key_dim (sum-merge instead of concat)
        assert time_dim == key_dim and height_dim == key_dim
        self.n_window      = n_window
        self.n_height_bins = n_height_bins
        self.key_dim       = key_dim
        self.time_dim      = time_dim
        self.height_dim    = height_dim

        full = DepthAnything3.from_pretrained(weights_path)
        self.backbone = full.model.backbone
        self.head     = full.model.head
        del full

        if dino_feat_layers is None:
            self.dino_feat_layers = list(getattr(self.backbone, 'out_layers', [5, 7, 9, 11]))
        else:
            self.dino_feat_layers = list(dino_feat_layers)

        # Swap final aux conv to output (key_dim + 1) channels — split into KEY_DIM features
        # and 1 conf channel (DualDPT splits aux into [..., :-1] + [..., -1]).
        last_aux_seq = self.head.scratch.output_conv2_aux[-1]
        old_conv = last_aux_seq[-1]
        new_conv = nn.Conv2d(old_conv.in_channels, key_dim + 1, kernel_size=1, stride=1, padding=0)
        nn.init.zeros_(new_conv.bias)
        # Small but non-zero init so the bilinear scores are not stuck at 0 at start.
        nn.init.normal_(new_conv.weight, std=0.05)
        last_aux_seq[-1] = new_conv

        # Learnable embeddings — small init so sinusoidal dominates initially, learned
        # residual fine-tunes per-(t,z).
        self.h_emb = nn.Embedding(n_height_bins, height_dim)
        self.t_emb = nn.Embedding(n_window, time_dim)
        nn.init.trunc_normal_(self.h_emb.weight, std=0.02, a=-0.05, b=0.05)
        nn.init.trunc_normal_(self.t_emb.weight, std=0.02, a=-0.05, b=0.05)

        # Sinusoidal positional features (fixed, registered as buffers).
        self.register_buffer("t_sin", self._sinusoidal_features(n_window, key_dim),
                              persistent=False)
        self.register_buffer("h_sin", self._sinusoidal_features(n_height_bins, key_dim),
                              persistent=False)

        # LayerNorm on pixel features (kept from v2 — numeric stability).
        self.pixel_norm = nn.LayerNorm(key_dim)
        # Learnable logit scale (CLIP-style). exp(2.66) ≈ 14.
        self.logit_scale = nn.Parameter(torch.tensor(2.66))

    @staticmethod
    def _sinusoidal_features(n_values, dim):
        """NeRF/transformer-style sinusoidal positional encoding.
        Returns (n_values, dim) where each row is sin/cos at log-spaced freqs.
        Normalize position to [0, 1] so all n_values fit one cycle base.
        """
        assert dim % 2 == 0
        L = dim // 2
        pos = torch.arange(n_values, dtype=torch.float32) / max(n_values - 1, 1)  # (n,)
        # Log-spaced frequencies; base of 2^k for k=0..L-1, scaled by π so cos(πx) covers [0,1].
        freqs = 2.0 ** torch.arange(L, dtype=torch.float32)                          # (L,)
        angles = pos.unsqueeze(1) * freqs.unsqueeze(0) * math.pi                     # (n, L)
        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)             # (n, 2L=dim)

    def _build_keys(self):
        """v3: sinusoidal (fixed) + small learned embedding, SUMMED across t and z.
        key(t, z) = (t_sin[t] + t_emb[t]) + (h_sin[z] + h_emb[z])  ∈ R^key_dim.
        """
        t_total = (self.t_sin + self.t_emb.weight).unsqueeze(1)             # (T, 1, key_dim)
        h_total = (self.h_sin + self.h_emb.weight).unsqueeze(0)             # (1, Z, key_dim)
        return t_total + h_total                                             # (T, Z, key_dim)

    def forward(self, rgb):
        """rgb: (B, 3, 504, 504) in [0, 1]."""
        x = rgb.unsqueeze(1)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats, _aux = self.backbone(
                x, cam_token=None,
                export_feat_layers=self.dino_feat_layers,
                ref_view_strategy="saddle_balanced",
            )
        H, W = x.shape[-2], x.shape[-1]
        with torch.autocast(device_type=rgb.device.type, enabled=False):
            head_out = self.head(feats, H, W, patch_start_idx=0)

        # depth (B, S, H, W) → drop S
        depth = head_out['depth']
        if depth.dim() == 5:
            pred_depth = depth[:, 0, 0]
        elif depth.dim() == 4:
            pred_depth = depth[:, 0]
        else:
            raise RuntimeError(f"unexpected depth shape {tuple(depth.shape)}")

        # ray/aux: (B, S, H, W, KEY_DIM) — channel dim already permuted to last
        ray = head_out['ray']
        if ray.dim() == 5:
            pixel_feats = ray[:, 0].permute(0, 3, 1, 2)   # (B, KEY_DIM, h_out, w_out)
        elif ray.dim() == 4:
            pixel_feats = ray.permute(0, 3, 1, 2)
        else:
            raise RuntimeError(f"unexpected ray shape {tuple(ray.shape)}")

        B, Cf, Hf, Wf = pixel_feats.shape
        # LayerNorm over the channel dim.
        f_ln = self.pixel_norm(pixel_feats.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (B,C,H,W)
        # v3: L2-normalize F per pixel and keys per (t,z), then dot product = cosine.
        f_unit = f_ln / (f_ln.norm(dim=1, keepdim=True) + 1e-6)               # (B, C, H, W)

        keys = self._build_keys()                                              # (T, Z, key_dim)
        keys_unit = keys / (keys.norm(dim=-1, keepdim=True) + 1e-6)            # (T, Z, key_dim)

        # logits[b, t, z, h, w] = exp(logit_scale) * sum_c f_unit · keys_unit
        # Clamp logit_scale to ≤ log(100) — CLIP convention to avoid runaway.
        scale = self.logit_scale.clamp(max=math.log(100.0)).exp()
        volume_logits = torch.einsum("bchw, tzc -> btzhw", f_unit, keys_unit) * scale

        dino_feats = []
        for layer_feats in feats:
            if isinstance(layer_feats, (list, tuple)):
                dino_feats.append(layer_feats[0])
            else:
                dino_feats.append(layer_feats)

        return {
            "volume_logits": volume_logits,    # (B, T, Z, H_out, W_out)
            "pred_depth":    pred_depth,
            "pixel_feats":   pixel_feats,      # for viz/debug
            "dino_feats":    dino_feats,
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DA3VolumeModel().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, DA3_INPUT, DA3_INPUT).to(device)
    with torch.no_grad():
        out = m(rgb)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
        elif isinstance(v, list):
            print(f"  {k}: list({len(v)})")
            if v and hasattr(v[0], 'shape'): print(f"    first: {tuple(v[0].shape)}")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
