"""Load DINOv3 from keygrip and extract 16x16 patch features from 256x256 RGB image."""

import os
from pathlib import Path

import torch
import torch.nn as nn

DINO_PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def load_dino_from_keygrip(keygrip_root: Path):
    """Load DINOv3 model from keygrip repo. keygrip_root must contain dinov3/ and dinov3/weights/."""
    keygrip_root = Path(keygrip_root).resolve()
    cwd = os.getcwd()
    os.chdir(keygrip_root)
    dino = torch.hub.load(
        "dinov3",
        "dinov3_vits16plus",
        source="local",
        weights="dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth",
    )
    os.chdir(cwd)
    for p in dino.parameters():
        p.requires_grad = False
    dino.eval()
    return dino


class DINOFeaturizer(nn.Module):
    """Extract patch features (B, D, H_p, W_p) from RGB (B, 3, 256, 256). H_p=W_p=16."""

    def __init__(self, keygrip_root: Path):
        super().__init__()
        self.dino = load_dino_from_keygrip(keygrip_root)
        self.embed_dim = self.dino.embed_dim
        self.register_buffer("mean", torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor(IMAGENET_STD).view(1, 3, 1, 1))

    def normalize(self, x):
        return (x - self.mean) / self.std

    def forward(self, x):
        """x: (B, 3, H, W) in [0,1] or [-1,1]. Returns (B, D, H_p, W_p)."""
        if x.shape[-1] != 256 or x.shape[-2] != 256:
            x = torch.nn.functional.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False)
        if x.min() < 0:
            x = (x + 1.0) / 2.0
        x = self.normalize(x)
        B = x.shape[0]
        x_tokens, (H_p, W_p) = self.dino.prepare_tokens_with_masks(x)
        for blk in self.dino.blocks:
            rope_sincos = self.dino.rope_embed(H=H_p, W=W_p) if self.dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        if self.dino.untie_cls_and_patch_norms:
            x_norm_patches = self.dino.norm(x_tokens[:, self.dino.n_storage_tokens + 1 :])
        else:
            x_norm_patches = self.dino.norm(x_tokens)[:, self.dino.n_storage_tokens + 1 :]
        patch_tokens = x_norm_patches
        patch_features = patch_tokens.reshape(B, H_p, W_p, self.embed_dim).permute(0, 3, 1, 2)
        return patch_features
