BERT源码分析之数据预处理部分

我们以Mrpc任务来分析源码,Mrpc任务是判断两个句子是不是一个意思,
从run_classifier.py开始 首先定位到296行,找到类MrpcProcessor,所有预处理的类都继承自DataProcessor,这个类在177行,有一个read_tsv方法,

#我对这个方法适当的修改了,尽量避免大家没见过的方法,例如tf.gfile.Open 可以用open代替,csv.reader(delimiter="\t"),也完全可以用readlines()+split("\t)代替 def read_tsv(input_file,seperated="\t"): #seperated就是delimiter,quotechar是指对于双引号的区域也要引用,这个可以忽略 with open(input_file) as f: lines=f.readlines() #lines是一个列表,每一行是一个字符串,也就是input_file的每一行文本 output=[] for line in lines: output.append(line.strip().split(seperated)) #seperated由具体文件内的文本格式决定 return output

这个类还有四个方法是需要一一实现的,回到296行
我们可以看到四个方法一一实现了,前三个是为了得到examples.通过self.read_tsv方法我们知道返回的是一个长度和文件行数一样长的list,每一个元素也是一个list,是文件中的每一行文本字符串经过"\t"分割后的list。接下来将这个list送入create_examples方法生成examples,现在看create_examples,在318行:
#在进入create_examples之前还要知道tokenization中的一个方法convert_to_unicode,在tokenization文件中的78行 def convert_to_unicode(text): #注意传进来的text是一句话或者一个单词,也就是一个字符串 #这个函数也可以适当的修改,因为现在肯定都是python3,而且python3默认的编码方式就是utf-8,所以这个函数甚至可以不用 #但是由于后面经常引用这个函数,还是把它写上吧 if isinstance(text,str): return text#text在python3中就是str形式 elif isinstance(text,bytes): return text.decode("utf-8","ignore")#由bytes转到utf-8要调用decode,由utf-8转到bytes要调用encode else: raise ValueError("Unsupported string type: %s "%(type(text)))#知道convert_to_unicode怎么回事了下面回到run_classifier.py的318行 def create_examples(self,lines,set_type): examples=[] for (i,line) in enumerate(lines): if i==0: continue#注意,官方Mrpc任务中给出的数据集的第一行不是训练的文本 #传进来的lines的第一行是这样的['\ufeffQuality', '#1 ID', '#2 ID', '#1 String', '#2 String'] #所以i==0时continue guid="%s-%s" % (set_type,i)#set_type是"train","dev","test"中的某一个 text_a=tokenization.convert_to_unicode(line[3]) text_b=tokenization.convert_to_unicode(line[4]) #把句子1和句子2拿出来 if set_type=="test": label="0"#测试时不要标签 else: label=tokenization.convert_to_unicode(line[0]) examples.append(InputExample(guid,text_a,text_b=text_b,label=label)) return examples #这个InputExample在127行,就是一个类,甚至没有方法。

现在定位到main方法,在783行,注意接下来我写的代码中适当的略去了源码中main方法的某些行代码,略去的代码不影响我们分析源码
processors={"mrpc": MrpcProcessor,}#其他的就不写了 task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。 processor=processors[task_name]()#相当于processor=MrpcProcessor() label_list=processor.get_labels()#['0','1'] #注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py

我们接下来就进入到tokenization.py中,我们最终是要知道FullTokenizer是什么,不过之前需要对里面的函数逐个分析
#首先说明一下源码中随处可见的unicodedata unicodedata.categorty(char)#返回一个字符在unicode里的类别,源码中用到的类别有 ''' [Cc]Other,Control [Cf]Other,Format [Mn]Mark,Nonspacing [Nd]Number,Digit [Po]Punctuation,Other [Zs]Separator,Space ''' #例如unicodedata.category(char)如果是Zs的话,那就意味着这个字符是分隔符或者空格 #unicodedata.category(char)返回的类别如果是以P开头的话,那么说明这个字符是punctuation(标点符号的意思) #unicodedata.category(char)如果是Cc,或者Cf,那么说明char是控制字符(我不了解什么是控制字符) unicodedata.normalize(form,unistr)#将unicode编码形式的unistr转成普通格式的字符串 #normalize主要是解决那些特别奇怪的字符,像平时我们见到的无论是中文还是英文,没有说需要normalize的。 def is_whitespace(char):#362行 #\t,\r,\n是控制字符,但是源码中把它们视为是空白字符 if char==" " or char =="\t" or char=="\r" or char=="\n": return True if unicodedata.category(char)=="Zs": return True#"Zs" 对应于separator,space return Falsedef is_control(char): #注意\t \r \n源码中视为是空白字符 if char=="\t" or char=="\r" or char=="\n": return False if unicodedata.category(char) in ("Cc","Cf"): return True return Falsedef is_punctuation(char): #这个就不写了,判断是不是标点符号的

