使用pmml实现跨平台部署机器学习模型
一、概述
??对于由Python训练的机器学习模型,通常有pickle和pmml两种部署方式,pickle方式用于在python环境中的部署,pmml方式用于跨平台(如Java环境)的部署,本文叙述的是pmml的跨平台部署方式。
??PMML(Predictive Model Markup Language,预测模型标记语言)是一种基于XML描述来存储机器学习模型的标准语言。如,对在Python环境中由sklearn训练得到的模型,通过sklearn2pmml模块可将它完整地保存为一个pmml格式的文件,再在其他平台(如java)中加载该文件进行使用,从而实现模型的跨平台部署。
文章图片
二、实现步骤
?1.训练环境中安装生成pmml文件的工具。
??如在Python环境中安装sklearn2pmml模块(pip install sklearn2pmml)。
?2.训练模型。
?3.将模型保存为pmml文件。
?4.部署环境中导入依赖的工具包。
??如在Java环境中导入pmml-evaluator、pmml-evaluator-extension(特殊情况下另加)、jaxb-core、jaxb-api、jaxb-impl等jar包。
?5.开发应用,加载、使用模型。
注:对sklearn2pmml生成的pmml模型文件,在java中加载使用时,需将文件中的命名空间属性xmlns=".../PMML-4_4"改为xmlns=".../PMML-4_3",以适应低版本的jar包对它的解析。
三、示例 ??在python中使用sklearn训练一个线性回归模型,并在java环境中部署使用。
工具:PyCharm-2017、Python-39、sklearn2pmml-0.76.1;IntelliJ IDEA-2018、jdk-14.0.2。
1.训练数据集training_data.csv
文章图片
2.训练、保存模型
import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from sklearn import linear_model as lm
import os
import pandas as pddef save_model(data, model_path):
pipeline = PMMLPipeline([("regression", lm.LinearRegression())]) #定义模型,放入pipeline管道
pipeline.fit(data[["x"]], data["y"]) #训练模型,由数据中第一行的名称确定自变量和因变量
pmml.sklearn2pmml(pipeline, model_path, with_repr=True) #保存模型if __name__ == "__main__":
data = https://www.it610.com/article/pd.read_csv("training_data.csv")
model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/my_example_model.pmml"
save_model(data, model_path)
print("模型保存完成。")
3.将pmml文件的xmlns属性修改为PMML-4_3
文章图片
4.java程序中加载、使用模型
(1)创建maven项目,将pmml模型文件拷贝至项目根目录下。
(2)加入依赖包
org.jpmml
pmml-evaluator
1.4.15
com.sun.xml.bind
jaxb-core
2.2.11
javax.xml
jaxb-api
2.1
com.sun.xml.bind
jaxb-impl
2.2.11
(3)java程序加载模型完成预测
public class MLPmmlDeploy {
public static void main(String[] args) {String model_path = "./my_example_model.pmml";
//模型路径
int x = 20;
//测试的自变量值Evaluator model = loadModel(model_path);
//加载模型
Object r = predict(model, x);
//预测Double result = Double.parseDouble(r.toString());
System.out.println("预测的结果为:" + result);
}private static Evaluator loadModel(String model_path){
PMML pmml = new PMML();
//定义PMML对象
InputStream inputStream;
//定义输入流
try {
inputStream = new FileInputStream(model_path);
//输入流接到磁盘上的模型文件
pmml = PMMLUtil.unmarshal(inputStream);
//将输入流解析为PMML对象
}catch (Exception e){
e.printStackTrace();
}ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
//实例化一个模型构造工厂
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
//将PMML对象构造为Evaluator模型对象return evaluator;
}private static Object predict(Evaluator evaluator, int x){
Map data = https://www.it610.com/article/new HashMap();
//定义测试数据Map,存入各元自变量
data.put("x", x);
//键"x"为自变量的名称,应与训练数据中的自变量名称一致
List inputFieldList = evaluator.getInputFields();
//得到模型各元自变量的属性列表Map arguments = new LinkedHashMap();
for (InputField inputField : inputFieldList) { //遍历各元自变量的属性列表
FieldName inputFieldName = inputField.getName();
Object rawValue = https://www.it610.com/article/data.get(inputFieldName.getValue());
//取出该元变量的值
FieldValue inputFieldValue = inputField.prepare(rawValue);
//将值加入该元自变量属性中
arguments.put(inputFieldName, inputFieldValue);
//变量名和变量值的对加入LinkedHashMap
}Map results = evaluator.evaluate(arguments);
//进行预测
List targetFieldList = evaluator.getTargetFields();
//得到模型各元因变量的属性列表
FieldName targetFieldName = targetFieldList.get(0).getName();
//第一元因变量名称
Object targetFieldValue = https://www.it610.com/article/results.get(targetFieldName);
//由因变量名称得到值return targetFieldValue;
}}
【使用pmml实现跨平台部署机器学习模型】
文章图片
示例下载:
https://download.csdn.net/download/Albert201605/45645889
End.
参考
- https://www.freesion.com/article/4628411548/
- https://www.cnblogs.com/pinard/p/9220199.html
- https://www.cnblogs.com/moonlightpoet/p/5533313.html
推荐阅读
- 由浅入深理解AOP
- 【译】20个更有效地使用谷歌搜索的技巧
- 关于QueryWrapper|关于QueryWrapper,实现MybatisPlus多表关联查询方式
- mybatisplus如何在xml的连表查询中使用queryWrapper
- MybatisPlus|MybatisPlus LambdaQueryWrapper使用int默认值的坑及解决
- MybatisPlus使用queryWrapper如何实现复杂查询
- python学习之|python学习之 实现QQ自动发送消息
- 孩子不是实现父母欲望的工具——林哈夫
- opencv|opencv C++模板匹配的简单实现
- Node.js中readline模块实现终端输入