TFRecord 是TensorFlow专用的数据处理文件,方便在训练的时候快速读取和转移
现在就基于VOC数据集介绍一下。
1、生成TFRecord
首先就是封装数据集,其具体方法如下:
【Tensorflow之TFRecord制作——VOC数据为例】
文章图片
具体实现代码为:
with tf.io.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example#写成tfrecord文件
writer = tf.io.TFRecordWriter(output_path)
for idx, example in enumerate(examples_list): tf_example = dict_to_tf_example(data, data_dir, VOC_NAME_LABEL,
ignore_difficult_instances)
writer.write(tf_example.SerializeToString())writer.close()
2、解析TFRecord
就是定义好解析字典IMAGE_FEATURE_MAP 和解析方法parse_example,就是对出来的数据进行组合处理,最终输出结果
需要主要的是以下2个方面对应
tf.io.FixedLenFeature([], tf.int64) ==> tf.Tensor(375, shape=(), dtype=int64)
tf.io.VarLenFeature(tf.float32) ==> SparseTensor(indices=tf.Tensor([[0]], shape=(1, 1), dtype=int64), values=tf.Tensor([12], shape=(1,), dtype=int64), dense_shape=tf.Tensor([1], shape=(1,), dtype=int64))
#解析对应格式
IMAGE_FEATURE_MAP = {
'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/source_id': tf.io.FixedLenFeature([], tf.string),
'image/key/sha256': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/format': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), # 如果数据中存放的list长度大于1, 表示数据是不定长的, 使用VarLenFeature解析
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/difficult': tf.io.VarLenFeature(tf.int64),
'image/object/truncated': tf.io.VarLenFeature(tf.int64),
'image/object/view': tf.io.VarLenFeature(tf.string),
}def parse_example(serialized_example,height=512,width=512):
#解析序列化的example
x = tf.io.parse_single_example(serialized_example, IMAGE_FEATURE_MAP)
#然后就可以根据字典获取值了
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (height,width))
#class_text = x['image/object/class/text'] # 原始类型是SparseTensor, https://blog.csdn.net/JsonD/article/details/73105490
#class_text = tf.sparse.to_dense(x['image/object/class/text'], default_value='')
labels = tf.cast(tf.sparse.to_dense(x['image/object/class/label']), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/ymin']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/xmax']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/ymax']), # shape: [m]
labels# shape: [m]
], axis=1) # shape:[m, 5], m是图片中目标的个数, 每张图片的m可能不一样# 每个图片最多包含100个目标
paddings = [[0, 100 - tf.shape(y_train)[0]], [0, 0]] # 上下左右分别填充0, 100 - tf.shape(y_train)[0], 0, 0
# The padded size of each dimension D of the output is:
# paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]
y_train = tf.pad(y_train, paddings)
return x_train, y_traindef _parse_function(example_proto):
# Parse the input `tf.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, IMAGE_FEATURE_MAP)if __name__ == '__main__':
dataset = tf.data.TFRecordDataset(filenames=['/data/data/VOC2007/train.tfrecord'])
print(dataset)
# raw_eaxmple = next(iter(dataset))
# parsed = tf.train.Example.FromString(raw_eaxmple.numpy())
# print(parsed)# for index ,record in enumerate(dataset):
#example = tf.io.parse_single_example(record,features=IMAGE_FEATURE_MAP)
#for key,value in example.items():
#print(key,'=>',value)# parsed_dataset = dataset.map(_parse_function)
parsed_dataset = dataset.map(parse_example)#map就可以对每个序列化的example进行解析for parsed_record in parsed_dataset.take(10):
# print(repr(parsed_record))
print(repr(parsed_record))
print('=========')
本文主要参考了yinghuang/yolov2-tensorflow2
推荐阅读
- Keras|将Pytorch模型迁移到android端(android studio)【未实现】
- Tensorflow|Tensorflow学习笔记----梯度下降
- Tensorflow【branch-官网代码实践-Eager/tf.data/Keras/Graph】_8.19
- nlp|Keras(十一)梯度带(GradientTape)的基本使用方法,与tf.keras结合使用
- tensorflow|tf1.x究竟到底如何如何使用Embedding?
- python|Keras TensorFlow 验证码识别(附数据集)
- AI|bert实现端到端继续预训练
- Tensorflow|cuda由7.0升级到8.0
- tensorflow|利用Tensorflow的队列多线程读取数据
- 深度学习|conda源,tensorflow2,pytorch安装