第一次接触异构图神经网络时,我和大多数人一样感到困惑。传统的图神经网络处理的是同构图,所有节点和边都是同一种类型,但现实世界的数据往往复杂得多。比如在学术网络中,论文、作者、机构、研究领域都是不同类型的节点,它们之间的关系也各不相同。
PyTorch Geometric(PyG)是目前最流行的图神经网络框架之一,它专门为异构图设计了HeteroData数据结构和一系列工具。我刚开始用PyG处理OGB-MAG数据集时踩过不少坑,比如不理解如何正确初始化不同类型的节点特征,或者搞混了边类型的三元组表示法。经过几个项目的实践,我总结出了一套适合新手的入门方法。
OGB-MAG是Open Graph Benchmark提供的一个学术异构图数据集,包含近200万节点和2100万条边。这个数据集完美展示了现实世界中数据的异构特性:论文节点有128维特征,作者节点也有128维特征,但它们的语义完全不同。处理这种数据时,传统的GNN会丢失类型信息,而异构GNN能保持这种差异。
在开始之前,确保你已经安装了最新版的PyG。我推荐使用conda环境,这样可以避免依赖冲突:
bash复制conda install pytorch-geometric -c pyg
加载OGB-MAG数据集非常简单,PyG已经内置了这个数据集的支持。第一次运行时会自动下载数据,这个过程可能需要一些时间(数据集约5GB):
python复制from torch_geometric.datasets import OGB_MAG
import torch_geometric.transforms as T
# 转换为无向图,这对很多GNN模型很重要
transform = T.ToUndirected()
dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)
data = dataset[0]
数据集预处理时使用了metapath2vec方法,这是一种专门为异构图设计的嵌入技术。transform=ToUndirected()会为每种边类型添加反向边,这对消息传递很重要。比如原本只有"作者写论文"的边,现在会自动添加"论文被作者写"的反向边。
打印data对象可以看到数据集的完整结构:
python复制print(data)
输出展示了四种节点类型和四种边类型,每种都有自己的特征维度。特别要注意的是paper节点有y标签和train_mask等属性,这是我们的预测目标——论文的发表地点分类。
我刚开始时常犯的一个错误是混淆节点和边的访问方式。记住:
对于刚入门的新手,我推荐先用to_hetero()函数,它能把普通的GNN模型自动转换成异构版本。这种方式不需要理解底层细节,适合快速原型开发。
下面是一个完整的例子:
python复制import torch
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
这里有几个关键点:
当你有更复杂的需求时,可以直接使用HeteroConv。这种方式更灵活,可以为不同类型的边设计不同的消息传递逻辑。
下面是我在一个实际项目中用过的结构:
python复制from torch_geometric.nn import HeteroConv, GATConv, SAGEConv, Linear
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GATConv(-1, hidden_channels),
('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
('paper', 'rev_writes', 'author'): SAGEConv((-1, -1), hidden_channels)
}, aggr='sum')
self.conv2 = HeteroConv({
('paper', 'cites', 'paper'): GATConv(hidden_channels, out_channels),
('author', 'writes', 'paper'): SAGEConv((hidden_channels, hidden_channels), out_channels),
('paper', 'rev_writes', 'author'): SAGEConv((hidden_channels, hidden_channels), out_channels)
}, aggr='sum')
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
这个模型对引用关系使用GATConv(带注意力机制),对作者-论文关系使用SAGEConv。在实际应用中,我发现这种混合结构通常比单一卷积效果更好。
准备好模型和数据后,训练过程与常规PyTorch模型类似,但要注意输入是字典形式的:
python复制import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x_dict, data.edge_index_dict)
mask = data['paper'].train_mask
loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
loss.backward()
optimizer.step()
return float(loss)
这里的关键点是:
当图太大无法放入内存时,需要使用邻居采样。PyG的NeighborLoader对异构图有很好的支持:
python复制from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[15] * 2,
batch_size=128,
input_nodes=('paper', data['paper'].train_mask)
)
def train():
model.train()
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x_dict, batch.edge_index_dict)
batch_size = batch['paper'].batch_size
loss = F.cross_entropy(out['paper'][:batch_size], batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()
total_loss += float(loss)
return total_loss / len(train_loader)
小批量训练时要注意batch['paper'].batch_size这个属性,它表示原始batch的大小(128)。由于邻居采样会引入额外节点,输出结果的第一个维度可能大于batch_size,我们只需要取前batch_size个结果计算损失。
异构图中不同节点的特征尺度可能差异很大,我强烈建议添加特征归一化:
python复制from torch_geometric.transforms import NormalizeFeatures
transform = T.Compose([
T.ToUndirected(),
NormalizeFeatures()
])
dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)
这个简单的步骤能让我的模型准确率提升3-5个百分点。特别是在OGB-MAG数据集中,paper和author的特征虽然维度相同,但分布差异很大。
OGB-MAG中的论文发表地点类别分布不均衡,可以通过加权损失函数来解决:
python复制from torch import unique
classes, counts = unique(data['paper'].y, return_counts=True)
class_weights = 1.0 / counts.float()
class_weights = class_weights / class_weights.sum()
def train():
# ...
loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask],
weight=class_weights.to(device))
# ...
在实践中,我发现异构GNN通常不需要太深,2-3层就足够了。过深的模型容易过拟合,可以配合Dropout使用:
python复制class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.dropout = torch.nn.Dropout(0.5)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x
原始OGB-MAG数据没有利用节点类型信息,我们可以手动添加:
python复制# 为每种节点类型添加可学习的嵌入
class HeteroGNNWithType(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.type_emb = torch.nn.Embedding(4, hidden_channels)
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x_dict, edge_index_dict):
# 为每个节点添加类型嵌入
x_dict = {
'paper': x_dict['paper'] + self.type_emb(torch.tensor(0)),
'author': x_dict['author'] + self.type_emb(torch.tensor(1)),
# ...
}
x_dict = self.conv1(x_dict, edge_index_dict).relu()
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
这个方法在我的实验中能提升模型对节点类型的敏感度,特别是对于边类型较少的关系。
如果数据随时间变化(如论文发表年份),可以将时间信息融入模型:
python复制# 假设data['paper'].year包含年份信息
year = data['paper'].year
year_norm = (year - year.min()) / (year.max() - year.min())
data['paper'].x = torch.cat([data['paper'].x, year_norm.unsqueeze(1)], dim=1)
理解异构GNN的决策过程很重要,可以使用Captum库进行解释:
python复制from captum.attr import IntegratedGradients
model.eval()
ig = IntegratedGradients(model)
# 只解释paper节点的预测
attr, delta = ig.attribute(
(data.x_dict, data.edge_index_dict),
target=data['paper'].y,
additional_forward_args=('paper',)
)
这能帮助我们分析模型更关注哪些节点和边类型,在实际业务场景中非常有用。