Causal Trees/Forests Interpretation with Feature Importance and SHAP Values
[1]:
import pandas as pd
import numpy as np
import multiprocessing as mp
np.random.seed(42)
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
import shap
import causalml
from causalml.metrics import plot_gain, plot_qini, qini_score
from causalml.dataset import synthetic_data
from causalml.inference.tree import plot_dist_tree_leaves_values, get_tree_leaves_mask
from causalml.inference.tree import CausalRandomForestRegressor, CausalTreeRegressor
from causalml.inference.tree.utils import timeit
import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'retina'
Failed to import duecredit due to No module named 'duecredit'
[2]:
import importlib
for libname in ["causalml", "shap"]:
print(f"{libname}: {importlib.metadata.version(libname)}")
causalml: 0.14.1
shap: 0.42.1
[3]:
# Simulate randomized trial: mode=2
y, X, w, tau, b, e = synthetic_data(mode=2, n=2000, p=10, sigma=3.0)
df = pd.DataFrame(X)
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
df.columns = feature_names
df['outcome'] = y
df['treatment'] = w
df['treatment_effect'] = tau
[4]:
# Split data to training and testing samples for model validation
df_train, df_test = train_test_split(df, test_size=0.2, random_state=111)
n_train, n_test = df_train.shape[0], df_test.shape[0]
X_train, y_train = df_train[feature_names], df_train['outcome'].values
X_test, y_test = df_test[feature_names], df_test['outcome'].values
treatment_train, treatment_test = df_train['treatment'].values, df_test['treatment'].values
effect_test = df_test['treatment_effect'].values
observation = X_test.loc[[0]]
CausalTreeRegressor
[5]:
ctree = CausalTreeRegressor()
ctree.fit(X=X_train.values, y=y_train, treatment=treatment_train)
[5]:
CausalTreeRegressor()
CausalRandomForestRegressor
[6]:
crforest = CausalRandomForestRegressor(criterion="causal_mse",
min_samples_leaf=200,
control_name=0,
n_estimators=50,
n_jobs=mp.cpu_count() - 1)
crforest.fit(X=X_train, y=y_train, treatment=treatment_train)
[6]:
CausalRandomForestRegressor(min_samples_leaf=200, n_estimators=50, n_jobs=11)
1. Impurity-based feature importance
[7]:
df_importances = pd.DataFrame({'tree': ctree.feature_importances_,
'forest': crforest.feature_importances_,
'feature': feature_names
})
forest_std = np.std([tree.feature_importances_ for tree in crforest.estimators_], axis=0)
fig, ax = plt.subplots()
df_importances['tree'].plot.bar(ax=ax)
ax.set_title("Causal Tree feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()
fig, ax = plt.subplots()
df_importances['forest'].plot.bar(yerr=forest_std, ax=ax)
ax.set_title("Causal Forest feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()
2. Permutation-based feature importance
[8]:
for name, model in zip(('Causal Tree', 'Causal Forest'), (ctree, crforest)):
imp = permutation_importance(model, X_test, y_test,
n_repeats=50,
random_state=0)
fig, ax = plt.subplots()
ax.set_title(f"{name} feature importances")
ax.set_ylabel("Mean decrease in impurity")
plt.bar(feature_names, imp['importances_mean'], yerr=imp['importances_std'])
ax.set_xticklabels(feature_names, rotation=45)
plt.show()
SHAP values
TreeExplainer
Details: https://shap.readthedocs.io/en/latest/generated/shap.TreeExplainer.html#shap.TreeExplainer
[10]:
tree_explainer = shap.TreeExplainer(ctree)
# Expected values for treatment=0 and treatment=1. i.e. Y|X,T=0 and Y|X,T=1
tree_explainer.expected_value
[10]:
array([0.93121212, 1.63459276])
[11]:
# Tree Explainer for treatment=0
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[0],
shap_values[0],
observation)
[11]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[12]:
# Tree Explainer for treatment=1
tree_explainer = shap.TreeExplainer(ctree)
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[1],
shap_values[1],
observation)
[12]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[13]:
# Tree Explainer for treatment=0
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[0],
shap_values[0],
observation)
[13]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[14]:
# Tree Explainer for treatment=1
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[1],
shap_values[1],
observation)
[14]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[15]:
for i in [0, 1]:
print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
shap.dependence_plot("feature_0",
cforest_explainer.shap_values(X_test)[i],
X_test,
interaction_index="feature_2")
If treatment = 0,i.e. Y_hat|X,T=0
If treatment = 1,i.e. Y_hat|X,T=1
[16]:
# Sort the features indexes by their importance in the model
# (sum of SHAP value magnitudes over the validation dataset)
for i in [0, 1]:
print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
shap_values = cforest_explainer.shap_values(X_test)[i]
top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))
# Make SHAP plots of the three most important features
for i in range(4):
shap.dependence_plot(top_inds[i], shap_values, X_test)
If treatment = 0,i.e. Y_hat|X,T=0
If treatment = 1,i.e. Y_hat|X,T=1
[ ]: