PyTorch学习笔记(六) ---- 数据加载和处理
转载请注明作者和出处: http://blog.csdn.net/john_bh/
文章目录
- 1.准备处理数据的工具包
- 2.准备数据
- 3.读取数据集
- 4.可视化图像和标注信息
- 5.建立数据集类
- 6.数据可视化
- 7.数据变换
- 8.迭代数据集
- 9.torchvision包
1.准备处理数据的工具包
- scikit-image:用于图像的IO和变换
- pandas:用于更容易地进行csv解析
import os,torch
import pandas as pd
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode
2.准备数据 数据下载链接下载数据集, 数据存于“data / faces /”的目录中。这个数据集实际上是imagenet数据集标注为face的图片当中在 dlib 面部检测 (dlib’s pose estimation) 表现良好的图片。我们要处理的是一个面部姿态的数据集。也就是按如下方式标注的人脸:
文章图片
数据中数据集的标注在face_landmarks.csv文件中,标注格式如下:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, … ,part_67_x,part_67_y3.读取数据集 将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量。读取数据代码如下:
import os,torch
import pandas as pd
# from skimage import io,transform
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
# from torch.utils.data import Dataset,DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode# 读取数据集将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量landmarks_frame=pd.read_csv("data/faces/face_landmarks.csv")n=65
img_name=landmarks_frame.iloc[n,0]
landmarks=landmarks_frame.iloc[n,1:].as_matrix()
landmarks=landmarks.astype('float').reshape(-1,2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
输出:
Image name: person-7.jpg4.可视化图像和标注信息
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]
import os,torch
import pandas as pd
# from skimage import io,transform
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
# from torch.utils.data import Dataset,DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode# 读取数据集将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量landmarks_frame=pd.read_csv("data/faces/face_landmarks.csv")n=65
img_name=landmarks_frame.iloc[n,0]
landmarks=landmarks_frame.iloc[n,1:].as_matrix()
landmarks=landmarks.astype('float').reshape(-1,2)# print('Image name: {}'.format(img_name))
# print('Landmarks shape: {}'.format(landmarks.shape))
# print('First 4 Landmarks: {}'.format(landmarks[:4]))# 展示一张图片和它对应的标注点def show_landmarks(image,landmarks):
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(0.001)plt.figure()
show_landmarks(io.imread(os.path.join("data/faces/",img_name)),landmarks)plt.show()
输出:
文章图片
5.建立数据集类 torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法 __len__实现 len(dataset) 返还数据集的尺寸。 __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)。
为面部数据集创建一个数据集类。我们将在 __init__中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存空间。只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。
数据样本将按这样一个字典{‘image’: image, ‘landmarks’: landmarks}组织。 数据集类将添加一个可选参数transform 以方便对样本进行预处理。 __init__方法代码实现如下所示:
import os,torch
import pandas as pd
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode# 读取数据集将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量landmarks_frame=pd.read_csv("data/faces/face_landmarks.csv")n=65
img_name=landmarks_frame.iloc[n,0]
landmarks=landmarks_frame.iloc[n,1:].as_matrix()
landmarks=landmarks.astype('float').reshape(-1,2)# print('Image name: {}'.format(img_name))
# print('Landmarks shape: {}'.format(landmarks.shape))
# print('First 4 Landmarks: {}'.format(landmarks[:4]))# 展示一张图片和它对应的标注点def show_landmarks(image,landmarks):
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(0.001)# plt.figure()
# show_landmarks(io.imread(os.path.join("data/faces/",img_name)),landmarks)# plt.show()#数据集类
# torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法 * __len__ 实现 len(dataset) 返还数据集的尺寸。 * __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)# 为面部数据集创建一个数据集类。我们将在 __init__中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存 空间。
# 只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。# 数据样本将按这样一个字典{'image': image, 'landmarks': landmarks}组织。 我们的数据集类将添加一个可选参数transform 以方便对样本进行预处理class FaceLandmarksDataset(Dataset):
"""人脸标记数据集"""def __init__(self,csv_file,root_dir,transform=None):
"""
csv_file(string):带注释的csv文件的路径。
root_dir(string):包含所有图像的目录。
transform(callable, optional):一个样本上的可用的可选变换
"""
self.landmarks_frame=pd.read_csv(csv_file)
self.root_dir=root_dir
self.transform=transformdef __len__(self):
return len(self.landmarks_frame)def __getitem__(self,idx):
img_name=os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
image=io.imread(img_name)
landmarks=self.landmarks_frame.iloc[idx,1:]
landmarks=np.array([landmarks])
landmarks=landmarks.astype('float').reshape(-1,2)
sample={'image':image,'landmarks':landmarks}if self.transform:
sample=self.transform(sample)return sample
6.数据可视化 实例化中的这个数据类并遍历数据样本。这里将会打印出前四个例子的尺寸并展示标注的特征点。 代码如下所示:
import os,torch
import pandas as pd
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode# 读取数据集将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量landmarks_frame=pd.read_csv("data/faces/face_landmarks.csv")n=65
img_name=landmarks_frame.iloc[n,0]
landmarks=landmarks_frame.iloc[n,1:].as_matrix()
landmarks=landmarks.astype('float').reshape(-1,2)# print('Image name: {}'.format(img_name))
# print('Landmarks shape: {}'.format(landmarks.shape))
# print('First 4 Landmarks: {}'.format(landmarks[:4]))# 展示一张图片和它对应的标注点def show_landmarks(image,landmarks):
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(0.001)# plt.figure()
# show_landmarks(io.imread(os.path.join("data/faces/",img_name)),landmarks)# plt.show()#数据集类
# torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法 * __len__ 实现 len(dataset) 返还数据集的尺寸。 * __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)# 为面部数据集创建一个数据集类。我们将在 __init__中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存 空间。
# 只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。# 数据样本将按这样一个字典{'image': image, 'landmarks': landmarks}组织。 我们的数据集类将添加一个可选参数transform 以方便对样本进行预处理class FaceLandmarksDataset(Dataset):
"""人脸标记数据集"""def __init__(self,csv_file,root_dir,transform=None):
"""
csv_file(string):带注释的csv文件的路径。
root_dir(string):包含所有图像的目录。
transform(callable, optional):一个样本上的可用的可选变换
"""
self.landmarks_frame=pd.read_csv(csv_file)
self.root_dir=root_dir
self.transform=transformdef __len__(self):
return len(self.landmarks_frame)def __getitem__(self,idx):
img_name=os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
image=io.imread(img_name)
landmarks=self.landmarks_frame.iloc[idx,1:]
landmarks=np.array([landmarks])
landmarks=landmarks.astype('float').reshape(-1,2)
sample={'image':image,'landmarks':landmarks}if self.transform:
sample=self.transform(sample)return sample# 数据可视化:实例化这个类并遍历数据样本。我们将会打印出前四个例子的尺寸并展示标注的特征点face_dataset=FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')fig=plt.figure()for i in range(len(face_dataset)):
sample=face_dataset[i]print(i,sample['image'].shape,sample['landmarks'].shape)ax=plt.subplot(1,4,i+1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)if i==3:
plt.show()
break
输出:
文章图片
7.数据变换 通过上面的例子我们会发现图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此我们需要做一些预处理。让我们创建三个转换: Rescale:缩放图片 RandomCrop:对图片进行随机裁剪。这是一种数据增强操作 ToTensor:把numpy格式图片转为torch格式图片 (我们需要交换坐标轴).
我们会把它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。我们只需要实现__call__方法,必 要的时候实现 __init__方法。我们可以这样调用这些转换:
import os,torch
import pandas as pd
from skimage import io, transform#用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,utilsimport warnings
warnings.filterwarnings("ignore")plt.ion()# interactive mode# 读取数据集将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量landmarks_frame=pd.read_csv("data/faces/face_landmarks.csv")n=65
img_name=landmarks_frame.iloc[n,0]
landmarks=landmarks_frame.iloc[n,1:].as_matrix()
landmarks=landmarks.astype('float').reshape(-1,2)# print('Image name: {}'.format(img_name))
# print('Landmarks shape: {}'.format(landmarks.shape))
# print('First 4 Landmarks: {}'.format(landmarks[:4]))# 展示一张图片和它对应的标注点def show_landmarks(image,landmarks):
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(0.001)# plt.figure()
# show_landmarks(io.imread(os.path.join("data/faces/",img_name)),landmarks)# plt.show()#数据集类
# torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法 * __len__ 实现 len(dataset) 返还数据集的尺寸。 * __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)# 为面部数据集创建一个数据集类。我们将在 __init__中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存 空间。
# 只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。# 数据样本将按这样一个字典{'image': image, 'landmarks': landmarks}组织。 我们的数据集类将添加一个可选参数transform 以方便对样本进行预处理class FaceLandmarksDataset(Dataset):
"""人脸标记数据集"""def __init__(self,csv_file,root_dir,transform=None):
"""
csv_file(string):带注释的csv文件的路径。
root_dir(string):包含所有图像的目录。
transform(callable, optional):一个样本上的可用的可选变换
"""
self.landmarks_frame=pd.read_csv(csv_file)
self.root_dir=root_dir
self.transform=transformdef __len__(self):
return len(self.landmarks_frame)def __getitem__(self,idx):
img_name=os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
image=io.imread(img_name)
landmarks=self.landmarks_frame.iloc[idx,1:]
landmarks=np.array([landmarks])
landmarks=landmarks.astype('float').reshape(-1,2)
sample={'image':image,'landmarks':landmarks}if self.transform:
sample=self.transform(sample)return sample# 数据可视化:实例化这个类并遍历数据样本。我们将会打印出前四个例子的尺寸并展示标注的特征点face_dataset=FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')class Rescale(object):"""将样本中的图像重新缩放到给定大小。Args:
output_size(tuple或int):所需的输出大小。
如果是元组,则输出为与output_size匹配。
如果是int,则匹配较小的图像边缘到output_size保持纵横比相同。
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))#如果要判断两个类型是否相同推荐使用 isinstance()。
self.output_size = output_sizedef __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}class RandomCrop(object):
"""随机裁剪样本中的图像.
Args:
output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。
"""def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_sizedef __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]
new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)image = image[top: top + new_h,
left: left + new_w]landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}class ToTensor(object):
"""将样本中的ndarrays转换为Tensors."""def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']# 交换颜色轴因为
# numpy包的图片是: H * W * C
# torch包的图片是: C * H * W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}#想要把图像的短边调整为256,然后随机裁剪(randomcrop)为224大小的正方形。也就是说,我们打算组合一个Rescale和 RandomCrop的变换。 我们可以调用一个简单的类 torchvision.transforms.Compose来实现这一操作
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)])# 在样本上应用上述的每个变换。
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample) ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)plt.show()
输出:
文章图片
8.迭代数据集 每次这个数据集被采样时: 及时地从文件中读取图片, 对读取的图片应用转换 ,由于其中一步操作是随机的 (randomcrop) , 数据被增强了
可以像之前那样使用for i in range循环来对所有创建的数据集执行同样的操作。
# 迭代数据集
# 让我们把这些整合起来以创建一个带组合转换的数据集。
# 每次这个数据集被采样时: * 及时地从文件中读取图片 * 对读取的图片应用转换 * 由于其中一步操作是随机的 (randomcrop) , 数据被增强# 使用for i in range循环来对所有创建的数据集执行同样的操作transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:
break
输出:
0 torch.Size([3, 224, 224]) torch.Size([68, 2])但是,对所有数据集简单的使用for循环牺牲了许多功能,尤其是: 批量处理数据 ,打乱数据 , 使用多线程multiprocessingworker 并行加载数据。torch.utils.data.DataLoader是一个提供上述所有这些功能的迭代器。下面使用的参数必须是清楚的。一个值得关注的参数是collate_fn, 可以通过它来决定如何对数据进行批处理。但是绝大多数情况下默认值就能运行良好。
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)# 辅助功能:显示批次
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())# 观察第4批次并停止。
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
输出:
文章图片
9.torchvision包 torchvision包提供了 常用的数据集类(datasets)和转换(transforms)。你可能不需要自己构造这些类。torchvision中还有一个更常用的数据集类ImageFolder。 它假定了数据集是以如下方式构造的:
root/ants/xxx.png【PyTorch学习笔记(六) ---- 数据加载和处理】其中’ants’,bees’等是分类标签。在PIL.Image中你也可以使用类似的转换(transforms)例如RandomHorizontalFlip,Scale。利 用这些你可以按如下的方式创建一个数据加载器(dataloader) :
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
import torch
from torchvision import transforms, datasetsdata_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
推荐阅读
- EffectiveObjective-C2.0|EffectiveObjective-C2.0 笔记 - 第二部分
- 由浅入深理解AOP
- 继续努力,自主学习家庭Day135(20181015)
- python学习之|python学习之 实现QQ自动发送消息
- Android中的AES加密-下
- 一起来学习C语言的字符串转换函数
- 定制一套英文学习方案
- 漫画初学者如何学习漫画背景的透视画法(这篇教程请收藏好了!)
- 《深度倾听》第5天──「RIA学习力」便签输出第16期
- 如何更好的去学习