动手实现ID3决策树代码
ID3决策树
本文从计算数据集的信息熵、划分数据集、选择最优特征、递归训练一棵树、预测五个方面介绍怎样构建ID3决策树。 先要介绍信息熵和信息增益的这两个公式: Ent ( D ) = − ∑ k = 1 ∣ Y ∣ p k log 2 p k operatorname{Ent}(D)=-sum_{k=1}^{|mathcal{Y}|} p_{k} log _{2} p_{k} Ent(D)=−k=1∑∣Y∣pklog2pk Gain ( D , a ) = Ent ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ Ent ( D v ) operatorname{Gain}(D, a)=operatorname{Ent}(D)-sum_{v=1}^{V} frac{left|D^{v} ight|}{|D|} operatorname{Ent}left(D^{v} ight) Gain(D,a)=Ent(D)−v=1∑V∣D∣∣Dv∣Ent(Dv) 具体也可参考,详细介绍了决策树公式及知识框架。
计算数据集的信息熵
假设现在的数据集dataSet最后一列为样本标签,因为数据集的信息熵只与标签的纯度有关,所以需要取出数据集最后一列的类别及其数量到字典中,代入公式计算信息熵。
以下为计算数据集信息熵的函数:
def entropy(dataSet): labelCounts = { } length = len(dataSet) for example in dataSet: if example[-1] not in labelCounts: labelCounts[example[-1]] = 0 labelCounts[example[-1]] += 1 e = 0.0 for i in labelCounts.values(): p_k = float(i) / length e -= p_k * log(p_k, 2) return e
划分数据集
以下函数功能为按照特征的不同值将样本集划分为多个数据集。
def splitDataSet(dataSet, axis, value): reDataSet = [] for example in dataSet: if example[axis] == value: reDataSet += [example[:axis] + example[axis + 1:]] return reDataSet
选择最优特征
遍历各个特征,选取信息熵最小的作为最优划分属性。
def chooseBestFeature(dataSet): bestFeature = -1 length = len(dataSet) minEntropy = entropy(dataSet) for axis in range(len(dataSet[0]) - 1): newEntropy = 0.0 record = { } for example in dataSet: if example[axis] not in record: record[example[axis]] = 0 record[example[axis]] += 1 for i in record: newEntropy += entropy(splitDataSet(dataSet, axis, i)) * record[i] / length if newEntropy < minEntropy: bestFeature = axis minEntropy = newEntropy return bestFeature
递归训练一棵树
当前结点都属于同一个标签时,结束递归。
def trainTree(dataSet,feature_name): myTree = { } k = chooseBestFeature(dataSet) s = set() for example in dataSet: s.add(example[k]) tree = { } for i in s: newDataSet = splitDataSet(dataSet, k, i) labelRecord = [] for example in newDataSet: labelRecord.append(example[-1]) if labelRecord.count(labelRecord[0]) == len(labelRecord): tree[i] = labelRecord[0] else: tree[i] = trainTree(newDataSet, feature_name) myTree[feature_name[k]] = tree return myTree
预测
当结果是字典时,继续递归,否则输出该值为预测结果。
def predict(inputTree,feature_name,testVec): feature = list(inputTree.keys())[0] k = feature_name.index(feature) childTree = inputTree[feature] label = childTree[testVec[k]] if isinstance(label, dict): return predict(label, feature_name, testVec) return label
sklearn实现决策树
除了ID3决策树,常用的还有C4.5决策树和CART决策树。
Python的sklearn库中提供了决策树的模型,可用于快速构建不同类型的决策树模型,具体可参考……待续
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
springboot将文件响应给前端