"""Vanilla DINOv3 + 1×1 conv heads — the simpler baseline.

Per Cameron 2026-05-18: go back to the DINOv3 prototype, the DA3 path may be
over-engineered. Same volume formulation (T × Z × H × W joint logits) and CE loss,
just swap the heavy DA3 backbone for vanilla DINOv3 ViT-S/16.

Inputs:  rgb (B, 3, IMG, IMG) in [0, 1] — we ImageNet-normalize internally
Outputs:
  volume_logits: (B, N_WINDOW, N_HEIGHT_BINS, h_out, w_out)
  pred_depth:    None (no depth head — keeps things simple)
  dino_feats:    list of intermediate features (for PCA viz)
"""
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F

DINO_REPO_DIR     = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
DINO_WEIGHTS_PATH = os.environ.get("DINO_WEIGHTS_PATH",
                                    "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
DINO_PATCH_SIZE = 16
IMG_SIZE        = 448            # multiple of 16
N_WINDOW        = 8
N_HEIGHT_BINS   = 32
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


class DinoVanillaModel(nn.Module):
    def __init__(self, n_window: int = N_WINDOW, n_height_bins: int = N_HEIGHT_BINS,
                 image_size: int = IMG_SIZE, head_hidden: int = 192,
                 freeze_backbone: bool = False,
                 dino_variant: str = "dinov3_vits16plus"):
        super().__init__()
        self.n_window      = n_window
        self.n_height_bins = n_height_bins
        self.image_size    = image_size
        self.patch_size    = DINO_PATCH_SIZE
        self.grid          = image_size // DINO_PATCH_SIZE
        self.head_hidden   = head_hidden
        self.dino_variant  = dino_variant

        if DINO_REPO_DIR not in sys.path:
            sys.path.insert(0, DINO_REPO_DIR)
        # Resolve weights path per-variant. Larger variants need the converted .pth.
        variant_weights = {
            "dinov3_vits16plus": DINO_WEIGHTS_PATH,
            "dinov3_vitl16":     "/data/cameron/keygrip/dinov3/weights/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
        }
        w = variant_weights.get(dino_variant, DINO_WEIGHTS_PATH)
        self.dino = torch.hub.load(DINO_REPO_DIR, dino_variant,
                                    source="local", weights=w)
        if freeze_backbone:
            for p in self.dino.parameters():
                p.requires_grad_(False)
        # ViT-S/16 plus: embed_dim = 384
        self.embed_dim = getattr(self.dino, "embed_dim", 384)

        # Head: a small refinement conv + 1×1 to produce (T*Z) channels at the patch grid.
        # Output is later bilinear-upsampled to (pred_size, pred_size).
        # Pred grid = grid * 2 (28 → 56 at image_size=448).
        self.pred_size = self.grid * 2
        self.refine = nn.Sequential(
            nn.Conv2d(self.embed_dim, head_hidden, 3, padding=1), nn.GELU(),
            nn.Conv2d(head_hidden,    head_hidden, 3, padding=1), nn.GELU(),
        )
        self.volume_head = nn.Conv2d(head_hidden, n_window * n_height_bins, 1)
        nn.init.zeros_(self.volume_head.bias)
        nn.init.normal_(self.volume_head.weight, std=0.01)

        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer("std",  torch.tensor(IMAGENET_STD ).view(1, 3, 1, 1), persistent=False)

    def _normalize(self, rgb01):
        return (rgb01 - self.mean) / self.std

    def forward(self, rgb):
        """rgb: (B, 3, IMG_SIZE, IMG_SIZE) in [0, 1]."""
        B = rgb.shape[0]
        if rgb.shape[-1] != self.image_size:
            rgb = F.interpolate(rgb, size=(self.image_size, self.image_size),
                                mode='bilinear', align_corners=False)
        x = self._normalize(rgb)
        # ViT forward — get last-layer patch tokens
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=rgb.device.type, dtype=autocast_dtype):
            feats = self.dino.forward_features(x)
        # DINOv3 returns dict with 'x_norm_patchtokens' (B, N, D) and 'x_norm_clstoken'
        if isinstance(feats, dict):
            patch_tokens = feats.get("x_norm_patchtokens", None)
            if patch_tokens is None:
                patch_tokens = feats.get("x_prenorm", None)
        else:
            patch_tokens = feats
        # (B, N, D) → (B, D, h, w)
        patch_tokens = patch_tokens.to(torch.float32)
        D = patch_tokens.shape[-1]
        h = w = self.grid
        f = patch_tokens.permute(0, 2, 1).reshape(B, D, h, w)
        # Refine + upsample to pred_size
        f = F.interpolate(f, size=(self.pred_size, self.pred_size),
                          mode='bilinear', align_corners=False)
        f = self.refine(f)
        vol_flat = self.volume_head(f)                                      # (B, T*Z, pred, pred)
        vol = vol_flat.view(B, self.n_window, self.n_height_bins,
                            self.pred_size, self.pred_size)
        return {
            "volume_logits": vol,
            "pred_depth":    None,
            "pixel_feats":   f,
            "dino_feats":    [patch_tokens],     # single layer for viz
        }


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = DinoVanillaModel().to(device).eval()
    n_t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Trainable: {n_t:,}")
    rgb = torch.rand(2, 3, IMG_SIZE, IMG_SIZE).to(device)
    with torch.no_grad():
        out = m(rgb)
    for k, v in out.items():
        if hasattr(v, 'shape'):
            print(f"  {k}: {tuple(v.shape)}")
        elif isinstance(v, (list, tuple)):
            print(f"  {k}: list({len(v)})")
            if v and hasattr(v[0], 'shape'): print(f"    first: {tuple(v[0].shape)}")
    if device.type == 'cuda':
        print(f"peak: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
