当你第一次听说PyTorch分布式训练时,可能会觉得这是个高大上的概念。其实简单来说,分布式训练就是让多台机器或多个GPU一起协作完成模型训练任务。想象一下,如果训练一个大型模型需要10天时间,用4台机器同时训练可能只需要2-3天,这就是分布式的魅力所在。
torch.distributed.launch是PyTorch提供的一个非常实用的工具,它帮我们处理了分布式训练中最麻烦的部分——进程管理和环境变量设置。在实际项目中,我发现很多同学卡在分布式训练的第一步,就是因为不熟悉这个启动工具的参数配置。
先来看个最简单的例子,单机多卡训练:
bash复制python -m torch.distributed.launch --nproc_per_node=4 train.py
这条命令告诉PyTorch:在当前机器上启动4个进程,每个进程使用一张GPU来运行train.py脚本。launch工具会自动为每个进程设置好所需的环境变量,包括RANK、LOCAL_RANK等。
在分布式训练中,nnodes和node_rank是两个最基础的参数。nnodes指定了参与训练的机器总数,node_rank则是当前机器的编号(从0开始)。
比如我们在两个机器上训练,配置应该是这样的:
bash复制# 机器0(主节点)
python -m torch.distributed.launch --nnodes=2 --node_rank=0 --master_addr="192.168.1.100" --nproc_per_node=4 train.py
# 机器1
python -m torch.distributed.launch --nnodes=2 --node_rank=1 --master_addr="192.168.1.100" --nproc_per_node=4 train.py
这里有个坑我踩过多次:master_addr必须指向主节点的IP,而且所有节点上的master_addr必须一致。有一次训练卡住半小时没反应,最后发现是第二个节点的master_addr写成了自己的IP。
master_port参数指定了主节点监听的端口号,默认是29500。如果多组训练任务在同一批机器上运行,一定要设置不同的端口号,否则会出现端口冲突。
实际项目中我常用这样的配置:
bash复制--master_port=29501
nproc_per_node决定了每个机器上使用多少个GPU。这个数字应该小于等于机器上的实际GPU数量。我曾经犯过一个错误:在只有8卡的机器上设置了--nproc_per_node=10,结果训练直接报错退出。
当launch启动训练脚本时,它会自动设置以下环境变量:
在代码中可以通过os.environ获取这些值:
python复制import os
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
这些环境变量在分布式训练中非常有用。比如数据分片:
python复制dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, sampler=sampler)
还有模型保存时只让rank 0进程执行:
python复制if rank == 0:
torch.save(model.state_dict(), 'model.pth')
PyTorch 2.0开始,官方推荐使用torchrun替代原来的launch.py。新方法简化了不少参数配置:
bash复制# 旧方式(已废弃)
python -m torch.distributed.launch --nproc_per_node=4 train.py
# 新方式
torchrun --nproc_per_node=4 train.py
最大的变化是--use_env现在默认启用,意味着LOCAL_RANK等参数必须从环境变量获取,而不是命令行参数。
问题1:训练卡住不开始
检查所有节点是否都启动了训练,特别是nnodes设置是否正确。分布式训练要求所有节点都就绪才会开始。
问题2:CUDA out of memory
确保nproc_per_node设置合理。可以尝试减小batch size或使用更少的GPU。
问题3:端口冲突
更换master_port值,确保没有其他程序占用该端口。
最后分享一个实用的调试技巧:在训练脚本开头打印所有环境变量,这能帮你快速定位问题:
python复制import os
print("=== Environment Variables ===")
for k, v in os.environ.items():
if 'RANK' in k or 'LOCAL' in k or 'WORLD' in k:
print(f"{k}: {v}")