在音频信号处理领域,我们经常遇到复数形式的频谱数据——无论是STFT变换结果还是梅尔频谱图,这些数据本质上都是复数。传统做法往往简单地将实部和虚部分离处理,或者只关注幅度信息而忽略相位,这种处理方式实际上破坏了复数数据内在的关联性。本文将带你深入理解复数神经网络的核心原理,并手把手实现一个完整的PyTorch复数网络架构。
当我们对音频信号进行短时傅里叶变换(STFT)时,得到的频谱数据天然具有复数形式。传统实数神经网络处理这类数据时,通常采用三种方法:
这些方法都存在明显缺陷:
| 处理方法 | 优点 | 缺点 |
|---|---|---|
| 分离实部虚部 | 实现简单 | 破坏复数乘法关系 |
| 幅度相位转换 | 保留部分信息 | 相位信息难以学习 |
| 实数近似 | 计算效率高 | 完全忽略复数特性 |
复数神经网络的核心优势在于它严格遵循复数运算规则:
python复制# 复数乘法示例
(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
这种运算保持了复数乘法的旋转特性,对于音频信号处理至关重要。研究表明,在音频分类、语音增强等任务中,复数网络相比实数网络能获得3-5%的性能提升。
复数卷积是复数网络的基础构建块。在PyTorch中,我们可以通过组合两个实数卷积层来实现复数卷积:
python复制class ComplexConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0):
super().__init__()
self.conv_r = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.conv_i = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
def forward(self, input_r, input_i):
return (self.conv_r(input_r) - self.conv_i(input_i),
self.conv_r(input_i) + self.conv_i(input_r))
注意:复数卷积的前向传播需要同时处理实部和虚部输入,并按照复数乘法规则组合结果
复数批归一化(Complex BatchNorm)比实数版本更复杂,需要同时归一化实部、虚部及其相关性:
python复制class ComplexBatchNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.bn_r = nn.BatchNorm2d(num_features, eps=eps)
self.bn_i = nn.BatchNorm2d(num_features, eps=eps)
def forward(self, input_r, input_i):
# 分别归一化实部和虚部
output_r = self.bn_r(input_r)
output_i = self.bn_i(input_i)
# 调整协方差关系
# ... (具体实现略)
return output_r, output_i
常用的复数激活函数有三种实现方式:
python复制def complex_relu(input_r, input_i):
"""分离ReLU实现"""
return F.relu(input_r), F.relu(input_i)
现在我们将上述组件组合成一个完整的复数分类网络:
python复制class ComplexAudioNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = ComplexConv2d(1, 32, kernel_size=5)
self.bn1 = ComplexBatchNorm2d(32)
self.conv2 = ComplexConv2d(32, 64, kernel_size=5)
self.bn2 = ComplexBatchNorm2d(64)
self.fc = ComplexLinear(64*4*4, num_classes)
def forward(self, x_r, x_i):
# 第一复数卷积层
x_r, x_i = self.conv1(x_r, x_i)
x_r, x_i = complex_relu(x_r, x_i)
x_r, x_i = self.bn1(x_r, x_i)
# 第二复数卷积层
x_r, x_i = self.conv2(x_r, x_i)
x_r, x_i = complex_relu(x_r, x_i)
x_r, x_i = self.bn2(x_r, x_i)
# 展平并分类
x_r = x_r.view(x_r.size(0), -1)
x_i = x_i.view(x_i.size(0), -1)
x_r, x_i = self.fc(x_r, x_i)
# 使用模作为最终输出
output = torch.sqrt(x_r**2 + x_i**2)
return output
处理音频数据时,正确的STFT参数设置至关重要:
python复制def compute_stft(waveform, sr=16000):
n_fft = 400
hop_length = 160
return torch.stft(waveform, n_fft=n_fft, hop_length=hop_length,
window=torch.hann_window(n_fft))
复数网络的训练需要特别注意以下几点:
python复制def train(model, train_loader, epochs=50):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for x_r, x_i, labels in train_loader:
optimizer.zero_grad()
outputs = model(x_r, x_i)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
评估复数网络性能时,除了准确率指标,还可以关注:
复数神经网络为音频信号处理提供了更自然的建模方式。在实际项目中,我们观察到复数网络在以下场景表现尤为突出: