机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)

核心代码

train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪 Cutout(0.5),#参数是遮挡的概率 transforms.RandomHorizontalFlip(),#随机左右翻转 transforms.ToTensor()#必不可少的数据转换 ])

效果图 机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)
文章图片

填充、随机裁剪效果
机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)
文章图片

机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)
文章图片

【机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)】Cutout() 的遮挡效果
机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)
文章图片

机器学习|pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)
文章图片

完整代码
import torch as t import numpy as np import torchvision as tv import matplotlib.pyplot as plt from torchvision import transforms from torchtoolbox.transform import Cutout ROOT = '../pytorch/cifar-10' BATCH_SIZE=128train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪 Cutout(0.5),#参数是遮挡的概率 transforms.RandomHorizontalFlip(), transforms.ToTensor() ])train_data = https://www.it610.com/article/tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=train_transform) train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')for i,data in enumerate(train_load): img,label = data if i<10: image = img[i].numpy() tag = label[i].numpy() #print(tag) print('img is a ', classes[tag]) show = np.zeros((32,32,3)) show[:,:,0]=image[0,:,:] show[:,:,1]=image[1,:,:] show[:,:,2]=image[2,:,:] #print(show.shape) #print(img[i]) plt.figure() plt.imshow(show) plt.show() print("------------------------------------") else: break

    推荐阅读