Meta-Learner Benchmarks with Semi-synthetic Data in Schuler, A., Jung, K., Tibshirani, R., Hastie, T., and Shah, N. Synth-validation: Selecting the best causal inference method for a given dataset (2017)

This notebook compares X-, R-, and T learners using the Constrained gradient boosting semi synthetic framework described in Schuler, A., Jung, K., Tibshirani, R., Hastie, T., and Shah, N (2017) using the IHDP dataset.

[1]:
import numpy as np
import pandas as pd

from causalml.inference.meta import BaseTRegressor
from causalml.inference.meta import BaseXRegressor
from causalml.inference.meta import BaseRRegressor

from causalml.dataset.semiSynthetic import SemiSynthDataGenerator

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.base import clone

from sklearn.linear_model import LogisticRegression, Lasso

from copy import deepcopy

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
/Users/iyarlin/.pyenv/versions/3.11.6/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Failed to import duecredit due to No module named 'duecredit'
[2]:
import importlib
print(importlib.metadata.version('causalml') )
0.15.5
[3]:
data = pd.read_csv('data/ihdp_npci_8.csv', header=None)
cols =  ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [str(i) for i in range(25)]
data.columns = cols

X = data[[str(i) for i in range(5)]]
y = data["y_factual"]
w = data["treatment"]
[4]:
class NaiveLearner():
    def _init_(self):
        pass
    def fit(self, X, treatment, y):
        self.ate = y[treatment==1].mean() - y[treatment==0].mean()
    def predict(self, X, p):
        return np.repeat(self.ate, len(X))
    def estimate_ate(self, X, treatment, y):
        ate = y[treatment==1].mean() - y[treatment==0].mean()
        return [ate] # No need for CI right now
[5]:
def run_experiments(X, w, y, learner_dict, propensity_learner, K=10, n=None, **data_generator_kwargs):

    synth_gen = SemiSynthDataGenerator(**data_generator_kwargs)
    synth_gen.fit(X, w, y)
    datasets = synth_gen.generate(K=K, n=n)
    result_list = []

    for q in range(len(datasets)):
        for k in range(len(datasets[q])):
            X = datasets[q][k][[str(i) for i in range(5)]]
            w = datasets[q][k]['w']
            y = datasets[q][k]['y']
            tau_i = datasets[q][k]['tau_i']
            X_train, X_test, w_train, _, y_train, _, _, tau_test = train_test_split(
                X, w, y, tau_i, test_size=0.2, random_state=111)

            em = clone(propensity_learner)
            em.fit(X_train, w_train)
            e_hat_test = em.predict_proba(X_test)[:, 1]

            for learner in learner_dict.keys():
                model = deepcopy(learner_dict[learner])
                model.fit(X = X_train, treatment = w_train, y = y_train)
                hat_tau = model.predict(X_test, p=e_hat_test)
                pehe = mean_squared_error(tau_test, hat_tau)
                result_list.append([q, k, learner, pehe])

    cols = ['q', 'k', 'learner', 'pehe']
    df_res = pd.DataFrame(result_list, columns=cols)
    return df_res

Lasso based experiments

[6]:
learner_dict = {
    'Naive-Learner': NaiveLearner(),
    'T-Learner': BaseTRegressor(learner=Lasso()),
    'X-Learner': BaseXRegressor(learner=Lasso()),
    'R-Learner': BaseRRegressor(learner=Lasso())
}

propensity_learner = LogisticRegression(penalty='l1', solver='liblinear')
df_res_lasso = run_experiments(X, w, y, learner_dict, propensity_learner, Q = 5, B=1, n = 10000)
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
[7]:
sns.boxplot(x='learner', y='pehe', data=df_res_lasso, linewidth=1, showfliers=False)
plt.ylabel('PEHE (MSE)')
plt.xlabel('')
plt.title('All experiments (Lasso)')
plt.show()
../_images/examples_benchmark_semi_synthetic_simulation_studies_8_0.png

The paper discusses benchmarking ATE esimation so let’s do that as well:

[8]:
def run_experiments2(X, w, y, learner_dict, K=10, n = None, **data_generator_kwargs):

    synth_gen = SemiSynthDataGenerator(**data_generator_kwargs)
    synth_gen.fit(X, w, y)
    datasets = synth_gen.generate(K=K, n=n)
    result_list = []

    for q in range(len(datasets)):
        for k in range(len(datasets[q])):
            X = datasets[q][k][[str(i) for i in range(5)]]
            w = datasets[q][k]['w']
            y = datasets[q][k]['y']
            tau_i = datasets[q][k]['tau_i']

            X_train, X_test, w_train, _, y_train, _, _, tau_test = train_test_split(
                X, w, y, tau_i, test_size=0.2, random_state=111)

            true_ate = tau_test.mean()

            for learner in learner_dict.keys():
                model = deepcopy(learner_dict[learner])
                ate_hat = float(model.estimate_ate(X = X_train, treatment = w_train, y = y_train)[0])
                ate_diff = np.abs(ate_hat - true_ate)
                result_list.append([q, k, learner, ate_diff, true_ate])

    cols = ['q', 'k', 'learner', 'ate_diff', 'true_ate']
    df_res = pd.DataFrame(result_list, columns=cols)
    return df_res
[9]:
learner_dict = {
    'Naive-Learner': NaiveLearner(),
    'T-Learner': BaseTRegressor(learner=Lasso()),
    'X-Learner': BaseXRegressor(learner=Lasso()),
    'R-Learner': BaseRRegressor(learner=Lasso())
}

df_res_lasso2 = run_experiments2(X, w, y, learner_dict, Q = 5, B=1, n = 10000)
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
Failed to import duecredit due to No module named 'duecredit'
[10]:
sns.boxplot(x='learner', y='ate_diff', data=df_res_lasso2, linewidth=1, showfliers=False)
plt.ylabel('ATE absolute diff')
plt.xlabel('')
plt.title('All experiments (Lasso)')
plt.show()
../_images/examples_benchmark_semi_synthetic_simulation_studies_12_0.png