技术实践干货 | 初探大规模 GBDT 训练

本文作者: 字节,观远数据首席科学家。主导多个AI项目在世界500强的应用落地,多次斩获智能零售方向Hackathon冠军。曾就职于微策略,阿里云,拥有十多年的行业经验。
本文是此前评估在 Spark 上做大规模 GBDT 训练时写的一篇入门级教程与框架评估。目前市面上似乎没有多少使用 Spark 来跑 GBDT 的分享,故分享出来看看是否有做过类似场景的同学可以一道交流。
背景 在服务一些客户做商业问题的机器学习建模时,我们会碰到不少拥有非常大量数据且对模型 pipeline 运行有一定要求的情况。相比直接的单机 Python 建模,这类项目有一些难点:
1. 数据量大。 由于预测粒度较细,导致历史数据量非常巨大。一些场景的 pilot 项目中已经达到近千万级别的训练数据量,后续拓展到整个业务线,数据量会超过十亿甚至百亿行级别。
2. 整体流程运行时间有一定要求。 一般模型所依赖的上游数据会在半夜开始通过一系列 ETL 任务从业务系统导入到 Hive 数仓中,大约在凌晨 3 点后,各类预测所需数据会准备就绪。接下来运行整个取数,清洗,特征,训练,预测,业务系统对接全流程,需要在早上 8 点前完成并下发到业务系统中,整体运行时间必须控制在 5 小时以内。
3. 对监控运维等方面的高要求。 海量数据细粒度的预测,覆盖非常多业务人员的日常工作需求,因而也会受到更多的审视与挑战。如何确保模型预测输出的稳定性,在业务反馈问题后又如何快速定位排查,遇到数据不可用,服务器 down 机等突发情况,有什么样的备选方案确保整体流程的稳定运行,都是需要考虑的问题。
从前两点来看,我们之前习惯的单机 Pandas + lgb/xgb 建模思路已经难以适用(除非搞台神威·太湖之光之类的机器),所以我们需要引入目前大数据界的当红炸子鸡 -- Spark 来协助完成此类项目。
部署 Spark 要玩 Spark,第一步是部署。如果是本机测试运行,一般跑一个 pip install pyspark 就能把一个 local 节点跑起来了,非常的方便。如果想部署一个相对完整一点的 standalone 集群,可以参考以下步骤:
  1. 到 Spark 官方网站[1] 下载 Spark。当时最新的稳定版本是 2.4.5,下载 pre-built for Apache Hadoop 2.7 就好。
  2. 解压缩,做一些简单的配置文件配置。在解压开的 spark 目录下,进入到 conf 目录里,会看到一系列配置的 template。把需要自定义的配置 copy 一份,例如:cp spark-env.sh.template spark-env.sh,然后进行编辑。我改的一些配置具体如下:
spark-env.sh # 配置 master 和 workerSPARK_MASTER_HOST=0.0.0.0 SPARK_DAEMON_MEMORY=4g SPARK_WORKER_CORES=6 SPARK_WORKER_MEMORY=36g

slaves # 指定 slaves 机器的列表,这里就选了本机localhost

spark-defaults.conf # 这个文件很多教程都会让你改,说是 spark-submit 命令会默认从这里读取相关配置 # 但要注意我们写的 PySpark 程序很多时候并不是通过 spark-submit 命令提交的,所以这里改了可能没用spark.driver.memory 4g

  1. 启动集群。直接运行 sbin/start-all.sh即可。或者也可以分别起 master 和 slave,运行./sbin/start-master.sh和 ./sbin/start-slave.sh spark://127.0.0.1:7077 -c 6 -m 36G 即可。
  2. 停止集群。命令与上面非常类似,sbin/stop-all.sh ,或者分别停 slave 和 master 都行。
这样就算部署完了!其中 spark master 会有一个监听 8080 端口的 web-ui,worker 会监听 8081,后面提交 application 就会有监听 4040 端口的管理界面,功能强大,用户友好度强。
跑第一个 PySpark 程序 直接上代码:
from pyspark.sql import SparkSessionspark = (SparkSession.builder .master('spark://127.0.0.1:7077') .appName('zijie') .getOrCreate()) df = spark.read.parquet('data/the_only_data_i_ever_wanted.parquet') df.show()

我们的大数据平台就跑起来了。
Spark 与 Pandas 的一些不同之处 在网上看一些 Spark 相关的介绍应该很快会有一些认识。有几个比较明显的区别点我大致列一下:
  1. Spark 里对 DataFrame 的操作大多是 lazy 的,也就是所谓的 transformation,只有少数的 action,例如 take, count, collect 等会真实进行计算返回结果。而 pandas 只要做了操作就会立刻执行。
  2. Pandas 里对性能方面的关注主要是这个操作能不能利用底层的计算库做 vectorize,而在 Spark 里需要关注的点就太多了,可能比较主要的是看怎么尽量减少 shuffle 这类宽依赖吧。当然还有什么数据倾斜等相关高级话题。
  3. 用 Spark 来做算法相关的应用时,要非常注意整体的计算逻辑(数据 lineage),对需要反复用到的数据集,一定要记得 cache/persist/checkpoint 才行(这条不知是否过时了)。
