"""Vanilla DINOv3 + factored KV volume head (v3-style architecture but DINO backbone).

Ablation flags:
  height_enc: 'sin' | 'learned' | 'sin_plus_learned'
  time_enc:   'sin' | 'learned' | 'sin_plus_learned'

Per Cameron 2026-05-19: try the volume projection approach with DINO features and all
combinations of height/time encoding.

Architecture:
  DINOv3 ViT-S/16 → patch tokens (B, N, D), upsampled to (B, 48, h_out, w_out).
  Per-pixel feature F is the "value" stream.
  Key per (t, z): h_emb[z] + t_emb[t] ∈ R^48.
  L2-normalize F and keys → cosine similarity → scaled by learnable logit_scale.
  Volume logits: einsum(F, K) → (B, T, Z, h_out, w_out).
"""
import os, sys, math
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH",
                                    "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMG_SIZE        = 448
N_WINDOW        = 8
N_HEIGHT_BINS   = 32
KEY_DIM         = 48
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


def sinusoidal_features(n_values: int, dim: int):
    """NeRF-style sinusoidal positional encoding. Returns (n_values, dim)."""
    assert dim % 2 == 0
    L = dim // 2
    pos = torch.arange(n_values, dtype=torch.float32) / max(n_values - 1, 1)
    freqs = 2.0 ** torch.arange(L, dtype=torch.float32)
    angles = pos.unsqueeze(1) * freqs.unsqueeze(0) * math.pi
    return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)


class DinoVolumeKV(nn.Module):
    def __init__(self, n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 key_dim: int = KEY_DIM, image_size: int = IMG_SIZE,
                 height_enc: str = 'sin', time_enc: str = 'sin',
                 head_hidden: int = 192,
                 dino_variant: str = 'dinov3_vits16plus'):
        super().__init__()
        assert height_enc in ('sin', 'learned', 'sin_plus_learned')
        assert time_enc   in ('sin', 'learned', 'sin_plus_learned')
        self.n_window      = n_window
        self.n_height_bins = n_height_bins
        self.key_dim       = key_dim
        self.image_size    = image_size
        self.height_enc    = height_enc
        self.time_enc      = time_enc
        self.dino_variant  = dino_variant
        self.patch_size    = DINO_PATCH_SIZE
        self.grid          = image_size // DINO_PATCH_SIZE
        self.pred_size     = self.grid * 2

        if DINO_REPO_DIR not in sys.path: sys.path.insert(0, DINO_REPO_DIR)
        variant_weights = {
            "dinov3_vits16plus": DINO_WEIGHTS_PATH,
            "dinov3_vitl16":     "/data/cameron/keygrip/dinov3/weights/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
        }
        w = variant_weights.get(dino_variant, DINO_WEIGHTS_PATH)
        self.dino = torch.hub.load(DINO_REPO_DIR, dino_variant,
                                    source="local", weights=w)
        self.embed_dim = getattr(self.dino, "embed_dim", 384)

        # Pixel feature head: project DINO tokens to key_dim via small refinement
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, head_hidden, 3, padding=1), nn.GELU(),
            nn.Conv2d(head_hidden,    head_hidden, 3, padding=1), nn.GELU(),
            nn.Conv2d(head_hidden,    key_dim,     1),
        )
        self.pixel_norm = nn.LayerNorm(key_dim)

        # Time + height embeddings
        if 'learned' in time_enc:
            self.t_emb_learned = nn.Embedding(n_window, key_dim)
            init_std = 0.02 if time_enc == 'sin_plus_learned' else 0.1
            nn.init.trunc_normal_(self.t_emb_learned.weight, std=init_std, a=-3*init_std, b=3*init_std)
        if 'sin' in time_enc:
            self.register_buffer("t_emb_sin",
                                  sinusoidal_features(n_window, key_dim),
                                  persistent=False)
        if 'learned' in height_enc:
            self.h_emb_learned = nn.Embedding(n_height_bins, key_dim)
            init_std = 0.02 if height_enc == 'sin_plus_learned' else 0.1
            nn.init.trunc_normal_(self.h_emb_learned.weight, std=init_std, a=-3*init_std, b=3*init_std)
        if 'sin' in height_enc:
            self.register_buffer("h_emb_sin",
                                  sinusoidal_features(n_height_bins, key_dim),
                                  persistent=False)

        # Learnable temperature (CLIP-style); exp(2.66) ≈ 14.
        self.logit_scale = nn.Parameter(torch.tensor(2.66))

        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer("std",  torch.tensor(IMAGENET_STD ).view(1, 3, 1, 1), persistent=False)

    def _normalize(self, rgb01):
        return (rgb01 - self.mean) / self.std

    def _build_keys(self):
        # Time component
        if self.time_enc == 'sin':       t_total = self.t_emb_sin
        elif self.time_enc == 'learned': t_total = self.t_emb_learned.weight
        else:                             t_total = self.t_emb_sin + self.t_emb_learned.weight
        # Height component
        if self.height_enc == 'sin':       h_total = self.h_emb_sin
        elif self.height_enc == 'learned': h_total = self.h_emb_learned.weight
        else:                               h_total = self.h_emb_sin + self.h_emb_learned.weight
        # Combine: key(t, z) = t_total[t] + h_total[z]
        return t_total.unsqueeze(1) + h_total.unsqueeze(0)                     # (T, Z, key_dim)

    def forward(self, rgb):
        B = rgb.shape[0]
        in_size = rgb.shape[-1]
        if in_size != self.image_size:
            rgb = F.interpolate(rgb, size=(self.image_size, self.image_size),
                                mode='bilinear', align_corners=False)
        x = self._normalize(rgb)
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats = self.dino.forward_features(x)
        if isinstance(feats, dict):
            patch_tokens = feats.get("x_norm_patchtokens", feats.get("x_prenorm"))
        else:
            patch_tokens = feats
        patch_tokens = patch_tokens.to(torch.float32)
        D = patch_tokens.shape[-1]
        h = w = self.grid
        feat_2d = patch_tokens.permute(0, 2, 1).reshape(B, D, h, w)
        feat_2d = F.interpolate(feat_2d, size=(self.pred_size, self.pred_size),
                                 mode='bilinear', align_corners=False)
        pixel_feats = self.refine(feat_2d)                                    # (B, key_dim, ph, pw)

        # Norm + L2 unit-vectors
        f_ln = self.pixel_norm(pixel_feats.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        f_unit = f_ln / (f_ln.norm(dim=1, keepdim=True) + 1e-6)
        keys = self._build_keys()                                              # (T, Z, key_dim)
        keys_unit = keys / (keys.norm(dim=-1, keepdim=True) + 1e-6)

        scale = self.logit_scale.clamp(max=math.log(100.0)).exp()
        volume_logits = torch.einsum("bchw, tzc -> btzhw", f_unit, keys_unit) * scale

        return {
            "volume_logits": volume_logits,
            "pred_depth":    None,
            "pixel_feats":   pixel_feats,
            "dino_feats":    [patch_tokens],
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for h_enc in ['sin', 'learned', 'sin_plus_learned']:
        for t_enc in ['sin', 'learned', 'sin_plus_learned']:
            m = DinoVolumeKV(height_enc=h_enc, time_enc=t_enc).to(device).eval()
            n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
            keys = m._build_keys()
            print(f"h={h_enc:<18} t={t_enc:<18} | params={n_t:>11,}  keys_mean_norm={keys.norm(dim=-1).mean().item():.3f}")
