本文描述了猿輔導(dǎo)開源分布式機(jī)器學(xué)習(xí)庫ytk-learn及分布式通信庫ytk-mp4j的相關(guān)內(nèi)容,可實(shí)現(xiàn)在多應(yīng)用場(chǎng)景中使用。ytk-learn 是基于Java的高效分布式機(jī)器學(xué)習(xí)庫, 簡單易用,文檔詳細(xì),只需要用戶安裝Java 8運(yùn)行時(shí)環(huán)境即可,而且所有模型都有可運(yùn)行的demo。
猿輔導(dǎo)公司開源了兩個(gè)機(jī)器學(xué)習(xí)項(xiàng)目——ytk-learn, ytk-mp4j,其中 ytk-mp4j 是一個(gè)高效的分布式通信庫,基于該通信庫我們實(shí)現(xiàn)了 ytk-learn 分布式機(jī)器學(xué)習(xí)庫,該機(jī)器學(xué)習(xí)庫目前在猿輔導(dǎo)很多應(yīng)用場(chǎng)景中使用,比如,自適應(yīng)學(xué)習(xí)、學(xué)生高考分預(yù)測(cè)、數(shù)據(jù)挖掘、課程推薦等。
ytk-learn分布式機(jī)器學(xué)習(xí)庫
項(xiàng)目背景
LR(Logistic Regression), GBDT(Gradient Boosting Decision Tree), FM(Factorization Machines), FFM(Field-aware Factorization Machines) 模型是廣告點(diǎn)擊率預(yù)測(cè)和推薦系統(tǒng)中廣泛使用的模型,但是到目前為止幾乎沒有一個(gè)高效的機(jī)器學(xué)習(xí)開源項(xiàng)目集這幾種常用模型于一身,而且很多機(jī)器學(xué)習(xí)開源項(xiàng)目只能在特定計(jì)算平臺(tái)下使用,最重要的是不能高效的整合到線上生產(chǎn)環(huán)境中。ytk-learn 就是解決以上問題而產(chǎn)生的。
圖1 ytk-learn 特性概略
項(xiàng)目簡介
ytk-learn 是基于Java的高效分布式機(jī)器學(xué)習(xí)庫,實(shí)現(xiàn)大量的主流傳統(tǒng)機(jī)器學(xué)習(xí)模型(GBDT, LR, FM, FFM等)和loss函數(shù),支持單機(jī)多線程、多機(jī)集群及分布式計(jì)算環(huán)境。
其中 GBDT/GBRT 的實(shí)現(xiàn)借鑒吸收了 XGBoost 和 LightGBM 的大部分有用特性,支持特征并行和數(shù)據(jù)并行,支持傳統(tǒng)的精確算法和直方圖近似算法,支持 level-wise 或者 leaf-wise 的建樹方式,而且還實(shí)現(xiàn)了分布式帶權(quán)分位數(shù)近似。在單機(jī)數(shù)據(jù)并行的場(chǎng)景中訓(xùn)練速度跟 XGBoost 相當(dāng),在非$2^n$臺(tái)機(jī)器的分布式場(chǎng)景中比 LightGBM 速度更快,更穩(wěn)定。
傳統(tǒng)的 GBDT/GBRT 在含有大量 Categorical 特征的場(chǎng)景中無法使用,我們實(shí)現(xiàn)了多種適用于大量 Categorical特征的 GBST(Gradient Boosting Soft Tree)模型,在猿輔導(dǎo)的點(diǎn)擊率預(yù)測(cè)和推薦場(chǎng)景中效果明顯好于LR、FM、FFM等模型。
ytk-learn 實(shí)現(xiàn)了改進(jìn)的 Hoag(Hyperparameter optimization with approximate gradient, ICML2016)算法,能夠自動(dòng)高效的進(jìn)行超參數(shù)搜索。當(dāng)目標(biāo)函數(shù)是凸函數(shù)時(shí),hoag 能快速得到最優(yōu)超參數(shù)(kaggle 比賽利器),效率明顯高于傳統(tǒng)的網(wǎng)格超參數(shù)搜索算法(grid search),而且在非凸目標(biāo)函數(shù)場(chǎng)景中也適用。
其他特性:
1. 簡單易用,文檔詳細(xì),只需要用戶安裝Java 8運(yùn)行時(shí)環(huán)境即可,而且所有模型都有可運(yùn)行的demo
2. 支持主流的操作系統(tǒng):Linux,Windows,Mac OS,僅需安裝Java8運(yùn)行環(huán)境即可使用
3. 支持單機(jī)多線程,多機(jī)集群及分布式環(huán)境(Hadoop,Spark),相比Hadoop Mahout, Spark MLlib效率高很多
4. 提供簡單易用的在線預(yù)測(cè)代碼,可以方便整合到線上生成環(huán)境
5. 支持多種目標(biāo)函數(shù)和評(píng)估指標(biāo),支持L1,L2,L1+L2正則
6. 樹模型支持樣本采樣,特征采樣,提供初始預(yù)估值的訓(xùn)練
7. 支持特征預(yù)處理(歸一化,縮放),特征哈希,特征過濾,基于樣本標(biāo)簽采樣
8. 提供了讀取數(shù)據(jù)時(shí)進(jìn)行高效數(shù)據(jù)處理的python腳本
9. 訓(xùn)練模型支持checkpoint,繼續(xù)訓(xùn)練
10. LR 支持 Laplace 近似,方便做 Exploitation&Exploration
11. 基于猿輔導(dǎo)的 ytk-mp4j 通信庫,分布式訓(xùn)練效率非常高
ytk-mp4j 分布式機(jī)器學(xué)習(xí)通信庫
?
項(xiàng)目背景
目前可以用于分布式機(jī)器學(xué)習(xí)的通信主要基于MPI和RPC,其中MPI是分布式高性能計(jì)算的標(biāo)配,雖然效率非常高,但是對(duì)于開發(fā)分布式機(jī)器學(xué)習(xí)任務(wù)來說有很多缺點(diǎn): 開發(fā)難度大、數(shù)據(jù)支持太底層、只能用C/C++, Fortran編寫等等;RPC 方式來實(shí)現(xiàn)類似 allreduce 這種操作,在特征維度特別高的場(chǎng)景,通信效率太低。所以我們開發(fā)了一套易用且高效的機(jī)器學(xué)習(xí)分布式通信庫。
圖2 ytk-mp4j 特性概略
項(xiàng)目簡介
ytk-mp4j 是基于Java的高效分布式機(jī)器學(xué)習(xí)通信庫,實(shí)現(xiàn)了類似 MPI Collective 通信中的大部分操作,包含gather, scatter, allgather, reduce-scatter, broadcast, reduce, allreduce,使用 ytk-mp4j 可以快速地把串行機(jī)器學(xué)習(xí)程序改造成支持多線程和多進(jìn)程,ytk-learn 中所有涉及到分布式通信操作都是基于 ytk-mp4j 實(shí)現(xiàn)(表1中給出了部分例子)。
相比于MPI, ytk-mp4j 擴(kuò)展實(shí)現(xiàn)了一些非常實(shí)用的特性:
1. 所有的通信操作都是基于最優(yōu)算法實(shí)現(xiàn)[1,2],性能非常高,同時(shí)支持多線程,多進(jìn)程。同樣的功能,在C/C++ 環(huán)境中,可能需要結(jié)合 MPI 和 OpenMP 才能實(shí)現(xiàn)
2. 不僅支持基本的數(shù)據(jù)類型(double, float, long, int, short, byte),而且還支持Java String及任意普通Java對(duì)象(Java 對(duì)象只需要實(shí)現(xiàn) Kryo的 Serializer 接口)
3. 不僅支持傳統(tǒng)數(shù)組類型的 Collective 通信,而且還支持Java Map數(shù)據(jù)類型,使用Map數(shù)據(jù)類型,用戶可以實(shí)現(xiàn)非常復(fù)雜的通信操作(例如:集合求交、求并,鏈表的連接等操作)
4. 支持?jǐn)?shù)據(jù)壓縮傳輸,在網(wǎng)絡(luò)資源很緊張的情況下,可以節(jié)約大量的帶寬
5. 純Java代碼實(shí)現(xiàn),可以無縫集成到 Hadoop, Spark 等分布式計(jì)算平臺(tái),構(gòu)建自己的分布式機(jī)器學(xué)習(xí)系統(tǒng)
6. 使用 Java的SDP(Sockets Direct Protocol)可以實(shí)現(xiàn)高效的RDMA(Remote Direct Memory Access)
表1 ytk-mp4j在ytk-learn中的使用
ytk-mp4j 操作ytk-learn 中使用場(chǎng)景
allreduceLoss 求和,梯度求和,Hessian 求和,計(jì)算分位數(shù),計(jì)算平均值,
計(jì)算評(píng)估指標(biāo)(AUC, Confusion Matrix…),統(tǒng)計(jì)樣本數(shù)量、
特征出現(xiàn)次數(shù)等等
reduce-scatterGBDT 高效梯度求和
allgatherL-BFGS 中計(jì)算Hv,GBDT 同步梯度
allreduce 操作是分布式機(jī)器學(xué)習(xí)中使用最多的通信操作,它對(duì)機(jī)群中所有的節(jié)點(diǎn)對(duì)應(yīng)的數(shù)據(jù)進(jìn)行歸約操作,然后再分發(fā)給各個(gè)節(jié)點(diǎn)。下面給出了 ytk-mp4j 在多進(jìn)程、多進(jìn)程、數(shù)組,Map 下的 allreduce (歸約操作為求和)示意圖:
性能測(cè)試
表2給出了 ytk-mp4j 實(shí)現(xiàn)的Collective操作時(shí)間復(fù)雜度,其中 ?是網(wǎng)絡(luò)連接延遲, 是傳輸1個(gè)字節(jié)需要的時(shí)間, 是需要傳輸?shù)淖止?jié)數(shù)量, 是進(jìn)行1字節(jié)數(shù)據(jù)歸約(reduction)操作需要的時(shí)間??梢钥闯觯S著機(jī)器數(shù)量的增加,所有操作數(shù)據(jù)傳輸?shù)臅r(shí)間是幾乎不會(huì)增加的,只有連接和歸約操作的時(shí)間會(huì)隨機(jī)器數(shù)量增加,但在大數(shù)據(jù)通信時(shí),連接和歸約的時(shí)間占比很小。這個(gè)時(shí)間復(fù)雜度特性非常重要,它使得在特征維度、樣本數(shù)量超過一定閾值的分布式機(jī)器學(xué)習(xí)訓(xùn)練任務(wù)中,訓(xùn)練加速比與機(jī)器數(shù)量接近線性關(guān)系。
表2 ytk-mp4j 實(shí)現(xiàn)的 Collective 操作時(shí)間復(fù)雜度
下圖是測(cè)試在 1Gigabit Ethernet 網(wǎng)絡(luò)下,10億維 double 數(shù)組,各種 Collective 通信操作在不同的機(jī)器數(shù)量下的通信性能(時(shí)間單位: ms),從圖中可以看出 ytk-mp4j 中的7種 Collective 操作的通信時(shí)間與機(jī)器數(shù)量的關(guān)系與理論值完全符合。
評(píng)論