从实际操作来看,在 PySpark 中其实有很多操作长得跟 Pandas 非常类似,比如我们常用的 df[df['date'] > '2020-01-01'] 之类的写法。当然区别也有不少,所以后来 Databricks 干脆推出了一个 Koalas的库来支持更平滑的切换。
Spark 特征工程 这里主要记录几个在项目过程中写的感觉比较好玩的,并对比 pandas 的版本方便大家理解。
日期填充
pandas version:
# 对每家店每个 SKU 历史无销售情况进行填零处理 def fill_dates(df): new_df = [] for store_id in df.store_id.unique(): for sku in df.query('store_id == @store_id').sku.unique(): tmp = pd.DataFrame() cond = (df.store_id == store_id) & (df.sku == sku) min_date = df.loc[cond, 'date'].min() max_date = df.loc[cond, 'date'].max() dates_in_between = daterange(min_date, max_date) tmp['date'] = dates_in_between tmp['sku'] = sku tmp['store_id'] = store_id new_df.append(tmp) new_df = pd.concat(new_df) new_df = new_df.merge(df, on=['date', 'sku', 'store_id'], how='left').fillna(0) return new_df

可以看到整体逻辑就是取所有 store, sku 的组合,然后找到每个组合最小最大的售卖日期,把中间的日期都填上。
PS: 这段代码应该效率不高,后续我们又迭代了几个版本。
Spark version:
from pyspark.sql import functions as Fdef fill_dates_spark(df): tmp = df.groupby(['store_id', 'sku']).agg(F.min('date').cast('date').alias('min_date'), F.max('date').cast('date').alias('max_date')) tmp = tmp.withColumn('date', F.explode(F.sequence('min_date', 'max_date'))).select( ['date', 'store_id', 'sku']) new_df = tmp.join(df, ['date', 'store_id', 'sku'], 'left').fillna(0, subset=['y']) return new_df

用了 sequence+explode 操作,代码简洁很多。其中 sequence 会自动生成从 start 到 end 的序列(时间,数字都支持),explode 操作直接把一行“炸开”成多行,省去了 join 操作,性能也更好。
Lag 特征
这个是我们最常用的一种特征了,在 pandas 里主要就是做循环 join:
def shift_daily_data(df, delay, shift_by='date', shift_value='https://www.it610.com/article/y'): groupby_df = [x for x in df.columns if (x != shift_by) and (x != shift_value)] shift_df = df.copy() shift_df[shift_by] = shift_df[shift_by].apply(lambda x: x + relativedelta(days=delay)) shift_df = shift_df.rename(columns={shift_value: '%s_%s_day_lag_%d' % ('_'.join(groupby_df), shift_value, delay)}) return shift_dfdef add_daily_shifts(df, days, categories, shift_by='date', shift_value='https://www.it610.com/article/y'): merge_df = df.copy() for base_categories in categories: feat_cols = base_categories + [shift_by] base_df = df.groupby(feat_cols, as_index=False).agg({shift_value: sum}) for i in days: delay_df = shift_daily_data(base_df, i, shift_by, shift_value) merge_df = pd.merge(left=merge_df, right=delay_df, how='left', on=feat_cols, sort=False).reset_index( drop=True).fillna(0) gc.collect() return merge_df# 按照不同维度生成 lag 自回归时序特征 def add_lag_features(all_data_df, fcst_type): lag_days = list(range(1, 11)) + [14, 21, 28, 29, 30, 31] lag_days = [x for x in lag_days if x >= fcst_type] groupby_cats = [['sku'], ['store_id'], ['sku', 'store_id']] all_data_df = add_daily_shifts(all_data_df, lag_days, groupby_cats) return all_data_df

在迁移到 Spark 时第一版我也采用了类似的写法,不过发现性能比较差,而且随着 lag 数的增多,join 次数也增多了,数据血缘关系会拉得非常长。
第二版我们采用了 window function 的写法:
from pyspark.sql import functions as F from pyspark.sql import Windowdef add_date_index(df, date_col, start_day='2016-01-01'): df = df.withColumn(f'{date_col}_index', F.datediff(date_col, F.lit(start_day))) return dfdef add_shifts_by_window(df, days, group_by, order_by='date_index', shift_value='https://www.it610.com/article/y'): # 取 lag 操作,其实就是要取一个时间点往前一个时间窗口中的值 # 然后这个窗口要考虑时间顺序,我们就加上 orderBy,需要分门店分 sku,我们就加上 partitionBy w = Window.orderBy(order_by).partitionBy(*group_by) new_col_prefix = f'{"_".join(group_by)}_{shift_value}_day_lag' # 再用lag函数取之前的值即可 new_cols = [F.coalesce(F.lag(shift_value, i).over(w), F.lit(0)).alias(f'{new_col_prefix}_{i}') for i in days] df = df.select('*', *new_cols) return df# 接下来主要就是调用了def add_daily_shifts_by_categories(df, days, categories, shift_by='date', shift_value='https://www.it610.com/article/y'): df = add_date_index(df, shift_by) shift_by = f'{shift_by}_index' cat_cols = ['store_id', 'sku'] merge_df = add_shifts_by_window(df, max(days), cat_cols, shift_by, shift_value) for base_categories in categories: if len(base_categories) < len(cat_cols): # 先聚合,再添加 lag 特征 feat_cols = base_categories + [shift_by] base_df = df.groupby(feat_cols).agg(F.sum(shift_value).alias(shift_value)) join_df = add_shifts_by_window(base_df, max(days), base_categories, shift_by, shift_value) join_df = join_df.drop(shift_value) merge_df = merge_df.join(join_df, feat_cols, 'left').fillna(0) return merge_df

