"""Load DINOv3 from keygrip; optional freeze for teacher. Uses backbones only to avoid PyTorch 2.4-only imports."""

import sys
from pathlib import Path

import torch
import torch.nn as nn

DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

WEIGHTS_FILENAME = "dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"


def load_dino(keygrip_root: Path, freeze: bool = True):
    """Load DINOv3 vits16plus from keygrip via dinov3.hub.backbones (avoids hubconf/segmentors for PyTorch < 2.4)."""
    keygrip_root = Path(keygrip_root).resolve()
    dinov3_repo = keygrip_root / "dinov3"
    weights_path = dinov3_repo / "weights" / WEIGHTS_FILENAME
    if not weights_path.exists():
        raise FileNotFoundError(f"DINO weights not found: {weights_path}")

    if str(dinov3_repo) not in sys.path:
        sys.path.insert(0, str(dinov3_repo))
    from dinov3.hub.backbones import dinov3_vits16plus

    dino = dinov3_vits16plus(pretrained=True, weights=str(weights_path), check_hash=False)
    if freeze:
        for p in dino.parameters():
            p.requires_grad = False
        dino.eval()
    return dino


def extract_patch_features(dino, x, mean, std):
    """x: (B, 3, H, W) normalized with mean/std. Returns (B, D, H_p, W_p)."""
    B = x.shape[0]
    x_tokens, (H_p, W_p) = dino.prepare_tokens_with_masks(x)
    for blk in dino.blocks:
        rope_sincos = dino.rope_embed(H=H_p, W=W_p) if dino.rope_embed else None
        x_tokens = blk(x_tokens, rope_sincos)
    if dino.untie_cls_and_patch_norms:
        x_norm_patches = dino.norm(x_tokens[:, dino.n_storage_tokens + 1 :])
    else:
        x_norm_patches = dino.norm(x_tokens)[:, dino.n_storage_tokens + 1 :]
    patch_features = x_norm_patches.reshape(B, H_p, W_p, -1).permute(0, 3, 1, 2)
    return patch_features


class DinoNormalize(nn.Module):
    """Resize to dino_size (256 or 512), map to [0,1] if needed, then ImageNet normalize. 512 -> 32x32 patches."""

    def __init__(self, dino_size: int = 256):
        super().__init__()
        self.dino_size = dino_size
        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor(IMAGENET_STD).view(1, 3, 1, 1))

    def __call__(self, x):
        if x.shape[-1] != self.dino_size or x.shape[-2] != self.dino_size:
            x = torch.nn.functional.interpolate(x, size=(self.dino_size, self.dino_size), mode="bilinear", align_corners=False)
        if x.min() < 0:
            x = (x + 1.0) / 2.0
        return (x - self.mean) / self.std
