DragonNet vs Meta-Learners Benchmark with IHDP + Synthetic Datasets

Dragonnet requires tensorflow. If you haven’t, please install tensorflow as follows:

pip install tensorflow
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import entropy
import warnings

from causalml.inference.meta import LRSRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.tf import DragonNet
from causalml.match import NearestNeighborMatch, MatchOptimizer, create_table_one
from causalml.propensity import ElasticNetPropensityModel
from causalml.dataset.regression import *
from causalml.metrics import *

import os, sys

%matplotlib inline

warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')
sns.set_palette('Paired')
plt.rcParams['figure.figsize'] = (12,8)

IHDP semi-synthetic dataset

Hill introduced a semi-synthetic dataset constructed from the Infant Health and Development Program (IHDP). This dataset is based on a randomized experiment investigating the effect of home visits by specialists on future cognitive scores. The data has 747 observations (rows). The IHDP simulation is considered the de-facto standard benchmark for neural network treatment effect estimation methods.

The original paper uses 1000 realizations from the NCPI package, but for illustration purposes, we use 1 dataset (realization) as an example below.

[3]:
df = pd.read_csv(f'data/ihdp_npci_3.csv', header=None)
cols =  ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [f'x{i}' for i in range(1,26)]
df.columns = cols
[4]:
df.shape
[4]:
(747, 30)
[5]:
df.head()
[5]:
treatment y_factual y_cfactual mu0 mu1 x1 x2 x3 x4 x5 ... x16 x17 x18 x19 x20 x21 x22 x23 x24 x25
0 1 5.931652 3.500591 2.253801 7.136441 -0.528603 -0.343455 1.128554 0.161703 -0.316603 ... 1 1 1 1 0 0 0 0 0 0
1 0 2.175966 5.952101 1.257592 6.553022 -1.736945 -1.802002 0.383828 2.244320 -0.629189 ... 1 1 1 1 0 0 0 0 0 0
2 0 2.180294 7.175734 2.384100 7.192645 -0.807451 -0.202946 -0.360898 -0.879606 0.808706 ... 1 0 1 1 0 0 0 0 0 0
3 0 3.587662 7.787537 4.009365 7.712456 0.390083 0.596582 -1.850350 -0.879606 -0.004017 ... 1 0 1 1 0 0 0 0 0 0
4 0 2.372618 5.461871 2.481631 7.232739 -1.045229 -0.602710 0.011465 0.161703 0.683672 ... 1 1 1 1 0 0 0 0 0 0

5 rows × 30 columns

[6]:
pd.Series(df['treatment']).value_counts(normalize=True)
[6]:
0    0.813922
1    0.186078
Name: treatment, dtype: float64
[7]:
X = df.loc[:,'x1':]
treatment = df['treatment']
y = df['y_factual']
tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1
               else d['y_cfactual'] - d['y_factual'],
               axis=1)
[9]:
p_model = ElasticNetPropensityModel()
p = p_model.fit_predict(X, treatment)
[10]:
s_learner = BaseSRegressor(LGBMRegressor())
s_ate = s_learner.estimate_ate(X, treatment, y)[0]
s_ite = s_learner.fit_predict(X, treatment, y)

t_learner = BaseTRegressor(LGBMRegressor())
t_ate = t_learner.estimate_ate(X, treatment, y)[0][0]
t_ite = t_learner.fit_predict(X, treatment, y)

x_learner = BaseXRegressor(LGBMRegressor())
x_ate = x_learner.estimate_ate(X, treatment, y, p)[0][0]
x_ite = x_learner.fit_predict(X, treatment, y, p)

