import torch, torch.distributed as dist, os
from datetime import timedelta
dist.init_process_group('nccl', timeout=timedelta(seconds=120))
rank = dist.get_rank()
device = torch.device(f'cuda:{int(os.environ.get("LOCAL_RANK", 0))}')
# Test with large tensor (similar to model params)
t = torch.randn(70_000_000, device=device)
print(f'Rank {rank}: broadcasting 70M params...', flush=True)
dist.broadcast(t, src=0)
print(f'Rank {rank}: broadcast done, sum={t.sum().item():.4f}', flush=True)
dist.destroy_process_group()
