第一次用BERT做文本分类时,我对着满屏的报错信息差点崩溃。后来发现只要环境装对了,后面的事情就简单多了。咱们先从最基础的开始,把PyTorch和Transformers库配置好。
建议直接用Anaconda创建虚拟环境,避免包冲突。这是我验证过的稳定版本组合:
bash复制conda create -n bert_tuning python=3.8
conda activate bert_tuning
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install transformers==4.28.1 datasets==2.11.0
数据集处理是NLP任务最耗时的环节。我用IMDb电影评论数据集举例,因为它结构清晰且适合练手。实际工作中你可能需要处理更混乱的数据,这时候要特别注意文本清洗:
python复制from datasets import load_dataset
# 加载数据集
raw_dataset = load_dataset('imdb')
print(raw_dataset['train'][0]) # 查看第一条数据
# 文本清洗函数示例
def clean_text(text):
text = text.replace("<br />", " ") # 去除HTML标签
text = re.sub(r'[^\w\s]', '', text.lower()) # 去标点+转小写
return text.strip()
# 应用清洗并划分数据集
dataset = raw_dataset.map(lambda x: {'text': clean_text(x['text'])})
split_dataset = dataset['train'].train_test_split(test_size=0.2)
遇到内存不足的情况时,可以改用生成器方式逐批加载数据。我常用这个技巧处理大型CSV文件:
python复制import pandas as pd
def batch_loader(file_path, batch_size=1000):
for chunk in pd.read_csv(file_path, chunksize=batch_size):
yield chunk['text'].apply(clean_text).tolist(), chunk['label'].tolist()
新手最容易犯的错误是直接使用原始BERT模型而不做任何调整。就像给你一辆F1赛车却不告诉你怎么换挡,再好的车也跑不起来。我们得给BERT加上分类头:
python复制from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=2, # 二分类
output_attentions=False, # 不需要注意力权重时可关闭节省内存
output_hidden_states=False
)
# 查看模型结构
print(model.classifier) # 默认的分类头是单层线性层
如果你需要更复杂的分类器,可以这样自定义:
python复制import torch.nn as nn
class CustomBertClassifier(nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
self.dropout = nn.Dropout(0.2)
self.fc1 = nn.Linear(768, 256)
self.fc2 = nn.Linear(256, 2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
x = self.dropout(pooled_output)
x = nn.ReLU()(self.fc1(x))
return self.fc2(x)
model = CustomBertClassifier(model.bert) # 复用预训练权重
处理多标签分类时(比如一个文本可能同时属于多个类别),需要修改损失函数:
python复制model = BertForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=10,
problem_type="multi_label_classification"
)
直接套用默认训练参数效果往往不理想。经过多次实验,我总结出这些黄金参数组合:
python复制from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16, # 根据GPU显存调整
per_device_eval_batch_size=64,
warmup_steps=500, # 避免初始学习率过大
weight_decay=0.01,
logging_dir='./logs',
logging_steps=50,
evaluation_strategy="steps",
eval_steps=500,
save_steps=1000,
load_best_model_at_end=True,
metric_for_best_model="f1", # 对于不平衡数据集比准确率更好
fp16=True # 30系以上GPU开启加速
)
自定义评估指标能让模型更符合业务需求:
python复制from sklearn.metrics import accuracy_score, f1_score
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
return {
'accuracy': accuracy_score(labels, preds),
'f1': f1_score(labels, preds, average='weighted')
}
遇到显存不足时,可以尝试这些技巧:
训练完模型只是开始,我见过太多在测试集表现良好但线上崩掉的案例。完整的评估应该包括:
python复制# 混淆矩阵分析
from sklearn.metrics import confusion_matrix
import seaborn as sns
y_true = [...] # 真实标签
y_pred = [...] # 预测标签
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True) # 可视化
# 错误样本分析
error_samples = []
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
if true != pred:
error_samples.append({
'text': test_texts[i],
'true': true,
'pred': pred
})
部署模型时,建议使用ONNX格式提升推理速度:
python复制from transformers import convert_graph_to_onnx
convert_graph_to_onnx.convert(
framework="pt",
model=model,
tokenizer=tokenizer,
output_path="model.onnx",
opset=12
)
对于实时性要求高的场景,可以尝试模型量化:
python复制from transformers import BertForSequenceClassification
quantized_model = BertForSequenceClassification.from_pretrained(
'./saved_model',
torch_dtype=torch.qint8,
low_cpu_mem_usage=True
)
最后提醒几个常见坑点: