Causal Trees/Forests Interpretation with Feature Importance and SHAP Values

[1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
[2]:
import pandas as pd
import numpy as np
import multiprocessing as mp
np.random.seed(42)

from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
import shap

import causalml
from causalml.metrics import plot_gain, plot_qini, qini_score
from causalml.dataset import synthetic_data
from causalml.inference.tree import plot_dist_tree_leaves_values, get_tree_leaves_mask
from causalml.inference.tree import CausalRandomForestRegressor, CausalTreeRegressor
from causalml.inference.tree.utils import timeit

import matplotlib.pyplot as plt
import seaborn as sns

%config InlineBackend.figure_format = 'retina'
Failed to import duecredit due to No module named 'duecredit'
[3]:
import importlib
for libname in ["causalml", "shap"]:
    print(f"{libname}: {importlib.metadata.version(libname)}")
causalml: 0.15.3.dev0
shap: 0.40.1.dev640
[4]:
# Simulate randomized trial: mode=2
y, X, w, tau, b, e = synthetic_data(mode=2, n=2000, p=10, sigma=3.0)

df = pd.DataFrame(X)
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
df.columns = feature_names
df['outcome'] = y
df['treatment'] = w
df['treatment_effect'] = tau
[5]:
# Split data to training and testing samples for model validation
df_train, df_test = train_test_split(df, test_size=0.2, random_state=111)
n_train, n_test = df_train.shape[0], df_test.shape[0]

X_train, y_train = df_train[feature_names], df_train['outcome'].values
X_test, y_test = df_test[feature_names], df_test['outcome'].values
treatment_train, treatment_test = df_train['treatment'].values, df_test['treatment'].values
effect_test = df_test['treatment_effect'].values

observation = X_test.loc[[0]]

CausalTreeRegressor

[6]:
ctree = CausalTreeRegressor()
ctree.fit(X=X_train.values, y=y_train, treatment=treatment_train)
[6]:
CausalTreeRegressor()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

CausalRandomForestRegressor

[7]:
crforest = CausalRandomForestRegressor(criterion="causal_mse",
                                  min_samples_leaf=200,
                                  control_name=0,
                                  n_estimators=50,
                                  n_jobs=mp.cpu_count() - 1)
crforest.fit(X=X_train, y=y_train, treatment=treatment_train)
[7]:
CausalRandomForestRegressor(min_samples_leaf=200, n_estimators=50, n_jobs=7)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

1. Impurity-based feature importance

[8]:
df_importances = pd.DataFrame({'tree': ctree.feature_importances_,
                               'forest': crforest.feature_importances_,
                               'feature': feature_names
                              })
forest_std = np.std([tree.feature_importances_ for tree in crforest.estimators_], axis=0)

fig, ax = plt.subplots()
df_importances['tree'].plot.bar(ax=ax)
ax.set_title("Causal Tree feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()

fig, ax = plt.subplots()
df_importances['forest'].plot.bar(yerr=forest_std, ax=ax)
ax.set_title("Causal Forest feature importances")
ax.set_ylabel("Mean decrease in impurity")
ax.set_xticklabels(feature_names, rotation=45)
plt.show()
../_images/examples_causal_trees_interpretation_11_0.png
../_images/examples_causal_trees_interpretation_11_1.png

2. Permutation-based feature importance

[9]:
for name, model in zip(('Causal Tree', 'Causal Forest'), (ctree, crforest)):

    imp = permutation_importance(model, X_test, y_test,
                                 n_repeats=50,
                                 random_state=0)

    fig, ax = plt.subplots()
    ax.set_title(f"{name} feature importances")
    ax.set_ylabel("Mean decrease in impurity")
    plt.bar(feature_names, imp['importances_mean'], yerr=imp['importances_std'])
    ax.set_xticklabels(feature_names, rotation=45)
    plt.show()
../_images/examples_causal_trees_interpretation_13_0.png
../_images/examples_causal_trees_interpretation_13_1.png

SHAP values

TreeExplainer

Details: https://shap.readthedocs.io/en/latest/generated/shap.TreeExplainer.html#shap.TreeExplainer

[16]:
shap.__version__
[16]:
'0.40.1.dev640'
[17]:
tree_explainer = shap.TreeExplainer(ctree)
# Expected values for treatment=0 and treatment=1. i.e. Y|X,T=0 and Y|X,T=1
tree_explainer.expected_value
[17]:
array([0.93121212, 1.63459276])
[18]:
# Tree Explainer for treatment=0
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[0],
                shap_values[0],
                observation)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[18], line 4
      2 shap.initjs()
      3 shap_values = tree_explainer.shap_values(observation)
----> 4 shap.force_plot(tree_explainer.expected_value[0],
      5                 shap_values[0],
      6                 observation)

File ~/dev/shap/shap/plots/_force.py:217, in force(base_value, shap_values, features, feature_names, out_names, link, plot_cmap, matplotlib, show, figsize, ordering_keys, ordering_keys_time_format, text_rotation, contribution_threshold)
    215     display_features = ["" for i in range(len(feature_names))]
    216 else:
--> 217     display_features = features[k, :]
    219 instance = Instance(np.ones((1, len(feature_names))), display_features)
    220 e = AdditiveExplanation(
    221     base_value,
    222     np.sum(shap_values[k, :]) + base_value,
   (...)
    228     DenseData(np.ones((1, len(feature_names))), list(feature_names))
    229 )

IndexError: index 1 is out of bounds for axis 0 with size 1
[21]:
# Tree Explainer for treatment=1
tree_explainer = shap.TreeExplainer(ctree)
shap.initjs()
shap_values = tree_explainer.shap_values(observation)
shap.force_plot(tree_explainer.expected_value[1],
                shap_values[1],
                observation)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[21], line 6
      3 shap.initjs()
      4 shap_values = tree_explainer.shap_values(observation)
      5 shap.force_plot(tree_explainer.expected_value[1],
----> 6                 shap_values[1],
      7                 observation)

IndexError: index 1 is out of bounds for axis 0 with size 1
[22]:
# Tree Explainer for treatment=0
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[0],
                shap_values[0],
                observation)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[22], line 5
      3 shap.initjs()
      4 shap_values = cforest_explainer.shap_values(observation)