r_learner = BaseRRegressor(LGBMRegressor())
r_ate = r_learner.estimate_ate(X, treatment, y, p)[0][0]
r_ite = r_learner.fit_predict(X, treatment, y, p)
[11]:
dragon = DragonNet(neurons_per_layer=200, targeted_reg=True)
dragon_ite = dragon.fit_predict(X, treatment, y, return_components=False)
dragon_ate = dragon_ite.mean()
Epoch 1/30
10/10 [==============================] - 5s 156ms/step - loss: 1790.3492 - regression_loss: 864.6742 - binary_classification_loss: 41.3394 - treatment_accuracy: 0.7299 - track_epsilon: 0.0063 - val_loss: 242.1589 - val_regression_loss: 87.0011 - val_binary_classification_loss: 32.6806 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0055
Epoch 2/30
10/10 [==============================] - 0s 7ms/step - loss: 311.9302 - regression_loss: 135.2392 - binary_classification_loss: 32.8420 - treatment_accuracy: 0.8643 - track_epsilon: 0.0059 - val_loss: 230.2209 - val_regression_loss: 79.8030 - val_binary_classification_loss: 34.3533 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0047
Epoch 3/30
10/10 [==============================] - 0s 6ms/step - loss: 274.1216 - regression_loss: 118.1561 - binary_classification_loss: 31.3200 - treatment_accuracy: 0.8169 - track_epsilon: 0.0044 - val_loss: 238.4452 - val_regression_loss: 82.0530 - val_binary_classification_loss: 36.2376 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0049
Epoch 4/30
10/10 [==============================] - 0s 6ms/step - loss: 205.4690 - regression_loss: 85.9585 - binary_classification_loss: 27.2440 - treatment_accuracy: 0.8606 - track_epsilon: 0.0053 - val_loss: 235.7122 - val_regression_loss: 78.5524 - val_binary_classification_loss: 39.7929 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0057
Epoch 1/300
10/10 [==============================] - 1s 41ms/step - loss: 195.6840 - regression_loss: 80.7820 - binary_classification_loss: 27.7316 - treatment_accuracy: 0.8497 - track_epsilon: 0.0054 - val_loss: 207.3960 - val_regression_loss: 67.6306 - val_binary_classification_loss: 38.1122 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0177
Epoch 2/300
10/10 [==============================] - 0s 6ms/step - loss: 183.9956 - regression_loss: 75.3269 - binary_classification_loss: 26.4330 - treatment_accuracy: 0.8622 - track_epsilon: 0.0182 - val_loss: 197.0267 - val_regression_loss: 64.0559 - val_binary_classification_loss: 38.4298 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0117
Epoch 3/300
10/10 [==============================] - 0s 6ms/step - loss: 178.8321 - regression_loss: 72.7892 - binary_classification_loss: 26.7587 - treatment_accuracy: 0.8555 - track_epsilon: 0.0081 - val_loss: 195.6257 - val_regression_loss: 63.5609 - val_binary_classification_loss: 38.2400 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0073
Epoch 4/300
10/10 [==============================] - 0s 6ms/step - loss: 177.0419 - regression_loss: 71.8475 - binary_classification_loss: 27.1255 - treatment_accuracy: 0.8521 - track_epsilon: 0.0091 - val_loss: 200.6521 - val_regression_loss: 65.3493 - val_binary_classification_loss: 37.6216 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0082
Epoch 5/300
10/10 [==============================] - 0s 6ms/step - loss: 198.0597 - regression_loss: 82.4320 - binary_classification_loss: 27.0536 - treatment_accuracy: 0.8497 - track_epsilon: 0.0076 - val_loss: 194.4365 - val_regression_loss: 63.1230 - val_binary_classification_loss: 37.8598 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0064
Epoch 6/300
10/10 [==============================] - 0s 5ms/step - loss: 174.1273 - regression_loss: 70.2306 - binary_classification_loss: 27.7639 - treatment_accuracy: 0.8460 - track_epsilon: 0.0075 - val_loss: 194.3751 - val_regression_loss: 63.1176 - val_binary_classification_loss: 37.9318 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0100
Epoch 7/300
10/10 [==============================] - 0s 6ms/step - loss: 187.2528 - regression_loss: 77.2338 - binary_classification_loss: 26.6574 - treatment_accuracy: 0.8545 - track_epsilon: 0.0094 - val_loss: 193.4222 - val_regression_loss: 62.8618 - val_binary_classification_loss: 37.8932 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0100
Epoch 8/300
10/10 [==============================] - 0s 6ms/step - loss: 179.5122 - regression_loss: 72.3961 - binary_classification_loss: 28.5867 - treatment_accuracy: 0.8357 - track_epsilon: 0.0110 - val_loss: 196.3768 - val_regression_loss: 63.7690 - val_binary_classification_loss: 37.5827 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0094
Epoch 9/300
10/10 [==============================] - 0s 6ms/step - loss: 180.4453 - regression_loss: 74.0497 - binary_classification_loss: 26.0765 - treatment_accuracy: 0.8582 - track_epsilon: 0.0077 - val_loss: 192.4576 - val_regression_loss: 62.1838 - val_binary_classification_loss: 37.8295 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0104
Epoch 10/300
10/10 [==============================] - 0s 6ms/step - loss: 159.2664 - regression_loss: 63.3608 - binary_classification_loss: 26.5497 - treatment_accuracy: 0.8544 - track_epsilon: 0.0118 - val_loss: 192.8072 - val_regression_loss: 62.6150 - val_binary_classification_loss: 38.0268 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0147
Epoch 11/300
10/10 [==============================] - 0s 6ms/step - loss: 157.3967 - regression_loss: 62.1042 - binary_classification_loss: 27.1641 - treatment_accuracy: 0.8496 - track_epsilon: 0.0139 - val_loss: 190.8307 - val_regression_loss: 61.5484 - val_binary_classification_loss: 37.7637 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0105
Epoch 12/300
10/10 [==============================] - 0s 6ms/step - loss: 171.3408 - regression_loss: 69.0045 - binary_classification_loss: 27.2531 - treatment_accuracy: 0.8468 - track_epsilon: 0.0092 - val_loss: 190.0314 - val_regression_loss: 61.2129 - val_binary_classification_loss: 37.7777 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0087
Epoch 13/300
10/10 [==============================] - 0s 6ms/step - loss: 162.0215 - regression_loss: 62.9260 - binary_classification_loss: 30.3296 - treatment_accuracy: 0.8189 - track_epsilon: 0.0084 - val_loss: 189.2025 - val_regression_loss: 60.8970 - val_binary_classification_loss: 37.7255 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0126
Epoch 14/300
10/10 [==============================] - 0s 6ms/step - loss: 173.8554 - regression_loss: 70.6637 - binary_classification_loss: 26.2371 - treatment_accuracy: 0.8540 - track_epsilon: 0.0126 - val_loss: 189.1865 - val_regression_loss: 60.8833 - val_binary_classification_loss: 37.6686 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0099
Epoch 15/300
10/10 [==============================] - 0s 6ms/step - loss: 157.4061 - regression_loss: 63.5324 - binary_classification_loss: 24.4340 - treatment_accuracy: 0.8718 - track_epsilon: 0.0093 - val_loss: 186.8761 - val_regression_loss: 60.0529 - val_binary_classification_loss: 37.9590 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0089
Epoch 16/300
10/10 [==============================] - 0s 6ms/step - loss: 172.2053 - regression_loss: 69.6182 - binary_classification_loss: 26.8678 - treatment_accuracy: 0.8502 - track_epsilon: 0.0080 - val_loss: 188.4488 - val_regression_loss: 60.4575 - val_binary_classification_loss: 37.6228 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0113
Epoch 17/300
10/10 [==============================] - 0s 6ms/step - loss: 162.3663 - regression_loss: 65.2318 - binary_classification_loss: 26.0364 - treatment_accuracy: 0.8562 - track_epsilon: 0.0105 - val_loss: 186.5810 - val_regression_loss: 59.7591 - val_binary_classification_loss: 37.6693 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0114
Epoch 18/300
10/10 [==============================] - 0s 6ms/step - loss: 163.8466 - regression_loss: 65.5130 - binary_classification_loss: 26.7643 - treatment_accuracy: 0.8503 - track_epsilon: 0.0119 - val_loss: 190.8825 - val_regression_loss: 62.1830 - val_binary_classification_loss: 37.9718 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0087
Epoch 19/300
10/10 [==============================] - 0s 6ms/step - loss: 167.6180 - regression_loss: 68.2214 - binary_classification_loss: 25.3027 - treatment_accuracy: 0.8620 - track_epsilon: 0.0066 - val_loss: 185.6225 - val_regression_loss: 59.4746 - val_binary_classification_loss: 37.6837 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0060
Epoch 20/300
10/10 [==============================] - 0s 6ms/step - loss: 168.5476 - regression_loss: 68.3371 - binary_classification_loss: 25.8652 - treatment_accuracy: 0.8578 - track_epsilon: 0.0079 - val_loss: 184.5200 - val_regression_loss: 59.2031 - val_binary_classification_loss: 37.5099 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0098
Epoch 21/300
10/10 [==============================] - 0s 6ms/step - loss: 157.9161 - regression_loss: 63.3173 - binary_classification_loss: 25.2899 - treatment_accuracy: 0.8634 - track_epsilon: 0.0083 - val_loss: 185.0839 - val_regression_loss: 59.1452 - val_binary_classification_loss: 37.5366 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0079
Epoch 22/300
10/10 [==============================] - 0s 6ms/step - loss: 160.4739 - regression_loss: 63.1629 - binary_classification_loss: 28.0595 - treatment_accuracy: 0.8358 - track_epsilon: 0.0086 - val_loss: 183.9351 - val_regression_loss: 59.1525 - val_binary_classification_loss: 37.5067 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0066
Epoch 23/300
10/10 [==============================] - 0s 6ms/step - loss: 155.0890 - regression_loss: 60.5116 - binary_classification_loss: 28.0962 - treatment_accuracy: 0.8349 - track_epsilon: 0.0046 - val_loss: 183.6170 - val_regression_loss: 58.7653 - val_binary_classification_loss: 37.2800 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0051
Epoch 24/300
10/10 [==============================] - 0s 6ms/step - loss: 149.4288 - regression_loss: 58.8520 - binary_classification_loss: 25.9568 - treatment_accuracy: 0.8546 - track_epsilon: 0.0052 - val_loss: 188.7191 - val_regression_loss: 60.7916 - val_binary_classification_loss: 37.1127 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0069
Epoch 25/300
10/10 [==============================] - 0s 6ms/step - loss: 156.3095 - regression_loss: 61.0641 - binary_classification_loss: 28.1708 - treatment_accuracy: 0.8315 - track_epsilon: 0.0080 - val_loss: 182.5451 - val_regression_loss: 58.6875 - val_binary_classification_loss: 37.3179 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0075
Epoch 26/300
10/10 [==============================] - 0s 6ms/step - loss: 154.8900 - regression_loss: 61.1673 - binary_classification_loss: 26.2975 - treatment_accuracy: 0.8542 - track_epsilon: 0.0059 - val_loss: 184.6580 - val_regression_loss: 59.6472 - val_binary_classification_loss: 37.4789 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0040
Epoch 27/300
10/10 [==============================] - 0s 6ms/step - loss: 153.7275 - regression_loss: 60.6342 - binary_classification_loss: 26.5628 - treatment_accuracy: 0.8494 - track_epsilon: 0.0054 - val_loss: 187.6736 - val_regression_loss: 60.4940 - val_binary_classification_loss: 37.1448 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0046
Epoch 28/300
10/10 [==============================] - 0s 6ms/step - loss: 159.4707 - regression_loss: 63.5524 - binary_classification_loss: 26.3749 - treatment_accuracy: 0.8515 - track_epsilon: 0.0043 - val_loss: 185.6613 - val_regression_loss: 60.3003 - val_binary_classification_loss: 37.4100 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0041
Epoch 29/300
10/10 [==============================] - 0s 6ms/step - loss: 144.6116 - regression_loss: 57.2770 - binary_classification_loss: 24.2153 - treatment_accuracy: 0.8699 - track_epsilon: 0.0037 - val_loss: 183.7683 - val_regression_loss: 58.8389 - val_binary_classification_loss: 37.2161 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0046
Epoch 30/300
10/10 [==============================] - 0s 6ms/step - loss: 156.1744 - regression_loss: 61.9692 - binary_classification_loss: 26.0622 - treatment_accuracy: 0.8516 - track_epsilon: 0.0058 - val_loss: 181.6741 - val_regression_loss: 58.3027 - val_binary_classification_loss: 37.2483 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0056
Epoch 31/300
10/10 [==============================] - 0s 6ms/step - loss: 149.5090 - regression_loss: 59.6090 - binary_classification_loss: 24.2077 - treatment_accuracy: 0.8685 - track_epsilon: 0.0052 - val_loss: 183.1471 - val_regression_loss: 58.8850 - val_binary_classification_loss: 37.3616 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0048
Epoch 32/300
10/10 [==============================] - 0s 6ms/step - loss: 158.8317 - regression_loss: 63.1364 - binary_classification_loss: 26.3766 - treatment_accuracy: 0.8524 - track_epsilon: 0.0040 - val_loss: 182.8958 - val_regression_loss: 58.5878 - val_binary_classification_loss: 37.0665 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0047
Epoch 33/300
10/10 [==============================] - 0s 6ms/step - loss: 153.5294 - regression_loss: 59.8976 - binary_classification_loss: 27.7524 - treatment_accuracy: 0.8380 - track_epsilon: 0.0065 - val_loss: 184.9241 - val_regression_loss: 59.4232 - val_binary_classification_loss: 36.8858 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0053
Epoch 34/300
10/10 [==============================] - 0s 6ms/step - loss: 149.8718 - regression_loss: 59.2581 - binary_classification_loss: 25.4000 - treatment_accuracy: 0.8586 - track_epsilon: 0.0050 - val_loss: 181.1425 - val_regression_loss: 58.2419 - val_binary_classification_loss: 37.1219 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 35/300
10/10 [==============================] - 0s 6ms/step - loss: 147.3198 - regression_loss: 58.8689 - binary_classification_loss: 23.9842 - treatment_accuracy: 0.8717 - track_epsilon: 0.0020 - val_loss: 182.3788 - val_regression_loss: 58.5675 - val_binary_classification_loss: 37.1934 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0054
Epoch 36/300
10/10 [==============================] - 0s 6ms/step - loss: 143.2286 - regression_loss: 55.5299 - binary_classification_loss: 26.2675 - treatment_accuracy: 0.8521 - track_epsilon: 0.0059 - val_loss: 183.0977 - val_regression_loss: 59.1656 - val_binary_classification_loss: 37.0963 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0033
Epoch 37/300
10/10 [==============================] - 0s 6ms/step - loss: 149.7070 - regression_loss: 59.5447 - binary_classification_loss: 24.7418 - treatment_accuracy: 0.8625 - track_epsilon: 0.0023 - val_loss: 183.3780 - val_regression_loss: 58.6777 - val_binary_classification_loss: 37.0991 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 38/300
10/10 [==============================] - 0s 6ms/step - loss: 156.6596 - regression_loss: 62.0057 - binary_classification_loss: 26.7843 - treatment_accuracy: 0.8461 - track_epsilon: 0.0034 - val_loss: 183.2619 - val_regression_loss: 58.6840 - val_binary_classification_loss: 36.8925 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0030
Epoch 39/300
10/10 [==============================] - 0s 6ms/step - loss: 155.1980 - regression_loss: 60.9731 - binary_classification_loss: 27.2335 - treatment_accuracy: 0.8443 - track_epsilon: 0.0027 - val_loss: 181.5153 - val_regression_loss: 58.3543 - val_binary_classification_loss: 37.0235 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0026
Epoch 40/300
10/10 [==============================] - 0s 6ms/step - loss: 148.6864 - regression_loss: 58.6236 - binary_classification_loss: 25.4435 - treatment_accuracy: 0.8563 - track_epsilon: 0.0023 - val_loss: 181.4140 - val_regression_loss: 58.1816 - val_binary_classification_loss: 36.9654 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0040

