找回密码
 会员注册
查看: 34|回复: 0

SHAP解释运用基于python的树模型特征选择+随机森林回归预测+SHAP解释预测

[复制链接]

6

主题

0

回帖

19

积分

新手上路

积分
19
发表于 2024-9-10 09:49:51 | 显示全部楼层 |阅读模式
1.导入必要的库importpandasaspdimportnumpyasnpimportmatplotlib.pyplotaspltimportseabornassnsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.ensembleimportRandomForestRegressorfromsklearn.treeimportexport_graphviz#fromsklearn.inspectionimportplot_partial_dependencefromsklearn.metricsimportmean_squared_errorimportshapimportwarnings2.设置忽略警告与显示字体、负号warnings.filterwarnings("ignore")#设置Matplotlib的字体属性plt.rcParams['font.sans-serif']=['SimHei']#用于中文显示,你可以更改为其他支持中文的字体plt.rcParams['axes.unicode_minus']=False#用来正常显示负号3.导入数据集3.1加载数据#1.加载数据df=pd.read_excel('train.xlsx')X=df.iloc[:,:-1]#特征y=df.iloc[:,-1]#标签3.2查看数据分布1.箱线图plt.figure(figsize=(30,6))sns.boxplot(data=df)plt.title('BoxPlotsofDatasetFeatures',fontsize=40,color='black')#如果需要设置坐标轴标签的字体大小和颜色plt.xlabel('X-axisLabel',fontsize=20,color='red')#设置x轴标签的字体大小和颜色plt.ylabel('Y-axisLabel',fontsize=20,color='green')#设置y轴标签的字体大小和颜色#还可以调整刻度线的长度、宽度等属性plt.tick_params(axis='x',labelsize=20,colors='blue',length=5,width=1)#设置x轴刻度线、刻度标签的更多属性plt.tick_params(axis='y',labelsize=20,colors='deepskyblue',length=5,width=1)#设置y轴刻度线、刻度标签的更多属性plt.xticks(rotation=45)#如果特征名很长,可以旋转x轴标签plt.show()        结果如图3-1所示:图3-1        结果图实在丑陋,这是由数据分布不均衡造成的,这里重点不是数据清洗,就这样凑着用吧。2.分布图#注意:distplot在seaborn0.11.0+中已被移除#你可以分别使用histplot和kdeplotplt.figure(figsize=(50,10))fori,featureinenumerate(df.columns,1):plt.subplot(1,len(df.columns),i)sns.histplot(df[feature],kde=True,bins=30,label=feature,color='blue')plt.title(f'QQplotof{feature}',fontsize=40,color='black')#如果需要设置坐标轴标签的字体大小和颜色plt.xlabel('X-axisLabel',fontsize=35,color='red')#设置x轴标签的字体大小和颜色plt.ylabel('Y-axisLabel',fontsize=40,color='green')#设置y轴标签的字体大小和颜色#还可以调整刻度线的长度、宽度等属性plt.tick_params(axis='x',labelsize=40,colors='blue',length=5,width=1)#设置x轴刻度线、刻度标签的更多属性plt.tick_params(axis='y',labelsize=40,colors='deepskyblue',length=5,width=1)#设置y轴刻度线、刻度标签的更多属性plt.tight_layout()plt.show()        结果如图3-2所示:图3-23.QQ图fromscipyimportstatsplt.figure(figsize=(50,10))fori,featureinenumerate(df.columns,1):plt.subplot(1,len(df.columns),i)stats.probplot(df[feature],dist="norm",plot=plt)plt.title(f'QQplotof{feature}',fontsize=40,color='black')#如果需要设置坐标轴标签的字体大小和颜色plt.xlabel('X-axisLabel',fontsize=35,color='red')#设置x轴标签的字体大小和颜色plt.ylabel('Y-axisLabel',fontsize=40,color='green')#设置y轴标签的字体大小和颜色#还可以调整刻度线的长度、宽度等属性plt.tick_params(axis='x',labelsize=40,colors='blue',length=5,width=1)#设置x轴刻度线、刻度标签的更多属性plt.tick_params(axis='y',labelsize=40,colors='deepskyblue',length=5,width=1)#设置y轴刻度线、刻度标签的更多属性plt.tight_layout()plt.show()        结果如图3-3所示:图3-34.树模型特征选择#4.特征选择(使用随机森林的特征重要性)rf=RandomForestRegressor(n_estimators=100,random_state=42)rf.fit(X_scaled,y)importances=rf.feature_importances_indices=np.argsort(importances)[::-1]#可视化特征重要性plt.figure(figsize=(10,7))plt.title("Featureimportances")plt.bar(range(X.shape[1]),importances[indices],align="center",color='cyan')plt.xticks(range(X.shape[1]),[X.columns[i]foriinindices],rotation='vertical')plt.xlim([-1,X.shape[1]])plt.show()        特征重要性比较如图4-1所示:图4-15.随机森林回归预测#划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X_scaled,y,test_size=0.2,random_state=42)#随机森林回归预测rf_final=RandomForestRegressor(n_estimators=100,random_state=42)rf_final.fit(X_train,y_train)y_pred=rf_final.predict(X_test)mse=mean_squared_error(y_test,y_pred)print(f"MeanSquaredError:{mse}")#预测结果输出与比对plt.figure()plt.plot(np.arange(21),y_test[:100],"go-",label="Truevalue")plt.plot(np.arange(21),y_pred[:100],"ro-",label="Predictvalue")plt.title("TruevalueAndPredictvalue")plt.legend()plt.show()        预测结果如图5-1所示:图5-1        由图5-1结合这里的误差MeanSquaredError:16.092619015714185,说明预测效果很一般,不过本身数据集没有经过清洗,数据分布不合理,有这样的结果也能接受。我一般使用matlab进行数据清晰和标准化,matlab暂时打不开,先搁置,后面我会出数据标准化的文章。5.SHAP库解释预测5.1shap库下载安装        这里的shap库我已经下载安装过了,没有下载安装的在pycharm终端、AnacondaPromt终端等等执行命令进行下载安装,最好带上清华镜像源,在网络信号不好时也能顺利安装且节省时间。pipinstall-ihttps://pypi.tuna.tsinghua.edu.cn/simpleshap5.2waterfallshap.plots.waterfall(shap_values[0])#Forthefirstobservation        结果如图5-1所示:图5-15.3forceplot#相互作用图force_plot1=shap.force_plot(explainer.expected_value,np.mean(shap_values,axis=0),np.mean(X_test,axis=0),feature_label,matplotlib=True,show=False)shap_interaction_values=explainer.shap_interaction_values(X_test)shap.summary_plot(shap_interaction_values,X_test)        结果如图5-2所示:图5-25.4特征影响图shap.plots.force(explainer.expected_value,shap_values.values,shap_values.data)        结果如图5-3所示:图5-35.5特征密度散点图:summary_plot/beeswarm5.5.1summary_plot#创建SHAP解释器explainer=shap.TreeExplainer(rf)#计算SHAP值shap_values=explainer.shap_values(X_test)#特征标签feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']plt.rcParams['font.family']='serif'plt.rcParams['font.serif']='TimesNewRoman'plt.rcParams['font.size']=13#设置字体大小为14#现在创建SHAP可视化#配色viridisSpectralcoolwarmRdYlGnRdYlBuRdBuRdGyPuOrBrBGPRGnPiYGshap.summary_plot(shap_values,X_test,feature_names=feature_label)#粉红色点:表示该特征值在这个观察中对模型预测产生了正面影响(增加预测值)#蓝色点:表示该特征值在这个观察中对模型预测产生了负面影响(降低预测值)#水平轴(SHAP值)显示了影响的大小。点越远离中心线(零点),该特征对模型输出的影响越大#图中垂直排列的特征按影响力从上到下排序。上方的特征对模型输出的总体影响更大,而下方的特征影响较小。#最上方的特征显示了大量的正面和负面影响,表明它在不同的观察值中对模型预测的结果有很大的不同影响。#中部的特征也显示出两种颜色的点,但点的分布更集中,影响相对较小。#底部的特征对模型的影响最小,且大部分影响较为接近零,表示这些特征对模型预测的贡献较小    结果如图5-4所示:图5-4#创建SHAP解释器explainer=shap.TreeExplainer(rf)#计算SHAP值shap_values=explainer.shap_values(X_test)#特征标签feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']plt.rcParams['font.family']='serif'plt.rcParams['font.serif']='TimesNewRoman'plt.rcParams['font.size']=13#设置字体大小为14#现在创建SHAP可视化#配色viridisSpectralcoolwarmRdYlGnRdYlBuRdBuRdGyPuOrBrBGPRGnPiYGshap.summary_plot(shap_values,X_test,feature_names=feature_label,cmap='Spectral')使颜色丰富些如图5-5所示:图5-55.5.2beeswarm#summarizetheeffectsofallthefeatures#样本决策图shap.initjs()shap_values=explainer(X_test)expected_value=explainer.expected_valueshap.plots.beeswarm(shap_values)结果如图5-6所示:图5-65.6特征重要性SHAP值shap.summary_plot(shap_values,X_test,feature_names=feature_label,plot_type='bar')#主要表示绝对重要值的大小,把SHAPvalue的样本取了绝对平均值        或者:shap.plots.bar(shap_values)    结果如图5-7、图5-8所示,本质都是一样的:图5-7图5-85.7聚类热力图:heatmapplot#热图shap.initjs()shap_values=explainer(X_test)shap.plots.heatmap(shap_values)    结果如图5-9所示:图5-95.7层次聚类shap值#层次聚类+SHAP值clust=shap.utils.hclust(X,y,linkage="single")shap.plots.bar(shap_values,clustering=clust,clustering_cutoff=1)    结果如图5-10所示:图5-105.8决策图#样本决策图shap.initjs()shap_values=explainer.shap_values(X_test)expected_value=explainer.expected_valueshap.decision_plot(expected_value,shap_values,feature_label)    结果如图5-11所示:图5-11变形1:由数值->概率#样本决策图shap.initjs()shap_values=explainer.shap_values(X_test)expected_value=explainer.expected_valuefeature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']shap.decision_plot(expected_value,shap_values,feature_label,link='logit')    结果如图5-12所示:图5-12变形2:高亮某个样本线highlightshap.decision_plot(expected_value,shap_values,feature_label,highlight=12)    结果如图5-13所示:图5-135.9特征依赖图:dependence_plot5.9.1单个特征依赖shap.dependence_plot('feature1',shap_values,X_test,interaction_index=None)    结果如图5.14所示:图5-145.9.2相互依赖图shap.dependence_plot('feature3',shap_values,X_test,interaction_index='feature4')    结果如图5-15所示:图5-155.10相互作用图:summary_plotshap.summary_plot(shap_interaction_values,X_test)    结果如图5-16所示:图5-16具体的每种解释图的含义可以搜寻以下参考文章:代码借鉴:http://t.csdnimg.cn/6JWrj理论借鉴  http://t.csdnimg.cn/6JWrjhttp://t.csdnimg.cn/H9X0Bhttp://t.csdnimg.cn/zvtA8http://t.csdnimg.cn/nygl6http://t.csdnimg.cn/zyHy0http://t.csdnimg.cn/rTPw2
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2025-1-5 09:25 , Processed in 1.166243 second(s), 26 queries .

Powered by Discuz! X3.5

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表