深夜调试代码时,我盯着屏幕上那张布满褐色斑点的番茄叶片照片,突然意识到——农业AI的浪漫在于,我们正在用卷积神经网络解读植物的"语言"。PlantVillage数据集里这些看似普通的叶片图像,背后是农民一整年的收成希望。本文将带你从零开始,用PyTorch构建一个能识别30+种作物病害的智能系统,过程中你会遇到数据不平衡的挑战、GPU内存不足的报错,但最终收获的不仅是能跑通的代码,更是一套解决真实世界问题的完整方法论。
工欲善其事,必先利其器。推荐使用Google Colab Pro作为实验环境,它不仅提供免费T4 GPU,还能直接挂载Google Drive实现数据持久化。以下是需要安装的核心组件:
bash复制pip install torch==2.0.1 torchvision==0.15.2
pip install albumentations==1.3.1 kaggle==1.5.12
从Kaggle下载PlantVillage数据集时,有个小技巧可以绕过手动下载的麻烦。先在Kaggle账户创建API token,然后执行:
python复制import os
os.environ['KAGGLE_USERNAME'] = 'your_username'
os.environ['KAGGLE_KEY'] = 'your_key'
!kaggle datasets download -d abdallahalidev/plantvillage-dataset
解压后你会看到这样的目录结构:
code复制plantvillage/
├── color/
│ ├── Apple___Apple_scab/
│ ├── Apple___Black_rot/
│ └── ...38个类别
└── grayscale/ # 忽略灰度图像
注意:原始数据集存在类别不平衡问题,比如健康叶片样本量是病害叶片的3倍。建议先运行以下分析代码:
python复制from pathlib import Path
class_dist = {p.stem: len(list(p.glob('*.JPG')))
for p in Path('plantvillage/color').iterdir()
if p.is_dir()}
print(sorted(class_dist.items(), key=lambda x: x[1]))
面对有限的农业图像数据,聪明的增强策略能让模型见识到更多"虚拟病害"。我推荐使用Albumentations库,它比torchvision的transform快30%,且支持更复杂的空间变换:
python复制import albumentations as A
train_transform = A.Compose([
A.RandomResizedCrop(256, 256, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.3),
A.RandomRotate90(p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2,
saturation=0.2, hue=0.1, p=0.5),
A.CoarseDropout(max_holes=8, max_height=32,
max_width=32, fill_value=0, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
处理类别不平衡的三大实战技巧:
weight = 1 / log(1.2 + class_count)自定义Dataset类的核心写法:
python复制class PlantDiseaseDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform=None):
self.classes = sorted([d.name for d in Path(root_dir).glob('*')])
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = list(Path(root_dir).rglob('*.JPG'))
self.transform = transform
def __getitem__(self, idx):
img_path = self.samples[idx]
img = cv2.imread(str(img_path))[:, :, ::-1] # BGR→RGB
label = self.class_to_idx[img_path.parent.name]
if self.transform:
augmented = self.transform(image=img)
img = augmented['image']
return img.transpose(2, 0, 1), label # HWC→CHW
ResNet50是个不错的起点,但直接全量训练会浪费显存。我的改进方案是:
实现核心代码:
python复制from torchvision.models import resnet50
class DiseaseClassifier(nn.Module):
def __init__(self, num_classes=38):
super().__init__()
self.backbone = resnet50(pretrained=True)
self.attention = CBAM(2048) # 自定义注意力模块
self.backbone.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(2048, num_classes)
)
def forward(self, x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
x = self.attention(x) # 添加注意力
x = self.backbone.avgpool(x)
x = torch.flatten(x, 1)
return self.backbone.fc(x)
提示:遇到"CUDA out of memory"时,除了减小batch_size,还可以尝试:
- 使用梯度累积:每4个batch更新一次参数
- 启用混合精度训练:
scaler = torch.cuda.amp.GradScaler()
农业图像分类有三大独特挑战:背景干扰、病害相似性、拍摄条件差异。我的解决方案是:
多阶段学习率调度:
python复制optimizer = torch.optim.AdamW([
{'params': model.backbone.parameters(), 'lr': 1e-5},
{'params': model.attention.parameters(), 'lr': 1e-4},
{'params': model.backbone.fc.parameters(), 'lr': 3e-4}
])
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=3e-4,
steps_per_epoch=len(train_loader),
epochs=20
)
评估指标选择:
python复制from sklearn.metrics import classification_report
def evaluate(model, dataloader):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in dataloader:
outputs = model(inputs.cuda())
preds = torch.argmax(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
print(classification_report(all_labels, all_preds,
target_names=class_names))
plot_confusion_matrix(all_labels, all_preds) # 自定义可视化函数
用Gradio快速搭建演示界面比Flask更高效,以下代码创建了一个带病害解释的交互应用:
python复制import gradio as gr
model = load_model('best_model.pth')
class_descriptions = {
'Tomato_Early_blight': '建议每周喷洒铜基杀菌剂',
'Corn_Common_rust': '需清除田间杂草宿主'
}
def predict(image):
img = preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
output = model(img)
prob = torch.softmax(output, dim=1)[0]
top3_idx = torch.topk(prob, 3).indices
result = []
for i in top3_idx:
cls_name = class_names[i]
result.append((
cls_name,
f"{prob[i]:.1%}",
class_descriptions.get(cls_name, '暂无防治建议')
))
return result
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Dataframe(
headers=['病害类型', '置信度', '防治建议'],
datatype=['str', 'str', 'str']
),
examples=['test_images/tomato1.jpg', 'test_images/corn1.jpg']
)
demo.launch(server_name="0.0.0.0", server_port=7860)
当你的基础模型准确率达到85%以上时,可以尝试这些提升策略:
多模态融合:
异常检测机制:
python复制# 使用Mahalanobis距离检测未知病害
def is_anomaly(features, mean, cov_inv, threshold=5.0):
delta = features - mean
distance = np.sqrt(delta.T @ cov_inv @ delta)
return distance > threshold
模型轻量化方案:
model = torch.quantization.quantize_dynamic(model)