import torch
import torch.distributed as dist
dist.init_process_group('nccl')
rank = dist.get_rank()
device = torch.device(f'cuda:{rank % torch.cuda.device_count()}')
t = torch.tensor([rank], device=device, dtype=torch.float32)
dist.all_reduce(t)
print(f'Rank {rank}: all_reduce result = {t.item()} (expected 6)', flush=True)
dist.destroy_process_group()
