目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)


文章目录

    • 一、MMDetection安装并测试
      • 安装步骤
    • 二、数据集准备
    • 三、训练前的准备(修改参数)
      • 事先说明
      • 修改数据集相关参数
      • 修改训练相关参数
    • 四、开始训练
      • 单GPU训练
      • 指定多GPU训练
    • 五、使用训练结果进行测试并可视化
      • 验证集图片测试
      • 训练日志可视化
      • DetVisGUI可视化
      • 计算模型复杂度
    • References

一、MMDetection安装并测试 MMDetection是MMLab家族的一员,是由香港中文大学和商汤科技共同推出的,以一个统一的架构支撑了15个大方向的研究领域,实现了1300+的算法复现工作。MMDetection依赖Pytorch和MMCV(mmcv/mmcv-full),因此安装之前需要先安装这两个库,具体安装步骤可参考这篇博客:MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)。
安装步骤
  1. Anaconda虚拟环境搭建
    1. conda create -n mmdet python=3.8
    2. conda activate mmdet
  2. Pytorch安装
    1. nvidia-smi确定服务器中CUDA的版本
    2. conda install pytorch==1.9.1 torchvision torchaudio cudatoolkit=10.2 -c pytorch安装对应版本的torch
    3. torch.cuda.is_available验证torch是否安装成功
  3. 安装MMCV
  4. 安装MMDetection
  5. 使用Demo验证是否安装成功
    1. mmdetection/新建checkpoints文件夹,下载faster-rcnn的预训练模型权重到该文件夹下
    2. mmdetection/新建test_demo.py文件,输入以下代码,然后运行
# 测试mmdet、mmcv是否安装成功 from mmdet.apis import init_detector, inference_detector, show_result_pyplot import torchconfig_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = init_detector(config_file, checkpoint_file, device=device) img = 'demo/demo.jpg' result = inference_detector(model, img) print(len(result)) show_result_pyplot(model=model, img=img, result=result, score_thr=0.9)

若安装成功,则运行结果如下图所示
目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)
文章图片


二、数据集准备 【目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)】本文所使用的数据集为三种水果数据集,下载链接为:https://download.csdn.net/download/weixin_43799388/84425688。在mmdetection/新建data文件夹,将数据集解压后放到这里。
数据集文档结构如下:
|——-Fruit |---Annotations |---001.xml |---002.xml ... ... |---340.xml |---images |---001.jpg |---002.jpg ... ... |---340.jpg

MMDetection一共支持三种形式应用新数据集:
  1. 将数据集重新组织为 COCO 格式
  2. 将数据集重新组织为一个中间格式
  3. 实现一个新的数据集
官方建议使用前面两种方法,因为它们通常来说比第三种方法要简单。
该数据集的标注形式为xml格式,其中GT框的坐标信息是以左上角坐标(xmin, ymin)和右下角坐标(xmax, ymax)形式来标注的,这里我们采用MMDetection建议的第一种方法,将数据集重新组织为 COCO 格式,具体转换步骤可以参考我写的另一篇博客:VOC(xml)标注格式转换为YOLOv5(txt)和COCO2017(json)格式。
转换结果的文档结构如下所示,我们真正用到的是annotations、train和val这三个文件夹。
|——-Fruit |---annotations |---Fruit_train.json |---Fruit_val.json |---Annotations |---001.xml |---002.xml ... ... |---340.xml |---images |---001.jpg |---002.jpg ... ... |---340.jpg |---ImageSets |---train.txt# 存放训练集图片名称 |---val.txt# 存放验证集图片名称 |---train |---001.jpg |---002.jpg ... ... |---val |---046.jpg |---049.jpg ... ...


三、训练前的准备(修改参数) 事先说明
  • 数据集类别一共有3个:apple/banana/grape
  • 使用的训练模型是configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py

修改数据集相关参数
虽然我们已经将xml格式标注的数据集转换成了MMDetection使用的COCO数据集格式,但还需要修改一些配置参数(官方建议直接修改coco数据集定义文件):
  1. 修改模型配置文件:configs/_base_/models/faster_rcnn_r50_fpn.py
    1. 定位到roi_head字典出,修改bbox_head字典中的num_classes为3
  2. 修改coco数据集定义文件:mmdet/datasets/coco.py
    1. CLASSES那里的参数修改为:CLASSES = ('apple', 'banana', 'grape')
    2. PALETTE参数随意选三个留下即可,例如:PALETTE = [(220, 20, 60), (119, 11, 32), (191, 162, 208)],这个参数用来指定每个类别框的显示颜色
  3. 修改class_name:mmdet/core/evaluation/class_names.py
    1. 定位到coco_classes函数,修改return中的参数为:'apple', 'banana', 'grape'
  4. mmdetection目录下新建test_work_dirs文件夹

