在深度学习模型部署到移动端或嵌入式设备时,模型大小和计算效率往往是关键瓶颈。想象一下,你训练了一个效果不错的DenseNet-121模型,但当尝试把它部署到手机APP时,发现推理速度慢得无法接受。这时候就需要模型压缩技术,而结构化剪枝正是其中最实用的方法之一。
传统剪枝就像修剪树木的枝叶,可以随意剪掉任意树枝(参数)。但结构化剪枝更像是修剪盆栽——需要按照特定结构(如整条树枝)来修剪,这样才能保持树的基本形态(模型结构)。这种修剪方式特别适合需要保持输入输出张量形状的卷积神经网络。
我去年在部署一个图像分类模型时就踩过坑:先用非结构化剪枝减少了70%参数,结果推理速度只提升了10%。后来改用结构化剪枝,虽然只减少了50%参数,但速度直接翻倍。这就是结构化剪枝的魔力——它真正优化的是计算图的结构,而不仅仅是参数数量。
DepGraph(Dependency Graph)就像模型的"神经系统图"。当我们要剪掉一个卷积层的某些通道时,需要知道这个操作会影响哪些其他层。比如在ResNet中,剪枝第一个卷积层会连锁影响后续的BN层、残差连接等。
传统做法需要手动编写这些依赖规则,就像每次剪枝都要重新画一遍神经连接图。而Torch-Pruning的DepGraph能自动构建这个依赖关系图。举个例子:
python复制import torch_pruning as tp
from torchvision.models import resnet18
model = resnet18().eval()
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
这三行代码就完成了整个ResNet-18的依赖分析。我实测下来,即使是DenseNet这种密集连接的网络,构建依赖图也只需要几秒钟。
DepGraph的工作原理可以用快递仓库来类比。假设仓库(模型)里有各种传送带(张量流动)连接不同工作站(网络层)。当我们要关闭某个工作站的部分通道时,需要调整所有相连的传送带宽度。
具体到代码层面,当我们想剪枝某个卷积层时:
python复制group = DG.get_pruning_group(
model.conv1,
tp.prune_conv_out_channels,
idxs=[2, 6, 9]
)
这个group就包含了所有需要同步调整的层。就像快递仓库管理员,我们只需要决定关闭哪些通道(idxs),DepGraph会自动处理所有传送带的调整。
DenseNet是出了名的难剪枝,因为它的密集连接会产生复杂的层间依赖。我们先准备好环境和模型:
python复制import torch
import torch_pruning as tp
from torchvision.models import densenet121
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = densenet121(pretrained=True).to(device)
example_inputs = torch.randn(1, 3, 224, 224).to(device)
这里有个小技巧:在构建依赖图前,最好把模型和输入数据放到同一设备上。我曾在CPU上构建依赖图,然后移到GPU微调,结果出现了奇怪的维度错误。
对于复杂模型,我推荐使用迭代式剪枝——分多次小幅度剪枝,每次剪枝后都进行微调。就像理发时多次修剪比一次性剃光更安全:
python复制# 重要性评估策略
imp = tp.importance.MagnitudeImportance(p=2)
# 忽略分类层
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
# 初始化剪枝器
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=5, # 分5次剪枝
ch_sparsity=0.4, # 最终剪掉40%通道
ignored_layers=ignored_layers,
)
这里我选择L2范数作为重要性指标(p=2),它对大多数CV任务都表现稳定。如果是小数据集,可以尝试BNScaleImportance,它对数据分布变化更鲁棒。
真正的魔法发生在剪枝和微调的交替进行中:
python复制base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
pruner.step() # 执行剪枝
# 计算当前模型大小
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Step {i+1}: MACs {macs/1e9:.2f}G, Params {nparams/1e6:.2f}M")
# 微调阶段(简化版)
train_one_epoch(model, train_loader, criterion, optimizer)
在我的实验中,这种渐进式剪枝比一次性剪枝能保持高3-5%的准确率。特别是在DenseNet上,因为其密集连接的特性,更需要谨慎处理。
第一次使用时,我最常遇到的报错是维度不匹配。比如在剪枝后出现类似这样的错误:
code复制RuntimeError: Given groups=1, weight of size [64,64,3,3],
expected input[1,128,56,56] to have 64 channels, but got 128 channels
这通常是因为:
解决方法是用DG.check_pruning_group(group)检查剪枝组是否合法,并打印出group查看所有受影响层。
经过多个项目实践,我总结了几个提升剪枝效果的经验:
这些技巧帮助我在最近的项目中,将DenseNet-121的推理速度提升了2.3倍,而准确率仅下降1.2%。
Torch-Pruning的强大之处在于可以灵活定制剪枝策略。比如实现一个考虑计算量的重要性指标:
python复制class FLOPsAwareImportance(tp.importance.Importance):
def __call__(self, group):
# 计算每个通道的FLOPs
flops = []
for layer in group:
if isinstance(layer, nn.Conv2d):
flops.append(layer.weight.abs().sum(dim=(1,2,3)))
# 返回标准化后的重要性
importance = torch.stack(flops).mean(0)
return importance / importance.sum()
这个策略会优先保留计算量大的通道,适合对延迟敏感的场景。我在部署到树莓派的项目中就采用了类似方法。
Torch-Pruning不仅适用于视觉模型。最近我在处理一个多模态模型时,发现它同样能很好地处理跨模态的连接:
python复制# 处理跨模态依赖的特殊配置
custom_pruning_config = {
'CrossModalAttention': {
'in_channel': 'q_linear',
'out_channel': ['k_linear', 'v_linear']
}
}
DG = tp.DependencyGraph(custom_pruning_config)
通过自定义配置,可以明确指定哪些层应该共享相同的剪枝模式。这比手动修改每个attention头要高效得多。