快捷搜索: 王者荣耀 脱发

可视化决策树之Python实现

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。一些基础原理这里就不再一一介绍了,直接进入今天的主题,如何可视化决策树。

本篇使用klearn来实现决策树的过程,下面是详细讲解:

首先导入必要的包:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score

然后,导入数据集。我用的是kaggle上的蘑菇数据集,这是一个经典的决策树数据集,非常适合决策树,下面我们就会知道。

data = pd.read_csv("mushrooms.csv")
data.head()

先初步认识一下数据集:

可以看出这是一个分类变量的数据集。然后,我们就要将它变成数值变量,好利于下面的建模。

from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
for col in data.columns:
    data[col] = labelencoder.fit_transform(data[col])
data.head()

之后,我们来看看数据的大小:

data.shape

(8124, 23) 数据准备后,我们开始提取训练集与测试集。

y = data[class]
X = data.drop(class, axis=1)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size=0.8)
columns = X_train.columns

接着标准化训练集

# 数据标准化
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler()
ss_y = StandardScaler()
X_train = ss_X.fit_transform(X_train)
X_test = ss_X.transform(X_test)

接着,构建决策树模型

from sklearn.tree import DecisionTreeClassifier
model_tree = DecisionTreeClassifier()
model_tree.fit(X_train, y_train)

评价模型准确性

y_prob = model_tree.predict_proba(X_test)[:,1]
y_pred = np.where(y_prob > 0.5, 1, 0)
model_tree.score(X_test, y_pred)

可以得到结果:1.

说明决策树非常吻合此数据集。

最后,完成决策树的可视化

# 可视化树图
data_ = pd.read_csv("mushrooms.csv")
data_feature_name = data_.columns[1:]
data_target_name = np.unique(data_["class"])
import graphviz
import pydotplus
from sklearn import tree
from IPython.display import Image
import os
os.environ["PATH"] += os.pathsep + C:/Program Files (x86)/Graphviz2.38/bin/
dot_tree = tree.export_graphviz(model_tree,out_file=None,feature_names=data_feature_name,class_names=data_target_name,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
img = Image(graph.create_png())
graph.write_png("out.png")

注意:graphviz包不仅需要使用pip install graphviz安装还需要单独安装。使用时,还需要引入graphviz绝对路径。

参考:

graphviz-2.38.msi安装包下载:

数据集:

经验分享 程序员 微信小程序 职场和发展