当前位置: 首页 > news >正文

XGBoost + SHAP 一键生成 10 张出版级模型解释图

XGBoost + SHAP 一键生成 10 张出版级模型解释图

拒绝算法黑盒!XGBoost + SHAP 一键生成 10 张出版级模型解释图

现在跑机器学习,大家最头疼的往往不是怎么把 R² 刷高,而是被导师或者业务方灵魂拷问:“你这个模型为什么会得出这个结果?哪个特征起决定性作用?

像 XGBoost、随机森林这种集成树模型,虽然精度吊打传统回归,但“黑盒”属性太强。这时候,SHAP (SHapley Additive exPlanations) 就是我们最好的破局利器。

今天分享一套我压箱底的 Python 自动化脚本。它不仅能完成 XGBoost 的训练与评估,更核心的是:它能一口气生成 10 张高颜值、高分辨率(400 DPI)的 SHAP 可视化图表(包括小提琴图、热力图、瀑布图、依赖图等),直接满足发 Paper 或做汇报的全部需求。


🛠️ 核心代码逻辑拆解

这套脚本主打一个“端到端”,从数据塞进去到美图吐出来一气呵成。为了方便理解,我们把核心操作拆解开来看看。

1. 中文映射与模型训练

在很多实际业务中(比如做城市规划、经济地理分析),我们的特征变量通常是中文(如“人均GDP”、“交通可达性”)。为了防止绘图时出现乱码,脚本里内置了字段映射字典,并在训练前完成了数据清洗和 XGBoost 拟合:

# 核心特征中文化映射
FEATURE_CN_MAP = {"Feat1": "人均GDP","Feat2": "专利/万人",# ... 其他特征
}
TARGET_CN_NAME = "韧性指数"# XGBoost 模型训练
model = xgb.XGBRegressor(n_estimators=1000, learning_rate=0.05, max_depth=6, random_state=42
)
model.fit(X_train, y_train)

2. SHAP 值的核心计算

模型算完了,接下来就是把模型喂给 SHAP 解释器。这一步是所有可视化的基础,它会计算出每个样本、每个特征对最终预测结果的贡献度(SHAP Value)。

# 实例化 SHAP 解释器并计算测试集的 SHAP 值
explainer = shap.Explainer(model)
shap_values_test = explainer(X_test)
shap_mat = shap_values_test.values# 顺手把特征按照重要性(SHAP绝对值均值)排个序,方便后续画图
feature_order = np.argsort(np.abs(shap_mat).mean(axis=0))[::-1]

3. 出版级图表定制(以热力图为例)

