引入
- 上一篇文章介绍了如何使用Paddle2.0构建了GPT-2模型
- 本次就使用之前构建好的模型加载清源CPM-LM模型参数来实现简单的问答机器人
- 支持问答和古诗默写两个模式
文章图片
- 可以在百度AIStudio平台上快速体验这个项目:链接
- 清源 CPM (Chinese Pretrained Models) 是北京智源人工智能研究院和清华大学研究团队合作开展的大规模预训练模型开源计划
- 清源计划是以中文为核心的大规模预训练模型
- 首期开源内容包括预训练中文语言模型和预训练知识表示模型,可广泛应用于中文自然语言理解、生成任务以及知识计算应用
- 所有模型免费向学术界和产业界开放下载,供研究使用
- 清源 CPM 项目官网、Github
- 通过下面的代码就可以使用之前构建的GPT-2模型来加载CPM-LM的参数了
- 适用于Paddle平台的CPM-LM 模型参数可以在这里下载
- 由于官方提供的模型参数为FP16半精度储存,所以加载时需要提前将参数转换为FP32格式
- 其他地方与加载普通模型并无差别
import paddle
from GPT2 import GPT2Model# 初始化GPT-2模型
model = GPT2Model(
vocab_size=30000,
layer_size=32,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)# 读取CPM-LM模型参数(FP16)
state_dict = paddle.load('CPM-LM.pdparams')# FP16 -> FP32
for param in state_dict:
state_dict[param] = state_dict[param].astype('float32')# 加载CPM-LM模型参数
model.set_dict(state_dict)# 将模型设置为评估状态
model.eval()
问答机器人实现方式
- CPM-LM有着不错的few-shot文本生成的能力,即可以通过输入几个样例,然后能够学习样例进行对应的文本生成,就像下面这样:
- ps. 下面的代码仅为演示,其中的model指代的不是上面的那个model
inputs = '''默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
默写古诗:
日照香炉生紫烟,遥看瀑布挂前川。
飞流直下三千尺,疑是银河落九天。
inputs = '''问题:西游记是谁写的?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
问题:西游记是谁写的?
答案:吴承恩。
inputs = '''小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
小明决定去吃饭,小红继续写作业
问题:去吃饭的人是谁?
答案:小明
inputs = '''默写英文:
狗:dog
猫:'''
outputs = model.predict(inputs, max_len=10, end_word='\n')
print(inputs+outputs)
【NLP|使用GPT-2加载CPM-LM模型实现简单的问答机器人】默写英文:
狗:dog
猫:cat
- 所以只需要通过拼接几个简单的few-shot预测函数,就可以实现一个简单的问答机器人
- 下面通过代码简单了解一下程序的运行流程
- 具体的代码详情参考本人的GitHub项目CPM-Generate-Paddle
- 与官方开源项目使用的采样的解码方式不同
- 本项目解码时使用到的是最简单的Greedy Search,所以相同输入对应的输出是唯一的
- 于是乎本项目不太适合生成文章类的文本,因为生成的结果过于固定
import paddle
import argparse
import numpy as np
from GPT2 import GPT2Model, GPT2Tokenizer# 参数设置
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model", type=str, required=True, help="the detection model dir.")
args = parser.parse_args()# 初始化GPT-2模型
model = GPT2Model(
vocab_size=30000,
layer_size=32,
block_size=1024,
embedding_dropout=0.0,
embedding_size=2560,
num_attention_heads=32,
attention_dropout=0.0,
residual_dropout=0.0)print('正在加载模型,耗时需要几分钟,请稍后...')# 读取CPM-LM模型参数(FP16)
state_dict = paddle.load(args.pretrained_model)# FP16 -> FP32
for param in state_dict:
state_dict[param] = state_dict[param].astype('float32')# 加载CPM-LM模型参数
model.set_dict(state_dict)# 将模型设置为评估状态
model.eval()# 加载编码器
tokenizer = GPT2Tokenizer(
'GPT2/bpe/vocab.json',
'GPT2/bpe/chinese_vocab.model',
max_len=512)# 初始化编码器
_ = tokenizer.encode('_')print('模型加载完成.')# 基础预测函数
def predict(text, max_len=10):
ids = tokenizer.encode(text)
input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = model(input_id, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
ids += [nid]
out = [nid]
for i in range(max_len):
input_id = paddle.to_tensor(np.array([nid]).reshape(1, -1).astype('int64'))
output, cached_kvs = model(input_id, cached_kvs, use_cache=True)
nid = int(np.argmax(output[0, -1].numpy()))
ids += [nid]
# 若遇到'\n'则结束预测
if nid==3:
break
out.append(nid)
print(tokenizer.decode(out))# 问答
def ask_question(question, max_len=10):
predict('''问题:中国的首都是哪里?
答案:北京。
问题:李白在哪个朝代?
答案:唐朝。
问题:%s
答案:''' % question, max_len)# 古诗默写
def dictation_poetry(front, max_len=10):
predict('''默写古诗:
白日依山尽,黄河入海流。
%s,''' % front, max_len)# 主程序
mode = 'q'
funs = ask_question
print('输入“切换”更换问答和古诗默写模式,输入“exit”退出')
while True:
if mode == 'q':
inputs = input("当前为问答模式,请输入问题:")
else:
inputs = input("当前为古诗默写模式,请输入古诗的上半句:")
if inputs=='切换':
if mode == 'q':
mode = 'd'
funs = dictation_poetry
else:
mode = 'q'
funs = ask_question
elif inputs=='exit':
break
else:
funs(inputs)
总结
- 通过简单实现的两个few-shot的预测函数,就能构建这样的一个简单的问答机器人,实现问答和古诗默写的功能
- 通过这个例子可以看出,CPM-LM模型的few-shot文本生成能力还是不错的,甚至zero-shot的表现也不错,这样的一个超大的中文预训练模型,确实有那么一点GPT-3的味道了
推荐阅读
- python继承的特征有哪些()
- python之uWSGI和WSGI
- 渗透测试领域.|Python 开发 利用SQLmap API接口进行批量的SQL注入检测.(SRC挖掘)
- #yyds干货盘点#如何用Python发送告警通知到钉钉()
- 我用Python抓取了S11全球总决赛直播评论,EDG nb
- 计算机|.NET Core中JWT+Auth2.0实现SSO,附完整源码(.NET6)
- python基础教程|[python] 函数的缺省参数和注意事项
- python|python hashlib_Python常用模块之hashlib
- python|python hacklib模块_python常用模块——hashlib模块