在机器人抓取领域,高质量的训练数据是模型性能的基石。传统手工标注方式不仅耗时费力,还难以保证标注一致性。本文将分享一套基于Cornell抓取数据集和自定义采集数据的自动化处理方案,通过标签转换脚本与数据增强技术,快速生成GGCNN网络可直接训练的.mat格式数据。
构建抓取检测数据集的第一步是获取原始图像资源。对于大多数研究者,可以从两个渠道获得数据:
环境配置方面,需要准备以下关键组件:
python复制# 基础环境安装示例
pip install pyrealsense2 opencv-python imageio scipy
conda install scikit-image=0.18.3 numpy mkl
提示:使用清华源可加速安装过程,添加
-i https://pypi.tuna.tsinghua.edu.cn/simple参数
GGCNN网络需要特定格式的.mat文件作为输入,包含三个关键通道:
转换脚本main_label.py的工作流程如下:
关键数学转换公式:
code复制抓取角度 = arctan2(dy, dx) # 计算抓取方向
cosθ = cos(angle) # 角度余弦分量
sinθ = sin(angle) # 角度正弦分量
宽度归一化 = 原始宽度/最大宽度 # 缩放到0-1范围
为提高模型泛化能力,需要在数据预处理阶段应用多种增强技术:
| 增强类型 | 参数范围 | 标签同步调整 |
|---|---|---|
| 随机缩放 | 0.9-1.1倍 | 等比例调整抓取点坐标 |
| 旋转 | ±30度 | 对应旋转角度矩阵 |
| 裁剪 | 300×300像素 | 调整抓取点偏移量 |
| 翻转 | 50%概率 | 镜像对称处理角度 |
实现示例代码:
python复制def augment_data(image, label):
# 随机缩放
scale = np.random.uniform(0.9, 1.1)
image = cv2.resize(image, None, fx=scale, fy=scale)
label['points'] *= scale
# 随机旋转
angle = np.random.uniform(-30, 30)
M = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1)
image = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))
# 随机裁剪
crop_size = 300
x = np.random.randint(0, image.shape[1] - crop_size)
y = np.random.randint(0, image.shape[0] - crop_size)
image = image[y:y+crop_size, x:x+crop_size]
label['points'] -= [x, y]
return image, label
将预处理流程集成到DataLoader中,实现训练时的实时数据增强:
python复制class GraspDataset(Dataset):
def __init__(self, data_dir, augment=True):
self.image_files = glob.glob(f"{data_dir}/*d.tiff")
self.label_files = [f.replace('d.tiff', 'r.mat') for f in self.image_files]
self.augment = augment
def __getitem__(self, idx):
# 读取原始数据
depth = load_tiff(self.image_files[idx])
label = loadmat(self.label_files[idx])
# 数据增强
if self.augment:
depth, label = random_augment(depth, label)
# 转换为张量
depth = torch.from_numpy(depth).float()
grasp_map = torch.from_numpy(label['grasp_map']).float()
cos_map = torch.from_numpy(label['cos_map']).float()
sin_map = torch.from_numpy(label['sin_map']).float()
width_map = torch.from_numpy(label['width_map']).float()
return depth, (grasp_map, cos_map, sin_map, width_map)
关键处理技巧:
torch.from_numpy实现零拷贝数据转换num_workers=4)在真实项目中应用该流程时,有几个值得注意的经验点:
标注效率优化:
cv2.namedWindow创建交互式标注界面数据质量检查:
存储优化:
python复制# HDF5存储示例
import h5py
with h5py.File('grasp_data.h5', 'w') as f:
# 创建可扩展数据集
f.create_dataset('depth', shape=(0,480,640), maxshape=(None,480,640),
chunks=(1,480,640), compression='gzip')
f.create_dataset('grasp_map', shape=(0,480,640), maxshape=(None,480,640),
chunks=(1,480,640), compression='gzip')
# 追加新数据
f['depth'].resize((f['depth'].shape[0]+1), axis=0)
f['depth'][-1] = new_depth
这套流程在实际项目中显著提升了数据准备效率,将原本需要数周的手工标注工作压缩到2-3天内完成,同时通过自动化增强使训练数据量扩大5-8倍。对于需要处理自定义物体的场景,建议先在小规模数据(50-100张)上验证标注质量,再扩展到完整数据集。