"""Compute normalization statistics for a config.

This script is used to compute the normalization statistics for a given config. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config assets directory.

If the config has s3_checkpoint_bucket set, norm stats will be:
  1. Downloaded from S3 before computing (skipping computation if found).
  2. Uploaded to S3 after computing.
"""

import logging
import pathlib
import re

import numpy as np
import tqdm
import tyro

import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _parse_s3_url(url: str) -> tuple[str, str]:
    """Parse 's3://bucket/key' into (bucket, key)."""
    match = re.match(r"s3://([^/]+)/?(.*)", url)
    if not match:
        raise ValueError(f"Invalid S3 URL: {url}")
    return match.group(1), match.group(2).rstrip("/")


def _s3_norm_stats_key(s3_bucket: str, config_name: str, repo_id: str) -> tuple[str, str]:
    """Return (bucket, key) for the norm stats file on S3."""
    bucket, prefix = _parse_s3_url(s3_bucket)
    key = f"{prefix}/assets/{config_name}/{repo_id}/norm_stats.json" if prefix else f"assets/{config_name}/{repo_id}/norm_stats.json"
    return bucket, key


def try_download_norm_stats(s3_bucket: str, config_name: str, repo_id: str, local_path: pathlib.Path) -> bool:
    """Try to download norm stats from S3. Returns True if successful."""
    try:
        import boto3
        bucket, key = _s3_norm_stats_key(s3_bucket, config_name, repo_id)
        local_file = local_path / "norm_stats.json"
        logger.info(f"Checking S3 for norm stats: s3://{bucket}/{key}")
        s3 = boto3.client("s3")
        s3.download_file(bucket, key, str(local_file))
        logger.info(f"Downloaded norm stats from S3 to {local_file}")
        return True
    except Exception as e:
        logger.info(f"Norm stats not found on S3 (will compute): {e}")
        return False


def upload_norm_stats(s3_bucket: str, config_name: str, repo_id: str, local_path: pathlib.Path) -> None:
    """Upload norm stats to S3."""
    try:
        import boto3
        bucket, key = _s3_norm_stats_key(s3_bucket, config_name, repo_id)
        local_file = local_path / "norm_stats.json"
        logger.info(f"Uploading norm stats to s3://{bucket}/{key}")
        s3 = boto3.client("s3")
        s3.upload_file(str(local_file), bucket, key)
        logger.info("Norm stats uploaded to S3")
    except Exception as e:
        logger.warning(f"Failed to upload norm stats to S3: {e}")


class RemoveStrings(transforms.DataTransformFn):
    def __call__(self, x: dict) -> dict:
        return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}


