"""
Download UVA pretrained checkpoints (Libero10, UMI multitask) into checkpoints/.
Run from unified_video_action repo root: python scripts/download_ckpts.py
"""

import subprocess
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]
CHECKPOINTS_DIR = REPO_ROOT / "checkpoints"

# Google Drive file IDs from README
GDOWN_IDS = {
    "pusht.ckpt": "1OduHcxfc2hqUYSccMQNf9g-vAt-q2UhF",
    "pusht_multitask.ckpt": "1ZppZJyQdEdjhu8TIt4ddyaWy_mSdjoAZ",
    "libero10.ckpt": "11c2VrmaRp48yw__5A5xpcu8EPzkexHSi",
    "umi_multitask.ckpt": "1rUWtpXReULf8h42P80Go7GeTiZs3irFS",  # UMI Multitask (Cup, Towel, Mouse)
}


def download(name: str, overwrite: bool = False) -> Path:
    out = CHECKPOINTS_DIR / name
    if out.is_file() and not overwrite:
        print(f"Already exists: {out}")
        return out
    gid = GDOWN_IDS.get(name)
    if not gid:
        raise KeyError(f"Unknown checkpoint name: {name}. Choose from {list(GDOWN_IDS)}")
    CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
    try:
        subprocess.run(
            [sys.executable, "-m", "gdown", gid, "-O", str(out), "--fuzzy"],
            check=True,
        )
    except subprocess.CalledProcessError as e:
        raise RuntimeError(
            f"Download failed for {name}. You may need to download manually from "
            f"https://drive.google.com/uc?id={gid} and save to {out}"
        ) from e
    except FileNotFoundError as e:
        raise RuntimeError("Install gdown: pip install gdown, then run again.") from e
    print(f"Downloaded: {out}")
    return out


def main():
    import argparse
    p = argparse.ArgumentParser(description="Download UVA checkpoints")
    p.add_argument("--checkpoint", "-c", choices=list(GDOWN_IDS), default=None,
                   help="Download one checkpoint (default: all libero10 + umi_multitask)")
    p.add_argument("--overwrite", action="store_true", help="Re-download if file exists")
    args = p.parse_args()

    to_download = [args.checkpoint] if args.checkpoint else ["libero10.ckpt", "umi_multitask.ckpt"]
    for name in to_download:
        try:
            download(name, overwrite=args.overwrite)
        except Exception as e:
            print(f"Warning: {e}")


if __name__ == "__main__":
    main()
