自然语言处理|BERT(5)---实战[BERT+CNN文本分类]

1. 模型介绍 获取bert模型最后的token-level形式的输出(get_sequence_output)也就是transformer模型最后一层的输出,将此作为embedding_inputs,作为卷积的输入;用三个不同的卷积核进行卷积和池化,最后将三个结果concat,三种不同的卷积核大小为[2, 3, 4];每种卷积核的数量都为128个;卷积流程类似下图所示,但实际上有所区别,下图表示的是使用大小为[2, 3, 4]三种卷积核,每种卷积核数量为2。
自然语言处理|BERT(5)---实战[BERT+CNN文本分类]
文章图片

卷积后的输出形状大小为[batchsize, num_filters*len(filter_size)]=[batchsize, 128*3];将该输出再连接两个全连接层,后一个全连接层用作分类。
2. 数据处理及训练 【自然语言处理|BERT(5)---实战[BERT+CNN文本分类]】因为数据先放入BERT模型中, 再将BERT的输出接入CNN, 因此我们需要将数据处理成BERT模型能接收的格式,所以此处数据处理大多参考BERT源码中的数据处理方式
首先依然类似BERT fine-tuning章节中所述自定义一个类来处理原始数据,在该类中主要实现以下功能:加载训练、测试、验证数据, 设置分类标签,具体实现如下:

class TextProcessor(object): """按照InputExample类形式载入对应的数据集""""""load train examples""" def get_train_examples(self, data_dir): return self._create_examples( self._read_file(os.path.join(data_dir, "train.tsv")), "train")"""load dev examples""" def get_dev_examples(self, data_dir): return self._create_examples( self._read_file(os.path.join(data_dir, "dev.tsv")), "dev")"""load test examples""" def get_test_examples(self, data_dir): return self._create_examples( self._read_file(os.path.join(data_dir, "test.tsv")), "test")"""set labels""" def get_labels(self): return ['sport', 'military', 'aerospace', 'car', 'business', 'chemistry', 'construction', 'culture', 'electric', 'finance', 'geology', 'it', 'law', 'mechanical', 'medicine', 'tourism']"""read file""" def _read_file(self, input_file): with codecs.open(input_file, "r",encoding='utf-8') as f: lines = [] for line in f.readlines(): try: line=line.split('\t') assert len(line)==2 lines.append(line) except: pass np.random.shuffle(lines) return lines"""create examples for the data set """ def _create_examples(self, lines, set_type): examples = [] for (i, line) in enumerate(lines): guid = "%s-%s" % (set_type, i) text_a = tokenization.convert_to_unicode(line[1]) label = tokenization.convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples

通过上面的代码我们得到example仅仅是BERT中的text_a, text_b, label这还不能作为BERT模型的输入,我们需要将其转成[CLS] + 分词好的text_a + [SEP] + 分词好的text_b的形式送入BERT模型中。因此定义convert_examples_to_features函数将所有的InputExamples样本数据转化成模型要输入的token形式,最后输出bert模型需要的四个变量input_ids, input_mask, segment_ids, label_ids
def convert_examples_to_features(examples,label_list, max_seq_length,tokenizer): label_map = {} for (i, label) in enumerate(label_list): label_map[label] = iinput_data=https://www.it610.com/article/[] for (ex_index, example) in enumerate(examples): tokens_a = tokenizer.tokenize(example.text_a) if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))if len(tokens_a) > max_seq_length - 2: tokens_a = tokens_a[0:(max_seq_length - 2)]tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(tokens)input_mask = [1] * len(input_ids)while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_lengthlabel_id = label_map[example.label] if ex_index < 3: tf.logging.info("*** Example ***") tf.logging.info("guid: %s" % (example.guid)) tf.logging.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])) tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) tf.logging.info("label: %s (id = %d)" % (example.label, label_id))features = collections.OrderedDict() features["input_ids"] = input_ids features["input_mask"] = input_mask features["segment_ids"] = segment_ids features["label_ids"] =label_id input_data.append(features)return input_data

然后加载保存好的BERT预训练模型,然后开始训练,当然我们训练时候是每次传入一个batch的数据,因此我们还需要将刚才处理好的数据打包,每次送给模型一个batch的数据,训练过程如下:
for epoch in range(config.num_epochs): batch_train = batch_iter(trian_data,config.batch_size) start = time.time() tf.logging.info('Epoch:%d'%(epoch + 1)) for batch_ids,batch_mask,batch_segment,batch_label in batch_train: feed_dict = feed_data(batch_ids,batch_mask,batch_segment,batch_label, config.keep_prob) _, global_step, train_summaries, train_loss, train_accuracy = session.run([model.optim, model.global_step, merged_summary, model.loss, model.acc], feed_dict=feed_dict) tf.logging.info('step:%d'%(global_step)) if global_step % config.print_per_batch == 0: end = time.time() val_loss,val_accuracy=evaluate(session,dev_data) merged_acc=(train_accuracy+val_accuracy)/2 if merged_acc > best_acc: saver.save(session, save_path) best_acc = merged_acc last_improved=global_step improved_str = '*' else: improved_str = '' tf.logging.info("step: {},train loss: {:.3f}, train accuracy: {:.3f}, val loss: {:.3f}, val accuracy: {:.3f},training speed: {:.3f}sec/batch {}".format( global_step, train_loss, train_accuracy, val_loss, val_accuracy,(end - start) / config.print_per_batch,improved_str)) start = time.time()if global_step - last_improved > config.require_improvement: tf.logging.info("No optimization over 1500 steps, stop training") flag = True break if flag: break config.lr *= config.lr_decay

3.搭建服务 使用tornado搭建服务,如下所示
class ClassifierHandler(tornado.web.RequestHandler): def post(self): error_logger = Logger(loggername="bert_cnn_error_" + time.strftime("%Y_%m_%d", time.localtime()), logpath=LOGPATH + "bert_cnn_error_" + time.strftime("%Y_%m_%d", time.localtime()) + ".log").log() try: segment = self.get_argument("segment") language = self.get_argument("lang") sentence_data = https://www.it610.com/article/json.loads(segment) results = predict.run(sentence_data, language) self.write(json.dumps(results)) except Exception as e: self.write(repr(e)) error_logger.error("error message:%s" % repr(e)) error_logger.error("error position:%s" % traceback.format_exc())def write_error(self, status_code, **kwargs): self.write("errors: %d." % status_code)if __name__ == "__main__": tornado.options.parse_command_line() app = tornado.web.Application(handlers=[(r"/classifier", ClassifierHandler)]) http_server = tornado.httpserver.HTTPServer(app) http_server.listen(options.port) tornado.ioloop.IOLoop.instance().start()

4.其他 数据分为16个类别['sport', 'military', 'aerospace', 'car', 'business', 'chemistry', 'construction', 'culture', 'electric', 'finance', 'geology', 'it', 'law', 'mechanical', 'medicine', 'tourism'],总数据约一个亿左右
中文BERT模型:12-layer, 768-hidden, 12-heads, 110M parameters
英文BERT模型:12-layer, 768-hidden, 12-heads, 110M parameters
运行程序(主程序) text_run.py

    推荐阅读