Epoch 00040: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
Epoch 41/300
10/10 [==============================] - 0s 6ms/step - loss: 144.7339 - regression_loss: 56.2116 - binary_classification_loss: 26.5446 - treatment_accuracy: 0.8501 - track_epsilon: 0.0048 - val_loss: 180.4425 - val_regression_loss: 57.9971 - val_binary_classification_loss: 36.8370 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0046
Epoch 42/300
10/10 [==============================] - 0s 6ms/step - loss: 152.6819 - regression_loss: 59.9780 - binary_classification_loss: 26.7807 - treatment_accuracy: 0.8460 - track_epsilon: 0.0035 - val_loss: 181.0754 - val_regression_loss: 58.0677 - val_binary_classification_loss: 36.8611 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 43/300
10/10 [==============================] - 0s 6ms/step - loss: 141.4455 - regression_loss: 55.1271 - binary_classification_loss: 25.5353 - treatment_accuracy: 0.8533 - track_epsilon: 0.0021 - val_loss: 181.6914 - val_regression_loss: 58.2652 - val_binary_classification_loss: 36.9244 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0029
Epoch 44/300
10/10 [==============================] - 0s 6ms/step - loss: 143.5718 - regression_loss: 54.9356 - binary_classification_loss: 27.7278 - treatment_accuracy: 0.8368 - track_epsilon: 0.0035 - val_loss: 182.9616 - val_regression_loss: 58.7127 - val_binary_classification_loss: 36.7605 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0030
Epoch 45/300
10/10 [==============================] - 0s 6ms/step - loss: 148.0318 - regression_loss: 58.2297 - binary_classification_loss: 25.4171 - treatment_accuracy: 0.8623 - track_epsilon: 0.0028 - val_loss: 182.2657 - val_regression_loss: 58.7829 - val_binary_classification_loss: 36.9492 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0031
Epoch 46/300
10/10 [==============================] - 0s 6ms/step - loss: 148.2912 - regression_loss: 58.4527 - binary_classification_loss: 25.4111 - treatment_accuracy: 0.8550 - track_epsilon: 0.0034 - val_loss: 181.4299 - val_regression_loss: 58.1758 - val_binary_classification_loss: 36.8661 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0024
Epoch 47/300
10/10 [==============================] - 0s 6ms/step - loss: 149.8918 - regression_loss: 57.9189 - binary_classification_loss: 28.0705 - treatment_accuracy: 0.8318 - track_epsilon: 0.0019 - val_loss: 181.7516 - val_regression_loss: 58.2234 - val_binary_classification_loss: 36.8086 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017
Epoch 48/300
10/10 [==============================] - 0s 6ms/step - loss: 143.9826 - regression_loss: 56.0217 - binary_classification_loss: 25.9636 - treatment_accuracy: 0.8518 - track_epsilon: 0.0018 - val_loss: 183.2366 - val_regression_loss: 58.7730 - val_binary_classification_loss: 36.7622 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020