上面三个函数就是用来判断一个字符是空格,分隔符,换行,标点符号,控制字符的哪一类。
#定位到121行load_vocab def load_vocab(vocab_file):#这个vocab_file就是我们下载的预训练模型中的vocab.txt vocab=collections.OrderedDict()#和{}的区别就是这个有序,我也不知道换成普通的{}行不行 index=0 with open(vocab_file) as f: lines=f.readlines() for line in lines: token=convert_to_unicode(line) if not token: break vocab[token.strip()]=index index+=1 return vocab#这个函数的作用就是建立vocab.txt中每一个词到对应的id的一个词典#152行 def whitespace_tokenize(text): text=text.strip() if not text: return [] tokens=text.split() return tokens #whitespace_tokenize函数就是一句话text.strip().split(),也就是将一行字符串按照空格分割

FullTokenizer由两部分组成BasicTokenizer+WordPieceTokenizer
class BasicTokenizer(object): def __init__(self,do_lower_case=True): self.do_lower_case=do_lower_case def clean_text(self,text): #text是一句话,clean_text就是将这句话中的\t,\r,\n替换成空格,对于其它无效字符或者控制字符直接去掉 output=[] for char in text: if ord(char)==0 or ord(char)==0xfffd or is_control(char): continue#ord()就是将字符转成对应的整数值,例如ord('a')=97 if is_whitespace(char): output.append(' ')#\t \r \n用空格代替 else: output.append(char) return "".join(output)#返回的是字符串 def tokenize_chinese_chars(self,text): output=[] for char in text: cp=ord(char) if self.is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) ''' 说明一下,假如输入的句子是"处理 中文 是 按照 字 来 处理 的" 那么tokenize_chinese_chars的输出是' 处理 \u3000 中文 \u3000 是 \u3000 按照 \u3000 字 \u3000 来 \u3000 处理 \u3000 的 ' \u3000就是中文的空格 再经过whitespace_tokenize后就变成了 ['处', '理', '中', '文', '是', '按', '照', '字', '来', '处', '理', '的'] 这时看一下196行tokenize方法就明白流程了 text=convert_to_unicode(text)#通常这个函数用不上 text=clean_text(text)#将\t \r \n等换成空格,控制字符等去掉 text=tokenize_chinese_chars(text)#针对中文的处理 ''' def run_strip_accents(self,text): #accents是重音符号的意思,这个貌似中文用不上,其它的语言有重音符号的现象 text=unicodedata.normalize("NFD",text)#normalize的作用就是将text转换成普通字符 output=[] for char in text: if unicodedata.category(char)=="Mn": continue#Mn 指的是Mark nonspace output.append(char) return "".join(output) def run_split_on_punc(self,text): #将text中的标点符号与单词分离, #输入是"Splited the sentence, with punctuations." #输出是['Splited the sentence', ',', ' with punctuations', '.'] chars=list(text)#所有单个字符组成的列表 start_new_word=True outputs=[] for i in range(len(chars)): char=chars[i] if is_punctuation(char): outputs.append([char]) start_new_word=True else: if start_new_word==True: outputs.append([])#如果开始一个新的单词,那么在output中加入一个[] outputs[-1].append(char)#这个单词的每一个字符就会相继的加入到这个[]中 start_new_word=False return ["".join(x) for x in outputs] #注意观察211-215行就会发现,其实输入给run_split_on_punc的是单词,而不是句子,假如输入的单词没有标点符号那么这个函数就没什么作用 #输入的单词是单词如果是"punctuations." #输出就是["punctuations","."] #现在我们来看tokenize函数196行,一定要知道whitespace_tokenize(text)函数就是一句代码text.strip().split() def tokenize(self,text): #假设输入的是"This is basic, tokenizer.\n" text=convert_to_unicode(text)#This is basic, tokenizer.\n text=self.clean_text(text)#This is basic, tokenizer. text=self.tokenize_chinese_chars(text)#This is basic, tokenizer. orig_tokens=whitespace_tokenize(text)#["This","is","basic,","tokenizer."] split_tokens=[] for token in orig_tokens: if self.do_lower_case: token=token.lower() token=self.run_strip_accents(token) split_tokens.extend(self.run_split_on_punc(token)) #split_tokens==["this","is","basic",",","tokenizer","."] output_tokens=whitespace_tokenize(" ".join(split_tokens)) return output_tokens

