#!/usr/bin/env python3
"""Remove outlier actions from HDF5 demonstration datasets.

Detects and removes timesteps with:
- Actions exceeding a magnitude threshold
- NaN values in actions

Usage:
    python remove_outliers.py --hdf5-file path/to/dataset.hdf5
    python remove_outliers.py --hdf5-file path/to/dataset.hdf5 --threshold 10.0
    python remove_outliers.py --hdf5-file path/to/dataset.hdf5 --dry-run  # Preview only
"""

import argparse
import h5py
import numpy as np
from pathlib import Path


def get_all_datasets(group, prefix=''):
    """Recursively get all dataset paths in an HDF5 group."""
    datasets = []
    for key in group.keys():
        path = f"{prefix}/{key}" if prefix else key
        if isinstance(group[key], h5py.Dataset):
            datasets.append(path)
        elif isinstance(group[key], h5py.Group):
            datasets.extend(get_all_datasets(group[key], path))
    return datasets


def remove_outliers_inplace(hdf5_path: str, threshold: float = 5.0, dry_run: bool = True, min_timesteps: int = 10):
    """
    Remove demos with outlier actions or NaN values from the HDF5 file.
    Truncates at the first outlier timestep, or deletes the entire demo if outliers are too early.
    
    Args:
        hdf5_path: Path to the HDF5 file
        threshold: Action values beyond this magnitude are considered outliers
        dry_run: If True, only print what would be done without modifying the file
        min_timesteps: Minimum timesteps required before first outlier (demos with earlier outliers are deleted)
    """
    mode = 'r' if dry_run else 'r+'
    
    print(f"{'[DRY RUN] ' if dry_run else ''}Processing: {hdf5_path}")
    print(f"Threshold: {threshold}, Min timesteps: {min_timesteps}")
    print("-" * 60)
    
    with h5py.File(hdf5_path, mode) as f:
        demo_keys = sorted([k for k in f['data'].keys() if k.startswith('demo_')], 
                          key=lambda x: int(x.split('_')[1]))
        
        print(f"Total demos: {len(demo_keys)}")
        
        truncated_count = 0
        deleted_count = 0
        clean_count = 0
        
        for demo_key in demo_keys:
            demo = f[f'data/{demo_key}']
            
            if 'action/droid_joint_pos_action' not in demo:
                continue
                
            actions = np.array(demo['action/droid_joint_pos_action'])
            arm_actions = actions[:, :7]
            
            # Find outlier or NaN timesteps
            outlier_mask = np.any(np.abs(arm_actions) > threshold, axis=1)
            nan_mask = np.any(np.isnan(actions), axis=1)
            inf_mask = np.any(np.isinf(actions), axis=1)
            bad_mask = outlier_mask | nan_mask | inf_mask
            
            if not np.any(bad_mask):
                clean_count += 1
                continue
            
            # Find the first bad timestep
            first_bad_idx = np.where(bad_mask)[0][0]
            max_bad_value = np.nanmax(np.abs(arm_actions[bad_mask]))
            
            # If bad data starts too early, delete the whole demo
            if first_bad_idx < min_timesteps:
                print(f"  {demo_key}: DELETE (bad at t={first_bad_idx}, max={max_bad_value:.2e})")
                deleted_count += 1
                if not dry_run:
                    del f[f'data/{demo_key}']
                continue
            
            # Otherwise truncate at the first bad timestep
            n_to_remove = len(actions) - first_bad_idx
            print(f"  {demo_key}: TRUNCATE at t={first_bad_idx} (remove {n_to_remove}/{len(actions)}, max={max_bad_value:.2e})")
            truncated_count += 1
            
            if not dry_run:
                dataset_paths = get_all_datasets(demo)
                
                for ds_path in dataset_paths:
                    dataset = demo[ds_path]
                    old_shape = dataset.shape
                    
                    # Truncate datasets that match action length or action length + 1 (for states)
                    if old_shape[0] == len(actions) or old_shape[0] == len(actions) + 1:
                        truncate_len = first_bad_idx if old_shape[0] == len(actions) else first_bad_idx + 1
                        truncated_data = dataset[:truncate_len]
                        del demo[ds_path]
                        demo.create_dataset(ds_path, data=truncated_data)
        
        print("-" * 60)
        print(f"Clean demos: {clean_count}")
        print(f"Truncated demos: {truncated_count}")
        print(f"Deleted demos: {deleted_count}")
        
        if dry_run:
            print(f"\n[DRY RUN] No changes made. Run without --dry-run to apply changes.")
        else:
            print(f"\n✅ Changes applied successfully!")


def verify_clean(hdf5_path: str, threshold: float = 5.0):
    """Verify that no outliers remain in the file."""
    print(f"\nVerifying: {hdf5_path}")
    
    outlier_count = 0
    nan_count = 0
    
    with h5py.File(hdf5_path, 'r') as f:
        demo_keys = [k for k in f['data'].keys() if k.startswith('demo_')]
        print(f"Total demos: {len(demo_keys)}")
        
        for demo_key in demo_keys:
            demo = f[f'data/{demo_key}']
            if 'action/droid_joint_pos_action' in demo:
                actions = np.array(demo['action/droid_joint_pos_action'])
                if np.any(np.abs(actions[:, :7]) > threshold):
                    outlier_count += 1
                if np.any(np.isnan(actions)) or np.any(np.isinf(actions)):
                    nan_count += 1
    
    print(f"Demos with outliers (|action| > {threshold}): {outlier_count}")
    print(f"Demos with NaN/Inf: {nan_count}")
    
    if outlier_count == 0 and nan_count == 0:
        print("✅ All clean!")
        return True
    else:
        print("❌ Still has issues!")
        return False


def main():
    parser = argparse.ArgumentParser(description="Remove outlier actions from HDF5 demonstration datasets.")
    parser.add_argument("--hdf5-file", type=str, required=True, help="Path to the HDF5 file")
    parser.add_argument("--threshold", type=float, default=5.0, help="Action magnitude threshold (default: 5.0)")
    parser.add_argument("--min-timesteps", type=int, default=10, help="Min timesteps before first outlier (default: 10)")
    parser.add_argument("--dry-run", action="store_true", help="Preview changes without modifying the file")
    parser.add_argument("--verify-only", action="store_true", help="Only verify the file, don't modify")
    args = parser.parse_args()
    
    hdf5_path = Path(args.hdf5_file).expanduser().resolve()
    
    if not hdf5_path.exists():
        print(f"Error: File not found: {hdf5_path}")
        return 1
    
    if args.verify_only:
        success = verify_clean(str(hdf5_path), args.threshold)
        return 0 if success else 1
    
    remove_outliers_inplace(str(hdf5_path), args.threshold, args.dry_run, args.min_timesteps)
    
    if not args.dry_run:
        verify_clean(str(hdf5_path), args.threshold)
    
    return 0


if __name__ == "__main__":
    exit(main())














