AI|bert实现端到端继续预训练

ACL2020好多paper都证明如果能有领域内的数据进行继续预训练,能够对模型最终的效果有较大的提升(通常为几个百分点),现有的继续预训练的的步骤为 create_pretraining_data(创建预训练的数据)--- run_pretraining (根据上面的数据进行继续预训练),但是现实中很多时候不想过多的IO(对磁盘有不可毁灭的伤害,而创建预训练数据产生大量的IO),针对这种需求将创建数据和预训练结合到一个代码中实现端到端。
一、首先查看继续预训练读取数据的函数为 input_fn_builder, 所以需要改写input_fn_builder函数实现端到端训练。

#这个代码就是从文件中读取数据的意思,目的就是为了替换这一句 d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))

替换为:
#改变成从生成器,也就是从create_pretraining_data中获取数据 d = tf.data.Dataset.from_generator(take_write_2_flow,(tf.string), (tf.TensorShape([])))

【AI|bert实现端到端继续预训练】 二、那么上面的生成器器函数是怎样的呢,代码精髓第一个是通过yield产生生成器,第二个是通过上面的tf.data.Dataset.from_generator处理生成器
#这里采用了1000每批次去生成训练数据,并通过yield每一条数据产生生成器,这样不需要将所有数据写到磁盘也能进行端到端训练 def take_write_2_flow(): tf.logging.set_verbosity(tf.logging.INFO)tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.gfile.Glob(input_pattern))tf.logging.info("*** Reading from input files ***") for input_file in input_files: tf.logging.info("%s", input_file)rng = random.Random(FLAGS.random_seed)file_need_precessing = [] with tf.gfile.GFile(input_files[0], "r") as reader: while True: strings = reader.readline() if not strings: break if len(file_need_precessing) == 1000: # dosomething #for some stratrety input_file_one_process = file_need_precessing[:] instances = fusion_input_and_out(input_file_one_process, tokenizer, rng) writers = create_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, FLAGS.max_predictions_per_seq) for write in writers: yield write # yield writers file_need_precessing = [] file_need_precessing.append(strings)def create_instance_to_example_files(instances, tokenizer, max_seq_length, max_predictions_per_seq): """Create TF example files from `TrainingInstance`s.""" writers = []for (inst_index, instance) in enumerate(instances): if inst_index == 0: tf.logging.info("*** Example ***") tf.logging.info("tokens: %s" % " ".join( [tokenization.printable_text(x) for x in instance.tokens])) input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) input_mask = [1] * len(input_ids) segment_ids = list(instance.segment_ids) assert len(input_ids) <= max_seq_lengthwhile len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0)assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_lengthmasked_lm_positions = list(instance.masked_lm_positions) masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) masked_lm_weights = [1.0] * len(masked_lm_ids)while len(masked_lm_positions) < max_predictions_per_seq: masked_lm_positions.append(0) masked_lm_ids.append(0) masked_lm_weights.append(0.0)next_sentence_label = 1 if instance.is_random_next else 0features = collections.OrderedDict() features["input_ids"] = create_int_feature(input_ids) features["input_mask"] = create_int_feature(input_mask) features["segment_ids"] = create_int_feature(segment_ids) features["masked_lm_positions"] = create_int_feature(masked_lm_positions) features["masked_lm_ids"] = create_int_feature(masked_lm_ids) features["masked_lm_weights"] = create_float_feature(masked_lm_weights) features["next_sentence_labels"] = create_int_feature([next_sentence_label])tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writers.append(tf_example.SerializeToString())return writers

    推荐阅读