在医学影像处理领域,CT图像重建是个经典问题。传统方法里,正投影(Forward Projection, FP)和反投影(Back Projection, FBP)是两个关键步骤。简单来说,正投影就是把3D物体"拍扁"成2D投影数据的过程,就像我们用X光拍片;反投影则是反过来,从多个角度的投影数据重建出3D图像。
但传统方法有个大问题:它们是不可微分的黑盒子。这意味着当我们想把它们嵌入到深度学习网络里时,梯度没法顺利传播。想象一下,你在搭积木,但中间有几块积木是胶水粘死的,没法调整——这就是传统FP/FBP在神经网络里的处境。
PyTorch的自动微分机制给了我们新思路。通过用PyTorch重新实现这些模块,我们能让梯度自由流动,让整个系统可以端到端训练。我在实际项目中发现,这种可微分模块特别适合用在像SIN(Sinogram Inpainting Network)这样的架构里,它能同时处理投影域和图像域的信息,大幅提升重建质量。
正投影的核心是旋转+累加。具体来说,就是把图像旋转到不同角度,然后沿一个方向求和。听起来简单,但用PyTorch实现时有几个坑要注意。
首先是旋转。PyTorch的grid_sample函数是我们的好朋友,它支持可微分的双线性插值。但要注意坐标系转换——CT图像通常以中心为原点,而PyTorch的网格坐标是归一化的[-1,1]。我踩过的坑是忘记调整坐标,结果图像旋转后位置偏移。
python复制def forward(self, x):
sino = torch.zeros(self.batchSize, 1, self.chanNum, self.viewNum).cuda()
for i in range(self.viewNum):
angle = -math.pi + 2*math.pi*i/self.viewNum # 均匀采样0-360度
theta = torch.tensor([
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0]
]).unsqueeze(0).repeat(self.batchSize,1,1).cuda()
grid = F.affine_grid(theta, x.size(), align_corners=False)
x_rotate = F.grid_sample(x, grid, align_corners=False)
sino[:,:,:,i] = x_rotate.sum(dim=2) * self.pixel_size # 累加并考虑像素尺寸
直接循环每个角度计算虽然直观,但效率不高。我后来改用了向量化操作,速度提升了3倍多。关键是把所有旋转矩阵预先计算好,然后用torch.stack批量处理:
python复制angles = torch.linspace(0, 2*math.pi, self.viewNum, device=x.device)
rot_mats = torch.stack([
torch.stack([angles.cos(), -angles.sin(), torch.zeros_like(angles)], dim=1),
torch.stack([angles.sin(), angles.cos(), torch.zeros_like(angles)], dim=1)
], dim=1) # shape: [viewNum, 2, 3]
# 批量生成所有网格
grids = F.affine_grid(rot_mats.repeat(self.batchSize,1,1,1).view(-1,2,3),
torch.Size([self.batchSize*self.viewNum, 1, x.size(2), x.size(3)]))
反投影比正投影复杂得多,因为需要先对投影数据滤波。我试过时域和频域两种实现,最终选择了频域方案——不是因为它更快(实际上可能稍慢),而是因为更稳定,梯度更平滑。
Ramp滤波是核心,它用来补偿反投影过程中的高频信息损失。在频域实现时,要注意零频率分量要放在正确位置(PyTorch的fft默认不调整顺序),还要处理边缘效应。我的经验是前后各补100个零再做FFT,效果最好。
python复制def ramp_filter(batch_size, channels, n):
"""生成ramp滤波器"""
freq = torch.fft.fftfreq(n).abs() * 2 # 乘以2补偿积分效应
freq[0] = 0.25 # 直流分量特殊处理
return freq.view(1,1,-1,1).repeat(batch_size, channels, 1, 1).cuda()
def apply_filter(proj):
n = proj.size(2)
padded = F.pad(proj, (0,0,100,100)) # 前后补零
spec = torch.fft.fft(padded, dim=2)
filtered = spec * ramp_filter(proj.size(0), proj.size(1), n+200)
return torch.fft.ifft(filtered, dim=2).real[:,:,100:-100]
滤波后的反投影和正投影类似,但有个关键区别:正投影是图像旋转后投影,反投影是投影数据"反向"旋转后叠加。这里最容易出错的是旋转方向——一定要和正投影相反,否则重建的图像会是模糊的。
我常用的调试技巧是先用一个简单点源图像测试,如果重建位置不对,就检查旋转角度符号。另一个经验是重建时要加FOV(视场)限制,避免边缘伪影:
python复制# 生成圆形FOV掩码
y, x = torch.meshgrid(torch.linspace(-1,1,512), torch.linspace(-1,1,512))
mask = (x**2 + y**2 <= 1).float().view(1,1,512,512).cuda()
def backproject(sino):
recon = torch.zeros(self.batchSize, 1, 512, 512).cuda()
for i in range(self.viewNum):
angle = math.pi/2 - 2*math.pi*i/self.viewNum # 注意符号和偏移
theta = ... # 类似正投影
grid = F.affine_grid(theta, recon.size())
slice_expanded = sino[:,:,:,i].unsqueeze(-1).expand(-1,-1,-1,512)
recon += F.grid_sample(slice_expanded, grid) * mask
return recon * (math.pi / self.viewNum) # 角度间隔归一化
把这些模块嵌入网络时,我发现几个实用技巧:
nn.Module子类,这样能自动处理梯度detach()有时能稳定训练一个典型的SIN网络架构可能是:
code复制图像域CNN → FP → 投影域CNN → FBP → 图像域CNN
这种设计允许网络同时在两个域进行特征学习。实测下来,比纯图像域网络重建质量高约15%。
由于FP/FBP计算量较大,我有这些优化建议:
torch.cuda.amp)调试时,务必先验证模块单独的正确性:
x.grad is not None)torch.autograd.gradcheck验证数值梯度临床CT更多是用锥束或扇束几何。修改我们的模块支持这些情况主要需要:
我在扇束实现中发现,距离加权的反投影能显著减少边缘伪影。代码改动主要在旋转后的累加步骤:
python复制# 扇束反投影的累加
distance_weight = 1.0 / (source_to_detector / (source_to_center + y_coords * math.sin(angle)))
recon += F.grid_sample(slice_expanded, grid) * distance_weight
在医疗数据敏感的场景,我们可以在FP前加入可学习的噪声层,实现差分隐私。关键是要控制噪声的频谱特性,避免破坏有用信号。一个实用方案是:
python复制class PrivacyLayer(nn.Module):
def __init__(self):
super().__init__()
self.noise_scale = nn.Parameter(torch.zeros(1))
def forward(self, x):
if self.training:
# 生成符合频域特性的噪声
noise = torch.randn_like(x)
noise_fft = torch.fft.fft2(noise)
mask = torch.exp(-torch.fft.fftfreq(x.size(2))**2 * 10)
return x + torch.fft.ifft2(noise_fft * mask).real * self.noise_scale
return x
这种设计让网络能自动学习该加多少噪声,既保护隐私又不严重影响图像质量。