"""DINO encoder: single-frame → 32x32 patch features. Delegates to dino_vid_model."""

import sys
from pathlib import Path

# Allow importing from dino_vid_model (sibling of diffusion_dino)
_vidgen = Path(__file__).resolve().parents[1]
if str(_vidgen) not in sys.path:
    sys.path.insert(0, str(_vidgen))

from dino_vid_model.dino_loader import load_dino, extract_patch_features, DinoNormalize

DINO_SIZE = 512
PATCH_H = PATCH_W = 32  # 512/16


def load_dino_encoder(keygrip_root: Path, freeze: bool = True):
    """Load DINO and return (dino_module, norm_module). Single frame in → (B, D, 32, 32) out."""
    dino = load_dino(Path(keygrip_root), freeze=freeze)
    norm = DinoNormalize(dino_size=DINO_SIZE)
    return dino, norm


def encode_frame(dino, norm, frame):
    """frame: (B, 3, H, W) in [-1, 1] or [0, 1]. Returns (B, D, 32, 32)."""
    x = norm(frame)
    return extract_patch_features(dino, x, None, None)
