当现成的AI绘画工具遍地开花时,真正酷的开发者早已将目光投向底层架构。本文将带你用PyTorch和CLIP搭建一个能理解文字描述的图像生成系统——这不是另一个Stable Diffusion教程,而是一次对条件扩散模型的深度解构。
在开始写代码之前,我们需要理解这个架构的独特优势。传统扩散模型像是个盲画家,只能随机涂抹颜料;而加入CLIP文本条件后,它变成了能听懂需求的数字艺术家。
CLIP的双塔结构是其关键所在:
当这个预训练好的模型遇到扩散模型时,魔法就发生了:
python复制# CLIP文本嵌入示例
text_embedding = clip_model.encode_text("a watercolor of sunset") # 输出512维向量
这个向量会成为扩散模型每一步去噪的"指南针"。与直接使用文本标签相比,CLIP嵌入携带更丰富的语义信息——它能区分"卡通龙"和"写实恐龙"的微妙差别。
标准的UNet就像个没有记忆的迷宫,我们需要给它装上理解文本的"大脑"。关键在于两个改造:
传统UNet的卷积块需要升级为条件卷积块:
python复制class ConditionalConv(nn.Module):
def __init__(self, in_ch, out_ch, cond_dim):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.cond_proj = nn.Linear(cond_dim, out_ch * 2) # 同时预测scale和shift
def forward(self, x, cond):
# 投影条件向量
gamma, beta = self.cond_proj(cond).chunk(2, dim=1)
# 应用条件调制
h = self.conv(x)
h = h * (1 + gamma[..., None, None]) + beta[..., None, None]
return h
这种设计比简单的向量相加更有效,类似Transformer中的FiLM机制。
扩散模型还需要处理时间步信息:
| 融合方式 | 优点 | 缺点 |
|---|---|---|
| 简单拼接 | 实现简单 | 信息交互不充分 |
| 自适应归一化 | 条件控制精细 | 计算量稍大 |
| 注意力机制 | 长程依赖捕捉 | 内存消耗高 |
我们推荐使用自适应归一化(AdaGN):
python复制class AdaGN(nn.Module):
def __init__(self, channels, cond_dim):
super().__init__()
self.norm = nn.GroupNorm(8, channels)
self.cond_proj = nn.Linear(cond_dim, channels * 2)
def forward(self, x, cond):
gamma, beta = self.cond_proj(cond).chunk(2, dim=1)
x = self.norm(x)
return x * (1 + gamma[..., None, None]) + beta[..., None, None]
在CIFAR-10这类小数据集上训练条件扩散模型,需要些特殊技巧:
嵌入降维:CLIP的512维向量对小模型可能过大
python复制# 添加一个适配层
self.text_proj = nn.Sequential(
nn.Linear(512, 256),
nn.SiLU(),
nn.Linear(256, 128)
)
数据增强组合:
学习率调度:
python复制scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=2e-4,
total_steps=len(dataloader)*epochs
)
注意:在小数据集上,建议使用预训练的CLIP权重并冻结文本编码器,只微调投影层。
基础的DDPM采样往往产生模糊结果,这些技巧可以提升质量:
温度调节:在预测噪声时加入温度系数
python复制pred_noise = model(x, t, cond) * temperature # 0.7~1.3之间调节
混合初始化:用CLIP图像编码器初始化噪声
python复制with torch.no_grad():
ref_img = preprocess(Image.open("reference.jpg")).unsqueeze(0).to(device)
img_emb = clip_model.encode_image(ref_img)
x_T = img_emb @ clip_model.visual.proj.t() # 投影到噪声空间
提示工程:文本提示的微妙调整
模式崩溃:所有输出都相似
生成质量差:图像扭曲不清晰
python复制# 检查数值稳定性
assert not torch.isnan(x).any(), "出现NaN值!"
assert x.min() >= -1 and x.max() <= 1, "数值超出范围!"
显存不足的解决方案:
python复制from torch.utils.checkpoint import checkpoint
x = checkpoint(self.block, x, cond) # 分段计算
python复制scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
loss = model(x, t, cond)
scaler.scale(loss).backward()
scaler.step(optimizer)
当基本模型跑通后,可以尝试这些升级:
动态条件缩放:
python复制# 在训练时随机丢弃条件
if torch.rand(1) < 0.1:
cond = torch.zeros_like(cond) # 10%概率无条件
多模态条件融合:
python复制# 同时使用CLIP和传统标签
cond = torch.cat([clip_embed, label_embed], dim=1)
分层条件控制:
python复制# 在不同深度注入条件
if i in [0, 3, 6]: # 在特定层注入
x = block(x, cond)
else:
x = block(x)
我在实际项目中发现,用CIFAR-10训练的小模型虽然生成32x32图像,但适当加入以上技巧后,能产生令人惊讶的语义一致性。比如输入"red frog with black spots",模型确实能生成符合描述的斑点图案,尽管分辨率有限。