U2Net基于ModelArts|U2Net基于ModelArts Notbook的仿真实验
摘要:U2Net是一个优秀的显著性目标检测算法,由Qin Xuebin等人发表在Pattern Recognition 2020期刊[Arxiv]。U2Net名称的来源在于其网络结构由两层嵌套的Unet结构,可以在不需要预训练骨干网络的情况下从零开始训练,拥有优异的表现。本文分享自华为云社区《ModelArts Notebook快速开源项目实战 — U2Net》,作者:shpity 。
一、U2Net介绍 【U2Net基于ModelArts|U2Net基于ModelArts Notbook的仿真实验】U2Net是一个优秀的显著性目标检测算法,由Qin Xuebin等人发表在Pattern Recognition 2020期刊[Arxiv]。U2Net名称的来源在于其网络结构由两层嵌套的Unet结构,可以在不需要预训练骨干网络的情况下从零开始训练,拥有优异的表现。其网络结构如图1所示。
文章图片
图1. U2Net的主体框架是一个类似于U-Net的编解码结构,但是每一个block替换为新提出的残差U-block模块
项目开源地址:https://github.com/xuebinqin/...
二、创建Notebook开发环境 1.进入ModelArts控制台
2.选择开发环境 -> Notebook -> 创建
3.创建Notebook
3.1 可以选择和任务相关的名称,方便管理;
3.2 为了减少不必要的资源消耗,建议开启自动停止;
3.3 U2Net所需的运行环境在公共镜像中已经包含,可以选择pytorch1.4-cuda10.1-cudnn7-ubuntu18.04;
3.4 建议选择GPU类型,方便模型快速训练;
3.5 选择立即创建 -> 提交,等待notebook创建完成后打开Notebook。
文章图片
文章图片
文章图片
4.导入开源项目源码(git/手动上传)
4.1 在Terminal使用git克隆远程仓库
cd work # 注意:只有/home/ma-user/work目录及其子目录下的文件在Notebook实例关闭后会保存
git clone https://github.com/xuebinqin/U-2-Net.git
4.2 如果git速度较慢也可以从本地上传代码,直接将压缩包拖到左侧文件目录栏或者采用OBS上传。
三、 数据准备 1.下载训练数据APDrawing dataset
使用Wget直接下载到Notebook,也可下载本地后再拖拽到Notebook中。
wget https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingDB.zip
unzip APDrawingDB.zip
注:如果数据集较大(>5GB)需要下载到其它目录(实例停止后会被删除),建议存放在OBS中,需要的时候随时拉取。
#从OBS中拉取代码到指定目录
sh-4.4$ source /home/ma-user/anaconda3/bin/activate PyTorch-1.4
sh-4.4$ python
>>> mox.file.copy_parallel('obs://bucket-xxxx/APDrawingDB', '/home/ma-user/work/APDrawingDB')
2.切分训练数据
数据集中./APDrawingDB/data/train中包含了420张训练图片,分辨率为512*1024,左侧为输入图像,右侧为对应的ground truth。我们需要将大图从中间切分为两个子图。
文章图片
2.1 在Notebook开发环境中新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为split.ipynb,脚本将会在./APDrawingDB/data/train/split目录下生成840张子图,其中原始图像以.jpg结尾,gt图像以.png结尾,方便后续训练代码读取【test文件夹切分步骤同理】。
from PIL import Image
import os
train_img_dir = os.path.join("./APDrawingDB/data/train")
img_list = os.listdir(train_img_dir)
for image in img_list:
img_path = os.path.join(train_img_dir, image)
if not os.path.isdir(img_path):
img = Image.open(img_path)
#print(img.size)
save_img_dir = os.path.join(train_img_dir, 'split_train')
if not os.path.exists(save_img_dir):
os.mkdir(save_img_dir)
save_img_path = os.path.join(save_img_dir, image)
cropped_left = img.crop((0, 0, 512, 512))# (left, upper, right, lower)
cropped_right = img.crop((512, 0, 1024, 512))# (left, upper, right, lower)
cropped_left.save(save_img_path[:-3] + 'jpg')
cropped_right.save(save_img_path)test_img_dir = os.path.join("./APDrawingDB/data/test")
img_list = os.listdir(test_img_dir)
for image in img_list:
img_path = os.path.join(test_img_dir, image)
if not os.path.isdir(img_path):
img = Image.open(img_path)
#print(img.size)
save_img_dir = os.path.join(test_img_dir, 'split')
if not os.path.exists(save_img_dir):
os.mkdir(save_img_dir)
save_img_path = os.path.join(save_img_dir, image)
cropped_left = img.crop((0, 0, 512, 512))# (left, upper, right, lower)
cropped_right = img.crop((512, 0, 1024, 512))# (left, upper, right, lower)
cropped_left.save(save_img_path[:-3] + 'jpg')
3.将切分好的数据按照如下层级结构整理出训练和测试所需的datasets文件夹
datasets/
├── test (70张切分图片,只包含原图)
└── train (840张切分图片,包含420张原图及对应的gt)
注:可以将切分好的数据集保存到OBS目录中,减少./work的磁盘空间占用。
4.完整的U-2-Net项目结构如下所示:
U-2-Net/
├── .git
├── LICENSE
├── README.md
├── pycache
├── clipping_camera.jpg
├── data_loader.py
├── datasets
├── figures
├── gradio
├── model
├── requirements.txt
├── saved_models
├── setup_model_weights.py
├── test_data
├── u2net_human_seg_test.py
├── u2net_portrait_demo.py
├── u2net_portrait_test.py
├── u2net_test.py
└── u2net_train.py
四、训练 1.官方提供的训练代码中数据的路径和我们的datasets有些区别,需要对训练脚本进行一些修改,建议使用jupyter notebook方便排除错误
新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为train.ipynb
import moxing as mox
# 如果需要从OBS拷贝切分好的训练数据
#mox.file.copy_parallel('obs://bucket-test-xxxx', '/home/ma-user/work/U-2-Net/datasets')
INFO:root:Using MoXing-v1.17.3-43fbf97f
INFO:root:Using OBS-Python-SDK-3.20.7
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transformsimport numpy as np
import glob
import osfrom data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDatasetfrom model import U2NET
from model import U2NETP
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/skimage/io/manage_plugins.py:23: UserWarning: Your installed pillow version is < 7.1.0. Several security issues (CVE-2020-11538, CVE-2020-10379, CVE-2020-10994, CVE-2020-10177) have been fixed in pillow 7.1.0 or higher. We recommend to upgrade this library.
from .collection import imread_collection_wrapper
bce_loss = nn.BCELoss(size_average=True)
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))return loss0, loss
model_name = 'u2net' #'u2netp'data_dir = os.path.join(os.getcwd(), 'datasets', 'train' + os.sep)
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)image_ext = '.jpg'
label_ext = '.png'model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)epoch_num = 100000
batch_size_train = 24
batch_size_val = 1
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir+ '*' + image_ext)tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]tra_lbl_name_list.append(data_dir+ imidx + label_ext)print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")train_num = len(tra_img_name_list)
---
train images:420
train labels:420
---
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)if torch.cuda.is_available():
net.cuda()# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
---define optimizer...
# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations
---start training...
for epoch in range(0, epoch_num):
net.train()for i, data in enumerate(salobj_dataloader):
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1inputs, labels = data['image'], data['label']inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)# y zero the parameter gradients
optimizer.zero_grad()# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)loss.backward()
optimizer.step()# # print statistics
running_loss += loss.data.item()
running_tar_loss += loss2.data.item()# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))if ite_num % save_frq == 0:
model_weight = model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)
torch.save(net.state_dict(),model_weight)
mox.file.copy_parallel(model_weight, 'obs://bucket-xxxx/output/model_save/' + model_weight.split('/')[-1])
running_loss = 0.0
running_tar_loss = 0.0
net.train()# resume train
ite_num4val = 0
l0: 0.167562, l1: 0.153742, l2: 0.156246, l3: 0.163096, l4: 0.176632, l5: 0.197176, l6: 0.247590[epoch:1/100000, batch:24/420, ite: 500] train loss: 1.189413, tar: 0.159183
l0: 0.188048, l1: 0.179041, l2: 0.180086, l3: 0.187904, l4: 0.198345, l5: 0.218509, l6: 0.269199[epoch:1/100000, batch:48/420, ite: 501] train loss: 1.266652, tar: 0.168805
l0: 0.192491, l1: 0.187615, l2: 0.188043, l3: 0.197142, l4: 0.203571, l5: 0.222019, l6: 0.261745[epoch:1/100000, batch:72/420, ite: 502] train loss: 1.313146, tar: 0.174727
l0: 0.169403, l1: 0.155883, l2: 0.157974, l3: 0.164012, l4: 0.175975, l5: 0.195938, l6: 0.244896[epoch:1/100000, batch:96/420, ite: 503] train loss: 1.303333, tar: 0.173662
l0: 0.171904, l1: 0.157170, l2: 0.156688, l3: 0.162020, l4: 0.175565, l5: 0.200576, l6: 0.258133[epoch:1/100000, batch:120/420, ite: 504] train loss: 1.299787, tar: 0.173369
l0: 0.177398, l1: 0.166131, l2: 0.169089, l3: 0.176976, l4: 0.187039, l5: 0.205449, l6: 0.248036
五、测试 新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为test.ipynb
import moxing as mox
# 拷贝数据
mox.file.copy_parallel('obs://bucket-xxxx/output/model_save/u2net.pth', '/home/ma-user/work/U-2-Net/saved_models/u2net/u2net.pth')
import os
import sys
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optimimport numpy as np
from PIL import Image
import globfrom data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDatasetfrom model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)dn = (d-mi)/(ma-mi)return dndef save_output(image_name,pred,d_dir, show=False):predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)pb_np = np.array(imo)
if show:
show_on_notebook(image, im)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]imo.save(d_dir+imidx+'.png')
return im
def show_on_notebook(image_original, pred): #此函数可以在notebook中展示模型的预测效果
plt.subplot(1,2,1)
imshow(np.array(image_original))
plt.subplot(1,2,2)
imshow(np.array(pred))
# --------- 1. get image path and name ---------
model_name='u2net'#u2netpimage_dir = os.path.join(os.getcwd(), 'datasets', 'test') #注意这里的test_data/original存放的是datasets/test中的原始图片,不包含gt
prediction_dir = os.path.join(os.getcwd(), 'output', model_name + '_results' + os.sep)
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
img_name_list = glob.glob(os.path.join(os.getcwd(), 'datasets/test/*.jpg'))
# print(img_name_list)# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)# --------- 3. model define ---------
if(model_name=='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):#print("inferencing:",img_name_list[i_test].split(os.sep)[-1])inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)d1,d2,d3,d4,d5,d6,d7= net(inputs_test)# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)# save results to test_results folder
if not os.path.exists(prediction_dir):
os.makedirs(prediction_dir, exist_ok=True)
save_output(img_name_list[i_test],pred,prediction_dir, show=True)
#sys.exit(0)del d1,d2,d3,d4,d5,d6,d7
文章图片
六、附件
见附件
想了解更多的AI技术干货,欢迎上华为云的AI专区,目前有AI编程Python等六大实战营供大家免费学习。(六大实战营link:http://su.modelarts.club/qQB9)
附件.zip71.28KB
点击关注,第一时间了解华为云新鲜技术~
推荐阅读
- 基于微信小程序带后端ssm接口小区物业管理平台设计
- 基于|基于 antd 风格的 element-table + pagination 的二次封装
- 基于爱,才会有“愿望”当“要求”。2017.8.12
- javaweb|基于Servlet+jsp+mysql开发javaWeb学生成绩管理系统
- JavaScript|vue 基于axios封装request接口请求——request.js文件
- 韵达基于云原生的业务中台建设 | 实战派
- EasyOA|EasyOA 基于SSM的实现 未完成总结与自我批判
- 基于stm32智能风扇|基于stm32智能风扇_一款基于STM32的智能灭火机器人设计
- stm32|基于STM32和freeRTOS智能门锁设计方案
- Python|Python 基于datetime库的日期时间数据处理