docs/algo/sona/gbdt_sona.md
GBDT(Gradient Boosting Decision Tree)的相关原理部分可以参考
Spark中的RDD
instance: RDD[(Id, Instance)]:保存着训练样本的RDD,在算法循环迭代过程中,该RDD不会重新生成gradient: RDD[(Id, (Gradient, Hessian)))]:保存每个样本对应的一阶、二阶梯度值,每棵树会生成新的RDDprediction: RDD[(Id, Prediction)]:保存每个样本在当前模型下的预测值,每棵树会生成新的RDDAngel PS上的PSMatrix
instanceLayout:保存instance RDD每个partition的样本落地哪个树节点的信息gradHistogram:保存树的每个叶子节点的梯度直方图gbtModel:包含三个矩阵,分别是树节点的分离的特征ID、特征值,以及叶子节点的权重
GBDT的伪代码如下所示:
val instance: RDD[(Id, Instance)]
var gradient: RDD[(Id, (Gradient, Hessian)))]
var prediction: RDD[(Id, Prediction)]
val instanceLayout, gradHistogram, gbtModel: PSMatrix
sparkcontext.broadcast(createSketch(instance))
While (treeNum < maxTreeNum) {
(1) val tree = new Tree()
// Calculate instance gradient
(2) gradient = calcGrad(instance, prediction, instanceLayout)
While (tree.depth < maxDepth) {
// Build gradient histogram
(3.1) gradHistogram.push(buildHist(instance, gradient, instanceLayout))
// Find best split with PS function
(3.2) gbtModel.update(findSplit(gradHist))
(3.3) growTree(tree, gbtModel); layout.update(tree)
}
(4) prediction = updatePrediction(instance, gbtModel)
}
数据参数
算法参数
// TODO