----> 5 shap.force_plot(cforest_explainer.expected_value[0],
      6                 shap_values[0],
      7                 observation)

File ~/dev/shap/shap/plots/_force.py:217, in force(base_value, shap_values, features, feature_names, out_names, link, plot_cmap, matplotlib, show, figsize, ordering_keys, ordering_keys_time_format, text_rotation, contribution_threshold)
    215     display_features = ["" for i in range(len(feature_names))]
    216 else:
--> 217     display_features = features[k, :]
    219 instance = Instance(np.ones((1, len(feature_names))), display_features)
    220 e = AdditiveExplanation(
    221     base_value,
    222     np.sum(shap_values[k, :]) + base_value,
   (...)
    228     DenseData(np.ones((1, len(feature_names))), list(feature_names))
    229 )

IndexError: index 1 is out of bounds for axis 0 with size 1
[23]:
# Tree Explainer for treatment=1
cforest_explainer = shap.TreeExplainer(crforest)
shap.initjs()
shap_values = cforest_explainer.shap_values(observation)
shap.force_plot(cforest_explainer.expected_value[1],
                shap_values[1],
                observation)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[23], line 6
      3 shap.initjs()
      4 shap_values = cforest_explainer.shap_values(observation)
      5 shap.force_plot(cforest_explainer.expected_value[1],
----> 6                 shap_values[1],
      7                 observation)

IndexError: index 1 is out of bounds for axis 0 with size 1
[24]:
for i in [0, 1]:
    print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
    shap.dependence_plot("feature_0",
                         cforest_explainer.shap_values(X_test)[i],
                         X_test,
                         interaction_index="feature_2")
If treatment = 0,i.e. Y_hat|X,T=0
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[24], line 3
      1 for i in [0, 1]:
      2     print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
----> 3     shap.dependence_plot("feature_0", 
      4                          cforest_explainer.shap_values(X_test)[i], 
      5                          X_test, 
      6                          interaction_index="feature_2")

File ~/dev/shap/shap/plots/_scatter.py:612, in dependence_legacy(ind, shap_values, features, feature_names, display_features, interaction_index, color, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ax, show, ymin, ymax)
    609         pl.show()
    610     return
--> 612 assert shap_values.shape[0] == features.shape[0], \
    613     "'shap_values' and 'features' values must have the same number of rows!"
    614 assert shap_values.shape[1] == features.shape[1], \
    615     "'shap_values' must have the same number of columns as 'features'!"
    617 # get both the raw and display feature values

AssertionError: 'shap_values' and 'features' values must have the same number of rows!
../_images/examples_causal_trees_interpretation_21_2.png
[25]:
# Sort the features indexes by their importance in the model
# (sum of SHAP value magnitudes over the validation dataset)


for i in [0, 1]:
    print(f"If treatment = {i},i.e. Y_hat|X,T={i}")
    shap_values = cforest_explainer.shap_values(X_test)[i]
    top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))

    # Make SHAP plots of the three most important features
    for i in range(4):
        shap.dependence_plot(top_inds[i], shap_values, X_test)
If treatment = 0,i.e. Y_hat|X,T=0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[25], line 12
     10 # Make SHAP plots of the three most important features
     11 for i in range(4):
---> 12     shap.dependence_plot(top_inds[i], shap_values, X_test)

File ~/dev/shap/shap/plots/_scatter.py:572, in dependence_legacy(ind, shap_values, features, feature_names, display_features, interaction_index, color, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ax, show, ymin, ymax)
    570 if not hasattr(ind, "__len__"):
    571     if interaction_index == "auto":
--> 572         interaction_index = approximate_interactions(ind, shap_values, features)[0]
    573     interaction_index = convert_name(interaction_index, shap_values, feature_names)
    574 categorical_interaction = False

File ~/dev/shap/shap/utils/_general.py:123, in approximate_interactions(index, shap_values, X, feature_names)
    121 x = X[inds, index]
    122 srt = np.argsort(x)
--> 123 shap_ref = shap_values[inds, index]
    124 shap_ref = shap_ref[srt]
    125 inc = max(min(int(len(x) / 10.0), 50), 1)

IndexError: index 10 is out of bounds for axis 0 with size 10
[ ]: