import torch
import torch.distributed as dist
import os

def main():
    from datetime import timedelta
    dist.init_process_group('nccl', timeout=timedelta(seconds=30))
    rank = dist.get_rank()
    world = dist.get_world_size()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    device = torch.device(f'cuda:{local_rank}')
    print(f'Rank {rank}/{world} on {device} ({torch.cuda.get_device_name(device)})', flush=True)

    t = torch.tensor([rank], device=device, dtype=torch.float32)
    dist.all_reduce(t)
    expected = sum(range(world))
    print(f'Rank {rank}: all_reduce = {t.item()} (expected {expected})', flush=True)

    dist.barrier()
    if rank == 0:
        print('ALL RANKS PASSED', flush=True)
    dist.destroy_process_group()

if __name__ == '__main__':
    main()