Epoch 00048: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.
Epoch 49/300
10/10 [==============================] - 0s 6ms/step - loss: 141.9178 - regression_loss: 54.9617 - binary_classification_loss: 26.0167 - treatment_accuracy: 0.8551 - track_epsilon: 0.0025 - val_loss: 181.6692 - val_regression_loss: 58.2019 - val_binary_classification_loss: 36.8163 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0035
Epoch 50/300
10/10 [==============================] - 0s 6ms/step - loss: 154.0470 - regression_loss: 60.6821 - binary_classification_loss: 26.7084 - treatment_accuracy: 0.8442 - track_epsilon: 0.0037 - val_loss: 181.5136 - val_regression_loss: 58.1967 - val_binary_classification_loss: 36.7926 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0030
Epoch 51/300
10/10 [==============================] - 0s 6ms/step - loss: 154.2879 - regression_loss: 61.1156 - binary_classification_loss: 25.9554 - treatment_accuracy: 0.8530 - track_epsilon: 0.0026 - val_loss: 181.1187 - val_regression_loss: 58.0944 - val_binary_classification_loss: 36.7854 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 52/300
10/10 [==============================] - 0s 6ms/step - loss: 147.1910 - regression_loss: 57.9212 - binary_classification_loss: 25.5444 - treatment_accuracy: 0.8585 - track_epsilon: 0.0019 - val_loss: 180.9492 - val_regression_loss: 58.0477 - val_binary_classification_loss: 36.8212 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 53/300
10/10 [==============================] - 0s 6ms/step - loss: 144.5095 - regression_loss: 57.1991 - binary_classification_loss: 24.5623 - treatment_accuracy: 0.8633 - track_epsilon: 0.0025 - val_loss: 181.2697 - val_regression_loss: 58.0844 - val_binary_classification_loss: 36.8072 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0025
Epoch 54/300
10/10 [==============================] - 0s 6ms/step - loss: 149.1749 - regression_loss: 58.6545 - binary_classification_loss: 26.4255 - treatment_accuracy: 0.8508 - track_epsilon: 0.0027 - val_loss: 181.6855 - val_regression_loss: 58.2050 - val_binary_classification_loss: 36.8246 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0024
Epoch 55/300
10/10 [==============================] - 0s 6ms/step - loss: 143.3488 - regression_loss: 57.4108 - binary_classification_loss: 22.5247 - treatment_accuracy: 0.8810 - track_epsilon: 0.0018 - val_loss: 181.5240 - val_regression_loss: 58.1889 - val_binary_classification_loss: 36.8670 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017
Epoch 56/300
10/10 [==============================] - 0s 6ms/step - loss: 145.3050 - regression_loss: 57.0797 - binary_classification_loss: 25.3895 - treatment_accuracy: 0.8567 - track_epsilon: 0.0020 - val_loss: 181.2868 - val_regression_loss: 58.1061 - val_binary_classification_loss: 36.7694 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 57/300
10/10 [==============================] - 0s 6ms/step - loss: 150.1354 - regression_loss: 58.8804 - binary_classification_loss: 26.4138 - treatment_accuracy: 0.8482 - track_epsilon: 0.0023 - val_loss: 181.1185 - val_regression_loss: 58.1035 - val_binary_classification_loss: 36.8033 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0027
Epoch 58/300
10/10 [==============================] - 0s 6ms/step - loss: 145.2017 - regression_loss: 57.0173 - binary_classification_loss: 25.5667 - treatment_accuracy: 0.8535 - track_epsilon: 0.0028 - val_loss: 181.2872 - val_regression_loss: 58.0891 - val_binary_classification_loss: 36.7888 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 59/300
10/10 [==============================] - 0s 6ms/step - loss: 147.7799 - regression_loss: 57.4570 - binary_classification_loss: 26.7010 - treatment_accuracy: 0.8456 - track_epsilon: 0.0022 - val_loss: 181.2169 - val_regression_loss: 58.0942 - val_binary_classification_loss: 36.7608 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0024

