Causal Trees/Forests Interpretation with Feature Importance and SHAP Values
[1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
[2]:
import pandas as pd
import numpy as np
import multiprocessing as mp
np.random.seed(42)
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
import shap
import causalml
from causalml.metrics import plot_gain, plot_qini, qini_score
from causalml.dataset import synthetic_data
from causalml.inference.tree import plot_dist_tree_leaves_values, get_tree_leaves_mask
from causalml.inference.tree import CausalRandomForestRegressor, CausalTreeRegressor
from causalml.inference.tree.utils import timeit
import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'retina'
Failed to import duecredit due to No module named 'duecredit'
[3]:
import importlib
for libname in ["causalml", "shap"]:
print(f"{libname}: {importlib.metadata.version(libname)}")
causalml: 0.15.3.dev0
shap: 0.40.1.dev640
[4]:
# Simulate randomized trial: mode=2
y, X, w, tau, b, e = synthetic_data(mode=2, n=2000, p=10, sigma=3.0)
df = pd.DataFrame(X)
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
df.columns = feature_names
df['outcome'] = y
df['treatment'] = w
df['treatment_effect'] = tau
[5]:
# Split data to training and testing samples for model validation
df_train, df_test = train_test_split(df, test_size=0.2, random_state=111)
n_train, n_test = df_train.shape[0], df_test.shape[0]
X_train, y_train = df_train[feature_names], df_train['outcome'].values
X_test, y_test = df_test[feature_names], df_test['outcome'].values
treatment_train, treatment_test = df_train['treatment'].values, df_test['treatment'].values
effect_test = df_test['treatment_effect'].values
observation = X_test.loc[[0]]
CausalTreeRegressor
[6]:
ctree = CausalTreeRegressor()
ctree.fit(X=X_train.values, y=y_train, treatment=treatment_train)
[6]:
CausalTreeRegressor()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
CausalTreeRegressor()
CausalRandomForestRegressor
[7]:
crforest = CausalRandomForestRegressor(criterion="causal_mse",
min_samples_leaf=200,
control_name=0,
n_estimators=50,
n_jobs=mp.cpu_count() - 1)
crforest.fit(X=X_train, y=y_train, treatment=treatment_train)
[7]:
CausalRandomForestRegressor(min_samples_leaf=200, n_estimators=50, n_jobs=7)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
CausalRandomForestRegressor(min_samples_leaf=200, n_estimators=50, n_jobs=7)
1. Impurity-based feature importance
[8]:
df_importances = pd.DataFrame({'tree': ctree.feature_importances_,
'forest': crforest.feature_importances_,
'feature': feature_names
})
forest_std = np.std([tree.feature_importances_ for tree in crforest.estimators_], axis=0)
fig, ax = plt.subplots()
df_importances['tree'].plot.bar(ax=ax)
ax.set_title("Causal Tree feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()
fig, ax = plt.subplots()
df_importances['forest'].plot.bar(yerr=forest_std, ax=ax)
ax.set_title("Causal Forest feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()


2. Permutation-based feature importance
[9]:
for name, model in zip(('Causal Tree', 'Causal Forest'), (ctree, crforest)):
imp = permutation_importance(model, X_test, y_test,
n_repeats=50,
random_state=0)
fig, ax = plt.subplots()
ax.set_title(f"{name} feature importances")
ax.set_ylabel("Mean decrease in impurity")
plt.bar(feature_names, imp['importances_mean'], yerr=imp['importances_std'])
ax.set_xticklabels(feature_names, rotation=45)
plt.show()


SHAP values
TreeExplainer
Details: https://shap.readthedocs.io/en/latest/generated/shap.TreeExplainer.html#shap.TreeExplainer
[16]:
shap.__version__
[16]:
'0.40.1.dev640'
[17]:
tree_explainer = shap.TreeExplainer(ctree)
# Expected values for treatment=0 and treatment=1. i.e. Y|X,T=0 and Y|X,T=1
tree_explainer.expected_value
[17]:
array([0.93121212, 1.63459276])
[18]:
# Tree Explainer for treatment=0
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[0],
shap_values[0],
observation)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[18], line 4
2 shap.initjs()
3 shap_values = tree_explainer.shap_values(observation)
----> 4 shap.force_plot(tree_explainer.expected_value[0],
5 shap_values[0],
6 observation)
File ~/dev/shap/shap/plots/_force.py:217, in force(base_value, shap_values, features, feature_names, out_names, link, plot_cmap, matplotlib, show, figsize, ordering_keys, ordering_keys_time_format, text_rotation, contribution_threshold)
215 display_features = ["" for i in range(len(feature_names))]
216 else:
--> 217 display_features = features[k, :]
219 instance = Instance(np.ones((1, len(feature_names))), display_features)
220 e = AdditiveExplanation(
221 base_value,
222 np.sum(shap_values[k, :]) + base_value,
(...)
228 DenseData(np.ones((1, len(feature_names))), list(feature_names))
229 )
IndexError: index 1 is out of bounds for axis 0 with size 1
[21]:
# Tree Explainer for treatment=1
tree_explainer = shap.TreeExplainer(ctree)
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[1],
shap_values[1],
observation)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[21], line 6
3 shap.initjs()
4 shap_values = tree_explainer.shap_values(observation)
5 shap.force_plot(tree_explainer.expected_value[1],
----> 6 shap_values[1],
7 observation)
IndexError: index 1 is out of bounds for axis 0 with size 1
[22]:
# Tree Explainer for treatment=0
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[0],
shap_values[0],
observation)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[22], line 5
3 shap.initjs()
4 shap_values = cforest_explainer.shap_values(observation)
----> 5 shap.force_plot(cforest_explainer.expected_value[0],
6 shap_values[0],
7 observation)
File ~/dev/shap/shap/plots/_force.py:217, in force(base_value, shap_values, features, feature_names, out_names, link, plot_cmap, matplotlib, show, figsize, ordering_keys, ordering_keys_time_format, text_rotation, contribution_threshold)
215 display_features = ["" for i in range(len(feature_names))]
216 else:
--> 217 display_features = features[k, :]
219 instance = Instance(np.ones((1, len(feature_names))), display_features)
220 e = AdditiveExplanation(
221 base_value,
222 np.sum(shap_values[k, :]) + base_value,
(...)
228 DenseData(np.ones((1, len(feature_names))), list(feature_names))
229 )
IndexError: index 1 is out of bounds for axis 0 with size 1
[23]:
# Tree Explainer for treatment=1
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[1],
shap_values[1],
observation)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[23], line 6
3 shap.initjs()
4 shap_values = cforest_explainer.shap_values(observation)
5 shap.force_plot(cforest_explainer.expected_value[1],
----> 6 shap_values[1],
7 observation)
IndexError: index 1 is out of bounds for axis 0 with size 1
[24]:
for i in [0, 1]:
print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
shap.dependence_plot("feature_0",
cforest_explainer.shap_values(X_test)[i],
X_test,
interaction_index="feature_2")
If treatment = 0,i.e. Y_hat|X,T=0
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[24], line 3
1 for i in [0, 1]:
2 print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
----> 3 shap.dependence_plot("feature_0",
4 cforest_explainer.shap_values(X_test)[i],
5 X_test,
6 interaction_index="feature_2")
File ~/dev/shap/shap/plots/_scatter.py:612, in dependence_legacy(ind, shap_values, features, feature_names, display_features, interaction_index, color, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ax, show, ymin, ymax)
609 pl.show()
610 return
--> 612 assert shap_values.shape[0] == features.shape[0], \
613 "'shap_values' and 'features' values must have the same number of rows!"
614 assert shap_values.shape[1] == features.shape[1], \
615 "'shap_values' must have the same number of columns as 'features'!"
617 # get both the raw and display feature values
AssertionError: 'shap_values' and 'features' values must have the same number of rows!

[25]:
# Sort the features indexes by their importance in the model
# (sum of SHAP value magnitudes over the validation dataset)
for i in [0, 1]:
print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
shap_values = cforest_explainer.shap_values(X_test)[i]
top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))
# Make SHAP plots of the three most important features
for i in range(4):
shap.dependence_plot(top_inds[i], shap_values, X_test)
If treatment = 0,i.e. Y_hat|X,T=0
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[25], line 12
10 # Make SHAP plots of the three most important features
11 for i in range(4):
---> 12 shap.dependence_plot(top_inds[i], shap_values, X_test)
File ~/dev/shap/shap/plots/_scatter.py:572, in dependence_legacy(ind, shap_values, features, feature_names, display_features, interaction_index, color, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ax, show, ymin, ymax)
570 if not hasattr(ind, "__len__"):
571 if interaction_index == "auto":
--> 572 interaction_index = approximate_interactions(ind, shap_values, features)[0]
573 interaction_index = convert_name(interaction_index, shap_values, feature_names)
574 categorical_interaction = False
File ~/dev/shap/shap/utils/_general.py:123, in approximate_interactions(index, shap_values, X, feature_names)
121 x = X[inds, index]
122 srt = np.argsort(x)
--> 123 shap_ref = shap_values[inds, index]
124 shap_ref = shap_ref[srt]
125 inc = max(min(int(len(x) / 10.0), 50), 1)
IndexError: index 10 is out of bounds for axis 0 with size 10
[ ]: