当你第一次打开PA100K数据集时,面对近10万张图片和密密麻麻的标签文件,可能会感到无从下手。这个数据集包含了26种行人属性标注,从基础的性别年龄到复杂的服饰搭配,每张图片都对应一个多维特征向量。传统单标签分类方法在这里完全失效——因为同一张图片可能同时属于"背包"和"长裤"等多个类别。
理解标签文件的结构是第一步。val_list.txt中的每一行都遵循图片路径\t标签向量的格式,其中标签向量是由26个0/1值组成的逗号分隔字符串。例如:
code复制val/000001.jpg 1,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,1,0,0,1,0,1,0,0,0
这个样本表示:
关键脚本组件解析:
python复制def find_indices(lst, value):
"""返回所有值等于value的索引"""
return [i for i, x in enumerate(lst) if x == value]
这个辅助函数是整个分类逻辑的核心,它能够快速定位标签向量中所有值为'1'的位置索引,对应图片所具有的属性。
建议采用以下项目结构,便于团队协作和后续扩展:
code复制PA100K_Processor/
├── configs/
│ ├── train_list.txt
│ └── val_list.txt
├── src/
│ └── dataset_organizer.py
├── data/
│ ├── raw/ # 原始数据集
│ └── processed/ # 分类后数据
└── requirements.txt
安装必要依赖:
bash复制pip install tqdm opencv-python
原始脚本有几个可以优化的关键点:
python复制import os
from pathlib import Path # 更现代的路径处理
from tqdm import tqdm
import shutil
import cv2 # 用于图片验证
class PA100KProcessor:
def __init__(self, config):
self.labels = config['labels']
self.src_root = config['src_root']
self.target_root = config['target_root']
self.file_list = config['file_list']
def validate_image(self, img_path):
"""验证图片是否可正常读取"""
try:
img = cv2.imread(img_path)
return img is not None
except:
return False
def process_dataset(self):
# 创建所有属性目录
Path(self.target_root).mkdir(parents=True, exist_ok=True)
for idx, attr in enumerate(self.labels):
(Path(self.target_root) / f"{idx}_{attr}").mkdir(exist_ok=True)
# 处理每条记录
with open(self.file_list) as f:
lines = f.readlines()
for line in tqdm(lines, desc="Processing images"):
parts = line.strip().split('\t')
if len(parts) != 2:
continue
img_rel_path, label_str = parts
img_path = Path(self.src_root) / img_rel_path
if not self.validate_image(str(img_path)):
continue
labels = label_str.split(',')
active_indices = [i for i, x in enumerate(labels) if x == '1']
for idx in active_indices:
target_dir = Path(self.target_root) / f"{idx}_{self.labels[idx]}"
shutil.copy2(img_path, target_dir / img_path.name)
改进亮点:
创建config.yaml:
yaml复制train_config:
labels: ["Hat", "Glasses", ..., "Back"] # 完整26个标签
src_root: "./data/raw"
target_root: "./data/processed/train"
file_list: "./configs/train_list.txt"
val_config:
labels: ["Hat", "Glasses", ..., "Back"]
src_root: "./data/raw"
target_root: "./data/processed/val"
file_list: "./configs/val_list.txt"
主执行脚本:
python复制import yaml
def main():
with open("config.yaml") as f:
configs = yaml.safe_load(f)
# 处理训练集
train_processor = PA100KProcessor(configs['train_config'])
train_processor.process_dataset()
# 处理验证集
val_processor = PA100KProcessor(configs['val_config'])
val_processor.process_dataset()
if __name__ == "__main__":
main()
分类完成后,我们可以进一步分析数据分布:
python复制import pandas as pd
import matplotlib.pyplot as plt
def analyze_distribution(target_root):
attr_stats = []
for attr_dir in Path(target_root).iterdir():
if not attr_dir.is_dir():
continue
attr_name = attr_dir.name.split('_', 1)[1]
count = len(list(attr_dir.glob('*.jpg')))
attr_stats.append({'attribute': attr_name, 'count': count})
df = pd.DataFrame(attr_stats)
df = df.sort_values('count', ascending=False)
plt.figure(figsize=(12, 6))
plt.barh(df['attribute'], df['count'], color='skyblue')
plt.xlabel('Image Count')
plt.title('Attribute Distribution')
plt.tight_layout()
plt.savefig('./attribute_distribution.png')
return df
常见分析维度:
这套处理流程可以轻松适配到其他多标签数据集,只需修改配置即可。以RAPv2数据集为例:
yaml复制rapv2_config:
labels: ["Male", "Age16-30", ..., "NoAccessory"] # RAP的72个属性
src_root: "./data/rapv2/images"
target_root: "./data/rapv2/processed"
file_list: "./configs/rapv2_list.txt"
通用化建议:
python复制def prepare_dataset(config):
processor = DatasetProcessor(
labels=config['labels'],
delimiter=config.get('delimiter', ','),
img_ext=config.get('img_ext', '.jpg'),
# 其他参数...
)
processor.run()
在实际项目中,这种预处理流程通常只是整个机器学习管道的第一步。接下来你可能需要:
当处理大规模数据集时,原始脚本可能会遇到性能瓶颈。以下是几个优化方向:
1. 并行处理加速
python复制from concurrent.futures import ThreadPoolExecutor
def process_image(args):
img_path, active_indices, labels, target_root = args
for idx in active_indices:
target_dir = Path(target_root) / f"{idx}_{labels[idx]}"
shutil.copy2(img_path, target_dir / img_path.name)
with ThreadPoolExecutor(max_workers=8) as executor:
args_list = [(img_path, active_indices, labels, target_root)
for ...] # 构建参数列表
list(tqdm(executor.map(process_image, args_list), total=len(args_list)))
2. 常见错误处理清单
| 错误类型 | 检测方法 | 解决方案 |
|---|---|---|
| 路径不存在 | Path.exists() |
创建父目录或跳过 |
| 图片损坏 | cv2.imread()验证 |
记录错误并跳过 |
| 标签格式错误 | 分割后长度检查 | 记录错误行号 |
| 磁盘空间不足 | shutil.disk_usage() |
提前检查并报警 |
| 权限问题 | try-catch块 | 提示用户修改权限 |
3. 内存优化技巧
shutil.copyfileobj()处理超大文件在实际部署这类预处理脚本时,有几点值得特别注意:
路径规范化:总是将路径转换为绝对路径后再操作,避免相对路径导致的意外行为。可以使用pathlib.Path.resolve()方法。
幂等性设计:确保脚本可以安全地多次运行,不会因为部分成功而导致重复拷贝或目录混乱。可以通过记录已处理文件的状态来实现。
日志系统:完善的日志记录比print语句更专业:
python复制import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('preprocess.log'),
logging.StreamHandler()
]
)
python复制import unittest
class TestFindIndices(unittest.TestCase):
def test_empty_list(self):
self.assertEqual(find_indices([], '1'), [])
def test_multiple_matches(self):
self.assertEqual(find_indices(['0','1','0','1'], '1'), [1, 3])
配置验证:使用JSON Schema或Pydantic验证配置文件的完整性,避免运行时因配置错误导致的异常。
进度持久化:对于超大规模数据集,实现断点续处理功能:
python复制class ProgressTracker:
def __init__(self, state_file='.progress'):
self.state_file = state_file
self.processed = set()
if Path(state_file).exists():
with open(state_file) as f:
self.processed = set(line.strip() for line in f)
def add_processed(self, img_path):
self.processed.add(str(img_path))
def save(self):
with open(self.state_file, 'w') as f:
f.writelines(f"{p}\n" for p in self.processed)
def is_processed(self, img_path):
return str(img_path) in self.processed