from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import auc
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from xgboost import XGBRegressor
from scipy.stats import entropy
import warnings
from causalml.inference.meta import (
BaseXRegressor,
BaseRRegressor,
BaseSRegressor,
BaseTRegressor,
)
from causalml.inference.tree.causal.causaltree import CausalTreeRegressor
from causalml.propensity import ElasticNetPropensityModel
from causalml.metrics import plot_gain, get_cumgain
plt.style.use("fivethirtyeight")
warnings.filterwarnings("ignore")
KEY_GENERATED_DATA = "generated_data"
KEY_ACTUAL = "Actuals"
RANDOM_SEED = 42
[docs]def get_synthetic_preds(synthetic_data_func, n=1000, estimators={}):
"""Generate predictions for synthetic data using specified function (single simulation)
Args:
synthetic_data_func (function): synthetic data generation function
n (int, optional): number of samples
estimators (dict of object): dict of names and objects of treatment effect estimators
Returns:
(dict): dict of the actual and estimates of treatment effects
"""
y, X, w, tau, b, e = synthetic_data_func(n=n)
preds_dict = {}
preds_dict[KEY_ACTUAL] = tau
preds_dict[KEY_GENERATED_DATA] = {
"y": y,
"X": X,
"w": w,
"tau": tau,
"b": b,
"e": e,
}
# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat = p_model.fit_predict(X, w)
if estimators:
for name, learner in estimators.items():
try:
preds_dict[name] = learner.fit_predict(
X=X, treatment=w, y=y, p=p_hat
).flatten()
except TypeError:
preds_dict[name] = learner.fit_predict(X=X, treatment=w, y=y).flatten()
else:
for base_learner, label_l in zip(
[BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
["S", "T", "X", "R"],
):
for model, label_m in zip([LinearRegression, XGBRegressor], ["LR", "XGB"]):
learner = base_learner(model())
model_name = "{} Learner ({})".format(label_l, label_m)
try:
preds_dict[model_name] = learner.fit_predict(
X=X, treatment=w, y=y, p=p_hat
).flatten()
except TypeError:
preds_dict[model_name] = learner.fit_predict(
X=X, treatment=w, y=y
).flatten()
learner = CausalTreeRegressor(random_state=RANDOM_SEED)
preds_dict["Causal Tree"] = learner.fit_predict(X=X, treatment=w, y=y).flatten()
return preds_dict
[docs]def get_synthetic_summary(synthetic_data_func, n=1000, k=1, estimators={}):
"""Generate a summary for predictions on synthetic data using specified function
Args:
synthetic_data_func (function): synthetic data generation function
n (int, optional): number of samples per simulation
k (int, optional): number of simulations
"""
summaries = []
for i in range(k):
synthetic_preds = get_synthetic_preds(
synthetic_data_func, n=n, estimators=estimators
)
actuals = synthetic_preds[KEY_ACTUAL]
synthetic_summary = pd.DataFrame(
{
label: [preds.mean(), mse(preds, actuals)]
for label, preds in synthetic_preds.items()
if label != KEY_GENERATED_DATA
},
index=["ATE", "MSE"],
).T
synthetic_summary["Abs % Error of ATE"] = np.abs(
(synthetic_summary["ATE"] / synthetic_summary.loc[KEY_ACTUAL, "ATE"]) - 1
)
for label in synthetic_summary.index:
stacked_values = np.hstack((synthetic_preds[label], actuals))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(synthetic_preds[label], bins=bins)[0]
distr = np.clip(distr / distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals, bins=bins)[0]
true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary.loc[label, "KL Divergence"] = kl
summaries.append(synthetic_summary)
summary = sum(summaries) / k
return summary[["Abs % Error of ATE", "MSE", "KL Divergence"]]
[docs]def scatter_plot_summary(synthetic_summary, k, drop_learners=[], drop_cols=[]):
"""Generates a scatter plot comparing learner performance. Each learner's performance is plotted as a point in the
(Abs % Error of ATE, MSE) space.
Args:
synthetic_summary (pd.DataFrame): summary generated by get_synthetic_summary()
k (int): number of simulations (used only for plot title text)
drop_learners (list, optional): list of learners (str) to omit when plotting
drop_cols (list, optional): list of metrics (str) to omit when plotting
"""
plot_data = synthetic_summary.drop(drop_learners).drop(drop_cols, axis=1)
fig, ax = plt.subplots()
fig.set_size_inches(12, 8)
xs = plot_data["Abs % Error of ATE"]
ys = plot_data["MSE"]
ax.scatter(xs, ys)
ylim = ax.get_ylim()
xlim = ax.get_xlim()
for i, txt in enumerate(plot_data.index):
ax.annotate(
txt,
(
xs[i] - np.random.binomial(1, 0.5) * xlim[1] * 0.04,
ys[i] - ylim[1] * 0.03,
),
)
ax.set_xlabel("Abs % Error of ATE")
ax.set_ylabel("MSE")
ax.set_title("Learner Performance (averaged over k={} simulations)".format(k))
[docs]def bar_plot_summary(
synthetic_summary,
k,
drop_learners=[],
drop_cols=[],
sort_cols=["MSE", "Abs % Error of ATE"],
):
"""Generates a bar plot comparing learner performance.
Args:
synthetic_summary (pd.DataFrame): summary generated by get_synthetic_summary()
k (int): number of simulations (used only for plot title text)
drop_learners (list, optional): list of learners (str) to omit when plotting
drop_cols (list, optional): list of metrics (str) to omit when plotting
sort_cols (list, optional): list of metrics (str) to sort on when plotting
"""
plot_data = synthetic_summary.sort_values(sort_cols, ascending=True)
plot_data = plot_data.drop(drop_learners + [KEY_ACTUAL]).drop(drop_cols, axis=1)
plot_data.plot(kind="bar", figsize=(12, 8))
plt.xticks(rotation=30)
plt.title("Learner Performance (averaged over k={} simulations)".format(k))
[docs]def distr_plot_single_sim(
synthetic_preds,
kind="kde",
drop_learners=[],
bins=50,
histtype="step",
alpha=1,
linewidth=1,
bw_method=1,
):
"""Plots the distribution of each learner's predictions (for a single simulation).
Kernel Density Estimation (kde) and actual histogram plots supported.
Args:
synthetic_preds (dict): dictionary of predictions generated by get_synthetic_preds()
kind (str, optional): 'kde' or 'hist'
drop_learners (list, optional): list of learners (str) to omit when plotting
bins (int, optional): number of bins to plot if kind set to 'hist'
histtype (str, optional): histogram type if kind set to 'hist'
alpha (float, optional): alpha (transparency) for plotting
linewidth (int, optional): line width for plotting
bw_method (float, optional): parameter for kde
"""
preds_for_plot = synthetic_preds.copy()
# deleted generated data and assign actual value
del preds_for_plot[KEY_GENERATED_DATA]
global_lower = np.percentile(np.hstack(list(preds_for_plot.values())), 1)
global_upper = np.percentile(np.hstack(list(preds_for_plot.values())), 99)
learners = list(preds_for_plot.keys())
learners = [learner for learner in learners if learner not in drop_learners]
# Plotting
plt.figure(figsize=(12, 8))
colors = [
"black",
"red",
"blue",
"green",
"cyan",
"brown",
"grey",
"pink",
"orange",
"yellow",
]
for i, (k, v) in enumerate(preds_for_plot.items()):
if k in learners:
if kind == "kde":
v = pd.Series(v.flatten())
v = v[v.between(global_lower, global_upper)]
v.plot(
kind="kde",
bw_method=bw_method,
label=k,
linewidth=linewidth,
color=colors[i],
)
elif kind == "hist":
plt.hist(
v,
bins=np.linspace(global_lower, global_upper, bins),
label=k,
histtype=histtype,
alpha=alpha,
linewidth=linewidth,
color=colors[i],
)
else:
pass
plt.xlim(global_lower, global_upper)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.title("Distribution from a Single Simulation")
[docs]def scatter_plot_single_sim(synthetic_preds):
"""Creates a grid of scatter plots comparing each learner's predictions with the truth (for a single simulation).
Args:
synthetic_preds (dict): dictionary of predictions generated by get_synthetic_preds() or
get_synthetic_preds_holdout()
"""
preds_for_plot = synthetic_preds.copy()
# deleted generated data and get actual column name
del preds_for_plot[KEY_GENERATED_DATA]
n_row = int(np.ceil(len(preds_for_plot.keys()) / 3))
fig, axes = plt.subplots(n_row, 3, figsize=(5 * n_row, 15))
axes = np.ravel(axes)
for i, (label, preds) in enumerate(preds_for_plot.items()):
axes[i].scatter(preds_for_plot[KEY_ACTUAL], preds, s=2, label="Predictions")
axes[i].set_title(label, size=12)
axes[i].set_xlabel("Actual", size=10)
axes[i].set_ylabel("Prediction", size=10)
xlim = axes[i].get_xlim()
ylim = axes[i].get_xlim()
axes[i].plot(
[xlim[0], xlim[1]],
[ylim[0], ylim[1]],
label="Perfect Model",
linewidth=1,
color="grey",
)
axes[i].legend(loc=2, prop={"size": 10})
[docs]def get_synthetic_preds_holdout(
synthetic_data_func, n=1000, valid_size=0.2, estimators={}
):
"""Generate predictions for synthetic data using specified function (single simulation) for train and holdout
Args:
synthetic_data_func (function): synthetic data generation function
n (int, optional): number of samples
valid_size(float,optional): validaiton/hold out data size
estimators (dict of object): dict of names and objects of treatment effect estimators
Returns:
(tuple): synthetic training and validation data dictionaries:
- preds_dict_train (dict): synthetic training data dictionary
- preds_dict_valid (dict): synthetic validation data dictionary
"""
y, X, w, tau, b, e = synthetic_data_func(n=n)
(
X_train,
X_val,
y_train,
y_val,
w_train,
w_val,
tau_train,
tau_val,
b_train,
b_val,
e_train,
e_val,
) = train_test_split(
X, y, w, tau, b, e, test_size=valid_size, random_state=RANDOM_SEED, shuffle=True
)
preds_dict_train = {}
preds_dict_valid = {}
preds_dict_train[KEY_ACTUAL] = tau_train
preds_dict_valid[KEY_ACTUAL] = tau_val
preds_dict_train["generated_data"] = {
"y": y_train,
"X": X_train,
"w": w_train,
"tau": tau_train,
"b": b_train,
"e": e_train,
}
preds_dict_valid["generated_data"] = {
"y": y_val,
"X": X_val,
"w": w_val,
"tau": tau_val,
"b": b_val,
"e": e_val,
}
# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)
for base_learner, label_l in zip(
[BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
["S", "T", "X", "R"],
):
for model, label_m in zip([LinearRegression, XGBRegressor], ["LR", "XGB"]):
# RLearner will need to fit on the p_hat
if label_l != "R":
learner = base_learner(model())
# fit the model on training data only
learner.fit(X=X_train, treatment=w_train, y=y_train)
try:
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_train, p=p_hat_train).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val, p=p_hat_val).flatten()
)
except TypeError:
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(
X=X_train, treatment=w_train, y=y_train
).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
)
else:
learner = base_learner(model())
learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_train).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val).flatten()
)
return preds_dict_train, preds_dict_valid
[docs]def get_synthetic_summary_holdout(synthetic_data_func, n=1000, valid_size=0.2, k=1):
"""Generate a summary for predictions on synthetic data for train and holdout using specified function
Args:
synthetic_data_func (function): synthetic data generation function
n (int, optional): number of samples per simulation
valid_size(float,optional): validation/hold out data size
k (int, optional): number of simulations
Returns:
(tuple): summary evaluation metrics of predictions for train and validation:
- summary_train (pandas.DataFrame): training data evaluation summary
- summary_train (pandas.DataFrame): validation data evaluation summary
"""
summaries_train = []
summaries_validation = []
for i in range(k):
preds_dict_train, preds_dict_valid = get_synthetic_preds_holdout(
synthetic_data_func, n=n, valid_size=valid_size
)
actuals_train = preds_dict_train[KEY_ACTUAL]
actuals_validation = preds_dict_valid[KEY_ACTUAL]
synthetic_summary_train = pd.DataFrame(
{
label: [preds.mean(), mse(preds, actuals_train)]
for label, preds in preds_dict_train.items()
if KEY_GENERATED_DATA not in label.lower()
},
index=["ATE", "MSE"],
).T
synthetic_summary_train["Abs % Error of ATE"] = np.abs(
(
synthetic_summary_train["ATE"]
/ synthetic_summary_train.loc[KEY_ACTUAL, "ATE"]
)
- 1
)
synthetic_summary_validation = pd.DataFrame(
{
label: [preds.mean(), mse(preds, actuals_validation)]
for label, preds in preds_dict_valid.items()
if KEY_GENERATED_DATA not in label.lower()
},
index=["ATE", "MSE"],
).T
synthetic_summary_validation["Abs % Error of ATE"] = np.abs(
(
synthetic_summary_validation["ATE"]
/ synthetic_summary_validation.loc[KEY_ACTUAL, "ATE"]
)
- 1
)
# calculate kl divergence for training
for label in synthetic_summary_train.index:
stacked_values = np.hstack((preds_dict_train[label], actuals_train))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_train[label], bins=bins)[0]
distr = np.clip(distr / distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_train, bins=bins)[0]
true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_train.loc[label, "KL Divergence"] = kl
# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
distr = np.clip(distr / distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_validation, bins=bins)[0]
true_distr = np.clip(true_distr / true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_validation.loc[label, "KL Divergence"] = kl
summaries_train.append(synthetic_summary_train)
summaries_validation.append(synthetic_summary_validation)
summary_train = sum(summaries_train) / k
summary_validation = sum(summaries_validation) / k
return (
summary_train[["Abs % Error of ATE", "MSE", "KL Divergence"]],
summary_validation[["Abs % Error of ATE", "MSE", "KL Divergence"]],
)
[docs]def scatter_plot_summary_holdout(
train_summary,
validation_summary,
k,
label=["Train", "Validation"],
drop_learners=[],
drop_cols=[],
):
"""Generates a scatter plot comparing learner performance by training and validation.
Args:
train_summary (pd.DataFrame): summary for training synthetic data generated by get_synthetic_summary_holdout()
validation_summary (pd.DataFrame): summary for validation synthetic data generated by
get_synthetic_summary_holdout()
label (string, optional): legend label for plot
k (int): number of simulations (used only for plot title text)
drop_learners (list, optional): list of learners (str) to omit when plotting
drop_cols (list, optional): list of metrics (str) to omit when plotting
"""
train_summary = train_summary.drop(drop_learners).drop(drop_cols, axis=1)
validation_summary = validation_summary.drop(drop_learners).drop(drop_cols, axis=1)
plot_data = pd.concat([train_summary, validation_summary])
plot_data["label"] = [i.replace("Train", "") for i in plot_data.index]
plot_data["label"] = [i.replace("Validation", "") for i in plot_data.label]
fig, ax = plt.subplots()
fig.set_size_inches(12, 8)
xs = plot_data["Abs % Error of ATE"]
ys = plot_data["MSE"]
group = np.array(
[label[0]] * train_summary.shape[0] + [label[1]] * validation_summary.shape[0]
)
cdict = {label[0]: "red", label[1]: "blue"}
for g in np.unique(group):
ix = np.where(group == g)[0].tolist()
ax.scatter(xs[ix], ys[ix], c=cdict[g], label=g, s=100)
for i, txt in enumerate(plot_data.label[:10]):
ax.annotate(txt, (xs[i] + 0.005, ys[i]))
ax.set_xlabel("Abs % Error of ATE")
ax.set_ylabel("MSE")
ax.set_title("Learner Performance (averaged over k={} simulations)".format(k))
ax.legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
plt.show()
[docs]def bar_plot_summary_holdout(
train_summary, validation_summary, k, drop_learners=[], drop_cols=[]
):
"""Generates a bar plot comparing learner performance by training and validation
Args:
train_summary (pd.DataFrame): summary for training synthetic data generated by get_synthetic_summary_holdout()
validation_summary (pd.DataFrame): summary for validation synthetic data generated by
get_synthetic_summary_holdout()
k (int): number of simulations (used only for plot title text)
drop_learners (list, optional): list of learners (str) to omit when plotting
drop_cols (list, optional): list of metrics (str) to omit when plotting
"""
train_summary = train_summary.drop([KEY_ACTUAL])
train_summary["Learner"] = train_summary.index
validation_summary = validation_summary.drop([KEY_ACTUAL])
validation_summary["Learner"] = validation_summary.index
for metric in ["Abs % Error of ATE", "MSE", "KL Divergence"]:
plot_data_sub = pd.DataFrame(train_summary.Learner).reset_index(drop=True)
plot_data_sub["train"] = train_summary[metric].values
plot_data_sub["validation"] = validation_summary[metric].values
plot_data_sub = plot_data_sub.set_index("Learner")
plot_data_sub = plot_data_sub.drop(drop_learners).drop(drop_cols, axis=1)
plot_data_sub = plot_data_sub.sort_values("train", ascending=True)
plot_data_sub.plot(kind="bar", color=["red", "blue"], figsize=(12, 8))
plt.xticks(rotation=30)
plt.title(
"Learner Performance of {} (averaged over k={} simulations)".format(
metric, k
)
)
[docs]def get_synthetic_auuc(
synthetic_preds,
drop_learners=[],
outcome_col="y",
treatment_col="w",
treatment_effect_col="tau",
plot=True,
):
"""Get auuc values for cumulative gains of model estimates in quantiles.
For details, reference get_cumgain() and plot_gain()
Args:
synthetic_preds (dict): dictionary of predictions generated by get_synthetic_preds()
or get_synthetic_preds_holdout()
outcome_col (str, optional): the column name for the actual outcome
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
treatment_effect_col (str, optional): the column name for the true treatment effect
plot (boolean,optional): plot the cumulative gain chart or not
Returns:
(pandas.DataFrame): auuc values by learner for cumulative gains of model estimates
"""
synthetic_preds_df = synthetic_preds.copy()
generated_data = synthetic_preds_df.pop(KEY_GENERATED_DATA)
synthetic_preds_df = pd.DataFrame(synthetic_preds_df)
synthetic_preds_df = synthetic_preds_df.drop(drop_learners, axis=1)
synthetic_preds_df["y"] = generated_data[outcome_col]
synthetic_preds_df["w"] = generated_data[treatment_col]
if treatment_effect_col in generated_data.keys():
synthetic_preds_df["tau"] = generated_data[treatment_effect_col]
assert (
(outcome_col in synthetic_preds_df.columns)
and (treatment_col in synthetic_preds_df.columns)
or treatment_effect_col in synthetic_preds_df.columns
)
cumlift = get_cumgain(
synthetic_preds_df,
outcome_col="y",
treatment_col="w",
treatment_effect_col="tau",
)
auuc_df = pd.DataFrame(cumlift.columns)
auuc_df.columns = ["Learner"]
auuc_df["cum_gain_auuc"] = [
auc(cumlift.index.values / 100, cumlift[learner].values)
for learner in cumlift.columns
]
auuc_df = auuc_df.sort_values("cum_gain_auuc", ascending=False)
if plot:
plot_gain(
synthetic_preds_df,
outcome_col=outcome_col,
treatment_col=treatment_col,
treatment_effect_col=treatment_effect_col,
)
return auuc_df