快捷搜索: 王者荣耀 脱发

动手实现ID3决策树代码

ID3决策树

本文从计算数据集的信息熵、划分数据集、选择最优特征、递归训练一棵树、预测五个方面介绍怎样构建ID3决策树。 先要介绍信息熵和信息增益的这两个公式: Ent ⁡ ( D ) = − ∑ k = 1 ∣ Y ∣ p k log ⁡ 2 p k operatorname{Ent}(D)=-sum_{k=1}^{|mathcal{Y}|} p_{k} log _{2} p_{k} Ent(D)=−k=1∑∣Y∣pklog2pk Gain ⁡ ( D , a ) = Ent ⁡ ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ Ent ⁡ ( D v ) operatorname{Gain}(D, a)=operatorname{Ent}(D)-sum_{v=1}^{V} frac{left|D^{v} ight|}{|D|} operatorname{Ent}left(D^{v} ight) Gain(D,a)=Ent(D)−v=1∑V∣D∣∣Dv∣Ent(Dv) 具体也可参考,详细介绍了决策树公式及知识框架。

计算数据集的信息熵

假设现在的数据集dataSet最后一列为样本标签,因为数据集的信息熵只与标签的纯度有关,所以需要取出数据集最后一列的类别及其数量到字典中,代入公式计算信息熵。

以下为计算数据集信息熵的函数:

def entropy(dataSet):
    labelCounts = {
          
   }
    length = len(dataSet)
    for example in dataSet:
        if example[-1] not in labelCounts: labelCounts[example[-1]] = 0
        labelCounts[example[-1]] += 1
    e = 0.0
    for i in labelCounts.values():
        p_k = float(i) / length
        e -= p_k * log(p_k, 2)
    return e

划分数据集

以下函数功能为按照特征的不同值将样本集划分为多个数据集。

def splitDataSet(dataSet, axis, value):
    reDataSet = []
    for example in dataSet:
        if example[axis] == value: 
            reDataSet += [example[:axis] + example[axis + 1:]]
    return reDataSet

选择最优特征

遍历各个特征,选取信息熵最小的作为最优划分属性。

def chooseBestFeature(dataSet):
    bestFeature = -1
    length = len(dataSet)
    minEntropy = entropy(dataSet)
    for axis in range(len(dataSet[0]) - 1):
        newEntropy = 0.0
        record = {
          
   }
        for example in dataSet:
            if example[axis] not in record: record[example[axis]] = 0
            record[example[axis]] += 1
        for i in record:
            newEntropy += entropy(splitDataSet(dataSet, axis, i)) * record[i] / length
        if newEntropy < minEntropy:
            bestFeature = axis
            minEntropy = newEntropy
    return bestFeature

递归训练一棵树

当前结点都属于同一个标签时,结束递归。

def trainTree(dataSet,feature_name):
    myTree = {
          
   }
    k = chooseBestFeature(dataSet)
    s = set()
    for example in dataSet:
        s.add(example[k])
    tree = {
          
   }
    for i in s:
        newDataSet = splitDataSet(dataSet, k, i)
        labelRecord = []
        for example in newDataSet:
            labelRecord.append(example[-1])
        if labelRecord.count(labelRecord[0]) == len(labelRecord):
            tree[i] = labelRecord[0]  
        else: 
            tree[i] = trainTree(newDataSet, feature_name)
    myTree[feature_name[k]] = tree
    return myTree

预测

当结果是字典时,继续递归,否则输出该值为预测结果。

def predict(inputTree,feature_name,testVec):
    feature = list(inputTree.keys())[0]
    k = feature_name.index(feature)
    childTree = inputTree[feature]
    label = childTree[testVec[k]]
    if isinstance(label, dict): return predict(label, feature_name, testVec)
    return label

sklearn实现决策树

除了ID3决策树,常用的还有C4.5决策树和CART决策树。

Python的sklearn库中提供了决策树的模型,可用于快速构建不同类型的决策树模型,具体可参考……待续

经验分享 程序员 微信小程序 职场和发展