深度学习|以cifar10为例,讲解TensorFlow数据的输入

前言 【深度学习|以cifar10为例,讲解TensorFlow数据的输入】数据输入一般包括一下8个部分,下面以cifar10为例讲解其中奥妙
1文件名列表
2可配置的文件名乱序
3可配置的最大训练迭代数
4文件名队列
5针对输入文件格式的阅读器
6记录解析器
7可配置的预处理器
8样本队列
1获取文件名

filenames=[os.path.join(data_dir,'data_batch_%d.bin'%i) for i in range(1,6)]

2文件名乱序
filename_queue=tf.train.string_input_producer(filenames)#默认shuffle是True

3针对输入文件的格式,创建阅读器来读取
(来自cifar10中的代码)read_cifar10(file_queue)
针对cifar中固定的二进制格式,使用相应方法进行读取
输入:文件名列表
输出:result类,其中height、width、depth、
key(一个Tensor)、label(int型tensor)、unit8image(一个[h,w,d]的unit型Tensor)
主要函数为tf.FixedLengthRecordReader()读取固定长度的比特数。
tf.decode_raw用来解码成unit8型。
def read_cifar10(filename_queue): key:a scalar string Tensor describing the filename & record number for this example label: an int32 Tensor with the label in the range 0~9 uint8 image: a [height, width, depth] uint8 Tensor with the image dataclass CIFAR10Record: pass result=CIFAR10Record() #CIFAR10数据库中图片的维度 label_bytes=1 #2 for CIFAR-100 result.height=32 result.width=32 result.depth=3 image_bytes=result.height*result.width*result.depth #每个记录都由一个字节的标签和3072字节的图像数据组成,长度固定 record_bytes=label_bytes+image_bytes#read a record, getting filenames from the filename_queue reader=tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key,value=https://www.it610.com/article/reader.read(filename_queue)#注意这里read每次只读取一行!该函数返回两个string型的Tensor #Convert from a string to a vector of uint8 that is record_bytes long record_bytes=tf.decode_raw(value,tf.uint8)#decode_raw可以将一个字符串转换为一个uint8的张量 #tf.stride_slice(data, begin, end)中,end是开区间,begin是闭区间,把data切片 #The first bytes represent the label, which we convert from uint8->int32 #cast(x, dtype, name=None) cast函数是改变数据格式的 result.label=tf.cast(tf.strided_slice(record_bytes,[0],[label_bytes]),tf.int32)#将剩下的图像数据部分reshape为【depth,height,width】的形式 depth_major=tf.reshape(tf.strided_slice(record_bytes,[label_bytes],[label_bytes+image_bytes]),[result.depth,result.height,result.width])#from【depth,height,width】to【height,width,depth】 result.uint8image=tf.transpose(depth_major,[1,2,0]) return result #返回的是一个类的对象!

4预处理,数据增强
distorted_image
5得到样本队列
def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle): """Construct a queued batch of images and labels.Args: image: 3-D Tensor of [height, width, 3] of type.float32. label: 1-D Tensor of type.int32 min_queue_examples: int32, minimum number of samples to retain in the queue that provides of batches of examples. batch_size: Number of images per batch. shuffle: boolean indicating whether to use a shuffling queue.Returns: images: Images. 4D tensor of [batch_size, height, width, 3] size. labels: Labels. 1D tensor of [batch_size] size. """ # Create a queue that shuffles the examples, and then # read 'batch_size' images + labels from the example queue. num_preprocess_threads = 16 #16个线程同时处理;这种方案可以保证同一时刻只在一个文件中进行读取操作,而不是同时读取多个文件 #避免了两个不同的线程从同一个文件中读取同一个样本;并且避免了过多的磁盘搜索操作 if shuffle: images, label_batch = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size, #capacity必须比min_after_dequeue大 min_after_dequeue=min_queue_examples) #min_after_dequeue 定义了我们会从多大的buffer中随机采样;大的值意味着更好的乱序但更慢的开始,和更多内存占用 else: images, label_batch = tf.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size)# Display the training images in the visualizer. tf.summary.image('images', images)return images, tf.reshape(label_batch, [batch_size])

其中用到了
tf.FIFOQueue(100,”float”) 先进先出的队列
tf.assign_add(A,B) 将A 赋值为B
q.enqueue(counter) 将counter放入队列
q.dequeue() 出队列
tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*1)
创建一个队列管理器QueueRunner
tf.train.start_queue_runners 在你运行任何训练步骤之前,需要调用此函数,否则数据流图将一直挂起。将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个tf.train.Coordinator
TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。
部分内容借鉴中文博客
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
原文是一个github的英文项目
https://github.com/jikexueyuanwiki/tensorflow-zh/blob/master/SOURCE/how_tos/reading_data/index.md

    推荐阅读