#!/usr/bin/env python3
"""Benchmark MDS data loading throughput (no model)."""

import time
import argparse
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")


def benchmark_mds(
    mds_path: str,
    batch_size: int = 256,
    num_workers: int = 64,
    num_batches: int = 100,
    warmup_batches: int = 10,
):
    """Benchmark MDS data loading throughput."""
    from openpi.training.mds_dataset import MDSDataset, MDSDatasetConfig

    logging.info(f"Benchmarking MDS: {mds_path}")
    logging.info(f"batch_size={batch_size}, num_workers={num_workers}")
    logging.info(f"num_batches={num_batches}, warmup_batches={warmup_batches}")

    config = MDSDatasetConfig(
        external_cam_key="external_cam",
        wrist_cam_key="wrist_cam",
        arm_joint_pos_key="arm_joint_pos",
        gripper_joint_pos_key="gripper_pos",
        action_key="action_chunk",
        prompt_key="prompt",
    )

    dataset = MDSDataset(
        remote=mds_path,
        local="/tmp/mds_cache",
        shuffle=False,
        action_chunk_size=16,
        batch_size=batch_size,
        num_workers=num_workers,
        config=config,
        predownload=6000,
        cache_limit="100gb",
    )

    logging.info(f"Dataset size: {len(dataset)} samples")
    logging.info(f"Starting benchmark...")

    # Warmup
    logging.info(f"Warming up ({warmup_batches} batches)...")
    data_iter = iter(dataset)
    for i in range(warmup_batches):
        batch = next(data_iter)
        if i == 0:
            logging.info(f"Batch shape: actions={batch['actions'].shape}, image={batch['observation']['image'].shape}")

    # Benchmark
    logging.info(f"Benchmarking ({num_batches} batches)...")
    start_time = time.perf_counter()

    for i in range(num_batches):
        batch = next(data_iter)
        if (i + 1) % 10 == 0:
            elapsed = time.perf_counter() - start_time
            batches_per_sec = (i + 1) / elapsed
            samples_per_sec = batches_per_sec * batch_size
            logging.info(f"  [{i+1}/{num_batches}] {batches_per_sec:.2f} batches/s, {samples_per_sec:.0f} samples/s")

    end_time = time.perf_counter()
    total_time = end_time - start_time

    batches_per_sec = num_batches / total_time
    samples_per_sec = batches_per_sec * batch_size
    time_per_batch = total_time / num_batches

    logging.info("=" * 60)
    logging.info(f"RESULTS:")
    logging.info(f"  Total time: {total_time:.2f}s for {num_batches} batches")
    logging.info(f"  Throughput: {batches_per_sec:.2f} batches/sec")
    logging.info(f"  Throughput: {samples_per_sec:.0f} samples/sec")
    logging.info(f"  Time per batch: {time_per_batch*1000:.1f}ms")
    logging.info("=" * 60)

    return batches_per_sec, samples_per_sec


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark MDS data loading")
    parser.add_argument("--mds-path", type=str, required=True, help="Path to MDS dataset (local or s3://)")
    parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
    parser.add_argument("--num-workers", type=int, default=64, help="Number of DataLoader workers")
    parser.add_argument("--num-batches", type=int, default=100, help="Number of batches to benchmark")
    parser.add_argument("--warmup-batches", type=int, default=10, help="Number of warmup batches")

    args = parser.parse_args()

    benchmark_mds(
        mds_path=args.mds_path,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        num_batches=args.num_batches,
        warmup_batches=args.warmup_batches,
    )