Epoch 00059: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06.
Epoch 60/300
10/10 [==============================] - 0s 6ms/step - loss: 143.9625 - regression_loss: 56.2486 - binary_classification_loss: 25.4459 - treatment_accuracy: 0.8584 - track_epsilon: 0.0021 - val_loss: 180.9821 - val_regression_loss: 58.0493 - val_binary_classification_loss: 36.7809 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 61/300
10/10 [==============================] - 0s 6ms/step - loss: 145.4550 - regression_loss: 55.9254 - binary_classification_loss: 27.6142 - treatment_accuracy: 0.8356 - track_epsilon: 0.0018 - val_loss: 181.0581 - val_regression_loss: 58.0568 - val_binary_classification_loss: 36.7708 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 62/300
10/10 [==============================] - 0s 6ms/step - loss: 146.6129 - regression_loss: 57.8608 - binary_classification_loss: 24.9736 - treatment_accuracy: 0.8610 - track_epsilon: 0.0019 - val_loss: 181.3456 - val_regression_loss: 58.1192 - val_binary_classification_loss: 36.7725 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 63/300
10/10 [==============================] - 0s 6ms/step - loss: 148.2220 - regression_loss: 58.4927 - binary_classification_loss: 25.3438 - treatment_accuracy: 0.8560 - track_epsilon: 0.0019 - val_loss: 181.4519 - val_regression_loss: 58.1659 - val_binary_classification_loss: 36.7662 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 64/300
10/10 [==============================] - 0s 6ms/step - loss: 140.0457 - regression_loss: 53.7192 - binary_classification_loss: 26.7202 - treatment_accuracy: 0.8449 - track_epsilon: 0.0018 - val_loss: 181.4884 - val_regression_loss: 58.1761 - val_binary_classification_loss: 36.7904 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018

