1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
| import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision import models, transforms
from torch.utils.data import DataLoader, TensorDataset
import os
def setup(rank, world_size):
"""初始化分布式训练环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
"""清理分布式训练环境"""
dist.destroy_process_group()
def create_dummy_dataset(size=1000):
"""创建虚拟数据集用于演示"""
# 模拟ImageNet数据 (224x224 RGB图像)
images = torch.randn(size, 3, 224, 224)
labels = torch.randint(0, 1000, (size,))
return TensorDataset(images, labels)
def train(rank, world_size, epochs=5):
"""训练函数"""
print(f"Running DDP on rank {rank}.")
setup(rank, world_size)
# 设置设备
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
# 创建模型
model = models.resnet18(pretrained=False)
model = model.to(device)
ddp_model = DDP(model, device_ids=[rank])
# 创建优化器和损失函数
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# 创建数据加载器
dataset = create_dummy_dataset()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
# 训练循环
ddp_model.train()
for epoch in range(epochs):
sampler.set_epoch(epoch)
epoch_loss = 0.0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 10 == 0 and rank == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
if rank == 0:
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch} 完成,平均损失: {avg_loss:.4f}")
cleanup()
def main():
"""主函数"""
world_size = torch.cuda.device_count()
print(f"使用 {world_size} 个GPU进行训练")
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
|