很多直接调 shap.plots 画出来的默认图表,颜色比较暗淡。脚本里我对 Matplotlib 进行了深度的客制化,统一使用了极简、明亮的配色(如 #5DADE2 亮蓝、#C0392B 亮红),并去除了冗余的边框,非常适合直接放进论文里。

以“全样本 SHAP 热力图”的生成为例:

# 提取排序后的数据并设置颜色阈值
heat_data = shap_mat[sample_order][:, top_idx].T
vmax = np.percentile(np.abs(heat_data), 98)# 自定义高颜值热力图
fig_h, ax_h = plt.subplots(figsize=(16, 9), dpi=150)
# 使用 RdBu_r 红蓝渐变色带,清晰对比正负贡献
im_h = ax_h.imshow(heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax)ax_h.set_title("SHAP Values Heatmap", fontsize=20, pad=12, fontweight="bold")
# 去除多余的网格线,保持画面干净
# ... (详见文末完整代码)

💻 完整源码(拿去即用)

以下是完整的 Python 脚本。运行前请确保安装了 xgboost, shap, pandas, matplotlib, scikit-learn 等依赖。

你只需要把 FILE_PATH 改成你自己的数据路径,调整一下映射字典,运行后在同级目录下就会自动生成 10 张高清美图和一个模型指标评估表(包含 R², RMSE, MAE)。

# -*- coding: utf-8 -*-
"""
功能:XGBoost 模型训练及 10 种出版级 SHAP 可视化出图
特点:高分辨率 (400 DPI)、明亮极简配色、支持中文字段映射
"""import os
import sys
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_splitif __name__ == "__main__":current_dir = os.path.dirname(os.path.abspath(__file__))sys.path = [p for p in sys.path if os.path.abspath(p or ".") != current_dir]shap = importlib.import_module("shap")# ==================== 参数配置区 ====================
FIG_DPI = 400
TOP_N = 15
RANDOM_HEATMAP_SAMPLES = 20# 数据路径请替换为您自己的 CSV
FILE_PATH = "./data/dataset.csv"
TARGET_COL = "XHM"# 字段映射(为了图表展示更直观)
FEATURE_CN_MAP = {"Feat1": "人均GDP","Feat2": "专利/万人","Feat3": "对外开放度","Feat4": "产业高级化","Feat5": "科技支出占比","Feat6": "交通可达性","Feat7": "普惠金融指数",
}
TARGET_CN_NAME = "韧性指数"# 解决图表中文字体显示问题
plt.rcParams["font.family"] = ["SimSun", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
# ====================================================# 1. 数据加载与清洗
df = pd.read_csv(FILE_PATH)
df_numeric = df.select_dtypes(include=[np.number])
df_numeric = df_numeric.loc[:, df_numeric.nunique(dropna=True) > 1]
df_numeric = df_numeric.replace([np.inf, -np.inf], np.nan).dropna()if TARGET_COL not in df_numeric.columns:raise ValueError(f"目标列 `{TARGET_COL}` 不在数据中。")X = df_numeric.drop(columns=[TARGET_COL], errors="ignore")
y = df_numeric[TARGET_COL]
X = X.rename(columns=FEATURE_CN_MAP)
y = y.rename(TARGET_CN_NAME)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=None
)# 2. 模型训练与评估
model = xgb.XGBRegressor(n_estimators=1000, learning_rate=0.05, max_depth=6, random_state=42
)
model.fit(X_train, y_train)y_pred = model.predict(X_test)
test_r2 = r2_score(y_test, y_pred)
test_rmse = np.sqrt(mean_squared_error(y_test, y_pred))
test_mae = mean_absolute_error(y_test, y_pred)metrics_df = pd.DataFrame({"Metric": ["R2", "RMSE", "MAE"],"Value": [test_r2, test_rmse, test_mae],}
)
metrics_df.to_csv("model_test_metrics.csv", index=False, encoding="utf-8-sig")# 3. 计算 SHAP 值
explainer = shap.Explainer(model)
shap_values_test = explainer(X_test)
shap_mat = shap_values_test.valuesfeature_order = np.argsort(np.abs(shap_mat).mean(axis=0))[::-1]
top_n = min(TOP_N, X_test.shape[1])
top_idx = feature_order[:top_n]
top_feature_names = [X_test.columns[i] for i in top_idx]
mean_abs_shap = np.abs(shap_mat).mean(axis=0)
top_importance = mean_abs_shap[top_idx]# ==================== 开始批量绘图 ====================# 1) Violin Plot (小提琴图)
plt.figure(figsize=(12, 8), dpi=150)
shap.summary_plot(shap_values_test,X_test,plot_type="violin",max_display=top_n,color="#5DADE2",show=False,
)
ax_v = plt.gca()
ax_v.set_title("SHAP Value Distribution (Violin Plot)", fontsize=20, pad=14, fontweight="bold")
ax_v.set_xlabel("SHAP Value", fontsize=16)
ax_v.set_ylabel("")
ax_v.tick_params(axis="both", labelsize=13)
ax_v.grid(axis="x", linestyle="--", alpha=0.2)
plt.tight_layout()
plt.savefig("shap_violin.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()# 2) 全样本热力图
sample_order = np.argsort(np.abs(shap_mat).sum(axis=1))[::-1]
heat_data = shap_mat[sample_order][:, top_idx].T
vmax = np.percentile(np.abs(heat_data), 98)fig_h, ax_h = plt.subplots(figsize=(16, 9), dpi=150)
im_h = ax_h.imshow(heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax)
ax_h.set_title("SHAP Values Heatmap", fontsize=20, pad=12, fontweight="bold")
ax_h.set_ylabel("Feature", fontsize=14)
ax_h.set_xlabel("Sample Index (sorted by total |SHAP|)", fontsize=14)
ax_h.set_yticks(np.arange(len(top_feature_names)))
ax_h.set_yticklabels(top_feature_names, fontsize=11)
ax_h.set_xticks(np.linspace(0, heat_data.shape[1] - 1, min(6, heat_data.shape[1])).astype(int))
ax_h.tick_params(axis="x", labelsize=10)
cbar_h = fig_h.colorbar(im_h, ax=ax_h, fraction=0.03, pad=0.02)
cbar_h.set_label("SHAP Value", fontsize=13)
cbar_h.ax.tick_params(labelsize=10)
fig_h.tight_layout()
fig_h.savefig("shap_heatmap.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_h)# 3) 20个随机样本热力图
rng = np.random.default_rng(42)
sample_count = min(RANDOM_HEATMAP_SAMPLES, X_test.shape[0])
rand_idx = np.sort(rng.choice(X_test.shape[0], size=sample_count, replace=False))
top12_idx = feature_order[: min(12, len(feature_order))]
random_heat_data = shap_mat[rand_idx][:, top12_idx]
random_feature_labels = [X_test.columns[i] for i in top12_idx]
random_sample_labels = [f"样本 {i}" for i in rand_idx]
vmax2 = np.percentile(np.abs(random_heat_data), 98)fig_r, ax_r = plt.subplots(figsize=(13, 10), dpi=150)
im_r = ax_r.imshow(random_heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax2, vmax=vmax2)
ax_r.set_title("SHAP Heatmap - 20 Random Samples", fontsize=20, pad=12, fontweight="bold")
ax_r.set_xlabel("Features", fontsize=14, fontweight="bold")
ax_r.set_ylabel("Samples", fontsize=14, fontweight="bold")
ax_r.set_xticks(np.arange(len(random_feature_labels)))
ax_r.set_xticklabels(random_feature_labels, rotation=40, ha="right", fontsize=11)
ax_r.set_yticks(np.arange(len(random_sample_labels)))
ax_r.set_yticklabels(random_sample_labels, fontsize=10)
cbar_r = fig_r.colorbar(im_r, ax=ax_r, fraction=0.036, pad=0.04)
cbar_r.set_label("SHAP Value", fontsize=13, fontweight="bold")
cbar_r.ax.tick_params(labelsize=10)
fig_r.tight_layout()
fig_r.savefig("shap_heatmap_20samples.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_r)# 4) Waterfall 图 (单样本解释)
waterfall_idx = min(5, len(shap_values_test) - 1)
plt.figure(figsize=(12, 9), dpi=150)
shap.plots.waterfall(shap_values_test[waterfall_idx], max_display=10, show=False)
ax_w = plt.gca()
ax_w.set_title(f"SHAP Waterfall Plot - Sample {waterfall_idx}", fontsize=20, pad=14, fontweight="bold")
ax_w.tick_params(axis="both", labelsize=12)
plt.tight_layout()
plt.savefig("shap_waterfall_sample5.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()# 5) 测试集预测散点图
fig_p, ax_p = plt.subplots(figsize=(8, 8), dpi=150)
ax_p.scatter(y_test, y_pred, s=86, color="#2E86C1", alpha=0.82, edgecolor="white", linewidth=0.8)
min_value = min(y_test.min(), y_pred.min())
max_value = max(y_test.max(), y_pred.max())
padding = (max_value - min_value) * 0.08 if max_value > min_value else 0.05
line_min = min_value - padding
line_max = max_value + padding
ax_p.plot([line_min, line_max], [line_min, line_max], color="#C0392B", linewidth=2.0, linestyle="--")
ax_p.set_xlim(line_min, line_max)
ax_p.set_ylim(line_min, line_max)
ax_p.set_title("Test Set Prediction Performance", fontsize=20, pad=14, fontweight="bold")
ax_p.set_xlabel(f"Actual {TARGET_CN_NAME}", fontsize=14)
ax_p.set_ylabel(f"Predicted {TARGET_CN_NAME}", fontsize=14)
ax_p.grid(linestyle="--", alpha=0.25)
ax_p.text(0.05,0.95,f"R² = {test_r2:.4f}\nRMSE = {test_rmse:.4f}\nMAE = {test_mae:.4f}",transform=ax_p.transAxes,va="top",fontsize=13,bbox={"boxstyle": "round,pad=0.35", "facecolor": "white", "edgecolor": "#D0D3D4", "alpha": 0.92},
)
fig_p.tight_layout()
fig_p.savefig("model_prediction_performance.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_p)# 6) 特征重要性条形图
bar_order = top_idx[::-1]
bar_names = [X_test.columns[i] for i in bar_order]
bar_values = mean_abs_shap[bar_order]fig_b, ax_b = plt.subplots(figsize=(12, 8), dpi=150)
colors = plt.cm.Blues(np.linspace(0.35, 0.95, len(bar_values)))
ax_b.barh(bar_names, bar_values, color=colors, edgecolor="white", linewidth=1.0)
ax_b.set_title("Mean |SHAP| Feature Importance", fontsize=20, pad=14, fontweight="bold")
ax_b.set_xlabel("Mean Absolute SHAP Value", fontsize=14)
ax_b.tick_params(axis="both", labelsize=12)
ax_b.grid(axis="x", linestyle="--", alpha=0.25)
for spine in ["top", "right", "left"]:ax_b.spines[spine].set_visible(False)
for value, name in zip(bar_values, bar_names):ax_b.text(value, name, f" {value:.4f}", va="center", fontsize=10)
fig_b.tight_layout()
fig_b.savefig("shap_feature_importance_bar.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_b)# 7) Beeswarm (蜂群图)
plt.figure(figsize=(12, 8), dpi=150)
shap.summary_plot(shap_values_test,X_test,plot_type="dot",max_display=top_n,show=False,
)
ax_s = plt.gca()
ax_s.set_title("SHAP Beeswarm Summary", fontsize=20, pad=14, fontweight="bold")
ax_s.set_xlabel("SHAP Value", fontsize=16)
ax_s.tick_params(axis="both", labelsize=12)
plt.tight_layout()
plt.savefig("shap_beeswarm.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()# 8-9) 依赖图 (Dependence Plots) - 自动画出前两个最重要的特征
def save_dependence_plot(feature_idx, rank):feature_name = X_test.columns[feature_idx]feature_values = X_test.iloc[:, feature_idx]feature_shap_values = shap_mat[:, feature_idx]fig_d, ax_d = plt.subplots(figsize=(10, 7), dpi=150)scatter = ax_d.scatter(feature_values,feature_shap_values,c=feature_values,cmap="coolwarm",s=78,alpha=0.85,edgecolor="white",linewidth=0.7,)ax_d.axhline(0, color="#777777", linewidth=1.2, linestyle="--", alpha=0.7)ax_d.set_title(f"Dependence Plot - {feature_name}", fontsize=18, pad=12, fontweight="bold")ax_d.set_xlabel(feature_name, fontsize=14)ax_d.set_ylabel("SHAP Value", fontsize=14)ax_d.tick_params(axis="both", labelsize=11)ax_d.grid(linestyle="--", alpha=0.22)cbar_d = fig_d.colorbar(scatter, ax=ax_d, fraction=0.045, pad=0.04)cbar_d.set_label("Feature Value", fontsize=12)cbar_d.ax.tick_params(labelsize=10)fig_d.tight_layout()fig_d.savefig(f"shap_dependence_top{rank}.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")plt.close(fig_d)for rank, feature_idx in enumerate(feature_order[: min(2, len(feature_order))], start=1):save_dependence_plot(feature_idx, rank)# 10) 累计贡献比例图
importance_pct = top_importance / top_importance.sum() * 100
cumulative_pct = np.cumsum(importance_pct)fig_c, ax_c = plt.subplots(figsize=(13, 7), dpi=150)
x_pos = np.arange(len(top_feature_names))
ax_c.bar(x_pos, importance_pct, color="#5DADE2", edgecolor="white", linewidth=1.0)
ax_c.set_title("SHAP Contribution Share", fontsize=20, pad=14, fontweight="bold")
ax_c.set_ylabel("Contribution Share (%)", fontsize=14)
ax_c.set_xticks(x_pos)
ax_c.set_xticklabels(top_feature_names, rotation=35, ha="right", fontsize=11)
ax_c.tick_params(axis="y", labelsize=11)
ax_c.grid(axis="y", linestyle="--", alpha=0.25)ax_c2 = ax_c.twinx()
ax_c2.plot(x_pos, cumulative_pct, color="#D35400", marker="o", linewidth=2.6)
ax_c2.set_ylabel("Cumulative Share (%)", fontsize=14)
ax_c2.set_ylim(0, 105)
ax_c2.tick_params(axis="y", labelsize=11)
for spine in ["top"]:ax_c.spines[spine].set_visible(False)ax_c2.spines[spine].set_visible(False)
fig_c.tight_layout()
fig_c.savefig("shap_contribution_share.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_c)print("模型与图表生成完毕,查看本地文件即可!")

最后成品:
ChatGPT Image 2026年5月31日 13_39_33

有了这套代码,以后只要换个数据集,修改一下 FEATURE_CN_MAP,一套精美的可视化解释图就出来了,直接拉满报告的专业度。欢迎大家在本地跑一跑!

http://www.zskr.cn/news/1434430.html

相关文章:

  • 如何用Untrunc快速修复损坏的MP4视频文件:终极完整指南
  • 终极解决方案:用.NET Windows Desktop Runtime彻底告别Windows应用部署难题
  • 低查重AI写教材大揭秘!高效工具推荐,快速生成优质教材!
  • 彻底解放你的Mac光标:Mousecape自定义鼠标指针完全指南
  • 无锡木木金银回收:滨湖专业的首饰回收选哪家 - LYL仔仔
  • Foresight研究报告【20260013】
  • WebLaTeX:3分钟掌握云端LaTeX写作的终极免费解决方案
  • 上海湘杰仪器仪表:淮安海绵压陷试验机怎么联系 - LYL仔仔
  • 如何用ChineseSubFinder实现影视库全自动中文字幕管理?
  • Linux下手动安装JDK
  • 5分钟解锁游戏性能:DLSS Swapper如何智能管理你的DLSS版本
  • 3个关键技巧解决ODrive电机控制中的常见性能问题
  • 基于74HC系列芯片与L293D的硬件密码锁电机驱动电路设计
  • 如何高效构建12306分布式购票系统:从零到一的完整实战指南
  • Arduino弯曲传感器与Unity交互:打造物理游戏控制器全流程指南
  • AI大模型小白入门必看:收藏这份高效学习指南,拥抱智能未来!
  • 从二极管单向导电到PCB设计:打造电压反接报警器的全流程实战
  • 抖音批量下载工具终极指南:一键获取无水印视频与完整资源
  • 揭秘AI教材写作:低查重AI工具,一键生成逻辑连贯的专业教材!
  • 大庆市窗老大门窗维修:龙凤门窗五金件更换推荐几家公司 - LYL仔仔
  • 2026 全国短视频培训机构十大综合排行榜,十大短视频培训机构最新排名 - 全国职业学校推荐官
  • 快速实现HTML转Word文档的完整指南:html-to-docx终极解决方案
  • 从零开始硬件开发:电路设计、焊接与嵌入式系统入门实践
  • 石家庄略钢商贸:无极专业的镀锌圆管批发选哪家 - LYL仔仔
  • 别再死记硬背三重循环了!用Java手把手带你理解Floyd算法的动态规划本质
  • 实测才敢推 2026 最新降AI率工具测评与推荐 - 降AI小能手
  • ChineseSubFinder:让影视字幕下载像呼吸一样简单
  • 石家庄迪奥回收指南:闲置包包这样出手,省心又划算 - 奢侈品回收测评
  • Linux系统中如何杀死一个进程
  • langchain如何初始化模型?一文详解