所以说BasicTokenize所做的就是将一个句子,去掉特殊字符,控制字符,\t,\r,\n等字符,以及将带有标点符号的单词与标点符号分离,最后返回一个列表,每一个元素值就是一个token.
下面来看WordpieceTokenizer,定位300行
class WordpieceTokenize(object): def __init__(self,vocab,unk_token="[UNK]",max_input_chars_per_word=200): self.vocab=vocab #vocab就是load_vocab返回的vocab self.max_input_chars_per_word=max_input_chars_per_word self.unk_token=unk_token def tokenize(self,text): #输入"this is in wordpiece tokenizer" #输出['this', 'is', 'in', 'word', '##piece', 'token', '##izer'] text=convert_to_unicode(text)#this is in wordpiece tokenizer output_tokens=[] token_list=whitespace_tokenize(text)#[this,is,in,wordpiece,tokenizer] for token in token_list: char=list(token)#['w','o','r','d','p','i','e','c','e'] if len(chars)>self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue#这个基本用不上,一个单词,怎么可能有200个字符那么长,如果真这么长的话,就用UNK代替 is_bad=False start=0 sub_tokens=[]#sub_tokens记录的是一个单词token会有多少个子单词#正向最大匹配,假设现在token是wordpiece,len(chars)==9 while(start0: substr="##"+substr if substr in self.vocab: cur_substr=substr break end-=1#当end=4时,chars[0:4]==substr=="word",此时找到了子串,cur_substr="word" break if cur_substr is None:#也就是说无论是整个单词还是子串都没有在vocab中出现过 is_bad=True break#跳出循环 #此时cur_substr=="word" sub_tokens.append(cur_substr)#将找到的子字符串加入到sub_tokens,然后start=end,接着找,会找到piece,而piece是在token的中间,所以substr="##piece" start=end#start=4,此时start:end就是单词piece if is_bad: #没找到怎么办呢,用UNK代替这个token output_tokens.append(self.unk_token) else: output_tokens.append(sub_tokens) return output_tokens#['this', 'is', 'in', 'word', '##piece', 'token', '##izer'] #output_tokens是针对整个句子的,sub_tokens是针对一个单词的,cur_substr是针对一个单词的一个子串的。 #所以整个tokenize翻译过来就是给一个字符串句子,先用whitespace_tokenize(text)将每一个单词切分出来(注意wordpiece_tokenize的text是经过basic_tokenize后传进去的,所以不用担心特殊字符,标点符号等问题), #切分后是一个列表,每一个元素是一个单词,对于每一个单词,取出来单词内的所有字符(chars=list(tokens)),然后从起始位置找这个单词的子单词有没有在vocab中出现过,如果无论是单词还是子串都没有在vocab中那么就用unk_token代替,所以正如this is in wordpiece tokenizer的输出所示: #this is in三个单词没有被切分是因为这三个单词在vocab中均出现过,而wordpiece,tokenizer没有在vocab出现过,而word,token出现过,那么就将wordpiece切分成word+##piece,tokenizer切分成token+##izer,这里##的目的我猜是为了标明这是个切分的单词,而且是单词的尾部

