当我们在LIBERO框架下实现终身学习时,**弹性权重固化(EWC)**算法提供了一种优雅的解决方案来缓解灾难性遗忘问题。EWC的核心洞见是:并非所有神经网络参数对已学任务都同等重要。
关键概念:EWC通过计算参数的Fisher信息矩阵来识别重要参数,这些参数在后续任务训练时会受到"保护"。
EWC算法的数学本质可以概括为:
Fisher信息矩阵的物理意义是:参数对任务损失函数的二阶导数,衡量参数对任务的重要性程度。在LIBERO的机器人操作任务中,那些对抓取、放置等基础动作至关重要的网络参数通常会获得较高的Fisher值。
在LIBERO的算法目录(libero/lifelong/algos/)中,EWC的实现主要包含三个关键部分:
python复制class EWC(Sequential):
def __init__(self, n_tasks, policy, datasets, ewc_lambda=5000, **kwargs):
super().__init__(n_tasks, policy, datasets, **kwargs)
self.ewc_lambda = ewc_lambda # 正则化强度系数
self.fisher_dict = {} # 存储每个任务的Fisher矩阵
self.optpar_dict = {} # 存储每个任务的最优参数
主要参数说明:
ewc_lambda:控制EWC正则化项的强度fisher_dict:以{task_id: fisher_matrix}形式存储历史信息optpar_dict:记录每个任务训练后的最优参数python复制def compute_fisher(self, dataloader):
fisher = {name: torch.zeros_like(param)
for name, param in self.policy.named_parameters()}
self.policy.eval()
for batch in dataloader:
self.optimizer.zero_grad()
loss = self.compute_loss(self.policy(batch['obs']), batch['actions'])
loss.backward()
for name, param in self.policy.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.pow(2) # 梯度平方近似Fisher信息
return {name: f/len(dataloader) for name, f in fisher.items()}
计算过程解析:
python复制def compute_ewc_loss(self):
if not self.fisher_dict: # 第一个任务无需EWC
return 0.0
ewc_loss = 0.0
for task_id in self.fisher_dict:
fisher = self.fisher_dict[task_id]
optpar = self.optpar_dict[task_id]
for name, param in self.policy.named_parameters():
param_diff = param - optpar[name]
ewc_loss += (fisher[name] * param_diff.pow(2)).sum()
return ewc_loss
损失函数组成:
fisher[name]:参数重要性权重param_diff:当前参数与历史最优参数的差值python复制from libero.lifelong.algos import Sequential
class CustomEWC(Sequential):
def __init__(self, n_tasks, policy, datasets, ewc_lambda=5000, **kwargs):
super().__init__(n_tasks, policy, datasets, **kwargs)
self.ewc_lambda = ewc_lambda
self.fisher = {}
self.optimal_params = {}
python复制def learn_task(self, task_id, epochs, batch_size):
# 常规训练过程
dataset = self.datasets[task_id]
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
for batch in dataloader:
# 计算任务损失
pred_actions = self.policy(batch['obs'])
task_loss = self.compute_loss(pred_actions, batch['actions'])
# 添加EWC正则项
ewc_loss = self.compute_ewc_loss()
total_loss = task_loss + self.ewc_lambda * ewc_loss
# 反向传播
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# 保存当前任务信息
self.fisher[task_id] = self.compute_fisher(dataloader)
self.optimal_params[task_id] = {
name: param.clone().detach()
for name, param in self.policy.named_parameters()
}
| 参数 | 推荐值 | 说明 |
|---|---|---|
| ewc_lambda | 1000-10000 | 控制遗忘防止强度 |
| fisher_samples | 100-1000 | 计算Fisher时的样本数 |
| batch_size | 32-128 | 平衡内存和稳定性 |
通过LIBERO基准测试,EWC算法在不同类型任务上展现出显著优势:
空间推理任务(LIBERO-Spatial)
物体操作任务(LIBERO-Object)
典型训练曲线特征:
实战建议:对于包含10个以上任务的长期学习,建议将ewc_lambda设置为5000-8000范围,并在每个任务训练后保存检查点。
python复制def update_fisher_online(self, batch):
# 小批量更新Fisher估计
self.optimizer.zero_grad()
loss = self.compute_loss(self.policy(batch['obs']), batch['actions'])
loss.backward()
for name, param in self.policy.named_parameters():
if param.grad is not None:
self.fisher[name] = 0.99 * self.fisher[name] + 0.01 * param.grad.pow(2)
优势:
python复制def get_adaptive_lambda(self, task_id):
base_lambda = 5000
decay_factor = 0.9 # 旧任务权重衰减
return base_lambda * (decay_factor ** (self.n_tasks - task_id - 1))
EWC + 经验回放
python复制class EWC_ER(EWC):
def __init__(self, memory_size=1000, **kwargs):
super().__init__(**kwargs)
self.memory = []
self.memory_size = memory_size
def learn_task(self, task_id, **kwargs):
# ...EWC训练逻辑...
# 保存部分数据到记忆库
mem_samples = min(self.memory_size//(task_id+1), len(dataset))
self.memory.extend(random.sample(dataset, mem_samples))
结合优势:
常见问题1:Fisher计算不稳定
python复制torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
常见问题2:正则项主导训练
python复制current_lambda = self.ewc_lambda / (1 + 0.1*task_id)
常见问题3:内存消耗过大
python复制fisher[name] = torch.zeros_like(param).diag() # 只存储对角线元素
在LIBERO框架中实践EWC算法时,建议从简单任务开始逐步验证实现正确性。一个可靠的检查点是:第一个任务训练后,Fisher矩阵应呈现明显的稀疏模式,只有部分参数具有显著非零值。