"""Compute train-set 2D argmax pixel error for any of our checkpoints.

This is the headline accuracy metric — measures how well the model FITS the training data
under the standard argmax decoder. Val_pix is multimodal and uninformative; train_argmax is
direct.

Usage:
  CUDA_VISIBLE_DEVICES=8 PYTHONPATH=... python eval_train_argmax.py --ckpt <path>

Auto-detects model class from checkpoint dir name; falls back to --model arg.
"""
import argparse, os, sys
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
sys.path.insert(0, "/data/cameron/da3_repo/src")

import torch
import numpy as np
from torch.utils.data import DataLoader
from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT, N_WINDOW, N_HEIGHT_BINS


def load_model(model_kind: str, ckpt_path: str, device):
    if model_kind == "da3_volume_v3":
        from model_da3_volume_v3 import DA3VolumeModel
        m = DA3VolumeModel(weights_path="/data/cameron/da3_large_weights").to(device).eval()
        needs_start = False
    elif model_kind == "da3_volume_v2":
        from model_da3_volume_v2 import DA3VolumeModel
        m = DA3VolumeModel(weights_path="/data/cameron/da3_large_weights").to(device).eval()
        needs_start = False
    elif model_kind == "da3_volume_v1":
        from model_da3_volume import DA3VolumeModel
        m = DA3VolumeModel(weights_path="/data/cameron/da3_large_weights").to(device).eval()
        needs_start = False
    elif model_kind == "da3_pixel":
        from model_da3_pixel import DA3PixelModel
        m = DA3PixelModel(weights_path="/data/cameron/da3_large_weights").to(device).eval()
        needs_start = False
    elif model_kind == "dino_vanilla":
        from model_dino_vanilla import DinoVanillaModel
        # Pick variant from ckpt-dir name; default to S/16+ if unspecified.
        variant = "dinov3_vits16plus"
        if "vitl16" in ckpt_path or "_vitl16" in ckpt_path:
            variant = "dinov3_vitl16"
        m = DinoVanillaModel(dino_variant=variant).to(device).eval()
        needs_start = False
    elif model_kind == "dino_eef_attn":
        from model_dino_eef_attn import DinoEefAttnModel
        m = DinoEefAttnModel().to(device).eval()
        needs_start = True
    elif model_kind.startswith("dino_kv"):
        h = "sin"; t = "sin"
        if "_h" in model_kind:
            h = model_kind.split("_h", 1)[1].split("_")[0]
        if "_t" in model_kind:
            t = model_kind.split("_t", 1)[-1].split("_")[0]
        variant = "dinov3_vits16plus"
        if "vitl16" in ckpt_path:
            variant = "dinov3_vitl16"
        from model_dino_volume_kv import DinoVolumeKV
        m = DinoVolumeKV(height_enc=h, time_enc=t, dino_variant=variant).to(device).eval()
        needs_start = False
    else:
        raise ValueError(f"unknown model_kind {model_kind}")
    sd = torch.load(ckpt_path, map_location=device, weights_only=False)
    missing, unexpected = m.load_state_dict(sd["model_state_dict"], strict=False)
    if missing: print(f"  missing keys: {len(missing)}")
    if unexpected: print(f"  unexpected keys: {len(unexpected)}")
    print(f"  loaded ckpt epoch={sd.get('epoch','?')}, val_v/h={sd.get('val_v', sd.get('val_h','?'))}")
    return m, needs_start


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", type=str, required=True)
    p.add_argument("--model", type=str, default="auto",
                   help="auto|da3_volume_v1|v2|v3|da3_pixel|dino_vanilla|dino_eef_attn")
    p.add_argument("--depth_subdir", type=str, default="da3_depth_large")
    p.add_argument("--batch_size", type=int, default=8)
    args = p.parse_args()
    device = torch.device("cuda")

    # Auto-detect model from ckpt path
    kind = args.model
    if kind == "auto":
        cp = args.ckpt.lower()
        if "eef_attn" in cp:        kind = "dino_eef_attn"
        elif "dino_kv_h" in cp:
            # Extract from dirname like "da3_dino_kv_hlearned_tsin"
            base = os.path.basename(os.path.dirname(cp))
            h = base.split("_h", 1)[1].split("_")[0]
            t = base.split("_t", 1)[-1].split("_")[0]
            kind = f"dino_kv_h{h}_t{t}"
        elif "dino_vanilla" in cp:  kind = "dino_vanilla"
        elif "v4_softargmax" in cp: kind = "da3_volume_v3"  # v4 has same arch as v3
        elif "v3_sincos" in cp:     kind = "da3_volume_v3"
        elif "v2_sumkeys" in cp or "v2_sum" in cp: kind = "da3_volume_v2"
        elif "kv_v0" in cp:         kind = "da3_volume_v1"
        elif "pixel" in cp:         kind = "da3_pixel"
        else: raise ValueError(f"could not auto-detect kind from {args.ckpt}; pass --model")
        print(f"Auto-detected model_kind = {kind}")

    print(f"Loading dataset…")
    ds = Smith300DA3VolumeDataset(depth_subdir=args.depth_subdir)
    print(f"Loading {kind} from {args.ckpt}…")
    model, needs_start = load_model(kind, args.ckpt, device)
    loader = DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=4)

    pix_errs_full8 = []   # only samples with n_valid == 8 (clean train signal)
    pix_errs_all   = []   # all valid future steps (includes clamped-window partials, mask-weighted)
    z_errs_full8   = []
    n_full8_samples = 0
    is_pixel_only = (kind == "da3_pixel")

    print("Evaluating…")
    with torch.no_grad():
        for batch in loader:
            rgb = batch["rgb"].to(device)
            gt_pix = batch["gt_pix_504"].to(device)        # (B, T, 2)
            gt_z   = batch["gt_z_bin"].to(device)          # (B, T)
            valid  = batch["gt_pix_valid"].to(device)      # (B, T) bool
            if needs_start:
                start_pix = batch["start_pix_504"].to(device)
                out = model(rgb, start_pix)
            else:
                out = model(rgb)
            # Pixel-only model has pred_heatmap; volume models have volume_logits
            if "volume_logits" in out and out["volume_logits"] is not None:
                vol = out["volume_logits"]                  # (B, T, Z, h, w)
                Bv, T, Z, h_out, w_out = vol.shape
                pred_flat = vol.reshape(Bv, T, -1).argmax(dim=-1)
                pred_z = pred_flat // (h_out * w_out)
                pred_yx = pred_flat % (h_out * w_out)
            else:
                hm = out["pred_heatmap"]                    # (B, T, h, w)
                Bv, T, h_out, w_out = hm.shape
                pred_yx = hm.reshape(Bv, T, -1).argmax(dim=-1)
                pred_z = None
            scale_x = w_out / DA3_INPUT
            scale_y = h_out / DA3_INPUT
            py = (pred_yx // w_out).float() / scale_y
            px = (pred_yx % w_out).float() / scale_x
            pred_pix = torch.stack([px, py], dim=-1)
            pe = (pred_pix - gt_pix).norm(dim=-1)           # (B, T) px-err
            full8 = (valid.float().sum(dim=1) == 8)         # (B,) bool
            n_full8_samples += full8.sum().item()
            v_mask = valid.float()
            denom = v_mask.sum().clamp_min(1.0)
            # All-valid average
            pix_errs_all.append(((pe * v_mask).sum() / denom).item())
            # Full-8-only average
            if full8.sum() > 0:
                pix_errs_full8.append(pe[full8].mean().item())
                if pred_z is not None:
                    ze = (pred_z[full8] - gt_z[full8]).abs().float()
                    z_errs_full8.append(ze.mean().item())

    print(f"\n=== {kind} | ckpt={os.path.basename(os.path.dirname(args.ckpt))} ===")
    print(f"Full train set ({len(ds)} samples, {n_full8_samples} full-8 windows):")
    print(f"  train_pix_err_argmax (mask-avg, all valid steps): {np.mean(pix_errs_all):.2f} px")
    print(f"  train_pix_err_argmax (full-8 windows only):       {np.mean(pix_errs_full8):.2f} px")
    if z_errs_full8:
        print(f"  train_z_err (full-8, bins): {np.mean(z_errs_full8):.2f} bins "
              f"(~{np.mean(z_errs_full8)*5.2:.1f} mm)")


if __name__ == "__main__":
    main()
