深度学习训练数据中的特征重要性排名
查看神经网络模型特征重要性的思路:依次变动各个特征,通过模型最终预测的结果来衡量特征的重要性。
神经网络特征重要性的获取步骤如下:
- 训练一个神经网络模型;
- 每次对一个特征列进行随机shuffle,并输入模型中进行预测得到Loss;
- 记录变动的每个特征列以及其对应的Loss;
- 每个Loss就是该特征对应的特征重要性,Loss越大,说明该特征对于模型越重要。
Code :
import matplotlib.pyplot as plt from tqdm.notebook import tqdm import tensorflow as tf from tensorflow import keras import tensorflow.keras.backend as K from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from tensorflow.keras.callbacks import LearningRateScheduler, ReduceLROnPlateau from tensorflow.keras.optimizers.schedules import ExponentialDecay from sklearn.metrics import mean_absolute_error as mae from sklearn.preprocessing import RobustScaler, normalize from sklearn.model_selection import train_test_split, GroupKFold, KFold from IPython.display import display COMPUTE_LSTM_IMPORTANCE = 1 ONE_FOLD_ONLY = 1 with gpu_strategy.scope(): kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=2021) test_preds = [] for fold, (train_idx, test_idx) in enumerate(kf.split(train, targets)): K.clear_session() print(-*15, >, fFold {fold+1}, <, -*15) X_train, X_valid = train[train_idx], train[test_idx] y_train, y_valid = targets[train_idx], targets[test_idx] # 导入已经训练好的模型 model = keras.models.load_model(models/XXX.h5) # 计算特征重要性 if COMPUTE_LSTM_IMPORTANCE: results = [] print( Computing LSTM feature importance...) for k in tqdm(range(len(COLS))): if k>0: save_col = X_valid[:,:,k-1].copy() np.random.shuffle(X_valid[:,:,k-1]) oof_preds = model.predict(X_valid, verbose=0).squeeze() mae = np.mean(np.abs( oof_preds-y_valid )) results.append({feature:COLS[k],mae:mae}) if k>0: X_valid[:,:,k-1] = save_col # 展示特征重要性 print() df = pd.DataFrame(results) df = df.sort_values(mae) plt.figure(figsize=(10,20)) plt.barh(np.arange(len(COLS)),df.mae) plt.yticks(np.arange(len(COLS)),df.feature.values) plt.title(LSTM Feature Importance,size=16) plt.ylim((-1,len(COLS))) plt.show() # SAVE LSTM FEATURE IMPORTANCE df = df.sort_values(mae,ascending=False) df.to_csv(flstm_feature_importance_fold_{fold}.csv,index=False) # ONLY DO ONE FOLD if ONE_FOLD_ONLY: break
Result :
上一篇:
通过多线程提高代码的执行效率例子