介绍完了BasicTokenizer和WordpieceTokenizer,那么就可以引入FullTokenizer了,定位161行
class FullTokenizer(object): def __init__(self,vocab_file,do_lower_case=True): self.vocab=load_vocab(vocab_file) self.inv_vocab={k:v for v,k in self.vocab.items()} #其实就是word2id和id2word self.basic_tokenizer=BasicTokenizer(do_lower_case=do_lower_case) self.wordpiece_tokenizer=WordpieceTokenizer(vocab=self.vocab) def tokenize(self,text): #传进来的text是一句话 split_token=[] for token in self.basic_tokenizer(text): #basic_tokenizer(text)返回的是一个列表,列表中每一个元素是去掉了特殊符号,标点符号的单词 for sub_token in self.word_tokenizer(token): #传给wordpiece_tokenizer的是一个单词,wordpiece_tokenizer返回的要么是原来的这个单词(说明这个单词在vocab中),要么是子串(如word,##piece,说明这个单词没有在vocab中,但是子部分在vocab中) 要么就是一个unk_token,说明这是一个bad token. split_tokens.append(sub_token) return split_token

现在回到run_classifier.py的783行main函数
processors={"mrpc": MrpcProcessor,}#其他的就不写了 task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。 processor=processors[task_name]()#相当于processor=MrpcProcessor() label_list=processor.get_labels()#['0','1'] #注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py #现在我们回到了tokenization.py tokenizer=tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,do_lower_case=FLAGS.do_lower_case)#FLAGS.vocab_file就是下载的预训练文件中的vocab.txt if FLAGS.do_train: train_examples=processor.get_train_examples(FLAGS.data_dir)#train_examples是一个列表,每一个元素值是InputExample的一个对象,这个对象有四个属性guid,text_a,text_b,label。所以你可以把train_examples看成是一个大的列表,长度和data_dir给出的文件的行数一样长,每一个元素值记录着文件的每一行. train_file=os.path.join(FLAGS.output_dir,"train.tf_record")#注意train_file是你在命令终端中输入的output_dir的位置,不要把它误以为是train.txt所在的文件位置 #代码会将train.txt经过一系列的操作后以tfrecord格式存储到train_file中,模型训练时是从train_file中读取数据,所以命名为train_file#下面定位到869行,file_based_convert_examples_to_features(),传进去的变量有train_examples.label_list,max_seq_length,tokenizer,train_file, #这几个变量大家应该知道都是什么了max_seq_length默认是128 #接下来进入file_based_convert_examples_to_features() #定位到479行

再来看看另外的一些函数
class PaddingInputExample: #就没有了,这个类的目的是用一个什么都没有的对象代替None class InputFeatures(object): def __init__(self,input_ids,input_mask,segment_ids,label_id,is_real_example=True): #input_ids是一个句子中每一个单词对应的在vocab中的索引 #input_mask是指传进来的input_ids中那些单词是pad的,需要mask的, #segment_ids是指明那些单词是第一个句子,那些单词是第二个句子 #label_id是一个整数值,表明当前这个example的标签 assert len(input_ids)==len(input_mask)==len(segment_ids)def truncate_seq_pair(tokens_a,tokens_b,max_seq_length): #如果两个句子的长度加起来>max_seq_length,那么就将长度比较长的句子截断。(不截断句子短的是因为本来就短,再截断后整个句子就没什么信息了) while True: sentence_length=len(tokens_a)+len(tokens_b) if sentence_length<=max_seq_length: break if len(tokens_a)>len(tokens_b): tokens_a.pop() else: tokens_b.pop()#列表还有pop的操作def convert_single_example(example_index,example,label_list,max_seq_length,tokenizer): #传进来的example是examples的一行,examples是InputExample的对象 if isinstance(example,PaddingInputExample): #也就是说如果传进来的example是None的话,那么就会执行下面的语句 return InputFeatures(input_ids=[0]*max_seq_length, input_mask=[0]*max_seq_length,segment_ids=[0]*max_seq_length, label_id=0,is_real_example=False) #这么做的目的看着挺困惑,源码中给出的解释是为了使得输入的examples的长度是batch_size的倍数,因为TPU需要固定的batch_size, #所以说输入的examples的最后的比batch_size少的那几个example,会在它们后面加上一些没有实际值的InputFeatures label_map={} for (i,label) in enumerate(label_list): label_map[label]=i tokens_a=tokenizer.tokenize(example.text_a) tokens_b=None if example.text_b: tokens_b=tokenizer.tokenize(example.text_b) if tokens_b: truncate_seq_pair(tokens_a,tokens_b,max_seq_length-3)#-3是因为如果有两个句子,那么就会有三个特殊字符[CLS]+tokens_a+[SEP]+tokens_b+[SEP] else: #只有一个句子 if len(tokens_a)+2>max_seq_length: tokens_a=tokens_a[:max_seq_length-2] tokens=[]#tokens记录的是每一个单词 segment_ids=[]#记录的是单词是属于哪一个句子的 tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) if tokens_b: for token in tokens_b: tokens.append(token) segment.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids=tokenizer.convert_tokens_to_ids(tokens)#这个函数的名字就告诉我们它的作用了,将每一个token转换成对应的id input_mask=[1]*len(input_ids)#len(input_ids)是句子的真实长度,所以在mask的时候是不能mask这个长度下单词的,所以input_mask在0-len(input_ids)的位置都是1 masked_length=max_seq_length-len(input_mask) input_ids+=[0]*masked_length input_mask+=[0]*masked_length segment_ids+=[0]*masked_length#这里注意的是segment_ids前面为0代表这个单词是第一个句子的,中间为1代表是第二个句子的,1后面还会有0,代表是pad的 label_id=label_map[example.label]#相当于label2id[label],也就是找到当前这个example的标签类别 return InputFeatures(input_ids,input_mask,segment_ids, label_id,is_real_example=True) ''' 也就是说传进来的是一个example,也就是InputExamples的一个对象,没有方法,有四个属性guid,text_a,text_b,label, 而经过convert_single_example后返回的是InputFeatures的一个对象,没有方法,有四个属性,input_ids,input_mask,segment_ids,label_id所以可以拿一个句子对来举例子: ["example is a object of input examples","feature is a object of input features","1"] 那么经过convert_single_example后就是 [input_ids=[1,2,3,4,5,6,7,8,2,3,4,5,6,7,9,0,0,0,0,0,0,0,0,000......] segment_ids=[0,0,0,0,0,0,0,1,1,1,1,1,1,0,000000,00..] input_mask=[1,1,1,1,1,1,1,1,0000000000000.......0] label_id=1] '''

TFRecord数据文件是将数据和对应的标签统一存储的二进制格式文件,生成tfrecord文件的格式是先读取原生数据,根据原生数据生成tfrecord文件,再写回磁盘。
然后再利用API从磁盘读取tfrecord文件
一般的生成tfrecord的步骤是
tf.train.Example(features=tf.train.Features(feature={key:value})) #其中key是你起的特征的名字,value就是特征,大致分为三种类型的特征: ''' Int64List,用来存储int型数据 BytesList,用来存储字符串 FloatList,用来存储浮点型数据 ''' #下面的就是源码中给出的形式 tf_example=tf.train.Example( features=tf.train.Features(feature= {"input_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.input_ids))),"input_mask":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.input_mask)))."segment_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.segment_ids))), })) #然后tf_example序列化,再将序列化后的文件写成tfrecord文件

下面我们来看file_based_convert_examples_to_features,这个函数的用处就是将examples写成tfrecord,定位到479行
def file_based_convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,output_file): #examples就是InputExamples的一个对象,没有方法,有四个属性,guid,text_a,text_b,label #label_list在Mrpc任务中是['0','1'] #max_seq_length 默认是128 #output_file就是将examples转换成tfrecord后的文件输出位置 writer=tf.python_io.TFRecordWriter(output_file) for (example_index,example) in enumerate(examples): feature=convert_single_example(example_index,example,label_list,max_seq_length,tokenizer) #传进去的example_index的作用是为了打印所有example的前五个,这就是当运行源码时一开始会看到一大堆的输出信息,包括input_ids,input_mask,segment_ids,你也可以在convert_single_example中去掉打印的那几行 #convert_single_example返回的feature是就是INputFeatures的一个对象,没有方法,有几个属性包括input_ids,input_mask,segment_ids,label_id,is_real_example等 '''接下来注意,我上面写的那个代码块和源码中492-502行意思是一样的, 目的就是构造一个dict,dict的key是特征名字,对应的value就是由tf.train.Feature创造的特征 然后将构造的dict作为features生成tf.train.Example的一个实例 ''' tf_example=tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString()) #每生成一个特征feature,就把它转成tfrecord格式,序列化后写入文件 writer.close() #上面就已经生成了tfrecord文件,接下来就是读取 def file_based_input_fn_builder(input_file,max_seq_length,is_training,drop_remainder): #input_file就是上一个函数的output_file '''读取的API主要是tf.data.TFRecordDataset(file_path),返回的是一个dataset, 这个dataset中的每一个元素就是序列化的一个tf_example,我们要把它解析回原来的类型 (详细说明下就是由convert_single_example返回的feature是一个对象,有几个属性(此时打印feature可以看到几个列表),利用这些属性值将feature改造成一个tf.train.Feature(此时打印feature就是json格式的类型,key就是名,value就是列表),然后序列化转成tfrecord(此时打印feature就是一个二进制字符串), 这就是生成tfrecord的过程,读取tfrecord文件再解析回原来列表型的数据类型就是相应的逆过程) ''' def input_fn(params): dataset=tf.data.TFRecordDataset(input_file) dataset=dataset.apply(tf.contrib.data.map_and_batch( lambda record:decode_record(record,name_to_features), batch_size=params["batch_size"], drop_remainder=drop_remainder)) return dataset #dataset.apply()的作用就是将dataset放到tf.contrib.data.map_and_batch()中,map_and_batch()会将dataset中的数据以batch_size的个数拿出来放到decode_record中, #如果最后的部分不足batch_size,drop_remainder=True的意思就是去掉这部分。 #我们刚才说了,你从dataset中读取出来的是二进制字符串,需要把它解析回原来的列表格式才能送进网络,decode_record的作用就是解析 def decode_record(record,name_to_features): example=tf.parse_single_example(record,name_to_features) #parse_single_example就是解析record,name_to_features的作用就是告诉函数record中的每个值原来是什么类型的 ''' name_to_features={ "input_ids":tf.FixedLenFeature([seq_length],tf.int64) } ''' #意思就是说record中name为input_ids的数据原来的格式是一个长度为seq_length的列表,类型是tf.int64 return example#只不过源码中把所有tf.int64类型的数据都转成tf.int32的数据 return input_fn

【BERT源码分析之数据预处理部分】好的现在回到868行
train_file=os.path.join(FLAGS.output_dir,"train.tf_record") #output_dir就是你在终端中输入的参数output_dir,并且程序执行过程中你也看到了Writting example %d of %d,对应487行 #这就是说模型在向output_dir中写入tf_record文件. file_based_convert_examples_to_features(train_examples,label_list,max_seq_length,tokenizer,train_file)#生成tfrecord数据写到train_file中 train_input_fn=file_based_input_fn_builder(train_file,max_seq_length)#读取tfrecord文件,解析回原来的数据格式,返回的是函数 estimator.train(train_input_fn,num_train_steps) #这就是整个数据处理的流程

回头总结一下数据处理的流程
  • 首先自己写一个MyProcessor,然后在790行那里加上名字。你的MyProcessor主要的方法有get_train_examples,get_dev_examples,get_test_examples.这三个方法返回的是InputExamples的一个对象,有四个属性,guid,这个不重要,text_a,text_b,label,后三个就是典型的Mrpc任务需要的数据,即句子1,句子2,标签。
  • 然后把get_train_examples返回的实例传入到file_based_convert_examples_to_features中生成tfrecord文件。
  • file_based_input_fn_builder会读取tfrecord文件,返回一个函数给estimator.
  • 剩下的就是细节,尤其是tokenization.py文件,里面的函数几乎要全部了解。(注意我只是介绍了数据处理的流程)

    推荐阅读