Epoch 00064: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07.
Epoch 65/300
10/10 [==============================] - 0s 6ms/step - loss: 138.7515 - regression_loss: 54.9897 - binary_classification_loss: 22.9222 - treatment_accuracy: 0.8764 - track_epsilon: 0.0019 - val_loss: 181.4227 - val_regression_loss: 58.1461 - val_binary_classification_loss: 36.7599 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 66/300
10/10 [==============================] - 0s 6ms/step - loss: 150.2262 - regression_loss: 59.0825 - binary_classification_loss: 26.3098 - treatment_accuracy: 0.8488 - track_epsilon: 0.0021 - val_loss: 181.3573 - val_regression_loss: 58.1316 - val_binary_classification_loss: 36.7465 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 67/300
10/10 [==============================] - 0s 5ms/step - loss: 149.6757 - regression_loss: 59.4637 - binary_classification_loss: 24.7313 - treatment_accuracy: 0.8603 - track_epsilon: 0.0021 - val_loss: 181.3223 - val_regression_loss: 58.1212 - val_binary_classification_loss: 36.7481 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 68/300
10/10 [==============================] - 0s 5ms/step - loss: 151.1728 - regression_loss: 60.0713 - binary_classification_loss: 25.0154 - treatment_accuracy: 0.8615 - track_epsilon: 0.0021 - val_loss: 181.1846 - val_regression_loss: 58.0904 - val_binary_classification_loss: 36.7660 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 69/300
10/10 [==============================] - 0s 6ms/step - loss: 148.2535 - regression_loss: 59.0331 - binary_classification_loss: 24.1994 - treatment_accuracy: 0.8658 - track_epsilon: 0.0021 - val_loss: 181.2948 - val_regression_loss: 58.1039 - val_binary_classification_loss: 36.7439 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023

Epoch 00069: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07.
Epoch 70/300
10/10 [==============================] - 0s 6ms/step - loss: 145.3151 - regression_loss: 57.0691 - binary_classification_loss: 25.4983 - treatment_accuracy: 0.8584 - track_epsilon: 0.0023 - val_loss: 181.3544 - val_regression_loss: 58.1245 - val_binary_classification_loss: 36.7449 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 71/300
10/10 [==============================] - 0s 6ms/step - loss: 150.1866 - regression_loss: 57.9752 - binary_classification_loss: 28.1966 - treatment_accuracy: 0.8297 - track_epsilon: 0.0023 - val_loss: 181.2958 - val_regression_loss: 58.1128 - val_binary_classification_loss: 36.7442 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 72/300
10/10 [==============================] - 0s 6ms/step - loss: 144.9820 - regression_loss: 57.4340 - binary_classification_loss: 24.4047 - treatment_accuracy: 0.8619 - track_epsilon: 0.0022 - val_loss: 181.3424 - val_regression_loss: 58.1227 - val_binary_classification_loss: 36.7440 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 73/300
10/10 [==============================] - 0s 6ms/step - loss: 148.8112 - regression_loss: 58.0125 - binary_classification_loss: 26.8692 - treatment_accuracy: 0.8447 - track_epsilon: 0.0022 - val_loss: 181.3199 - val_regression_loss: 58.1187 - val_binary_classification_loss: 36.7438 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 74/300
10/10 [==============================] - 0s 6ms/step - loss: 144.3984 - regression_loss: 56.9031 - binary_classification_loss: 24.6382 - treatment_accuracy: 0.8624 - track_epsilon: 0.0021 - val_loss: 181.3810 - val_regression_loss: 58.1361 - val_binary_classification_loss: 36.7440 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021