用这个方法的前提是,先要把日期填充做了,否则 window 中的数值可能是不连续的。当时也考虑过不做填充可不可以?比如用 F.create_map 的方法创建出时间点与值的 map:df = df.withColumn('m', F.create_map('date_index', 'y')),然后用类似的 collect_list手法获取 window 中的多个 map,合并 map,然后按 lag 顺序取 key,取不到的就填 0 即可。其中合并 map 需要用 udf,大致如下:combineMap = udf(lambda maps: dict(ChainMap(*maps)), MapType(IntegerType(), DoubleType()))。
在实验中发现,这个 udf 使用过程中会报错,说 pandas udf 目前不支持在 window function 中使用,需要用 Spark 3.0 才行。所以暂时用了以上的方案。实测下来发现,用上了 window function,建 lag 特征的时间从 20 多分钟降到了 200 秒左右,而且不管建多少个 lag,时间基本都是一样的,可扩展性很强!
从这个例子中也可以看到,window function 结合 Spark SQL 中带的各种方法非常强大灵活。而到了 PySpark 这里,还有更加神奇的 pandas udf,光看官方示例[2] 就感觉操作性很强,感兴趣的同学可以到文末参考资料中点击查看。
类别编码
这个项目中我们用的是类似 frequency encoding 的手法,Pandas 代码如下:
def y_rank_transform(df, col_name, orderby, ascending=True): sorted_df = df.groupby(col_name).agg({orderby: np.sum}).reset_index().sort_values(orderby, ascending=ascending) rank_map = {v: i for i, v in enumerate(sorted_df[col_name].values)} df[col_name] = df[col_name].map(rank_map) return df, rank_mapdef convert_category_feats(full_df, category_features, orderby): # 根据 orderby 值的大小对 category_features 进行排序编码 rank_maps = {} for c in category_features: if c in full_df: full_df, rank_map = y_rank_transform(full_df, c, orderby) rank_maps[c] = rank_map gc.collect() return full_df, rank_maps

还是比较好理解的。然后 Spark 里可以直接用 pyspark.ml.feature 里自带的一些实现来帮助我们做类似的事情:
from pyspark.ml.feature import StringIndexer from pyspark.sql import functions as Fdef convert_category_feats(full_df): cat_cols = get_category_cols() cat_cols = [x for x in cat_cols if x in full_df.columns] # 根据 orderby 值的大小对 category_features 进行排序编码 for c in cat_cols: if c in full_df.columns: target_col = f'{c}_index' indexer = StringIndexer(inputCol=c, outputCol=target_col) model = indexer.fit(full_df) full_df = model.transform(full_df).withColumn(target_col, F.col(target_col).cast('int')) return full_df

所以有时候也可以没事浏览下标准库里的东西,说不定你想要的功能都已经有现成实现的。
Spark 模型训练? 构建完特征,就到了模型训练环节!特征构建之类,总体来说还是尽在掌握的感觉,但十亿级数据量的训练,就感觉有点心里发虚了。这部分一开始的工作主要由学弟负责。学弟经过一番调研,最终锁定了一个名为 mmlspark 的库:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

初识 mmlspark
之前我们在不少场景用了 lgb,而这个 mmlspark 同是微软出品的框架,感觉应该稳了!
mmlspark 的安装问题
要尝试这个库,第一步肯定就是安装了!这个库的安装比较奇怪,没有提供 pypi/conda 安装包,官网上给出的用法是这样的:
import pyspark spark = pyspark.sql.SparkSession.builder.appName("MyApp") \ .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:1.0.0-rc1") \ .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven") \ .getOrCreate() import mmlspark

但真正跑起来的时候,碰到了一系列网络问题,中间尝试了好久的更换 maven/ivy2 源等,都没有很好的解决。最后我们在 GCP 的服务器上跑了一下代码,顺利下好了所有的依赖,然后把 ~/.ivy2/cache下的所有文件打包回来到本地缓存文件夹解压开。Work like a charm!
mmlspark 的 early_stopping 问题
代码跑起来后,没过多久,学弟就碰到了第二个困难:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

接口变了
上官方文档[3] 瞄了一眼,感觉不妙。一般文档写成这样的库,大概率使用的人很少。不过看支持的参数,比起 Spark ML 的 GBDT 还是丰富很多的,基本上应该是继承了原生 lgb 的接口。在这里我们的主要目标是通过 early_stopping 来做最基本的调参,这样可以保证模型运行时有比较可靠的表现。但是文档里根本没有提这个 early_stopping 应该怎么用,该怎么办呢?
遇到这种情况,一般就只能找 1)有没有别人的代码用了这个功能,2)源码里是怎么实现这个功能的。具体到这个问题,我就直接选择在 github 的 mmlspark repo 里搜 earlyStoppingRound这个参数:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

代码搜索是个好方法
然后通过调用路径做几层追踪,就会看到相关的实现:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

early stopping 相关逻辑
所以要用 early_stopping,需要几个条件:
  1. 设置 validationIndicatorCol参数
  2. 在训练数据中加一列上述参数指定的 column,以 true /false来指定训练集和验证集
  3. 设置好 earlyStoppingRound参数,需要大于 0
  4. 接下来就可以 fit model 了
通过类似的手段,我们解决了一系列因为文档和示例缺乏导致的使用困难,包括传入类别变量等。
mmlspark 训练数据要求
之前用原生 lgb 时,训练数据处理基本比较简单,直接用 lgb.dataset从 pandas/numpy 数据集进行构建即可。不过 mmlspark 就很不一样了,竟然要求传入 2 个 columns,一个叫 featuresCol,另一个叫 labelCol 。总不至于只支持 1 个特征吧?
转念一想,lightgbm 的 Python API 分了 native 和 sklearn 两套,那 mmlspark 这个 API 应该同理是为了符合 Spark ml 的标准。顺着这个思路,果然发现 Spark ml 都是这个套路,然后 Spark ml 库里也自带了一个类, 叫 pyspark.ml.feature.VectorAssembler,直接用上就能把需要的 feature columns 转换成 Vector 类型的单个 column 了。代码类似:
def vectorize(df, feat_cols): assembler = VectorAssembler(inputCols=feat_cols, outputCol='features') df = assembler.transform(df) return df

mmlspark 训练卡死的问题 1
终于一切代码就绪,开始跑训练了!没想到刚开始就出现了问题,训练启动后就一直没反应,看 Spark 的任务也完全没有进度,非常诡异。这个问题的排查绕了一些弯路,看了不少 mmlspark 的源码,尝试 callstack 的收集,strace 等,都没有理想的结果。最后还是从日志中发现了问题。
先截取一个酷炫的图,给大家看下在哪里看日志:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

查看 Spark 任务日志
具体日志:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

正常日志
上面这个是正常的日志,当时出错时发现 Spark executor 在启动 lgb 时报了一堆的错误:
NetworkInit failed with exception on local port...Retrying NetworkInit with local port...
然后可以用类似的方法,在 git repo 里搜这些错误信息,看到底是从哪里报出来的:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

搜索到的错误来源代码
再接着往上排查几层,看了下 lgb 分布式训练的一些文档说明,就大致明白问题出在哪了。总体的 mmlspark 训练过程其实就是把分布在各个机器上的数据转化为lgb.dataset 形式,然后再各自起原生的 lightgbm 来训练。多节点训练时各个节点需要通过网络端口来进行同步,因此需要在启动时设定好大家各自的端口。而且 mmlspark 里用的是 mapPartitions 方法来做具体的训练,我是在单机上跑(当然只要一台机器上有多个 parition 都会有这个问题),所以就出现多个 partition 启动 lightgbm 时监听的端口冲突问题。要解决的话也比较简单,只需要 repartition 数据到每台服务器启动一个 lgb task 即可。
此外对于类似此类 MPI 的计算 load,官方还提供了一个新的 barrier execution mode 来解决一系列相关问题:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

Barrier Mode
如何获取 best_iteration?
刚解决完训练卡死的问题,立刻又来了下一个问题。前面刚提到我们用 early_stopping 来寻找一个合适的树的数量参数,不过在 mmlspark 中用完 early_stopping 后,发现没有方法可以获取到这个 best_iteration 到底是多少?
搜了半天文档和代码,都没发现隐藏功能,只好在 git 上提了一个 issue[4],那会儿没人理,直到一年后终于支持了。如果我们要自力更生,怎么解决?
  1. mmlspark 的库里还是有不少方法,展现了如何调用原生 API,例如:
def getFeatureImportances(self, importance_type="split"): """ Get the feature importances as a list.The importance_type can be "split" or "gain". """ return list(self._java_obj.getFeatureImportances(importance_type))

这么简单吗?当然不是,还需要在 Scala 里转一层:
/** * Calls into LightGBM to retrieve the feature importances. * @param importanceType Can be "split" or "gain" * @return The feature importance values as an array. */ def getFeatureImportances(importanceType: String): Array[Double] = { val importanceTypeNum = if (importanceType.toLowerCase.trim == "gain") 1 else 0 if (boosterPtr == null) { LightGBMUtils.initializeNativeLibrary() boosterPtr = getBoosterPtrFromModelString(model) } val numFeaturesOut = lightgbmlib.new_intp() LightGBMUtils.validate( lightgbmlib.LGBM_BoosterGetNumFeature(boosterPtr, numFeaturesOut), "Booster NumFeature") val numFeatures = lightgbmlib.intp_value(numFeaturesOut) val featureImportances = lightgbmlib.new_doubleArray(numFeatures) LightGBMUtils.validate( lightgbmlib.LGBM_BoosterFeatureImportance(boosterPtr, -1, importanceTypeNum, featureImportances), "Booster FeatureImportance") (0 until numFeatures).map(lightgbmlib.doubleArray_getitem(featureImportances, _)).toArray }

所以要实现比如原生的 get_current_iteration方法,也得按照上面这个流程走一遍。
  1. 改代码还是太麻烦了,还需要考虑后续怎么持续维护跟主线不一样的代码(当然也可以直接成为项目 contributor)。所以我们还是另辟蹊径,通过已有的方法来绕过。
首先看到 mmlspark 实现了saveNativeModel方法,一看这个名字,应该会把模型存成 lgb native model。看了下代码应该没问题,就尝试存了一个。
接下来拿出我们的原生 lightgbm,来 load 这个存好的模型。因为是 Spark 存模型,还需要考虑分布式文件系统等问题,不过 mmlspark 也比较暴力,直接用了 coalesce(1) 加 write text 的方法来存模型,所以最终肯定就是一个文件啦!
读到原生模型后,取 current iteration 就易如反掌了!最后实现代码如下:
def get_native_lgb_model(file_path): txt_files = list(Path(file_path).glob('*.txt')) if len(txt_files) != 1: raise Exception('Aww...cannot read model file!') native_model = lgb.Booster(model_file=txt_files[0].as_posix()) return native_modeldef get_best_iteration(model, path_prefix='/share'): file_path = f'{path_prefix}/lgb_model' model.saveNativeModel(file_path) native_model = get_native_lgb_model(file_path) best_iteration = int(native_model.current_iteration() * 1.02) return best_iteration

性能优化
接下来经过了一阵风平浪静的开发日子,我们逐渐实现了一些参数搜索缓存,自动回测,与数据开发平台实现对接等功能,并逐渐把训练数据量提升到了一亿行。在这个阶段我们主要的目标是评估当数据量增长时,整体的性能变化,机器资源占用变化如何,进而产出对机器资源需求的规划方案来。如果所需的机器数量过多,就需要做一系列的优化控制整体成本。
前面有提到我们的整个数据获取,清洗,构建特征,模型训练预测,业务系统对接产出,必须在 5 个小时以内完成,这其中的每一个阶段要花的时间都要做好优化工作,确保没有明显的瓶颈点。
首先,优化的前提是监控,在本地集群和开发平台,我们都设计了相应的日志,用于抓取 pipeline 中每个阶段所需要花费的时间。另外开发平台部署了 Prometheus 和 Grafana,本地集群我们配备了 dstat, jstat, top 等脚本,主要用于监控 Spark, Python 相关进程的 cpu,内存使用情况,为整体的 capacity planning 做准备。
接着在监控的基础上,我们对各个 stage 做了相应的优化:
1. 数据获取。
  • 对于大数据量的表,我们采取了增量同步数据的方式。
  • 对所有 Hive 表的查询,都尽量走 partition key。例如销量表,如果不走 partition key 的查询,哪怕获取近 3 天数据,跟全量拉取的速度并没有多少区别。获取 partition key 的信息,可以通过 describe 语句进行查询。
2. 数据清洗。 这块目前速度都比较快。如果清洗逻辑没有前后依赖,可以适当并发进行。
3. 特征构建。 在模型训练完之后,我们可以获取到各个特征的重要度指标。我们整体的预测流程中,大约会构建 45 个不同的模型,基于这 45 个模型返回的特征重要度信息,我们制订了一个简单的特征筛选策略:
  • 记录每个模型中 feature importance top 10 的特征
  • 记录每个模型中 feature importance 为 0 的特征
  • 统计 b 中出现特征的频次,在频次最高的特征中,排除在 a 中出现过的特征,形成移除特征列表
  • 在特征构建阶段移除这些特征的生成操作
  • 后续特征的添加会考虑一个准入标准,同时考虑运行时间损耗和精度提升量
  • 也可以考虑将一些不依赖最新数据的固定特征预先计算生成好,节约运行窗口的时间
4. 模型训练。 以往的模型训练参数调优都主要以优化准确度为目标。在这个项目中,我们还需要考虑训练时间的问题。在 lgb 模型中,有一些参数与训练时间会有比较大的相关性,例如:
  • learningRate : 越大训练所需的轮次越少
  • numLeaves : 越小则每棵树越简单,但实际可能需要的数量越多
  • maxBin : 越小则训练速度越快,但会损失精度
  • baggingFraction 和 baggingFreq : 训练采样率,采样之后训练数据少了,自然速度就快了,可以控制每多少轮重新采样一次
  • featureFraction : 特征采样率,原理类似上一条。注意有些情况下这个参数设置为 1 效果才比较理想,一个典型的例子是 one hot encoding 后的数据,采样后可能导致类别信息的缺失
我们采取了一个比较简单的做法来做训练速度的优化,在原先随机搜索的基础上,除了记录模型的精度指标,我们还会一并记录训练所花费的时间。最后在做参数选择时,可以灵活选择可以接受的时间耗费,在训练时间小于这个要求的前提下,选取效果最优的参数。通过这一步优化,整体训练时间缩短了一半左右,而且训练精度并没有下降。
  1. 下游系统数据对接。这个目前改用 Spark 计算生成后,整体速度也非常快,基本在 1 分钟内完成,暂时没有优化的必要。
  2. 整体内存使用的优化。Spark 读入数据时为了不丢失精度,默认会用 bigint ,double等类型来存储数据,但在我们这个应用场景中, int , float类型就已经足够。因此可以做一些类型转换,节约内存占用和保存文件的大小。
经过一系列的优化工作,基本上可以达到使用 5 台 16c/64g 机器完成十亿级模型训练预测按时产出的需求。
mmlspark 训练卡死的问题 2
当训练数据扩充到一亿规模时,我们的 mmlspark 又出现了一个奇怪的卡顿问题。在训练过程中,这一亿数据并不是进入一个统一的大模型来训练,而是会根据策略引擎的规则,分发到不同的模型做训练。前面有提到模型的数量大约有 40+个,这其中有些模型分到的数据量会比较大,因而本身训练时间就比较慢。但随着训练流程的逐步进行,这个训练时间变得越来长,直至 task 失去响应。所以我们又启动了新一轮的排查流程。
系统资源检查
遇到卡顿,首先观察系统资源情况,例如 cpu, 内存,磁盘 io/空间,网络等。但运行过程中发现没有一个资源吃紧的情况,其中特别奇怪的是 cpu 使用率在 100%(机器是 8 核,正好用满一个 core),没有发挥所有的性能。
观察模型正常训练时的情况,Spark 启动的 lgb 会基本把所有 cpu 资源打满,因此怀疑是在进入训练之前的某些环节无法并行计算导致的问题。
JVM 排查工具介入
为了更好的追踪 jvm 内部情况,请出了 visualvm 。这是我多年前工作时用的主力排查工具。为了用上这个工具,需要对 Spark 的配置做一些修改:
  1. 配置 ${SPARK_HOME}/conf/metrics.properties文件,加上 jmx 相关的一些 sink
  2. 在启动任务时加上 jmx 相关的配置-Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false -Dcom.sun.management.jmxremote.ssl=false -Dcom.sun.management.jmxremote.port=22990 ,注意很多文章都说要加在 spark-defaults.conf里,但是我们直接运行 Python 程序并不会调用 spark-submit命令。所以这些参数需要在程序内的 spark session 中指定:
