第一次看到Swin Transformer在ImageNet上的表现时,我和所有计算机视觉从业者一样震惊——这个基于Transformer的模型不仅在精度上超越了CNN,连计算效率都实现了反超。后来在实际项目中用过后才发现,它的优势远不止论文里的那几个数字。
Swin Transformer最巧妙的设计在于分层特征图和移位窗口机制。传统Vision Transformer(ViT)在处理图像时,会把整张图打成patch然后全局计算注意力,这导致计算量随着图像尺寸平方级增长。而Swin Transformer像CNN一样构建层次化特征,并且在局部窗口内计算注意力,窗口之间还有重叠区域,既保留了全局感知能力,又把计算复杂度降到了线性级别。
我做过一个对比实验:用ResNet50和Swin-Tiny在相同的猫狗数据集上训练,Swin的验证准确率高出3.2%,而推理速度反而快了15%。特别是在处理512x512尺寸的医疗影像时,Swin的内存占用只有ViT的1/4。这种特性让它非常适合实际工业部署。
去年在给公司搭建Swin Transformer训练平台时,我花了三天时间才搞定环境配置。这里分享几个血泪教训:
首先是CUDA版本的选择。官方推荐PyTorch 1.7+CUDA 11.0的组合,但实测发现CUDA 11.2更稳定。安装时一定要先卸载所有已有的NVIDIA驱动:
bash复制sudo apt-get purge nvidia*
sudo apt-get install cuda-11-2
然后是apex库的安装。这个NVIDIA的混合精度工具能减少30%显存占用,但编译时经常报错。最稳妥的方式是:
bash复制git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
建议使用conda创建隔离环境,这是我验证过的完整依赖列表:
bash复制conda create -n swin python=3.7
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=11.0
pip install timm==0.3.2 opencv-python yacs termcolor
用猫狗数据集做demo很简单,但处理真实业务数据时我踩过不少坑。比如有一次标注团队给的数据集中,有15%的图片实际格式是.webp但后缀名是.jpg,导致训练时直接崩溃。
正确的数据处理流程应该是:
python复制import cv2
for img_path in Path('dataset').rglob('*.*'):
try:
img = cv2.imread(str(img_path))
assert img is not None
cv2.imwrite(str(img_path.with_suffix('.jpg')), img)
except:
print(f"Bad image: {img_path}")
os.remove(img_path)
python复制from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
for train_idx, val_idx in sss.split(X, y):
train_files = [files[i] for i in train_idx]
val_files = [files[i] for i in val_idx]
python复制from albumentations import (
HorizontalFlip, ShiftScaleRotate, RandomBrightnessContrast,
Compose, CoarseDropout
)
train_aug = Compose([
HorizontalFlip(p=0.5),
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10),
RandomBrightnessContrast(p=0.3),
CoarseDropout(max_holes=8, max_height=32, max_width=32, fill_value=0)
])
配置文件是Swin Transformer最让人头疼的部分。经过20+次实验,我总结出几个关键参数调整:
在configs/swin_tiny_patch4_window7_224.yaml中:
yaml复制DATA:
BATCH_SIZE: 64 # 根据GPU显存调整
CROP_SIZE: 224
TRAIN:
BASE_LR: 0.001 # 小数据集可降到0.0005
WARMUP_EPOCHS: 5 # 防止初期震荡
WEIGHT_DECAY: 0.05 # Transformer需要更强的正则化
MODEL:
DROP_PATH_RATE: 0.2 # 重要!防止小数据过拟合
启动训练时推荐用梯度累积来模拟更大batch:
bash复制python -m torch.distributed.launch --nproc_per_node=4 main.py \
--cfg configs/swin_tiny_patch4_window7_224.yaml \
--batch-size 32 \
--accumulation-steps 2 \
--amp-opt-level O1
遇到loss震荡时,可以尝试:
官方代码没有提供现成的推理脚本,我在实际部署中开发了一套生产级方案。关键点在于:
python复制import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network()
parser = trt.OnnxParser(network, logger)
# 转换ONNX模型
torch.onnx.export(model, dummy_input, "swin.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
# 构建TensorRT引擎
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
engine = builder.build_serialized_network(network, config)
python复制from fastapi import FastAPI, File
import numpy as np
app = FastAPI()
@app.post("/predict")
async def predict(image: bytes = File(...)):
img = preprocess_image(image)
output = model(img)
return {"class_id": int(np.argmax(output))}
python复制from prometheus_client import Counter, Gauge
REQUESTS = Counter('predict_total', 'Total prediction requests')
LATENCY = Gauge('predict_latency_seconds', 'Prediction latency in seconds')
@LATENCY.time()
def predict():
REQUESTS.inc()
# 预测逻辑
当你的准确率遇到瓶颈时,可以尝试这些方法:
python复制teacher = swin_large(pretrained=True)
student = swin_tiny()
loss_fn = nn.KLDivLoss(reduction='batchmean')
for data in loader:
with torch.no_grad():
t_logits = teacher(data)
s_logits = student(data)
loss = loss_fn(F.log_softmax(s_logits), F.softmax(t_logits))
python复制from torchattacks import FGSM
atk = FGSM(model, eps=0.03)
for images, labels in loader:
images = atk(images, labels)
outputs = model(images)
loss = criterion(outputs, labels)
python复制from torch.nn.utils import prune
parameters_to_prune = [
(module, 'weight') for module in filter(
lambda m: isinstance(m, nn.Linear), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.5,
)