def create_torch_dataloader(
    data_config: _config.DataConfig,
    action_horizon: int,
    batch_size: int,
    model_config: _model.BaseModelConfig,
    num_workers: int,
    max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
    if data_config.repo_id is None:
        raise ValueError("Data config must have a repo_id")
    dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
    dataset = _data_loader.TransformedDataset(
        dataset,
        [
            *data_config.repack_transforms.inputs,
            *data_config.data_transforms.inputs,
            # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
            RemoveStrings(),
        ],
    )
    if max_frames is not None and max_frames < len(dataset):
        num_batches = max_frames // batch_size
        shuffle = True
    else:
        num_batches = len(dataset) // batch_size
        shuffle = False
    data_loader = _data_loader.TorchDataLoader(
        dataset,
        local_batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        num_batches=num_batches,
    )
    return data_loader, num_batches


def create_rlds_dataloader(
    data_config: _config.DataConfig,
    action_horizon: int,
    batch_size: int,
    max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
    dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)
    dataset = _data_loader.IterableTransformedDataset(
        dataset,
        [
            *data_config.repack_transforms.inputs,
            *data_config.data_transforms.inputs,
            # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
            RemoveStrings(),
        ],
        is_batched=True,
    )
    if max_frames is not None and max_frames < len(dataset):
        num_batches = max_frames // batch_size
    else:
        # NOTE: this length is currently hard-coded for DROID.
        num_batches = len(dataset) // batch_size
    data_loader = _data_loader.RLDSDataLoader(
        dataset,
        num_batches=num_batches,
    )
    return data_loader, num_batches


def create_mds_dataloader(
    data_config: _config.DataConfig,
    action_horizon: int,
    batch_size: int,
    max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
    dataset = _data_loader.create_mds_dataset(data_config, action_horizon, batch_size)
    dataset = _data_loader.IterableTransformedDataset(
        dataset,
        [
            *data_config.repack_transforms.inputs,
            *data_config.data_transforms.inputs,
            # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
            RemoveStrings(),
        ],
        is_batched=True,
    )
    if max_frames is not None and max_frames < len(dataset):
        num_batches = max_frames // batch_size
    else:
        num_batches = len(dataset) // batch_size
    data_loader = _data_loader.RLDSDataLoader(
        dataset,
        num_batches=num_batches,
    )
    return data_loader, num_batches

def main(config_name: str, max_frames: int | None = None, dry_run: bool = False):
    """Compute and save normalization statistics for a training config.

    Args:
        config_name: Name of the training config (e.g. 'pi05_yam').
        max_frames: Cap on the number of dataset frames to process.
        dry_run: If True, validate the data pipeline (read 3 batches, print
                 shapes) without computing or saving any statistics.
    """
    config = _config.get_config(config_name)
    data_config = config.data.create(config.assets_dirs, config.model)

    if dry_run:
        print("DRY RUN: validating data pipeline — no stats will be saved", flush=True)
        _dry_run_data_pipeline(config, data_config)
        return

    output_path = config.assets_dirs / data_config.repo_id
    output_path.mkdir(parents=True, exist_ok=True)

    # Try to download pre-computed norm stats from S3.
    s3_bucket = config.remote_checkpoint_dir
    if s3_bucket and try_download_norm_stats(s3_bucket, config.name, data_config.repo_id, output_path):
        print(f"Using cached norm stats from S3 at: {output_path}")
        return

    # Compute norm stats from scratch.
    # Both WDS and MDS set rlds_data_dir; pure LeRobot configs leave it None.
    if data_config.rlds_data_dir is not None:
        data_loader, num_batches = create_rlds_dataloader(
            data_config, config.model.action_horizon, config.batch_size, max_frames
        )
    else:
        data_loader, num_batches = create_torch_dataloader(
            data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames
        )

    keys = ["state", "actions"]
    stats = {key: normalize.RunningStats() for key in keys}


    for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
        for key in keys:
            stats[key].update(np.asarray(batch[key]))
    norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}

    print(f"Writing stats to: {output_path}")
    normalize.save(output_path, norm_stats)

    # Upload to S3 for future runs.
    if s3_bucket:
        upload_norm_stats(s3_bucket, config.name, data_config.repo_id, output_path)


def _dry_run_data_pipeline(config, data_config, num_batches: int = 3) -> None:
    """Iterate a few batches and print shapes. Raises on any pipeline error."""
    # Both WDS and MDS set rlds_data_dir; pure LeRobot configs leave it None.
    if data_config.rlds_data_dir is not None:
        data_loader, _ = create_rlds_dataloader(
            data_config, config.model.action_horizon, config.batch_size, max_frames=num_batches * config.batch_size
        )
    else:
        data_loader, _ = create_torch_dataloader(
            data_config, config.model.action_horizon, config.batch_size, config.model,
            config.num_workers, max_frames=num_batches * config.batch_size
        )

    for i, batch in enumerate(data_loader):
        print(f"  batch {i}:", flush=True)
        for key in ("state", "actions"):
            if key in batch:
                arr = np.asarray(batch[key])
                print(f"    {key}: shape={arr.shape}  dtype={arr.dtype}", flush=True)
        if "image" in batch:
            for cam, img in batch["image"].items():
                print(f"    image/{cam}: shape={np.asarray(img).shape}", flush=True)
        if i + 1 >= num_batches:
            break
    print("DRY RUN: data pipeline OK", flush=True)


if __name__ == "__main__":
    tyro.cli(main)