"""Convert HF safetensors DINOv3 weights to the hub-format state_dict expected by
/data/cameron/keygrip/dinov3 ViT.

HF key prefix → hub key prefix
  embeddings.cls_token            → cls_token
  embeddings.mask_token           → mask_token
  embeddings.register_tokens      → storage_tokens
  embeddings.patch_embeddings.W   → patch_embed.proj.W
  layer.N.normM.*                 → blocks.N.normM.*
  layer.N.attention.q/k/v_proj.*  → blocks.N.attn.qkv.* (concatenated dim=0)
  layer.N.attention.o_proj.*      → blocks.N.attn.proj.*
  layer.N.mlp.up_proj.*           → blocks.N.mlp.fc1.*
  layer.N.mlp.down_proj.*         → blocks.N.mlp.fc2.*
  layer.N.layer_scaleM.lambdaN    → blocks.N.lsM.gamma  (??? — needs check)
  norm.*                          → norm.*

rope_embed.periods is a fixed buffer in the hub model (computed from config).
"""
import re
import sys
import torch
from safetensors.torch import load_file


def convert_hf_to_hub_state_dict(hf_sd, n_blocks: int):
    """Map HF DINOv3 safetensors to dinov3 hub state_dict format."""
    out = {}
    # Top-level
    if "embeddings.cls_token" in hf_sd:
        out["cls_token"] = hf_sd["embeddings.cls_token"]
    if "embeddings.mask_token" in hf_sd:
        out["mask_token"] = hf_sd["embeddings.mask_token"].squeeze(0)   # (1,1,D) → (1,D)
    if "embeddings.register_tokens" in hf_sd:
        out["storage_tokens"] = hf_sd["embeddings.register_tokens"]
    if "embeddings.patch_embeddings.weight" in hf_sd:
        out["patch_embed.proj.weight"] = hf_sd["embeddings.patch_embeddings.weight"]
        out["patch_embed.proj.bias"]   = hf_sd["embeddings.patch_embeddings.bias"]
    if "norm.weight" in hf_sd:
        out["norm.weight"] = hf_sd["norm.weight"]
        out["norm.bias"]   = hf_sd["norm.bias"]

    for i in range(n_blocks):
        # Norm layers
        for n in (1, 2):
            if f"layer.{i}.norm{n}.weight" in hf_sd:
                out[f"blocks.{i}.norm{n}.weight"] = hf_sd[f"layer.{i}.norm{n}.weight"]
                out[f"blocks.{i}.norm{n}.bias"]   = hf_sd[f"layer.{i}.norm{n}.bias"]
        # Attention: q/k/v separate → concat into qkv
        qw = hf_sd.get(f"layer.{i}.attention.q_proj.weight")
        kw = hf_sd.get(f"layer.{i}.attention.k_proj.weight")
        vw = hf_sd.get(f"layer.{i}.attention.v_proj.weight")
        if qw is not None:
            out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0)
        qb = hf_sd.get(f"layer.{i}.attention.q_proj.bias")
        kb = hf_sd.get(f"layer.{i}.attention.k_proj.bias")
        vb = hf_sd.get(f"layer.{i}.attention.v_proj.bias")
        if qb is not None and vb is not None:
            # `mask_k_bias=True` in dinov3 hub means K has no bias → zeros for K slot.
            if kb is None:
                kb = torch.zeros_like(qb)
            out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0)
        # o_proj → attn.proj
        if f"layer.{i}.attention.o_proj.weight" in hf_sd:
            out[f"blocks.{i}.attn.proj.weight"] = hf_sd[f"layer.{i}.attention.o_proj.weight"]
            out[f"blocks.{i}.attn.proj.bias"]   = hf_sd[f"layer.{i}.attention.o_proj.bias"]
        # MLP
        if f"layer.{i}.mlp.up_proj.weight" in hf_sd:
            out[f"blocks.{i}.mlp.fc1.weight"] = hf_sd[f"layer.{i}.mlp.up_proj.weight"]
            out[f"blocks.{i}.mlp.fc1.bias"]   = hf_sd[f"layer.{i}.mlp.up_proj.bias"]
        if f"layer.{i}.mlp.down_proj.weight" in hf_sd:
            out[f"blocks.{i}.mlp.fc2.weight"] = hf_sd[f"layer.{i}.mlp.down_proj.weight"]
            out[f"blocks.{i}.mlp.fc2.bias"]   = hf_sd[f"layer.{i}.mlp.down_proj.bias"]
        # LayerScale: HF uses layer_scale1 / layer_scale2 → ls1 / ls2 in hub
        # Inspect actual key names; hub uses `.gamma`, HF uses `.lambda1` (or similar)
        for n in (1, 2):
            hf_key_candidates = [f"layer.{i}.layer_scale{n}.lambda1",
                                  f"layer.{i}.layer_scale{n}.weight",
                                  f"layer.{i}.layer_scale{n}.lambda_"]
            for k in hf_key_candidates:
                if k in hf_sd:
                    out[f"blocks.{i}.ls{n}.gamma"] = hf_sd[k]
                    break
    return out


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--hf_safetensors", required=True)
    p.add_argument("--out_pth", required=True)
    p.add_argument("--arch", default="dinov3_vitl16", choices=["dinov3_vitl16", "dinov3_vitb16"])
    args = p.parse_args()

    sys.path.insert(0, "/data/cameron/keygrip/dinov3")
    hf_sd = load_file(args.hf_safetensors)
    print(f"Loaded HF safetensors with {len(hf_sd)} keys")

    # Diagnostic: list HF layer_scale key naming
    ls_keys = [k for k in hf_sd if "layer_scale" in k][:4]
    print(f"layer_scale key examples: {ls_keys}")

    n_blocks = max(int(re.search(r"layer\.(\d+)", k).group(1)) for k in hf_sd if k.startswith("layer.")) + 1
    print(f"n_blocks = {n_blocks}")

    hub_sd = convert_hf_to_hub_state_dict(hf_sd, n_blocks)
    print(f"Converted to {len(hub_sd)} hub keys")

    # Load hub model + load_state_dict (strict=False to see what's missing)
    m = torch.hub.load("/data/cameron/keygrip/dinov3", args.arch, source="local", pretrained=False)
    missing, unexpected = m.load_state_dict(hub_sd, strict=False)
    print(f"missing keys: {len(missing)}  unexpected: {len(unexpected)}")
    if missing[:5]: print(f"  missing[:5]: {missing[:5]}")
    if unexpected[:5]: print(f"  unexpected[:5]: {unexpected[:5]}")

    # Sanity check forward
    m.eval()
    x = torch.randn(1, 3, 224, 224)
    with torch.no_grad():
        out = m.forward_features(x)
    if isinstance(out, dict):
        for k, v in out.items():
            if hasattr(v, 'shape'): print(f"  {k}: {v.shape}")
    else:
        print(f"  out: {out.shape}")

    torch.save(m.state_dict(), args.out_pth)
    print(f"Saved hub-format weights to {args.out_pth}")
