Tensorflow之TFRecord制作——VOC数据为例

TFRecord 是TensorFlow专用的数据处理文件,方便在训练的时候快速读取和转移
现在就基于VOC数据集介绍一下。
1、生成TFRecord
首先就是封装数据集,其具体方法如下:
【Tensorflow之TFRecord制作——VOC数据为例】Tensorflow之TFRecord制作——VOC数据为例
文章图片

具体实现代码为:

with tf.io.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read()def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/[value])) def int64_list_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def bytes_list_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) def float_list_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value))example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/source_id': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj), 'image/object/truncated': dataset_util.int64_list_feature(truncated), 'image/object/view': dataset_util.bytes_list_feature(poses), })) return example#写成tfrecord文件 writer = tf.io.TFRecordWriter(output_path) for idx, example in enumerate(examples_list): tf_example = dict_to_tf_example(data, data_dir, VOC_NAME_LABEL, ignore_difficult_instances) writer.write(tf_example.SerializeToString())writer.close()

2、解析TFRecord
就是定义好解析字典IMAGE_FEATURE_MAP 和解析方法parse_example,就是对出来的数据进行组合处理,最终输出结果
需要主要的是以下2个方面对应
tf.io.FixedLenFeature([], tf.int64) ==> tf.Tensor(375, shape=(), dtype=int64)
tf.io.VarLenFeature(tf.float32) ==> SparseTensor(indices=tf.Tensor([[0]], shape=(1, 1), dtype=int64), values=tf.Tensor([12], shape=(1,), dtype=int64), dense_shape=tf.Tensor([1], shape=(1,), dtype=int64))
#解析对应格式 IMAGE_FEATURE_MAP = { 'image/height': tf.io.FixedLenFeature([], tf.int64), 'image/width': tf.io.FixedLenFeature([], tf.int64), 'image/filename': tf.io.FixedLenFeature([], tf.string), 'image/source_id': tf.io.FixedLenFeature([], tf.string), 'image/key/sha256': tf.io.FixedLenFeature([], tf.string), 'image/encoded': tf.io.FixedLenFeature([], tf.string), 'image/format': tf.io.FixedLenFeature([], tf.string), 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), # 如果数据中存放的list长度大于1, 表示数据是不定长的, 使用VarLenFeature解析 'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32), 'image/object/class/text': tf.io.VarLenFeature(tf.string), 'image/object/class/label': tf.io.VarLenFeature(tf.int64), 'image/object/difficult': tf.io.VarLenFeature(tf.int64), 'image/object/truncated': tf.io.VarLenFeature(tf.int64), 'image/object/view': tf.io.VarLenFeature(tf.string), }def parse_example(serialized_example,height=512,width=512): #解析序列化的example x = tf.io.parse_single_example(serialized_example, IMAGE_FEATURE_MAP) #然后就可以根据字典获取值了 x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3) x_train = tf.image.resize(x_train, (height,width)) #class_text = x['image/object/class/text'] # 原始类型是SparseTensor, https://blog.csdn.net/JsonD/article/details/73105490 #class_text = tf.sparse.to_dense(x['image/object/class/text'], default_value='') labels = tf.cast(tf.sparse.to_dense(x['image/object/class/label']), tf.float32) y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), # shape: [m] tf.sparse.to_dense(x['image/object/bbox/ymin']), # shape: [m] tf.sparse.to_dense(x['image/object/bbox/xmax']), # shape: [m] tf.sparse.to_dense(x['image/object/bbox/ymax']), # shape: [m] labels# shape: [m] ], axis=1) # shape:[m, 5], m是图片中目标的个数, 每张图片的m可能不一样# 每个图片最多包含100个目标 paddings = [[0, 100 - tf.shape(y_train)[0]], [0, 0]] # 上下左右分别填充0, 100 - tf.shape(y_train)[0], 0, 0 # The padded size of each dimension D of the output is: # paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1] y_train = tf.pad(y_train, paddings) return x_train, y_traindef _parse_function(example_proto): # Parse the input `tf.Example` proto using the dictionary above. return tf.io.parse_single_example(example_proto, IMAGE_FEATURE_MAP)if __name__ == '__main__': dataset = tf.data.TFRecordDataset(filenames=['/data/data/VOC2007/train.tfrecord']) print(dataset) # raw_eaxmple = next(iter(dataset)) # parsed = tf.train.Example.FromString(raw_eaxmple.numpy()) # print(parsed)# for index ,record in enumerate(dataset): #example = tf.io.parse_single_example(record,features=IMAGE_FEATURE_MAP) #for key,value in example.items(): #print(key,'=>',value)# parsed_dataset = dataset.map(_parse_function) parsed_dataset = dataset.map(parse_example)#map就可以对每个序列化的example进行解析for parsed_record in parsed_dataset.take(10): # print(repr(parsed_record)) print(repr(parsed_record)) print('=========')

本文主要参考了yinghuang/yolov2-tensorflow2

    推荐阅读