TensorFlow|TensorFlow 自定义生成 .record 文件
一、生成 .record 文件
前面的文章 TensorFlow 训练自己的目标检测器 中的第二部分第 2 小节中我们已经预先说过,会在后续的文章中阐述怎么自定义的将图像转化为 .record
文件,今天我们就来说一说这件事。
在文章 TensorFlow 训练 CNN 分类器 中我们生成了 50000 张 28 x 28 像素的图像,我们的目标就是将这些图像全部写入到一个后缀为 .record
的文件中。 .record
或 tfrecord
文件是 TensorFlow 中的标准数据读写格式,它是一种能够高效读写的二进制文件,能够快速的复制、移动、读写和存储等。
在文章 TensorFlow 训练 CNN 分类器 和文章 TensorFlow-slim 训练 CNN 分类模型 中,我们在训练模型时导入数据的方式都是一次性的将所有图像读入,然后循环的从中选择一个批量来训练。这对于小数据集来说不会产生问题,但如果训练数据异常大,那么很可能由于内存限制无法一次性将说有数据导入,这样前面的训练方式便不能采用了。此时,我们可以将数据转化为 .record
文件格式,然后再分批次的、逐步的读入 .record
文件进行训练。
要将图像写入 .record
文件,首先要将图像编码为字符或数字特征,这需要调用类 tf.train.Feature
。然后,在调用 tf.train.Example
将特征写入协议缓冲区。最后,通过类 tf.python_io.TFRecordWriter
将数据写入到 .record
文件中。比如,我们将前面提到的 50000 张图像写入 train.record
文件,使用如下代码(命名为 generate_tfrecord.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 26 09:02:10 2018@author: shirhe-lyh
""""""Generate tfrecord file from images.Example Usage:
---------------
python3 train.py \
--images_path: Path to the training images (directory).
--output_path: Path to .record.
"""import glob
import io
import os
import tensorflow as tffrom PIL import Imageflags = tf.app.flagsflags.DEFINE_string('images_path', None, 'Path to images (directory).')
flags.DEFINE_string('output_path', None, 'Path to output tfrecord file.')
FLAGS = flags.FLAGSdef int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/[value]))def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))def create_tf_example(image_path):
with tf.gfile.GFile(image_path,'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
width, height = image.size
label = int(image_path.split('_')[-1].split('.')[0])tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
return tf_exampledef generate_tfrecord(images_path, output_path):
writer = tf.python_io.TFRecordWriter(output_path)
for image_file in glob.glob(images_path):
tf_example = create_tf_example(image_file)
writer.write(tf_example.SerializeToString())
writer.close()def main(_):
images_path = os.path.join(FLAGS.images_path, '*.jpg')
images_record_path = FLAGS.output_path
generate_tfrecord(images_path, images_record_path)if __name__ == '__main__':
tf.app.run()
在该文件目录的终端执行:
python3 generate_tfrecord.py \
--images_path /home/.../datasets/images \
--output_path /home/.../datasets/train.record
便会在输出路径下生成
train.record
文件。以上代码中,最重要的部分是:1. 函数 create_tf_example
,该函数首先得到图像的二进制格式、图像的宽和高、以及图像对应的类标号等,然后将图像的这些信息写入协议缓冲区;2. 函数 generate_tfrecord
,该函数使用 tf.python_io.TFRecordWriter
类将协议缓冲区内的数据写入到 .record
文件中。二、读取 .record 文件 一旦将图像转化为了
.record
文件,接下来我们关心的就是怎么读取这个 .record
文件用于模型训练了。这可以借助我们前面使用过的模块 tf.contrib.slim
:slim = tf.contrib.slimdef get_record_dataset(record_path,
reader=None, image_shape=[28, 28, 3],
num_samples=50000, num_classes=10):
"""Get a tensorflow record file."""
if not reader:
reader = tf.TFRecordReaderkeys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='https://www.it610.com/article/jpeg'),
'image/class/label':
tf.FixedLenFeature([1], tf.int64, default_value=https://www.it610.com/article/tf.zeros([1],
dtype=tf.int64))}items_to_handlers = {'image': slim.tfexample_decoder.Image(shape=image_shape,
#image_key='image/encoded',
#format_key='image/format',
channels=3),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer between 0 and 9.'}
return slim.dataset.Dataset(
data_sources=record_path,
reader=reader,
decoder=decoder,
num_samples=num_samples,
num_classes=num_classes,
items_to_descriptions=items_to_descriptions,
labels_to_names=labels_to_names)
主要是借助了
tf.contib.slim
模块中的slim.dataset.Dataset(data_sources, reader, decoder,
num_samples, items_to_descriptions,
**kwargs)
和
slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
【TensorFlow|TensorFlow 自定义生成 .record 文件】这两个类。
使用时,直接传入
train.record
路径即可:dataset = get_record_dataset('./xxx/train.record')
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])
函数
get_record_dataset
返回 slim.dataset.Dataset
类的一个对象,之后通过类slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=1,
reader_kwargs=None,
shuffle=True, num_epochs=None,
common_queue_capacity=256,
common_queue_min=128,
record_key='record_key',
seed=None, scope=None)
的
get
方法得到图像和类标号的序列数据。参数 num_readers=1
表示一次读取一个数据,即一次读取一张图像,因此实际使用是还需要使用函数 tf.train.batch
将数据形成批量再用于训练,见下一篇文章。预告:下一篇文章将说明怎么完全使用
tf.contrib.slim
来构建和训练模型。推荐阅读
- SpringBoot调用公共模块的自定义注解失效的解决
- python自定义封装带颜色的logging模块
- 列出所有自定义的function和view
- 记录iOS生成分享图片的一些问题,根据UIView生成固定尺寸的分享图片
- ssh生成公钥秘钥
- Java内存泄漏分析系列之二(jstack生成的Thread|Java内存泄漏分析系列之二:jstack生成的Thread Dump日志结构解析)
- 15、IDEA学习系列之其他设置(生成javadoc、缓存和索引的清理等)
- Spring|Spring Boot 自动配置的原理、核心注解以及利用自动配置实现了自定义 Starter 组件
- 自定义MyAdapter
- javaweb|基于Servlet+jsp+mysql开发javaWeb学生成绩管理系统