"""Download DROID dataset by constructing file paths directly.

Bypasses the HF API tree enumeration (which triggers 429 rate limits)
by knowing the exact naming convention:
  data/chunk-NNN/episode_NNNNNN.parquet
  videos/chunk-NNN/observation.images.exterior_{1,2}_left/episode_NNNNNN.mp4

Uses concurrent downloads for speed.
"""

import os
import sys
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from huggingface_hub import hf_hub_download

REPO_ID = "cadene/droid_1.0.1"
LOCAL_DIR = "/data/cameron/droid"
TOTAL_EPISODES = 95600
CHUNKS = 96  # chunk-000 to chunk-095, 1000 episodes each (last has 600)
N_WORKERS = 8  # parallel downloads
CAMERAS = ["exterior_1_left", "exterior_2_left"]  # skip wrist


def episode_to_chunk(ep_idx):
    return ep_idx // 1000


def build_file_list():
    """Build list of all files to download."""
    files = []
    for ep in range(TOTAL_EPISODES):
        chunk = f"chunk-{episode_to_chunk(ep):03d}"
        ep_str = f"episode_{ep:06d}"
        # Parquet
        files.append(f"data/{chunk}/{ep_str}.parquet")
        # Videos (ext1 + ext2)
        for cam in CAMERAS:
            files.append(f"videos/{chunk}/observation.images.{cam}/{ep_str}.mp4")
    return files


def is_downloaded(rel_path):
    """Check if file already exists locally."""
    return (Path(LOCAL_DIR) / rel_path).exists()


def download_file(rel_path, max_retries=3):
    """Download a single file with retries."""
    local_path = Path(LOCAL_DIR) / rel_path
    if local_path.exists():
        return rel_path, "exists"

    for attempt in range(max_retries):
        try:
            hf_hub_download(
                REPO_ID,
                rel_path,
                repo_type="dataset",
                local_dir=LOCAL_DIR,
            )
            return rel_path, "ok"
        except Exception as e:
            if "429" in str(e) and attempt < max_retries - 1:
                time.sleep(5 * (attempt + 1))  # backoff: 5s, 10s, 15s
                continue
            if attempt == max_retries - 1:
                return rel_path, f"FAILED: {e}"
    return rel_path, "FAILED"


def main():
    print("Building file list...")
    all_files = build_file_list()
    print(f"Total files: {len(all_files)}")

    # Filter out already downloaded
    remaining = [f for f in all_files if not is_downloaded(f)]
    already = len(all_files) - len(remaining)
    print(f"Already downloaded: {already}")
    print(f"Remaining: {len(remaining)}")

    if not remaining:
        print("All files downloaded!")
        return

    print(f"\nDownloading with {N_WORKERS} workers...")
    done = 0
    failed = 0
    t0 = time.time()

    with ThreadPoolExecutor(max_workers=N_WORKERS) as pool:
        futures = {pool.submit(download_file, f): f for f in remaining}
        for future in as_completed(futures):
            rel_path, status = future.result()
            done += 1
            if "FAILED" in status:
                failed += 1
                print(f"  FAILED: {rel_path}: {status}")

            if done % 500 == 0 or done == len(remaining):
                elapsed = time.time() - t0
                rate = done / elapsed
                eta = (len(remaining) - done) / max(rate, 0.01)
                pct = (already + done) / len(all_files) * 100
                print(f"  [{done}/{len(remaining)}] {pct:.1f}% total, "
                      f"{rate:.1f} files/s, ETA {eta/3600:.1f}h, failed={failed}")

    print(f"\nDone! Downloaded {done - failed}/{len(remaining)} files, {failed} failed")
    print(f"Total: {already + done - failed}/{len(all_files)} files")


if __name__ == "__main__":
    main()
