BERT源码分析之数据预处理部分
我们以Mrpc任务来分析源码,Mrpc任务是判断两个句子是不是一个意思,
从run_classifier.py开始 首先定位到296行,找到类MrpcProcessor,所有预处理的类都继承自DataProcessor,这个类在177行,有一个read_tsv方法,
#我对这个方法适当的修改了,尽量避免大家没见过的方法,例如tf.gfile.Open 可以用open代替,csv.reader(delimiter="\t"),也完全可以用readlines()+split("\t)代替
def read_tsv(input_file,seperated="\t"):
#seperated就是delimiter,quotechar是指对于双引号的区域也要引用,这个可以忽略
with open(input_file) as f:
lines=f.readlines()
#lines是一个列表,每一行是一个字符串,也就是input_file的每一行文本
output=[]
for line in lines:
output.append(line.strip().split(seperated))
#seperated由具体文件内的文本格式决定
return output
这个类还有四个方法是需要一一实现的,回到296行
我们可以看到四个方法一一实现了,前三个是为了得到examples.通过self.read_tsv方法我们知道返回的是一个长度和文件行数一样长的list,每一个元素也是一个list,是文件中的每一行文本字符串经过"\t"分割后的list。接下来将这个list送入create_examples方法生成examples,现在看create_examples,在318行:
#在进入create_examples之前还要知道tokenization中的一个方法convert_to_unicode,在tokenization文件中的78行
def convert_to_unicode(text):
#注意传进来的text是一句话或者一个单词,也就是一个字符串
#这个函数也可以适当的修改,因为现在肯定都是python3,而且python3默认的编码方式就是utf-8,所以这个函数甚至可以不用
#但是由于后面经常引用这个函数,还是把它写上吧
if isinstance(text,str):
return text#text在python3中就是str形式
elif isinstance(text,bytes):
return text.decode("utf-8","ignore")#由bytes转到utf-8要调用decode,由utf-8转到bytes要调用encode
else:
raise ValueError("Unsupported string type: %s "%(type(text)))#知道convert_to_unicode怎么回事了下面回到run_classifier.py的318行 def create_examples(self,lines,set_type):
examples=[]
for (i,line) in enumerate(lines):
if i==0:
continue#注意,官方Mrpc任务中给出的数据集的第一行不是训练的文本
#传进来的lines的第一行是这样的['\ufeffQuality', '#1 ID', '#2 ID', '#1 String', '#2 String']
#所以i==0时continue
guid="%s-%s" % (set_type,i)#set_type是"train","dev","test"中的某一个
text_a=tokenization.convert_to_unicode(line[3])
text_b=tokenization.convert_to_unicode(line[4])
#把句子1和句子2拿出来
if set_type=="test":
label="0"#测试时不要标签
else:
label=tokenization.convert_to_unicode(line[0])
examples.append(InputExample(guid,text_a,text_b=text_b,label=label))
return examples
#这个InputExample在127行,就是一个类,甚至没有方法。
现在定位到main方法,在783行,注意接下来我写的代码中适当的略去了源码中main方法的某些行代码,略去的代码不影响我们分析源码
processors={"mrpc": MrpcProcessor,}#其他的就不写了
task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。
processor=processors[task_name]()#相当于processor=MrpcProcessor()
label_list=processor.get_labels()#['0','1']
#注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py
我们接下来就进入到tokenization.py中,我们最终是要知道FullTokenizer是什么,不过之前需要对里面的函数逐个分析
#首先说明一下源码中随处可见的unicodedata
unicodedata.categorty(char)#返回一个字符在unicode里的类别,源码中用到的类别有
'''
[Cc]Other,Control
[Cf]Other,Format
[Mn]Mark,Nonspacing
[Nd]Number,Digit
[Po]Punctuation,Other
[Zs]Separator,Space
'''
#例如unicodedata.category(char)如果是Zs的话,那就意味着这个字符是分隔符或者空格
#unicodedata.category(char)返回的类别如果是以P开头的话,那么说明这个字符是punctuation(标点符号的意思)
#unicodedata.category(char)如果是Cc,或者Cf,那么说明char是控制字符(我不了解什么是控制字符)
unicodedata.normalize(form,unistr)#将unicode编码形式的unistr转成普通格式的字符串
#normalize主要是解决那些特别奇怪的字符,像平时我们见到的无论是中文还是英文,没有说需要normalize的。
def is_whitespace(char):#362行
#\t,\r,\n是控制字符,但是源码中把它们视为是空白字符
if char==" " or char =="\t" or char=="\r" or char=="\n":
return True
if unicodedata.category(char)=="Zs":
return True#"Zs" 对应于separator,space
return Falsedef is_control(char):
#注意\t \r \n源码中视为是空白字符
if char=="\t" or char=="\r" or char=="\n":
return False
if unicodedata.category(char) in ("Cc","Cf"):
return True
return Falsedef is_punctuation(char):
#这个就不写了,判断是不是标点符号的
上面三个函数就是用来判断一个字符是空格,分隔符,换行,标点符号,控制字符的哪一类。
#定位到121行load_vocab
def load_vocab(vocab_file):#这个vocab_file就是我们下载的预训练模型中的vocab.txt
vocab=collections.OrderedDict()#和{}的区别就是这个有序,我也不知道换成普通的{}行不行
index=0
with open(vocab_file) as f:
lines=f.readlines()
for line in lines:
token=convert_to_unicode(line)
if not token:
break
vocab[token.strip()]=index
index+=1
return vocab#这个函数的作用就是建立vocab.txt中每一个词到对应的id的一个词典#152行
def whitespace_tokenize(text):
text=text.strip()
if not text:
return []
tokens=text.split()
return tokens
#whitespace_tokenize函数就是一句话text.strip().split(),也就是将一行字符串按照空格分割
FullTokenizer由两部分组成BasicTokenizer+WordPieceTokenizer
class BasicTokenizer(object):
def __init__(self,do_lower_case=True):
self.do_lower_case=do_lower_case
def clean_text(self,text):
#text是一句话,clean_text就是将这句话中的\t,\r,\n替换成空格,对于其它无效字符或者控制字符直接去掉
output=[]
for char in text:
if ord(char)==0 or ord(char)==0xfffd or is_control(char):
continue#ord()就是将字符转成对应的整数值,例如ord('a')=97
if is_whitespace(char):
output.append(' ')#\t \r \n用空格代替
else:
output.append(char)
return "".join(output)#返回的是字符串
def tokenize_chinese_chars(self,text):
output=[]
for char in text:
cp=ord(char)
if self.is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
'''
说明一下,假如输入的句子是"处理 中文 是 按照 字 来 处理 的"
那么tokenize_chinese_chars的输出是' 处理 \u3000 中文 \u3000 是 \u3000 按照 \u3000 字 \u3000 来 \u3000 处理 \u3000 的 '
\u3000就是中文的空格
再经过whitespace_tokenize后就变成了
['处', '理', '中', '文', '是', '按', '照', '字', '来', '处', '理', '的']
这时看一下196行tokenize方法就明白流程了
text=convert_to_unicode(text)#通常这个函数用不上
text=clean_text(text)#将\t \r \n等换成空格,控制字符等去掉
text=tokenize_chinese_chars(text)#针对中文的处理
'''
def run_strip_accents(self,text):
#accents是重音符号的意思,这个貌似中文用不上,其它的语言有重音符号的现象
text=unicodedata.normalize("NFD",text)#normalize的作用就是将text转换成普通字符
output=[]
for char in text:
if unicodedata.category(char)=="Mn":
continue#Mn 指的是Mark nonspace
output.append(char)
return "".join(output)
def run_split_on_punc(self,text):
#将text中的标点符号与单词分离,
#输入是"Splited the sentence, with punctuations."
#输出是['Splited the sentence', ',', ' with punctuations', '.']
chars=list(text)#所有单个字符组成的列表
start_new_word=True
outputs=[]
for i in range(len(chars)):
char=chars[i]
if is_punctuation(char):
outputs.append([char])
start_new_word=True
else:
if start_new_word==True:
outputs.append([])#如果开始一个新的单词,那么在output中加入一个[]
outputs[-1].append(char)#这个单词的每一个字符就会相继的加入到这个[]中
start_new_word=False
return ["".join(x) for x in outputs]
#注意观察211-215行就会发现,其实输入给run_split_on_punc的是单词,而不是句子,假如输入的单词没有标点符号那么这个函数就没什么作用
#输入的单词是单词如果是"punctuations."
#输出就是["punctuations","."]
#现在我们来看tokenize函数196行,一定要知道whitespace_tokenize(text)函数就是一句代码text.strip().split()
def tokenize(self,text):
#假设输入的是"This is basic, tokenizer.\n"
text=convert_to_unicode(text)#This is basic, tokenizer.\n
text=self.clean_text(text)#This is basic, tokenizer.
text=self.tokenize_chinese_chars(text)#This is basic, tokenizer.
orig_tokens=whitespace_tokenize(text)#["This","is","basic,","tokenizer."]
split_tokens=[]
for token in orig_tokens:
if self.do_lower_case:
token=token.lower()
token=self.run_strip_accents(token)
split_tokens.extend(self.run_split_on_punc(token))
#split_tokens==["this","is","basic",",","tokenizer","."]
output_tokens=whitespace_tokenize(" ".join(split_tokens))
return output_tokens
所以说BasicTokenize所做的就是将一个句子,去掉特殊字符,控制字符,\t,\r,\n等字符,以及将带有标点符号的单词与标点符号分离,最后返回一个列表,每一个元素值就是一个token.
下面来看WordpieceTokenizer,定位300行
class WordpieceTokenize(object):
def __init__(self,vocab,unk_token="[UNK]",max_input_chars_per_word=200):
self.vocab=vocab
#vocab就是load_vocab返回的vocab
self.max_input_chars_per_word=max_input_chars_per_word
self.unk_token=unk_token def tokenize(self,text):
#输入"this is in wordpiece tokenizer"
#输出['this', 'is', 'in', 'word', '##piece', 'token', '##izer']
text=convert_to_unicode(text)#this is in wordpiece tokenizer
output_tokens=[]
token_list=whitespace_tokenize(text)#[this,is,in,wordpiece,tokenizer]
for token in token_list:
char=list(token)#['w','o','r','d','p','i','e','c','e']
if len(chars)>self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue#这个基本用不上,一个单词,怎么可能有200个字符那么长,如果真这么长的话,就用UNK代替
is_bad=False
start=0
sub_tokens=[]#sub_tokens记录的是一个单词token会有多少个子单词#正向最大匹配,假设现在token是wordpiece,len(chars)==9
while(start0:
substr="##"+substr
if substr in self.vocab:
cur_substr=substr
break
end-=1#当end=4时,chars[0:4]==substr=="word",此时找到了子串,cur_substr="word" break
if cur_substr is None:#也就是说无论是整个单词还是子串都没有在vocab中出现过
is_bad=True
break#跳出循环
#此时cur_substr=="word"
sub_tokens.append(cur_substr)#将找到的子字符串加入到sub_tokens,然后start=end,接着找,会找到piece,而piece是在token的中间,所以substr="##piece"
start=end#start=4,此时start:end就是单词piece
if is_bad:
#没找到怎么办呢,用UNK代替这个token
output_tokens.append(self.unk_token)
else:
output_tokens.append(sub_tokens)
return output_tokens#['this', 'is', 'in', 'word', '##piece', 'token', '##izer']
#output_tokens是针对整个句子的,sub_tokens是针对一个单词的,cur_substr是针对一个单词的一个子串的。
#所以整个tokenize翻译过来就是给一个字符串句子,先用whitespace_tokenize(text)将每一个单词切分出来(注意wordpiece_tokenize的text是经过basic_tokenize后传进去的,所以不用担心特殊字符,标点符号等问题),
#切分后是一个列表,每一个元素是一个单词,对于每一个单词,取出来单词内的所有字符(chars=list(tokens)),然后从起始位置找这个单词的子单词有没有在vocab中出现过,如果无论是单词还是子串都没有在vocab中那么就用unk_token代替,所以正如this is in wordpiece tokenizer的输出所示:
#this is in三个单词没有被切分是因为这三个单词在vocab中均出现过,而wordpiece,tokenizer没有在vocab出现过,而word,token出现过,那么就将wordpiece切分成word+##piece,tokenizer切分成token+##izer,这里##的目的我猜是为了标明这是个切分的单词,而且是单词的尾部
介绍完了BasicTokenizer和WordpieceTokenizer,那么就可以引入FullTokenizer了,定位161行
class FullTokenizer(object):
def __init__(self,vocab_file,do_lower_case=True):
self.vocab=load_vocab(vocab_file)
self.inv_vocab={k:v for v,k in self.vocab.items()}
#其实就是word2id和id2word
self.basic_tokenizer=BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer=WordpieceTokenizer(vocab=self.vocab)
def tokenize(self,text):
#传进来的text是一句话
split_token=[]
for token in self.basic_tokenizer(text):
#basic_tokenizer(text)返回的是一个列表,列表中每一个元素是去掉了特殊符号,标点符号的单词
for sub_token in self.word_tokenizer(token):
#传给wordpiece_tokenizer的是一个单词,wordpiece_tokenizer返回的要么是原来的这个单词(说明这个单词在vocab中),要么是子串(如word,##piece,说明这个单词没有在vocab中,但是子部分在vocab中) 要么就是一个unk_token,说明这是一个bad token.
split_tokens.append(sub_token)
return split_token
现在回到run_classifier.py的783行main函数
processors={"mrpc": MrpcProcessor,}#其他的就不写了
task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。
processor=processors[task_name]()#相当于processor=MrpcProcessor()
label_list=processor.get_labels()#['0','1']
#注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py
#现在我们回到了tokenization.py
tokenizer=tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,do_lower_case=FLAGS.do_lower_case)#FLAGS.vocab_file就是下载的预训练文件中的vocab.txt
if FLAGS.do_train:
train_examples=processor.get_train_examples(FLAGS.data_dir)#train_examples是一个列表,每一个元素值是InputExample的一个对象,这个对象有四个属性guid,text_a,text_b,label。所以你可以把train_examples看成是一个大的列表,长度和data_dir给出的文件的行数一样长,每一个元素值记录着文件的每一行.
train_file=os.path.join(FLAGS.output_dir,"train.tf_record")#注意train_file是你在命令终端中输入的output_dir的位置,不要把它误以为是train.txt所在的文件位置
#代码会将train.txt经过一系列的操作后以tfrecord格式存储到train_file中,模型训练时是从train_file中读取数据,所以命名为train_file#下面定位到869行,file_based_convert_examples_to_features(),传进去的变量有train_examples.label_list,max_seq_length,tokenizer,train_file,
#这几个变量大家应该知道都是什么了max_seq_length默认是128
#接下来进入file_based_convert_examples_to_features()
#定位到479行
再来看看另外的一些函数
class PaddingInputExample:
#就没有了,这个类的目的是用一个什么都没有的对象代替None
class InputFeatures(object):
def __init__(self,input_ids,input_mask,segment_ids,label_id,is_real_example=True):
#input_ids是一个句子中每一个单词对应的在vocab中的索引
#input_mask是指传进来的input_ids中那些单词是pad的,需要mask的,
#segment_ids是指明那些单词是第一个句子,那些单词是第二个句子
#label_id是一个整数值,表明当前这个example的标签
assert len(input_ids)==len(input_mask)==len(segment_ids)def truncate_seq_pair(tokens_a,tokens_b,max_seq_length):
#如果两个句子的长度加起来>max_seq_length,那么就将长度比较长的句子截断。(不截断句子短的是因为本来就短,再截断后整个句子就没什么信息了)
while True:
sentence_length=len(tokens_a)+len(tokens_b)
if sentence_length<=max_seq_length:
break
if len(tokens_a)>len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()#列表还有pop的操作def convert_single_example(example_index,example,label_list,max_seq_length,tokenizer):
#传进来的example是examples的一行,examples是InputExample的对象
if isinstance(example,PaddingInputExample):
#也就是说如果传进来的example是None的话,那么就会执行下面的语句
return InputFeatures(input_ids=[0]*max_seq_length,
input_mask=[0]*max_seq_length,segment_ids=[0]*max_seq_length,
label_id=0,is_real_example=False)
#这么做的目的看着挺困惑,源码中给出的解释是为了使得输入的examples的长度是batch_size的倍数,因为TPU需要固定的batch_size,
#所以说输入的examples的最后的比batch_size少的那几个example,会在它们后面加上一些没有实际值的InputFeatures
label_map={}
for (i,label) in enumerate(label_list):
label_map[label]=i
tokens_a=tokenizer.tokenize(example.text_a)
tokens_b=None
if example.text_b:
tokens_b=tokenizer.tokenize(example.text_b)
if tokens_b:
truncate_seq_pair(tokens_a,tokens_b,max_seq_length-3)#-3是因为如果有两个句子,那么就会有三个特殊字符[CLS]+tokens_a+[SEP]+tokens_b+[SEP]
else:
#只有一个句子
if len(tokens_a)+2>max_seq_length:
tokens_a=tokens_a[:max_seq_length-2]
tokens=[]#tokens记录的是每一个单词
segment_ids=[]#记录的是单词是属于哪一个句子的
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids=tokenizer.convert_tokens_to_ids(tokens)#这个函数的名字就告诉我们它的作用了,将每一个token转换成对应的id
input_mask=[1]*len(input_ids)#len(input_ids)是句子的真实长度,所以在mask的时候是不能mask这个长度下单词的,所以input_mask在0-len(input_ids)的位置都是1
masked_length=max_seq_length-len(input_mask)
input_ids+=[0]*masked_length
input_mask+=[0]*masked_length
segment_ids+=[0]*masked_length#这里注意的是segment_ids前面为0代表这个单词是第一个句子的,中间为1代表是第二个句子的,1后面还会有0,代表是pad的
label_id=label_map[example.label]#相当于label2id[label],也就是找到当前这个example的标签类别
return InputFeatures(input_ids,input_mask,segment_ids,
label_id,is_real_example=True)
'''
也就是说传进来的是一个example,也就是InputExamples的一个对象,没有方法,有四个属性guid,text_a,text_b,label,
而经过convert_single_example后返回的是InputFeatures的一个对象,没有方法,有四个属性,input_ids,input_mask,segment_ids,label_id所以可以拿一个句子对来举例子:
["example is a object of input examples","feature is a object of input features","1"]
那么经过convert_single_example后就是
[input_ids=[1,2,3,4,5,6,7,8,2,3,4,5,6,7,9,0,0,0,0,0,0,0,0,000......]
segment_ids=[0,0,0,0,0,0,0,1,1,1,1,1,1,0,000000,00..]
input_mask=[1,1,1,1,1,1,1,1,0000000000000.......0]
label_id=1]
'''
TFRecord数据文件是将数据和对应的标签统一存储的二进制格式文件,生成tfrecord文件的格式是先读取原生数据,根据原生数据生成tfrecord文件,再写回磁盘。
然后再利用API从磁盘读取tfrecord文件
一般的生成tfrecord的步骤是
tf.train.Example(features=tf.train.Features(feature={key:value}))
#其中key是你起的特征的名字,value就是特征,大致分为三种类型的特征:
'''
Int64List,用来存储int型数据
BytesList,用来存储字符串
FloatList,用来存储浮点型数据
'''
#下面的就是源码中给出的形式
tf_example=tf.train.Example(
features=tf.train.Features(feature=
{"input_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.input_ids))),"input_mask":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.input_mask)))."segment_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=https://www.it610.com/article/list(feature.segment_ids))),
}))
#然后tf_example序列化,再将序列化后的文件写成tfrecord文件
下面我们来看file_based_convert_examples_to_features,这个函数的用处就是将examples写成tfrecord,定位到479行
def file_based_convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,output_file):
#examples就是InputExamples的一个对象,没有方法,有四个属性,guid,text_a,text_b,label
#label_list在Mrpc任务中是['0','1']
#max_seq_length 默认是128
#output_file就是将examples转换成tfrecord后的文件输出位置
writer=tf.python_io.TFRecordWriter(output_file)
for (example_index,example) in enumerate(examples):
feature=convert_single_example(example_index,example,label_list,max_seq_length,tokenizer)
#传进去的example_index的作用是为了打印所有example的前五个,这就是当运行源码时一开始会看到一大堆的输出信息,包括input_ids,input_mask,segment_ids,你也可以在convert_single_example中去掉打印的那几行
#convert_single_example返回的feature是就是INputFeatures的一个对象,没有方法,有几个属性包括input_ids,input_mask,segment_ids,label_id,is_real_example等
'''接下来注意,我上面写的那个代码块和源码中492-502行意思是一样的,
目的就是构造一个dict,dict的key是特征名字,对应的value就是由tf.train.Feature创造的特征
然后将构造的dict作为features生成tf.train.Example的一个实例
'''
tf_example=tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
#每生成一个特征feature,就把它转成tfrecord格式,序列化后写入文件
writer.close()
#上面就已经生成了tfrecord文件,接下来就是读取
def file_based_input_fn_builder(input_file,max_seq_length,is_training,drop_remainder):
#input_file就是上一个函数的output_file
'''读取的API主要是tf.data.TFRecordDataset(file_path),返回的是一个dataset,
这个dataset中的每一个元素就是序列化的一个tf_example,我们要把它解析回原来的类型
(详细说明下就是由convert_single_example返回的feature是一个对象,有几个属性(此时打印feature可以看到几个列表),利用这些属性值将feature改造成一个tf.train.Feature(此时打印feature就是json格式的类型,key就是名,value就是列表),然后序列化转成tfrecord(此时打印feature就是一个二进制字符串),
这就是生成tfrecord的过程,读取tfrecord文件再解析回原来列表型的数据类型就是相应的逆过程)
'''
def input_fn(params):
dataset=tf.data.TFRecordDataset(input_file)
dataset=dataset.apply(tf.contrib.data.map_and_batch(
lambda record:decode_record(record,name_to_features),
batch_size=params["batch_size"],
drop_remainder=drop_remainder))
return dataset
#dataset.apply()的作用就是将dataset放到tf.contrib.data.map_and_batch()中,map_and_batch()会将dataset中的数据以batch_size的个数拿出来放到decode_record中,
#如果最后的部分不足batch_size,drop_remainder=True的意思就是去掉这部分。
#我们刚才说了,你从dataset中读取出来的是二进制字符串,需要把它解析回原来的列表格式才能送进网络,decode_record的作用就是解析
def decode_record(record,name_to_features):
example=tf.parse_single_example(record,name_to_features)
#parse_single_example就是解析record,name_to_features的作用就是告诉函数record中的每个值原来是什么类型的
'''
name_to_features={
"input_ids":tf.FixedLenFeature([seq_length],tf.int64)
}
'''
#意思就是说record中name为input_ids的数据原来的格式是一个长度为seq_length的列表,类型是tf.int64
return example#只不过源码中把所有tf.int64类型的数据都转成tf.int32的数据
return input_fn
【BERT源码分析之数据预处理部分】好的现在回到868行
train_file=os.path.join(FLAGS.output_dir,"train.tf_record")
#output_dir就是你在终端中输入的参数output_dir,并且程序执行过程中你也看到了Writting example %d of %d,对应487行
#这就是说模型在向output_dir中写入tf_record文件.
file_based_convert_examples_to_features(train_examples,label_list,max_seq_length,tokenizer,train_file)#生成tfrecord数据写到train_file中
train_input_fn=file_based_input_fn_builder(train_file,max_seq_length)#读取tfrecord文件,解析回原来的数据格式,返回的是函数
estimator.train(train_input_fn,num_train_steps)
#这就是整个数据处理的流程
回头总结一下数据处理的流程
- 首先自己写一个MyProcessor,然后在790行那里加上名字。你的MyProcessor主要的方法有get_train_examples,get_dev_examples,get_test_examples.这三个方法返回的是InputExamples的一个对象,有四个属性,guid,这个不重要,text_a,text_b,label,后三个就是典型的Mrpc任务需要的数据,即句子1,句子2,标签。
- 然后把get_train_examples返回的实例传入到file_based_convert_examples_to_features中生成tfrecord文件。
- file_based_input_fn_builder会读取tfrecord文件,返回一个函数给estimator.
- 剩下的就是细节,尤其是tokenization.py文件,里面的函数几乎要全部了解。(注意我只是介绍了数据处理的流程)
推荐阅读
- 如何寻找情感问答App的分析切入点
- D13|D13 张贇 Banner分析
- 自媒体形势分析
- 2020-12(完成事项)
- Android事件传递源码分析
- Python数据分析(一)(Matplotlib使用)
- Quartz|Quartz 源码解析(四) —— QuartzScheduler和Listener事件监听
- 泽宇读书会——如何阅读一本书笔记
- Java内存泄漏分析系列之二(jstack生成的Thread|Java内存泄漏分析系列之二:jstack生成的Thread Dump日志结构解析)
- [源码解析]|[源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3)