在计算机视觉领域,文本识别一直是个极具挑战性的任务。传统方法需要先对文本图像进行单字切割,再逐个识别,这种方法不仅效率低下,而且对复杂排版和变形文本的适应性较差。本文将带你用PyTorch从零实现一个端到端的CRNN模型,直接处理不定长文本识别,彻底告别繁琐的单字切割流程。
CRNN(Convolutional Recurrent Neural Network)是当前最流行的不定长文本识别架构之一,它巧妙地将CNN的特征提取能力与RNN的序列建模能力结合起来。整个流程无需显式的字符分割,可以直接从整张图片预测文本序列。
一个标准的CRNN包含三个核心组件:
python复制class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh):
super(CRNN, self).__init__()
# CNN部分
self.cnn = nn.Sequential(
nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(True),
nn.MaxPool2d(2, 2),
# 更多卷积层...
)
# RNN部分
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass)
)
文本图像通常具有"高较小、宽较长"的特点,这要求我们在CNN设计中做出特殊调整:
提示:输入图像尺寸建议设置为32像素高,宽度按比例缩放。例如对于280×32的输入,CNN输出特征序列长度(时间步数)为40。
CNN输出的特征序列需要送入循环神经网络进行时序建模。这里我们使用双向LSTM来同时捕捉前后文信息。
python复制class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec)
return output.view(T, b, -1)
关键参数说明:
| 参数 | 典型值 | 说明 |
|---|---|---|
| nIn | 512 | 输入特征维度 |
| nHidden | 256 | LSTM隐藏单元数 |
| nOut | 字符类别数+1 | 输出维度(含blank类别) |
每个时间步,LSTM接收一个512维的特征向量,输出所有字符的概率分布。对于40个时间步的序列,我们得到一个40×(字符类别数+1)的后验概率矩阵。
CTC(Connectionist Temporal Classification)是解决不定长序列对齐的关键,它允许模型在不需要精确字符位置标注的情况下进行训练。
python复制# PyTorch中的CTC损失计算
criterion = nn.CTCLoss()
loss = criterion(log_probs, targets, input_lengths, target_lengths)
训练阶段:
测试阶段:
python复制# 解码示例
decoded, _ = torch.nn.functional.ctc_beam_search_decoder(
log_probs,
seq_len,
beam_width=10
)
我们使用合成数据来训练CRNN模型,关键步骤包括:
python复制def generate_text_image(text, width=280, height=32):
# 创建空白图像
image = Image.new('L', (width, height), color=255)
draw = ImageDraw.Draw(image)
# 随机选择字体和大小
font_size = random.randint(24, 32)
font = ImageFont.truetype(random.choice(fonts), font_size)
# 绘制文本
draw.text((10, (height-font_size)//2), text, font=font, fill=0)
# 添加噪声
image = add_noise(np.array(image))
return image
由于文本长度不一,我们使用稀疏矩阵存储标签:
python复制def collate_fn(batch):
images = torch.stack([item[0] for item in batch])
# 标签转为稀疏表示
targets = [item[1] for item in batch]
target_lengths = torch.tensor([len(t) for t in targets])
targets = torch.cat(targets)
return images, targets, target_lengths
python复制optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
for epoch in range(100):
for images, targets, target_lengths in train_loader:
# 前向传播
log_probs = model(images)
input_lengths = torch.full(
(images.size(0),),
log_probs.size(0),
dtype=torch.long
)
# 计算CTC损失
loss = criterion(log_probs, targets, input_lengths, target_lengths)
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
scheduler.step()
python复制class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
在真实场景中测试模型时,我发现对模糊和低对比度文本的识别仍有提升空间。一个实用的技巧是在推理时对图像进行适度的锐化和对比度增强,这能显著提升困难样本的识别率。