pytorch|MMSegmentation训练自己的分割数据集

首先确保在服务器上正常安装了MMSeg,注意安装完还需建立与自己的数据集之间的软连接,官方安装教程如下:
https://github.com/open-mmlab/mmsegmentation/blob/master/docs/get_started.md#installation

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html git clone https://github.com/open-mmlab/mmsegmentation.git cd mmsegmentation pip install -e .# or "python setup.py develop"mkdir data ln -s $DATA_ROOT data

由于工作需要,需要在MMSeg上训练自己的数据集,而之前只在上面训过Cityscapes的数据集,两种数据集质检的组织形式不同,需要自己重写配置文件,我的数据原始格式为一对img和label,两者文件名相同,后缀img.tif,label.png,如图:
原始图像:
pytorch|MMSegmentation训练自己的分割数据集
文章图片

标签:
pytorch|MMSegmentation训练自己的分割数据集
文章图片

参考官方文档的教程组织数据集:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/tutorials/customize_datasets.md
首先在/mmsegmentation-master/mmseg/datasets路径下新建一个mydataset.py文件
@DATASETS.register_module() class mydata(CustomDataset): CLASSES = ('background', 'plant', 'plantsoil', 'grass', 'building', 'railway', 'structure', 'humanland', 'baresoil', 'water') PALETTE = [[0,0,0], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35]] def __init__(self, **kwargs): super(mydata, self).__init__( img_suffix='.tif', seg_map_suffix='.png', reduce_zero_label=False, **kwargs) assert osp.exists(self.img_dir)

写法参考该文件夹下其他数据集的py文件,需要修改的地方为CLASSES和PALETTE,前者为你数据集中每一类的名称,后者为类别对应的颜色,注意不要漏掉背景值'background'。然后根据自己数据集的后缀更改img_suffix和seg_map_suffix,改好之后保存
同时还需要在同一文件夹下的__init__.py文件中添加刚刚新建的数据集
from .stare import STAREDataset from .voc import PascalVOCDataset from .mydataset import mydata__all__ = [ 'CustomDataset', 'Test_pathDataset','build_dataloader', 'ConcatDataset', 'RepeatDataset', 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 'COCOStuffDataset', 'mydata' ]

/mmsegmentation-master/mmseg/core/evaluation下的class_names.py文件也需要进行更改:
import mmcvdef mydata_classes(): """shengteng class names for external use.""" return [ 'background', 'plant', 'plantsoil', 'grass', 'building', 'railway', 'structure', 'humanland', 'baresoil', 'water' ]def mydata_palette(): return [[0,0,0], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35]]

至此数据集更改完成,还需要更改config文件下的模型文件:
norm_cfg = dict(type='SyncBN', requires_grad=True) backbone_norm_cfg = dict(type='LN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained='pretrain/swin_base_upernet_224x224_1K.pth',#从官网下载与训练模型 backbone=dict( type='SwinTransformer', pretrain_img_size=224, embed_dims=128, patch_size=4, window_size=7, mlp_ratio=4, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], strides=(4, 2, 2, 2), out_indices=(0, 1, 2, 3), qkv_bias=True, qk_scale=None, patch_norm=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.3, use_abs_pos_embed=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN', requires_grad=True)), decode_head=dict( type='UPerHead', in_channels=[128, 256, 512, 1024], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, dropout_ratio=0.1, num_classes=9,#安装自己数据集的类别数更改 norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),

修改batch和数据集路径,train val test三部分都要更改:
data = https://www.it610.com/article/dict( samples_per_gpu=8,#batch size workers_per_gpu=12,#每张GPU分配的CPU核心数 train=dict( type='mydata',#自定义的数据集名称 data_root='data/xxx',#数据根目录 img_dir='img/',#原始图像路径 ann_dir='label/',#标签路径 pipeline=[ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(512, 512), ratio_range=(1.0, 1.0)), dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ])

修改完成后开始在自己的数据集上训练MMSeg:
python ./tools/train.py ./config/swin/upernet_swin_base_patch4_window7_512x512.py

训练的log和权重文件会保存在workdir文件下:
pytorch|MMSegmentation训练自己的分割数据集
文章图片

pytorch|MMSegmentation训练自己的分割数据集
文章图片

多GPU训练:

bash ./tools/dist_train.sh config.py 2#config为你的配置文件,2为使用的GPU的数量

注意!
1、mmseg所使用的数据集标签是从0开始的,即(0,1,2...n),n为类别数;
2、首次运行mmseg工程之前,须先运行,将工作环境切换至当前路径
python setup.py develop

【pytorch|MMSegmentation训练自己的分割数据集】

    推荐阅读