Spark随机森林之多分类模型
Spark随机森林之多分类模型
关于随机森林
随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。
其中,决策树相关 ,这里不再详述。
官方实例
以下是官方给出的一个demo
import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // 加载数据 val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // 将数据随机分配为两份,一份用于训练,一份用于测试 val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // 随机森林训练参数设置 //分类数 val numClasses = 2 // categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量 val categoricalFeaturesInfo = Map[Int, Int]() //树的个数 val numTrees = 3 //特征子集采样策略,auto 表示算法自主选取 val featureSubsetStrategy = "auto" //纯度计算 val impurity = "gini" //树的最大层次 val maxDepth = 4 //特征最大装箱数 val maxBins = 32 //训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象 val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) // 测试数据评价训练好的分类器并计算错误率 val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model: " + model.toDebugString) // 将训练后的随机森林模型持久化 model.save(sc, "myModelPath") //加载随机森林模型到内存 val sameModel = RandomForestModel.load(sc, "myModelPath")
在这个demo中,numClasses=2,是一个二分类问题,numTrees=3,也就是说每个森林有三棵决策树,每一特征向量经过这三棵树进行分类,最后综合来看0类和1类在三棵树预测出的标签中的占比,如果2棵树或3棵树预测为0类,则为0类,否则为1类。
这给了我们一个进行多分类的小技巧:假设某一事物有多个类型A、B、C、D类,不仅如此,其中还有一些是复合类型,比如A类和B类的复合类型,那么这时候用随机森林如何进行分类判别?
那么这时就可以定义较多数量的决策树,比如我定义numTrees=10,通过
val model: RandomForestModel=RandomForest.trainClassifier( trainingData,numClasses,categoricalFeaturesInfo,numTrees, featureSubsetStrategy,impurity, maxDepth, maxBins) val tr: Array[DecisionTreeModel] =model.trees
可以获取随机森林中的每棵树,即一个DecisionTreeModel数组,每个DecisionTreeModel都有predict方法,这时可以获取每棵树对某一特征向量的分类判别,对10棵树的结果进行统计分析。
举例来说,对于某一特征向量,10棵树中,有8棵判别为A类,1棵判别为B类,1棵判别为C类,则这个特征所属的载体有80%的“概率”属于A类,当然也可以设定一个阙值,超过这个阙值直接判为该类别。对于另一特征向量,5棵树判别为A类,5棵树判别为D类,这时我们就可以认为它是属于A类和D类的复合类型。