metrics_conf = f"{spark_home}/conf/metrics.properties" jmx_conf = "-Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false " \ "-Dcom.sun.management.jmxremote.ssl=false -Dcom.sun.management.jmxremote.port=22990" spark = (SparkSession.builder .master('spark://127.0.0.1:7077') .appName('zijie') .config("spark.executor.memory", "36g") .config("spark.driver.memory", "6g") .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven") .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:1.0.0-rc1") .config("spark.metrics.conf", metrics_conf) .config("spark.executor.extraJavaOptions", jmx_conf) .getOrCreate())

配置好之后,启动应用,就能在 visualvm里添加 jmx 连接做监控了。我们获取到了卡顿时候的 cpu 使用情况截图如下:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

单 CPU 开销
可以看出模型训练阶段,cpu 使用率都是在 80%上下波动,但模型训练的中间,总有一些只占用了 1 个 cpu 资源的时间段。而且这些 cpu 资源使用是黄色的正常工作线程,而不是垃圾回收。
接下来一个比较自然的思路就是在这些 cpu 使用低谷去获取 thread dump,看系统到底在忙什么。用jstack 或者 visualvm等工具都可以获取到。一个典型的 thread dump 如下所示:(截取了前面 10%的内容)
2020-02-20 16:46:33 Full thread dump OpenJDK 64-Bit Server VM (25.242-b08 mixed mode):"Barrier task timer for barrier() calls." - Thread t@2867 java.lang.Thread.State: WAITING at java.lang.Object.wait(Native Method) - waiting on <51997c78> (a java.util.TaskQueue) at java.lang.Object.wait(Object.java:502) at java.util.TimerThread.mainLoop(Timer.java:526) at java.util.TimerThread.run(Timer.java:505)Locked ownable synchronizers: - None"JMX server connection timeout 2854" - Thread t@2854 java.lang.Thread.State: TIMED_WAITING at java.lang.Object.wait(Native Method) - waiting on <3be7c557> (a [I) at com.sun.jmx.remote.internal.ServerCommunicatorAdmin$Timeout.run(ServerCommunicatorAdmin.java:168) at java.lang.Thread.run(Thread.java:748)Locked ownable synchronizers: - None"RMI TCP Connection(6)-10.0.50.59" - Thread t@2853 java.lang.Thread.State: RUNNABLE at java.net.SocketInputStream.socketRead0(Native Method) at java.net.SocketInputStream.socketRead(SocketInputStream.java:116) at java.net.SocketInputStream.read(SocketInputStream.java:171) at java.net.SocketInputStream.read(SocketInputStream.java:141) at java.io.BufferedInputStream.fill(BufferedInputStream.java:246) at java.io.BufferedInputStream.read(BufferedInputStream.java:265) - locked <3e35bfb4> (a java.io.BufferedInputStream) at java.io.FilterInputStream.read(FilterInputStream.java:83) at sun.rmi.transport.tcp.TCPTransport.handleMessages(TCPTransport.java:555) at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.run0(TCPTransport.java:834) at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.lambda$run$0(TCPTransport.java:688) at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler$$Lambda$37/13510931.run(Unknown Source) at java.security.AccessController.doPrivileged(Native Method) at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.run(TCPTransport.java:687) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748)Locked ownable synchronizers: - locked <37829806> (a java.util.concurrent.ThreadPoolExecutor$Worker)

想起我当年作为新人来看 thread dump 时,心情是多么的激动!这里有个 WAITING,这里有个 locked,是不是找到问题了!但后来发现,其实都不是问题。如果你一开始看 thread dump 没有头绪,非常正常,一方面可以去搜索一些这些 thread state 代表什么含义,另一方面也可以在程序正常运行时跑个 thread dump 看看,会发现其实也有很多 WAITING 和 lock。
在这个具体的问题里,出问题的 thread 主要是以下这个:
"Executor task launch worker for task 8327" - Thread t@2767 java.lang.Thread.State: RUNNABLE at java.io.FileInputStream.readBytes(Native Method) at java.io.FileInputStream.read(FileInputStream.java:255) at org.apache.spark.network.util.LimitedInputStream.read(LimitedInputStream.java:99) at net.jpountz.lz4.LZ4BlockInputStream.readFully(LZ4BlockInputStream.java:269) at net.jpountz.lz4.LZ4BlockInputStream.refill(LZ4BlockInputStream.java:245) at net.jpountz.lz4.LZ4BlockInputStream.read(LZ4BlockInputStream.java:157) at org.apache.spark.storage.BufferReleasingInputStream.read(ShuffleBlockFetcherIterator.scala:591) at java.io.BufferedInputStream.fill(BufferedInputStream.java:246) at java.io.BufferedInputStream.read1(BufferedInputStream.java:286) at java.io.BufferedInputStream.read(BufferedInputStream.java:345) - locked <34545dd4> (a java.io.BufferedInputStream) at java.io.DataInputStream.read(DataInputStream.java:149) at org.spark_project.guava.io.ByteStreams.read(ByteStreams.java:899) at org.spark_project.guava.io.ByteStreams.readFully(ByteStreams.java:733) at org.apache.spark.sql.execution.UnsafeRowSerializerInstance$$anon$2$$anon$3.next(UnsafeRowSerializer.scala:127) at org.apache.spark.sql.execution.UnsafeRowSerializerInstance$$anon$2$$anon$3.next(UnsafeRowSerializer.scala:110) at scala.collection.Iterator$$anon$12.next(Iterator.scala:445) at scala.collection.Iterator$$anon$11.next(Iterator.scala:410) at org.apache.spark.util.CompletionIterator.next(CompletionIterator.scala:29) at org.apache.spark.InterruptibleIterator.next(InterruptibleIterator.scala:40) at scala.collection.Iterator$$anon$11.next(Iterator.scala:410) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:462) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409) at scala.collection.Iterator$class.foreach(Iterator.scala:891) at scala.collection.AbstractIterator.foreach(Iterator.scala:1334) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310) at scala.collection.AbstractIterator.to(Iterator.scala:1334) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1334) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289) at scala.collection.AbstractIterator.toArray(Iterator.scala:1334) at com.microsoft.ml.spark.lightgbm.TrainUtils$.translate(TrainUtils.scala:229) at com.microsoft.ml.spark.lightgbm.TrainUtils$.trainLightGBM(TrainUtils.scala:385) at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145) at org.apache.spark.rdd.RDDBarrier$$anonfun$mapPartitions$1$$anonfun$apply$1.apply(RDDBarrier.scala:51) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:123) at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748)Locked ownable synchronizers: - locked <3ddc0126> (a java.util.concurrent.ThreadPoolExecutor$Worker)

调用栈里面有很多 spark, mmlspark 的关键字,一看就是“自己人”。看到最上层,这个调用主要是在做文件 IO,那么问题来了,为什么这里 IO 不能并行利用多 CPU 呢?一个比较可疑的点是 lz4 那块的调用。搜索一番发现果然 lz4 压缩算法不是 splittable 的,这导致了在处理压缩文件时,必须把所有数据放在一起来运行(比如 MR 里就是 single mapper 了)。另外这里也跟我们前面用了 1 个 partition 有关,Spark 在做 shuffle, broadcast 时都会用到 lz4 压缩[5],然后在解压缩阶段只有一个 partition 参与来做,就自然出现了单 CPU 被打满的现象。
流程优化
细心的同学可能会发现,上面我们在找 mmlspark 中怎么做 early_stopping 的验证集时贴了一段 Spark 代码,其中有一个诡异的 broadcast 调用。这个 broadcast 被用来把 validation data 发送到各个数据分区去做验证集。我们之前代码中选取了 30 天的数据来做 validation set,这就导致有些 broadcast 的数据会非常巨大。Spark 的 broadcast 数据默认也会触发 compression,进一步加剧了这个问题。所以我们有几个改进点:
  1. 考虑在数据量较大的情况下,减小验证集的大小,以提升整体性能。
  2. 参数搜索的操作完全可以从训练主流程中剥离,不占用线上运行时间。
  3. 参数优化可以不用 early_stopping 形式,改用随机搜索或贝叶斯优化等方法。
参数调优
除了流程优化,还可以借鉴一些参数优化的经验(玄学调参处处有)。这里主要参考了几篇文章:
  • Facebook 关于 Spark 性能调优的分享[6]
  • Intel 关于压缩算法的分享[7]
通过一系列调整和实验,最终确定了一组设置:
spark.io.compression.lz4.blockSize="512k" spark.serializer="org.apache.spark.serializer.KryoSerializer" spark.kryoserializer.buffer.max="512m" spark.shuffle.file.buffer="1m"

在解决卡顿问题的基础上,进一步把整体训练时间从之前的 55 分钟缩短到了 25 分钟左右。改完参数后看到的 CPU 使用曲线就正常多了:
技术实践干货 | 初探大规模 GBDT 训练
文章图片

CPU 的充分利用
其它框架评估
看前面提到了这么多 mmlspark 的问题,我们还一直坚持使用,感觉一定是真爱了!其实在整个使用过程中也调研过一些其它的框架和方案。
Spark ML
Spark 自己就带了机器学习相关的库,其中就有 GBDT 的实现。在项目推进过程中,我们也尝试了 Spark ML 中的 GBDT 模型来进行训练。需要注意的是,应该使用 pyspark.ml.regression.GBTRegressor这个类,而不是之前的 pyspark.mllib.tree.GradientBoostedTrees 。实现起来还是非常顺利的,但实测下来发现性能非常的差,感觉用 Spark 来构建整个迭代式的算法流程,整体的效率不高。所以这个方案看起来不可行。
Native Lightgbm
Lightgbm 库自己也带了分布式训练的方案,具体可以参考官方文档[8]。
从支持的功能上和官方提供的性能报告上感觉效果非常优秀。例如可以根据数据与特征的大小,选择 feature/data/voting 三种不同的并行方案。官方给的例子里,15 亿行数据,60+特征,在多机上做 data parallel 训练,整体性能可以达到线性扩展的效果。
但有一个问题,lightgbm 本身并没有带数据分发的能力。官网上的例子可以看出用户需要自行做数据,配置,可执行文件的转换和分发,然后自行在多节点上启动训练任务。其它几点都还好说,可以用 pssh , pscp之类的命令。但数据分发和转换就是一个比较大的问题了。如果仔细往下想,就会发现整体实现思路可能跟 mmlspark 目前的实现非常类似了。
所以总体看下来,如果要自行集成 native lightgbm 做分布式训练,可能会需要写一个类似 mmlspark 的库,工作量大,也没有太大必要。
Xgboost/Catboost
顺带考察了 lgb 的两个老竞争对手,看看他们的分布式方案如何。Catboost 完全没有对分布式的支持,率先出局。Xgboost 里提到如果用 Spark 做数据处理,建议使用 Xgboost4j-Spark。粗略看了下还挺不错的,起码文档比 mmlspark 好多了!不过美中不足的是这个库叫 4j,所以只有 Java/Scala 接口,木有 Python 支持,集成起来会有一些难度。
Dask
从 mmlspark 的思路出发,自然会想到其实也可以结合别的并行计算框架,例如 Dask。Xgboost, lightgbm 都有相关的库,用 Dask 来支持分布式训练。这个方案看起来有几个问题:
  1. 需要维护 Spark,Dask 两套框架,要么就把原先数据处理的逻辑再迁移一次到 Dask。但大数据量的处理框架,总体来说 Spark 还是成熟不少,包括用户数量,工具链支持,可运维性等等。
  2. 数据分发仍然是一个问题,如何把 Spark 中的分布式数据集转化为 Dask 可以读取到的形式,还需要尽量避免数据的交换,不好解决。
  3. 跟数据开发平台的结合会有点困难,而且两者功能点上会有些重合。
不过几个 Dask 库里的实现方式还是值得一看的,提供了一些并发框架集成 Python 算法包的思路。
Angel
腾讯一个较为知名分布式机器学习库[9],基于 parameter server 架构实现了一系列算法,支持分布式大规模的训练。大致看了一下这个库,有几个 concern:1. 同样缺少 Python 接口,在 18 年的某个版本还有 PyAngel,后来就删掉了。2. 整个库的使用量,活跃度,都有点存疑,比如自 19 年 12 月以来基本没有 commit,很多 issue 无人回复。3. 部署相关的额外开销比较多,相比 mmlspark 要复杂不少。
所以结论还是不倾向使用,或许后面可以了解下 Angel 的实现方式,看看有没有借鉴意义。
TensorFlow
TF 里面也有 GBDT[10],可能很多人都不知道,这个我还没试过,另外真的要用的话还得考察下 TensorFlow on Spark。当然好处是说不定还能试一些网络模型看看效果如何。
自行任务管理
在 Spark 完成特征构建后,就可以通过不同的策略,把需要分模型训练的数据分别存储到分布式文件系统中,然后利用一些多机任务管理的框架(例如 Ray,我们的数据开发平台等),在不同的节点上分别取对应的数据进行训练。这个方案的好处是灵活性非常高,不再局限于 Spark 平台能支持的算法,可以跑任意我们熟悉的算法模块。但缺点就是任务管理,高可用,failover,可运维性等等方面都会有些 concern。
另外一个问题就是数据交换的额外开销。我们在项目中也尝试了一下在单机做 lgb 训练,也就是在 Spark 特征构建完之后,通过 toPandas调用把数据集转化为 pandas dataframe,然后再调用原生 Python lightgbm 库来做模型训练与预测。这个操作相比写入文件系统还少了磁盘 IO 的开销,但是整体测试下来用原生 lgb 训练整体时间需要 41 分钟,而直接用 mmlspark 用相同的配置只需要 27 分钟。假设我们有 10 亿行数据,70 个特征,那么每次训练的数据量达到了 500GB 左右,这部分的开销还是非常可观的。H2O 有个产品叫 Sparkling Water,就实现了 internal/external 两种 backend,其中也提到了内外部处理的优劣和适用场景等。
总结来看,目前还是mmlspark方案更加合适。后续我们也会持续关注类似框架,并比较评估大规模的GBDT模型与深度学习模型的表现差异。
参考资料
[1] Spark 官方网站: https://spark.apache.org/downloads.html
[2] 官方示例: https://databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html
[3] 官方文档: https://mmlspark.blob.core.windows.net/docs/1.0.0-rc1/pyspark/mmlspark.lightgbm.html
[4] issue: https://github.com/Azure/mmlspark/issues/775
[5] lz4 压缩: https://spark.apache.org/docs...
[6] Facebook 关于 Spark 性能调优的分享: https://www.slideshare.net/databricks/tuning-apache-spark-for-largescale-workloads-gaoxiang-liu-and-sital-kedia
[7] Intel 关于压缩算法的分享: https://www.slideshare.net/databricks/best-practice-of-compressiondecompression-codes-in-apache-spark-with-sophia-sun-and-qi-xie
[8] 官方文档: https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html
[9] 分布式机器学习库: https://github.com/Angel-ML/sona
【技术实践干货 | 初探大规模 GBDT 训练】[10] TF 里面也有 GBDT: https://www.tensorflow.org/api_docs/python/tf/estimator/BoostedTreesRegressor

    推荐阅读