修改训练相关参数
此外,在configs/_base_/default_runtime.py文件中可以修改训练时的其他参数,本次训练的default_runtime.py代码如下:
# 保存checkpoints的间隔 默认每次都保存 checkpoint_config = dict(interval=4) # yapf:disable 打印log的间隔(每个epoch中) 默认迭代50次打印一次(datasets的大小除以batchsize) log_config = dict( interval=5, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') ]) # yapf:enable custom_hooks = [dict(type='NumClassCheckHook')] dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None# 加载参数 # 断点续训 重新加载已训练好的checkpoints 包含epoch等信息 会覆盖load_form resume_from = None # 工作流 workflow = [('train', 1)] # disable opencv multithreading to avoid system being overloaded opencv_num_threads = 0 # set multi-process start method as `fork` to speed up the training mp_start_method = 'fork'


四、开始训练 单GPU训练
python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --gpus 1 --work-dir test_work_dirs


指定多GPU训练
CUDA_VISIBLE_DEVICES=2,3指定使用GPU-3和GPU-4,放到python命令之前,同时要设置--gpus 2
CUDA_VISIBLE_DEVICES=2,3 python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --gpus 2 --work-dir xiaolong_dir0

训练过程界面如下:

训练完之后test_work_dirs文件夹中会保存下训练过程中的log日志文件、每4个epoch的pth文件(因为在default_runtime.py设置了checkpoint_config = dict(interval=4)),这个文件将会用于后面的test测试。

五、使用训练结果进行测试并可视化 验证集图片测试
python tools/test.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py test_work_dirs/epoch_12.pth --eval bbox --out test_work_dirs/result12.pkl --show

传递参数说明:
  • config:模型训练的配置文件
  • checkpoint:训练结果的权重文件
  • --eval:验证指标,一般使用bbox
  • --out:测试结果文件保存的路径及名称
  • --show:展示每一张验证集图片的测试结果
测试结果如下:
目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)
文章图片


训练日志可视化
python tools/analysis_tools/analyze_logs.py plot_curve test_work_dirs/20220312_094204.log.json --keys acc loss_cls loss_bbox

将训练结果20220312_094204.log.json中的参数acc、loss_clsloss_bbox进行可视化,结果如下(由于数据集太小,并且只有3个class,因此训练很快就收敛了):
目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)
文章图片


DetVisGUI可视化
DetVisGUI工具是一个用于可视化MMDetection测试结果的轻量级GUI,它可以动态显示不同阈值的检测结果,便于验证检测结果和GT框的差异。
  • 在使用DetVisGUI工具之前,需要利用python tools/test.py命令,保存测试的结果文件(pkl格式或者json格式)
  • 之后从GitHub上下载DetVisGUI源码,放在mmdetection/DetVisGUI/文件夹中,运行DetDetVisGUI的命令格式是:
python DetVisGUI/DetVisGUI.py ${CONFIG_FILE} [--det_file ${RESULT_FILE}] [--stage ${STAGE}] [--output ${SAVE_DIRECTORY}]

传递参数说明:
  • config:模型训练的配置文件
  • --det_file:测试结果文件(pkl格式或json格式)
  • --stage:测试结果文件的三种stage(train/val/test),默认为val
  • --output:测试图片的保存路径,默认为output
运行如下命令:
python DetVisGUI/DetVisGUI.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py --det_file test_work_dirs/result12.pkl --output test_work_dirs/val_result

运行结果如下图所示,可以调节IOU阈值、置信度阈值、是否显示GT框及文本信息等,来对比测试结果。
目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)
文章图片

点击左上角的Save All Results按钮,即可将全部验证集图片的测试结果保存到指定路径中。

计算模型复杂度
首先来科普一下FLOPs和FLOPS的区别:
  • 计算复杂度:FLOPs(注意s是小写)
    • floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量,和软硬件的配置没有关系,可以公平地用来衡量算法/模型的复杂度
    • 计算公式: F L O P s = C o u t ? H o u t ? W o u t ? C i n ? k ? k FLOPs=C_{out}*H_{out}*W_{out}*C_{in}*k*k FLOPs=Cout??Hout??Wout??Cin??k?k
  • FLOPS(floating point operations per second)
    • 意指每秒浮点运算次数,理解为计算速度,是一个衡量硬件性能的指标
在MMDetection中可以使用tools/analysis_tools/get_flops.py命令来获取模型的复杂度:
python tools/analysis_tools/get_flops.py test_work_dirs/faster_rcnn_r50_fpn_1x_coco.py

运行结果如下所示,可以看到get_flops.py函数会打印出模型每一层的FLOPs和参数量,以及总的FLOPs和参数量。
但要注意的是,这些数据仅供参考,不一定准确,输出结果的最后一行也提醒我们,不建议把这个输出结果放到论文中。
目标检测|【MMDetection】v2.22.0入门(训练自己的数据集)
文章图片


References MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)
官方教程:MMDETECTION’S DOCUMENTATION!
VOC(xml)标注格式转换为YOLOv5(txt)和COCO2017(json)格式
【mmdetection】使用自定义的coco格式数据集进行训练及测试
mmdetection可视化工具-DetVisGUI

    推荐阅读