Epoch 00074: ReduceLROnPlateau reducing learning rate to 1.56249996052793e-07.
Epoch 75/300
10/10 [==============================] - 0s 6ms/step - loss: 147.5547 - regression_loss: 57.7667 - binary_classification_loss: 26.1622 - treatment_accuracy: 0.8478 - track_epsilon: 0.0021 - val_loss: 181.3161 - val_regression_loss: 58.1183 - val_binary_classification_loss: 36.7473 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 76/300
10/10 [==============================] - 0s 6ms/step - loss: 140.5001 - regression_loss: 53.5784 - binary_classification_loss: 27.3214 - treatment_accuracy: 0.8388 - track_epsilon: 0.0021 - val_loss: 181.2723 - val_regression_loss: 58.1086 - val_binary_classification_loss: 36.7488 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 77/300
10/10 [==============================] - 0s 6ms/step - loss: 143.8736 - regression_loss: 55.9839 - binary_classification_loss: 26.1250 - treatment_accuracy: 0.8466 - track_epsilon: 0.0022 - val_loss: 181.2639 - val_regression_loss: 58.1073 - val_binary_classification_loss: 36.7513 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 78/300
10/10 [==============================] - 0s 6ms/step - loss: 146.6917 - regression_loss: 58.5758 - binary_classification_loss: 23.5315 - treatment_accuracy: 0.8700 - track_epsilon: 0.0022 - val_loss: 181.2961 - val_regression_loss: 58.1147 - val_binary_classification_loss: 36.7518 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 79/300
10/10 [==============================] - 0s 6ms/step - loss: 143.4007 - regression_loss: 54.8006 - binary_classification_loss: 27.7054 - treatment_accuracy: 0.8383 - track_epsilon: 0.0021 - val_loss: 181.3115 - val_regression_loss: 58.1188 - val_binary_classification_loss: 36.7477 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021

Epoch 00079: ReduceLROnPlateau reducing learning rate to 7.81249980263965e-08.
Epoch 80/300
10/10 [==============================] - 0s 6ms/step - loss: 145.6183 - regression_loss: 57.3271 - binary_classification_loss: 25.1687 - treatment_accuracy: 0.8574 - track_epsilon: 0.0021 - val_loss: 181.2945 - val_regression_loss: 58.1152 - val_binary_classification_loss: 36.7475 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0021
Epoch 81/300
10/10 [==============================] - 0s 6ms/step - loss: 142.3395 - regression_loss: 54.9129 - binary_classification_loss: 26.6164 - treatment_accuracy: 0.8461 - track_epsilon: 0.0021 - val_loss: 181.3037 - val_regression_loss: 58.1177 - val_binary_classification_loss: 36.7449 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
[12]:
df_preds = pd.DataFrame([s_ite.ravel(),
                          t_ite.ravel(),
                          x_ite.ravel(),
                          r_ite.ravel(),
                          dragon_ite.ravel(),
                          tau.ravel(),
                          treatment.ravel(),
                          y.ravel()],
                       index=['S','T','X','R','dragonnet','tau','w','y']).T

df_cumgain = get_cumgain(df_preds)
[13]:
df_result = pd.DataFrame([s_ate, t_ate, x_ate, r_ate, dragon_ate, tau.mean()],
                     index=['S','T','X','R','dragonnet','actual'], columns=['ATE'])
df_result['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite, t_ite, x_ite, r_ite, dragon_ite],
                                                            [tau.values.reshape(-1,1)]*5 )
                ] + [None]
df_result['AUUC'] = auuc_score(df_preds)
[14]:
df_result
[14]:
ATE MAE AUUC
S 4.054511 1.027666 0.575822
T 4.100199 0.980788 0.580929
X 4.020918 1.116303 0.564651
R 4.257976 1.665557 0.556855
dragonnet 4.006536 1.165061 0.556426
actual 4.098887 NaN NaN
[15]:
plot_gain(df_preds)
../_images/examples_dragonnet_example_16_0.png

Synthetic Data Generation

[16]:
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
    train_test_split(X, y, w, tau, b, e, test_size=0.2, random_state=123, shuffle=True)

preds_dict_train = {}
preds_dict_valid = {}

preds_dict_train['Actuals'] = tau_train
preds_dict_valid['Actuals'] = tau_val

preds_dict_train['generated_data'] = {
    'y': y_train,
    'X': X_train,
    'w': w_train,
    'tau': tau_train,
    'b': b_train,
    'e': e_train}
preds_dict_valid['generated_data'] = {
    'y': y_val,
    'X': X_val,
    'w': w_val,
    'tau': tau_val,
    'b': b_val,
    'e': e_val}

# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)

