手把手教你处理TT100K数据集:从COCO格式转换到YOLO格式的完整流程(附Python脚本)
手把手教你处理TT100K数据集:从COCO格式转换到YOLO格式的完整流程(附Python脚本)
当你第一次打开TT100K数据集的压缩包,可能会被里面错综复杂的文件夹结构和数千张图片淹没。作为交通标志检测领域的经典数据集,TT100K虽然数据丰富,但原始格式和类别分布直接用于YOLO训练往往会遇到各种"坑"。本文将带你一步步完成从数据清洗到格式转换的全流程,并提供可直接运行的Python脚本。
1. 理解TT100K数据集的原始结构
TT100K数据集(2021版)默认包含三个主要文件夹:train、test和other。这种划分方式存在几个实际问题:
- 样本分布不均:other文件夹包含的图片数量(7641)甚至超过了train(6105)和test(3071)的总和
- 类别不平衡:部分交通标志类别可能只有个位数样本
- 格式不兼容:原始标注采用JSON格式,与YOLO需要的txt格式差异较大
先来看看原始数据的统计情况:
import os from collections import defaultdict # 统计各类别数量 def count_categories(annotation_path): category_count = defaultdict(int) with open(annotation_path) as f: data = json.load(f) for ann in data['annotations']: category_count[ann['category_id']] += 1 return category_count # 示例输出可能显示某些类别只有几个样本 { 1: 1200, # 限速标志 2: 850, # 禁止停车 ... 78: 3, # 罕见标志 79: 1 # 极罕见标志 }2. 数据清洗与类别过滤
面对类别不平衡问题,我们需要先进行数据清洗。保留样本数超过100的类别是个不错的起点,但实际操作中还需要考虑:
- 类别重要性:某些关键交通标志即使样本少也应保留
- 数据增强潜力:容易通过旋转、变色等增强的类别可以适当放宽标准
以下是过滤低频类别的Python实现:
def filter_categories(annotation_path, min_samples=100): with open(annotation_path) as f: data = json.load(f) # 统计类别 cat_count = defaultdict(int) for ann in data['annotations']: cat_count[ann['category_id']] += 1 # 确定保留的类别ID keep_cats = {k for k,v in cat_count.items() if v >= min_samples} # 过滤标注 new_annotations = [ ann for ann in data['annotations'] if ann['category_id'] in keep_cats ] # 更新数据 data['annotations'] = new_annotations data['categories'] = [ cat for cat in data['categories'] if cat['id'] in keep_cats ] return data3. COCO转YOLO格式的核心转换
YOLO格式要求每个图像对应一个txt文件,每行包含:
<class_id> <x_center> <y_center> <width> <height>这些坐标需要归一化到[0,1]区间。转换脚本的核心逻辑如下:
def coco_to_yolo(coco_ann, output_dir, img_width, img_height): os.makedirs(output_dir, exist_ok=True) for img_info in coco_ann['images']: img_id = img_info['id'] anns = [a for a in coco_ann['annotations'] if a['image_id'] == img_id] if not anns: continue txt_path = os.path.join(output_dir, f"{img_info['file_name'].split('.')[0]}.txt") with open(txt_path, 'w') as f: for ann in anns: # COCO格式是[x,y,width,height] x, y, w, h = ann['bbox'] # 转换为YOLO格式 x_center = (x + w/2) / img_width y_center = (y + h/2) / img_height norm_w = w / img_width norm_h = h / img_height f.write(f"{ann['category_id']} {x_center} {y_center} {norm_w} {norm_h}\n")4. 数据集重新划分的最佳实践
原始TT100K的划分方式不适合现代目标检测训练,我们需要重新划分为train/val/test。推荐的比例是70%/15%/15%,但具体实施时有几个技巧:
- 分层抽样:确保每个类别在三个集合中都有代表
- 防止数据泄漏:同一标志的不同角度图片应放在同一集合
- 考虑地理分布:不同拍摄地点的数据应均匀分布
from sklearn.model_selection import train_test_split def split_dataset(coco_ann, test_size=0.15): # 按图片ID分组 img_ids = list({ann['image_id'] for ann in coco_ann['annotations']}) # 第一次分割:分出test集 train_val_ids, test_ids = train_test_split( img_ids, test_size=test_size, random_state=42 ) # 第二次分割:分出val集 train_ids, val_ids = train_test_split( train_val_ids, test_size=test_size/(1-test_size), random_state=42 ) return { 'train': train_ids, 'val': val_ids, 'test': test_ids }5. 自动化处理流程整合
将上述步骤整合成完整流水线,并添加错误处理和日志记录:
def process_tt100k_dataset(input_dir, output_dir, min_samples=100): """完整的TT100K处理流程""" try: # 1. 加载原始标注 coco_ann = load_coco_annotation(input_dir) # 2. 过滤低频类别 filtered_ann = filter_categories(coco_ann, min_samples) # 3. 重新划分数据集 splits = split_dataset(filtered_ann) # 4. 为每个划分创建YOLO格式 for split_name, img_ids in splits.items(): split_ann = { 'images': [img for img in filtered_ann['images'] if img['id'] in img_ids], 'annotations': [ann for ann in filtered_ann['annotations'] if ann['image_id'] in img_ids], 'categories': filtered_ann['categories'] } # 转换格式 for img in split_ann['images']: img_path = os.path.join(input_dir, 'images', img['file_name']) img = cv2.imread(img_path) h, w = img.shape[:2] coco_to_yolo( split_ann, os.path.join(output_dir, 'labels', split_name), w, h ) # 复制图片到对应目录 os.makedirs(os.path.join(output_dir, 'images', split_name), exist_ok=True) shutil.copy( img_path, os.path.join(output_dir, 'images', split_name, img['file_name']) ) except Exception as e: logging.error(f"处理失败: {str(e)}") raise6. 常见问题与解决方案
在实际操作中,你可能会遇到以下典型问题:
坐标越界:转换后坐标超出[0,1]范围
- 检查原始标注是否有错误
- 添加边界检查代码:
x_center = max(0, min(1, x_center)) y_center = max(0, min(1, y_center))
类别ID不连续:过滤后类别ID出现间隔
- 建议重新映射为连续ID:
cat_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(keep_cats))}
- 建议重新映射为连续ID:
图片与标注不匹配:部分图片找不到对应标注
- 建立完整的校验流程:
for img in coco_ann['images']: if not any(ann['image_id'] == img['id'] for ann in coco_ann['annotations']): print(f"警告: 图片 {img['file_name']} 没有对应标注")
- 建立完整的校验流程:
7. 高级技巧与优化建议
当处理完基础转换后,可以考虑以下优化:
自动生成YOLO配置文件:
def generate_yaml(categories, output_path): with open(output_path, 'w') as f: f.write("train: ../images/train\n") f.write("val: ../images/val\n") f.write("test: ../images/test\n\n") f.write("nc: {}\n".format(len(categories))) f.write("names: {}\n".format( [cat['name'] for cat in categories] ))可视化验证:开发一个小工具检查转换结果是否正确
def plot_yolo_annotation(img_path, txt_path): img = cv2.imread(img_path) h, w = img.shape[:2] with open(txt_path) as f: for line in f: class_id, xc, yc, bw, bh = map(float, line.split()) # 转换回像素坐标 x1 = int((xc - bw/2) * w) y1 = int((yc - bh/2) * h) x2 = int((xc + bw/2) * w) y2 = int((yc + bh/2) * h) cv2.rectangle(img, (x1,y1), (x2,y2), (0,255,0), 2) cv2.imshow("Preview", img) cv2.waitKey(0)并行处理加速:对于大型数据集,可以使用多进程:
from multiprocessing import Pool def parallel_convert(args): img_path, output_dir, img_w, img_h = args # 转换逻辑... with Pool(processes=4) as pool: pool.map(parallel_convert, task_list)
处理TT100K这类复杂数据集时,最耗时的往往不是技术实现,而是对各种边缘情况的处理。在实际项目中,我通常会先抽取100张样本快速验证整个流程,确认无误后再处理完整数据集。保存中间结果和添加充分的日志也能在出现问题时快速定位原因。
