核心代码
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪
Cutout(0.5),#参数是遮挡的概率
transforms.RandomHorizontalFlip(),#随机左右翻转
transforms.ToTensor()#必不可少的数据转换
])
效果图
文章图片
填充、随机裁剪效果
文章图片
文章图片
【机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)】Cutout() 的遮挡效果
文章图片
文章图片
完整代码
import torch as t
import numpy as np
import torchvision as tv
import matplotlib.pyplot as plt
from torchvision import transforms
from torchtoolbox.transform import Cutout
ROOT = '../pytorch/cifar-10'
BATCH_SIZE=128train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪
Cutout(0.5),#参数是遮挡的概率
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])train_data = https://www.it610.com/article/tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=train_transform)
train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')for i,data in enumerate(train_load):
img,label = data
if i<10:
image = img[i].numpy()
tag = label[i].numpy()
#print(tag)
print('img is a ', classes[tag])
show = np.zeros((32,32,3))
show[:,:,0]=image[0,:,:]
show[:,:,1]=image[1,:,:]
show[:,:,2]=image[2,:,:]
#print(show.shape)
#print(img[i])
plt.figure()
plt.imshow(show)
plt.show()
print("------------------------------------")
else:
break
推荐阅读
- Python如何将两个字典合并到一个表达式中(有哪些方式?)
- 如何按照值对Python字典排序(有几种方式和对应的实例?)
- 如何理解Python中if__name__==“__main__”(它是什么意思,有什么用?)
- Web自动化|Web自动化 - 三种等待
- Python的yield关键字有什么作用(如何理解yield?)
- 使用Python调用外部命令有什么办法(有哪些方式?有实例吗?)
- Python中如何使用全局变量(求实例解释)
- Python中函数的星号和双星号参数是什么(有什么作用?)
- Python如何通过引用传递变量(python有引用传递吗?)