for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
                                 ['S', 'T', 'X', 'R']):
    for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
        # RLearner will need to fit on the p_hat
        if label_l != 'R':
            learner = base_learner(model())
            # fit the model on training data only
            learner.fit(X=X_train, treatment=w_train, y=y_train)
            try:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
            except TypeError:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
        else:
            learner = base_learner(model())
            learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
            preds_dict_train['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_train).flatten()
            preds_dict_valid['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_val).flatten()

learner = DragonNet(verbose=False)
learner.fit(X_train, treatment=w_train, y=y_train)
preds_dict_train['DragonNet'] = learner.predict_tau(X=X_train).flatten()
preds_dict_valid['DragonNet'] = learner.predict_tau(X=X_val).flatten()
[17]:
actuals_train = preds_dict_train['Actuals']
actuals_validation = preds_dict_valid['Actuals']

synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
                                        in preds_dict_train.items() if 'generated' not in label.lower()},
                                       index=['ATE', 'MSE']).T
synthetic_summary_train['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_train['ATE']/synthetic_summary_train.loc['Actuals', 'ATE']) - 1)

synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
                                             for label, preds in preds_dict_valid.items()
                                             if 'generated' not in label.lower()},
                                            index=['ATE', 'MSE']).T
synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_validation['ATE']/synthetic_summary_validation.loc['Actuals', 'ATE']) - 1)

# calculate kl divergence for training
for label in synthetic_summary_train.index:
    stacked_values = np.hstack((preds_dict_train[label], actuals_train))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_train[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_train, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_train.loc[label, 'KL Divergence'] = kl

# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
    stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_validation, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_validation.loc[label, 'KL Divergence'] = kl
[18]:
df_preds_train = pd.DataFrame([preds_dict_train['S Learner (LR)'].ravel(),
                               preds_dict_train['S Learner (XGB)'].ravel(),
                               preds_dict_train['T Learner (LR)'].ravel(),
                               preds_dict_train['T Learner (XGB)'].ravel(),
                               preds_dict_train['X Learner (LR)'].ravel(),
                               preds_dict_train['X Learner (XGB)'].ravel(),
                               preds_dict_train['R Learner (LR)'].ravel(),
                               preds_dict_train['R Learner (XGB)'].ravel(),
                               preds_dict_train['DragonNet'].ravel(),
                               preds_dict_train['generated_data']['tau'].ravel(),
                               preds_dict_train['generated_data']['w'].ravel(),
                               preds_dict_train['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'DragonNet','tau','w','y']).T

synthetic_summary_train['AUUC'] = auuc_score(df_preds_train).iloc[:-1]


[19]:
df_preds_validation = pd.DataFrame([preds_dict_valid['S Learner (LR)'].ravel(),
                               preds_dict_valid['S Learner (XGB)'].ravel(),
                               preds_dict_valid['T Learner (LR)'].ravel(),
                               preds_dict_valid['T Learner (XGB)'].ravel(),
                               preds_dict_valid['X Learner (LR)'].ravel(),
                               preds_dict_valid['X Learner (XGB)'].ravel(),
                               preds_dict_valid['R Learner (LR)'].ravel(),
                               preds_dict_valid['R Learner (XGB)'].ravel(),
                               preds_dict_valid['DragonNet'].ravel(),
                               preds_dict_valid['generated_data']['tau'].ravel(),
                               preds_dict_valid['generated_data']['w'].ravel(),
                               preds_dict_valid['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'DragonNet','tau','w','y']).T

synthetic_summary_validation['AUUC'] = auuc_score(df_preds_validation).iloc[:-1]
[20]:
synthetic_summary_train
[20]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.484486 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.528743 0.044194 0.091349 3.473087 0.508067
S Learner (XGB) 0.358208 0.310652 0.260643 0.817620 0.544115
T Learner (LR) 0.493815 0.022688 0.019255 0.289978 0.610855
T Learner (XGB) 0.397053 1.350928 0.180466 1.452143 0.521719
X Learner (LR) 0.493815 0.022688 0.019255 0.289978 0.610855
X Learner (XGB) 0.341352 0.620992 0.295435 1.086086 0.534827
R Learner (LR) 0.457692 0.028116 0.055304 0.335083 0.614414
R Learner (XGB) 0.434709 4.575591 0.102741 1.907325 0.505088
DragonNet 0.410899 0.044120 0.151888 0.467829 0.611620
[21]:
synthetic_summary_validation
[21]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.511242 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.528743 0.042236 0.034233 4.574498 0.495423
S Learner (XGB) 0.434208 0.260496 0.150680 0.854890 0.544206
T Learner (LR) 0.541503 0.025840 0.059191 0.686602 0.604712
T Learner (XGB) 0.483404 0.679398 0.054452 1.215394 0.526918
X Learner (LR) 0.541503 0.025840 0.059191 0.686602 0.604712
X Learner (XGB) 0.330427 0.344865 0.353678 1.227041 0.536599
R Learner (LR) 0.510236 0.030801 0.001967 0.654228 0.608133
R Learner (XGB) 0.417823 1.990451 0.182730 1.650560 0.504991
DragonNet 0.462146 0.043679 0.096032 0.825673 0.605744
[22]:
plot_gain(df_preds_train)
../_images/examples_dragonnet_example_24_0.png
[23]:
plot_gain(df_preds_validation)
../_images/examples_dragonnet_example_25_0.png
[ ]: