第一次接触医疗影像分割项目时,我完全低估了环境配置的复杂性。记得当时为了完成CVPR2024医疗影像分割赛题的课程作业,我选择了MedSAM作为baseline模型。这个决定让我在环境配置环节就踩了不少坑,现在把这些经验分享给大家。
创建conda环境是第一步,但这里有个小技巧:不要直接用默认的python版本。我最初使用python=3.8创建环境,结果在后续步骤中遇到了各种依赖冲突。后来发现MedSAM对python 3.10的支持最稳定,命令如下:
bash复制conda create -n MEDSAM python=3.10 -y
conda activate MEDSAM
安装PyTorch时最容易遇到版本问题。官方文档通常建议用pip install torch,但这会安装最新版,可能与项目不兼容。我的经验是明确指定版本号:
bash复制pip3 install torch==2.1.2 torchvision==0.16.2
克隆项目仓库时,很多人会忽略分支选择。MedSAM有多个分支,我最初直接克隆主分支,结果发现代码结构完全不同。正确的做法是指定LiteMedSAM分支:
bash复制git clone -b LiteMedSAM https://github.com/bowang-lab/MedSAM/
安装项目依赖时,我遇到了经典的依赖冲突问题。错误提示中提到的rosdep、rospkg等包看起来与项目无关,实际上是因为我之前做过机器人项目,环境残留了ROS相关包。解决方法是用--no-deps参数跳过依赖解析:
bash复制pip3 install -e . --no-deps
提示:如果遇到磁盘空间不足的问题,可以先清理conda缓存:
conda clean --all
医疗影像数据与普通CV数据集有很大不同。我最初以为下载完数据就能直接用,结果发现需要经过复杂的预处理流程。
数据集下载后,文件结构应该是这样的:
code复制MedSAM/
└── data/
├── FLARE22Train/
│ ├── images/ # 原始CT影像
│ └── labels/ # 医生标注的mask
└── MedSAM_test/
└── CT_Abd/ # 测试数据
医疗影像通常是NIfTI格式(.nii.gz),需要特殊处理。我编写了预处理脚本pre_CT_MR.sh:
bash复制#!/bin/bash
python pre_CT_MR.py \
-img_path data/FLARE22Train/images \
-gt_path data/FLARE22Train/labels \
-output_path data \
-modality CT \
-anatomy Abd \
--save_nii
这个阶段最容易遇到的三个坑:
chmod -R 777 data/解决格式转换是另一个难点。MedSAM训练需要NPZ格式,但预处理生成的是NIfTI。我写了转换脚本:
python复制import numpy as np
import nibabel as nib
img = nib.load('data/FLARE22Train/images/case_0000.nii.gz').get_fdata()
mask = nib.load('data/FLARE22Train/labels/case_0000.nii.gz').get_fdata()
np.savez('data/MedSAM_train/case_0000.npz', imgs=img, gts=mask)
在消费级GPU上训练MedSAM需要特别注意显存管理。我的RTX 3060(6GB显存)最初连batch_size=1都跑不起来,经过调优后能稳定跑batch_size=4。
训练脚本train_one_gpu.sh的关键参数:
bash复制#!/bin/bash
python train_one_gpu.py \
-data_root data/npy \ # 转换后的NPY格式数据
-pretrained_checkpoint work_dir/LiteMedSAM/lite_medsam.pth \
-batch_size 4 \ # 根据显存调整
-num_workers 2 \ # 不要超过CPU核心数
-num_epochs 50 \ # 医疗影像通常需要更多epoch
--amp # 启用混合精度训练
我遇到的典型错误及解决方案:
错误1:CUDA out of memory
model.enable_gradient_checkpointing()--amp参数启用混合精度训练错误2:Dataloader卡死
num_workers(从4降到2)pin_memory=False错误3:Loss变为NaN
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)训练过程监控很重要。我修改了原始代码,添加了以下功能:
python复制from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/exp1')
for epoch in range(epochs):
for i, batch in enumerate(train_loader):
loss = model(batch)
writer.add_scalar('Loss/train', loss, epoch*len(train_loader)+i)
训练完成后,推理阶段又有新的挑战。我最初直接使用训练保存的checkpoint,结果遇到state_dict不匹配的问题。
正确的推理流程应该是:
权重提取脚本extract_weights.sh:
bash复制#!/bin/bash
python utils/extract_weights.py \
-from_pth work_dir/medsam_lite_latest.pth \
-to_pth work_dir/extracted_latest.pth
3D推理脚本inference_3D.sh的关键参数:
bash复制#!/bin/bash
python inference_3D.py \
-data_root data/MedSAM_test/CT_Abd \
-medsam_lite_checkpoint_path work_dir/extracted_latest.pth \
-pred_save_dir ./preds/3D \
--save_overlay \ # 生成可视化结果
--overwrite # 覆盖已有结果
医疗影像可视化需要专业工具。我推荐两种方案:
方案1:ITK-SNAP
bash复制sudo apt install itksnap
itksnap -g test_demo/imgs/case_0000.nii.gz -s preds/3D/case_0000.nii.gz
方案2:Python可视化
python复制import matplotlib.pyplot as plt
import nibabel as nib
img = nib.load('data/FLARE22Train/images/case_0000.nii.gz').get_fdata()
mask = nib.load('preds/3D/case_0000.nii.gz').get_fdata()
plt.figure(figsize=(12,6))
plt.subplot(121)
plt.imshow(img[:,:,100], cmap='gray') # 第100层切片
plt.subplot(122)
plt.imshow(mask[:,:,100], alpha=0.5) # 半透明叠加
plt.savefig('case_0000_slice100.png')
在部署到生产环境时,我建议使用Docker容器。这能解决大多数环境依赖问题:
dockerfile复制FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime
WORKDIR /app
COPY . .
RUN pip install -e . --no-deps
CMD ["python", "inference_3D.py", "-data_root", "/data", "-pred_save_dir", "/output"]
构建并运行Docker容器:
bash复制docker build -t medsam .
docker run -v $PWD/data:/data -v $PWD/output:/output medsam