R语言如何做SHAP可解释性分析?

作品简介

 【文章转载自微信公众号:python算法小当家

SHAP(SHapley Additive exPlanations)是一种强大的工具,用于解释复杂模型的预测结果。通过分配给每个特征的贡献值,SHAP能帮助我们深入理解模型决策背后的原因。

SHAP分析Python实现可见往期推文:

顶刊复现:机器学习解释利器—SHAP实战【免费获取】

为什么用R语言进行SHAP分析?R语言不仅拥有丰富的数据分析和可视化包,还支持许多机器学习模型。结合shapviz、ggplot2和xgboost等包,R语言可以轻松实现SHAP值计算和可视化。

本文将说明如何使用R和shapviz进行SHAP分析,从数据读取到XGBoost模型训练,再到生成和保存各种图表。

代码免费获取见文末

数据集介绍

下面的数据集展示了钻石的几个重要特征以及其对应的价格。每一行代表一颗钻石,包含以下列:

  • carat(克拉重量):表示钻石的重量,单位是克拉(carat)。
  • cut(切工):表示钻石的切工质量,常见的评级有 Ideal、Premium、Good、Very Good 等。
  • color(颜色):表示钻石的颜色,使用字母从 D 到 J 来表示,其中 D 是最白的。
  • clarity(净度):表示钻石的净度,常见的评级有 SI2、SI1、VS1、VS2、VVS1、VVS2 等。
  • price(价格):表示钻石的价格,以美元为单位。

代码展示

1、安装和加载必要的R包

#安装并加载必要的包
install.packages("shapviz")
install.packages("gridExtra")
install.packages("readxl")
library(shapviz)
library(ggplot2)
library(xgboost)
library(gridExtra)
library(readxl)

2、读取数据并训练XGBoost模型

接下来,我们选择特征变量,并使用XGBoost训练模型。xgb.DMatrix函数将数据转换为DMatrix格式,xgb.train函数用于训练模型。

# 选择特征变量
x <- c("carat""cut""color""clarity")
# 准备训练数据,将特征变量转换为矩阵格式
dtrain <- xgb.DMatrix(data.matrix(diamonds_data[x]), label = diamonds_data$price)
# 训练XGBoost模型
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)

3、特征重要性图

# 绘制特征重要性图表并设置经典的黑白主题
importance_plot <- sv_importance(shp, show_numbers = TRUE) + theme_bw()
# 保存特征重要性图表
ggsave("importance_plot.png", importance_plot, width = 8, height = 6)

这个图展示了不同特征对钻石价格预测的重要性,通过SHAP值进行衡量。从图中可以看出,carat(克拉重量)是最重要的特征,对价格的影响最大,其平均SHAP值高达3172。其次是clarity(净度),平均SHAP值为594。color(颜色)和cut(切工)的影响相对较小,平均SHAP值分别为403和99。这表明在预测钻石价格时,克拉重量起到了决定性的作用,而净度、颜色和切工的影响则相对较弱。

4、依赖图

# 生成依赖图并设置经典的黑白主题
plot1 <- sv_dependence(shp, v = "carat") + theme_bw()
plot2 <- sv_dependence(shp, v = "cut") + theme_bw()
plot3 <- sv_dependence(shp, v = "color") + theme_bw()
plot4 <- sv_dependence(shp, v = "clarity") + theme_bw()
# 保存依赖图
ggsave("dependence_plot_carat.png", plot1, width = 8, height = 6)
ggsave("dependence_plot_cut.png", plot2, width = 8, height = 6)
ggsave("dependence_plot_color.png", plot3, width = 8, height = 6)
ggsave("dependence_plot_clarity.png", plot4, width = 8, height = 6)
# 使用gridExtra包将四个子图排列在一起,设置为两列布局
combined_plot <- grid.arrange(plot1, plot2, plot3, plot4, ncol = 2)
# 保存排列后的图表
ggsave("combined_dependence_plot.png", plot = combined_plot, width = 12, height = 10)

上图展示了不同特征对钻石价格的SHAP值的影响。左上图表明克拉重量(carat)越大,SHAP值越高,说明其对价格的正向影响显著。右上图显示了不同切工(cut)等级对价格的影响,发现Ideal和Premium切工的SHAP值较高,表示对价格有正向贡献。左下图展示了颜色(color)对价格的影响,颜色等级越靠近D,SHAP值越高,表示颜色越白对价格的贡献越大。右下图显示了净度(clarity)对价格的影响,净度等级越高(如IF和VVS),SHAP值越高,对价格的正向贡献越大。整体来看,克拉重量和净度对钻石价格的影响最为显著。


5、Waterfall图和Force图



第一个图显示了一个特定钻石样本的SHAP值分解。克拉重量(carat = 0.28)对价格有负面影响(-4047),净度(clarity = VVS2)和切工(cut = Ideal)对价格有正面影响(分别为+616和+0),整体价格预测低于平均水平。

第二个图展示了另一个钻石样本的SHAP值分解。克拉重量(carat = 1.5)对价格有很大的正面影响(+4733),但颜色(color = J)、净度(clarity = SI2)和切工(cut = Good)对价格有负面影响(分别为-1304、-1017和-358),整体价格预测高于平均水平。


6、beeswarm图


上图展示了各特征对钻石价格的SHAP值的分布情况,使用蜜蜂图形可视化。颜色表示特征值的高低(由低到高分别为紫色到黄色)。克拉重量(carat)对价格有显著的正向影响,高克拉重量的钻石(黄色点)具有更高的SHAP值。净度(clarity)和颜色(color)特征对价格的影响较为复杂,较高的净度和较白的颜色对价格有正面影响,但效果不如克拉重量明显。切工(cut)的影响较小且分布较为集中。总体来看,克拉重量是影响价格最显著的特征。





创作时间: