1. 初识Weights & Biases(wandb)
wandb(Weights & Biases)已经成为机器学习实验管理的行业标准工具之一。作为一名长期从事深度学习研究的工程师,我亲身体验过从手动记录Excel到使用TensorBoard,再到全面迁移到wandb的完整历程。这个工具彻底改变了我的实验管理方式。
wandb的核心价值在于它解决了机器学习研究中的几个关键痛点:
- 实验可复现性:通过自动记录超参数、代码版本和环境依赖,确保任何实验都能被精确复现
- 实时监控:训练过程中的指标变化可以实时查看,无需等待训练结束
- 团队协作:实验结果可以轻松分享给团队成员,支持评论和讨论
- 资源管理:自动跟踪GPU/CPU使用情况,帮助优化计算资源分配
与TensorBoard相比,wandb的独特优势在于:
- 云端存储:实验结果自动同步到云端,不受本地机器限制
- 可视化灵活:支持自定义面板布局,可以自由组合各种图表
- 项目管理:实验可以按项目组织,支持跨实验比较
- 生态系统:与主流ML框架(PyTorch、TensorFlow等)深度集成
提示:wandb免费版对个人用户已经足够强大,支持500GB的存储空间和100GB的模型存储,对于大多数研究项目来说完全够用。
2. wandb基础使用指南
2.1 安装与配置
wandb的安装非常简单,只需要一个pip命令:
bash复制pip install wandb
安装完成后需要进行登录认证。这里有两种方式:
- 交互式登录(推荐):
bash复制wandb login
执行后会提示输入API key,可以在wandb官网的个人设置页面找到。
- 非交互式登录(适合脚本自动化):
bash复制wandb login --relogin YOUR_API_KEY
验证登录状态:
bash复制wandb status
2.2 核心概念解析
理解wandb的几个核心概念对高效使用至关重要:
- Run(运行):代表一次完整的实验过程,包含从开始到结束的所有记录数据
- Project(项目):相关实验的集合,相当于一个文件夹
- Sweep(扫描):用于超参数搜索的自动化工具
- Artifacts(工件):版本化的数据集、模型等大型文件
2.3 基础使用模式
一个典型的wandb使用流程包含以下步骤:
python复制import wandb
# 1. 初始化一个Run
wandb.init(
project="my-project", # 项目名称
name="experiment-1", # 实验名称
config={ # 超参数配置
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 10
}
)
# 2. 获取配置(方便代码中使用)
config = wandb.config
# 3. 在训练循环中记录指标
for epoch in range(config.epochs):
# ...训练代码...
wandb.log({
"loss": train_loss,
"accuracy": train_acc,
"val_loss": val_loss,
"val_acc": val_acc
})
# 4. 可选:保存模型
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth") # 上传到wandb
# 5. 结束Run
wandb.finish()
3. 实战MNIST分类任务
让我们通过一个完整的MNIST分类示例,展示wandb在实际项目中的应用。
3.1 项目初始化
python复制import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 初始化wandb
wandb.init(
project="mnist-classification",
name="cnn-1",
config={
"learning_rate": 0.001,
"batch_size": 128,
"epochs": 10,
"dropout": 0.3,
"optimizer": "Adam"
}
)
config = wandb.config
3.2 数据准备
python复制# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载数据集
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False
)
3.3 模型定义
python复制class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout = nn.Dropout(config.dropout)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = self.dropout(x)
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = CNN().to(device)
wandb.watch(model) # 自动跟踪模型梯度
3.4 训练循环
python复制criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
for epoch in range(1, config.epochs + 1):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
# 计算训练指标
avg_train_loss = train_loss / len(train_loader)
train_acc = 100. * correct / total
# 验证阶段
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
avg_test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / len(test_loader.dataset)
# 记录指标
wandb.log({
"epoch": epoch,
"train_loss": avg_train_loss,
"train_acc": train_acc,
"test_loss": avg_test_loss,
"test_acc": test_acc
})
print(f"Epoch {epoch}: Train Acc {train_acc:.2f}%, Test Acc {test_acc:.2f}%")
# 保存模型
torch.save(model.state_dict(), "mnist_cnn.pth")
wandb.save("mnist_cnn.pth")
wandb.finish()
4. 高级功能与NavRL项目实战
4.1 自定义指标记录
在NavRL这类强化学习项目中,我们通常需要跟踪更多自定义指标。以下是如何在wandb中实现:
python复制# 在环境类中定义统计指标
stats_spec = {
"return": UnboundedContinuousTensorSpec(1),
"episode_len": UnboundedContinuousTensorSpec(1),
"reach_goal": UnboundedContinuousTensorSpec(1),
"collision": UnboundedContinuousTensorSpec(1),
# 各种调试奖励和惩罚
"debug_reward_vel": UnboundedContinuousTensorSpec(1),
"debug_reward_facing": UnboundedContinuousTensorSpec(1),
"debug_penalty_smooth": UnboundedContinuousTensorSpec(1),
"debug_penalty_yaw": UnboundedContinuousTensorSpec(1),
"debug_heading_error": UnboundedContinuousTensorSpec(1),
"debug_current_speed": UnboundedContinuousTensorSpec(1)
}
# 在训练循环中收集并记录指标
episode_stats = EpisodeStats(episode_stats_keys)
for i, data in enumerate(collector):
info = {
"env_frames": collector._frames,
"rollout_fps": collector._fps
}
# 训练模型并获取损失
train_loss_stats = policy.train(data)
info.update(train_loss_stats)
# 收集环境统计量
episode_stats.add(data)
if len(episode_stats) >= transformed_env.num_envs:
stats = {
"train/" + ".".join(k): torch.mean(v.float()).item()
for k, v in episode_stats.pop().items(True, True)
}
info.update(stats)
# 记录到wandb
run.log(info)
4.2 断点续训功能
wandb支持从特定运行ID恢复训练,这对于长时间运行的强化学习实验特别有用:
python复制if cfg.wandb.run_id is None:
run = wandb.init(
project=cfg.wandb.project,
name=f"{cfg.wandb.name}/{datetime.datetime.now().strftime('%m-%d_%H-%M')}",
entity=cfg.wandb.entity,
config=cfg,
mode=cfg.wandb.mode,
id=wandb.util.generate_id(),
)
else:
run = wandb.init(
project=cfg.wandb.project,
entity=cfg.wandb.entity,
resume="must",
id=cfg.wandb.run_id
)
4.3 离线模式使用
当在没有网络连接的环境中训练时,可以使用wandb的离线模式:
python复制wandb.init(mode="offline")
# ...训练代码...
wandb.finish()
训练完成后,可以在有网络的环境中将结果同步到wandb服务器:
bash复制wandb sync /path/to/offline/runs
5. 最佳实践与经验分享
5.1 命名规范建议
- 项目名称:简洁明确,如"drl-navigation"、"image-classification"
- 运行名称:包含关键信息,如"ppo-lr0.0001-bs2048"或"resnet50-adam"
- 指标名称:使用一致的命名约定,如"train/loss"和"val/loss"
5.2 高效使用技巧
- 自定义可视化:在wandb面板中可以创建自定义图表布局,将相关指标放在一起比较
- 超参数搜索:使用wandb sweep功能进行自动化超参数优化
- 模型版本控制:利用Artifacts功能管理不同版本的模型权重
- 团队协作:通过共享项目和添加评论与团队成员高效协作
5.3 常见问题解决
-
指标没有显示:
- 确保调用了wandb.log()
- 检查指标名称是否包含特殊字符
- 确认wandb.init()和wandb.finish()配对使用
-
上传速度慢:
- 减少日志频率
- 使用wandb.save()只上传关键模型检查点
- 考虑先使用离线模式,训练完成后再同步
-
内存占用高:
- 限制wandb.watch()的log_freq参数
- 减少记录的图像/视频数量
- 使用wandb.config而不是直接记录大型字典
在实际项目中,我发现wandb特别适合以下场景:
- 需要长时间运行的强化学习训练
- 涉及大量超参数组合的实验
- 团队协作开发机器学习模型
- 需要精确复现的实验记录
通过合理使用wandb的各种功能,我的实验管理效率提升了至少50%,调试时间减少了约70%。特别是在NavRL这类复杂项目中,能够清晰地看到各个奖励分量的变化趋势,对于调试奖励函数设计帮助极大。