DragonNet: JAX vs TF Benchmark

Compares the JAX (flax.nnx) and TF (Keras) DragonNet implementations on the IHDP semi-synthetic dataset using ATE, MAE, AUUC, and wall-clock training time.

Reference: Shi, Blei & Veitch (2019) — https://arxiv.org/pdf/1906.02120.pdf

Setup

TF backend:

pip install tensorflow

JAX backend:

pip install "causalml[jax]"
# or: uv pip install jax flax optax orbax-checkpoint
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import os
import time
import warnings

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error

from causalml.metrics import auuc_score, plot_gain

# fmt: off
%matplotlib inline
warnings.filterwarnings("ignore")
plt.style.use("fivethirtyeight")
sns.set_palette("Paired")
plt.rcParams["figure.figsize"] = (12, 8)
# fmt: on

try:
    from causalml.inference.tf import DragonNet as TFDragonNet

    HAS_TF = True
except ImportError:
    HAS_TF = False
    print("[INFO] tensorflow not installed — TF DragonNet will be skipped.")

try:
    from causalml.inference.jax import DragonNet as JAXDragonNet

    HAS_JAX = True
except ImportError:
    HAS_JAX = False
    print("[INFO] jax/flax not installed — JAX DragonNet will be skipped.")

IHDP Dataset

Semi-synthetic dataset from Hill (2011), used in the original DragonNet paper. 747 observations from the Infant Health and Development Program (IHDP) study. We use one realisation (ihdp_npci_3.csv) to reproduce the notebook benchmark.

[3]:
DATA_PATH = os.path.join(os.path.dirname(os.path.abspath("__file__")), "data", "ihdp_npci_3.csv")

df = pd.read_csv(DATA_PATH, header=None)
cols = ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [
    f"x{i}" for i in range(1, 26)
]
df.columns = cols

X = df.loc[:, "x1":].values
treatment = df["treatment"].values
y = df["y_factual"].values
tau = df.apply(
    lambda d: d["y_factual"] - d["y_cfactual"]
    if d["treatment"] == 1
    else d["y_cfactual"] - d["y_factual"],
    axis=1,
).values

print(f"n={len(y)},  actual ATE = {tau.mean():.4f}")
df.head()
n=747,  actual ATE = 4.0989
[3]:
treatment y_factual y_cfactual mu0 mu1 x1 x2 x3 x4 x5 ... x16 x17 x18 x19 x20 x21 x22 x23 x24 x25
0 1 5.931652 3.500591 2.253801 7.136441 -0.528603 -0.343455 1.128554 0.161703 -0.316603 ... 1 1 1 1 0 0 0 0 0 0
1 0 2.175966 5.952101 1.257592 6.553022 -1.736945 -1.802002 0.383828 2.244320 -0.629189 ... 1 1 1 1 0 0 0 0 0 0
2 0 2.180294 7.175734 2.384100 7.192645 -0.807451 -0.202946 -0.360898 -0.879606 0.808706 ... 1 0 1 1 0 0 0 0 0 0
3 0 3.587662 7.787537 4.009365 7.712456 0.390083 0.596582 -1.850350 -0.879606 -0.004017 ... 1 0 1 1 0 0 0 0 0 0
4 0 2.372618 5.461871 2.481631 7.232739 -1.045229 -0.602710 0.011465 0.161703 0.683672 ... 1 1 1 1 0 0 0 0 0 0

5 rows × 30 columns

Train DragonNet

[4]:
SEED = 42
HYPERPARAMS = dict(neurons_per_layer=200, targeted_reg=True, verbose=False)

np.random.seed(SEED)

results = {}  # label -> {"ite": array, "train_time": float}

for label, cls, available in [
    ("TF DragonNet", TFDragonNet if HAS_TF else None, HAS_TF),
    ("JAX DragonNet", JAXDragonNet if HAS_JAX else None, HAS_JAX),
]:
    if not available:
        print(f"[{label}] skipped (backend not installed).")
        continue
    print(f"[{label}] fitting...")
    model = cls(**HYPERPARAMS)
    t0 = time.perf_counter()
    model.fit(X, treatment, y)
    train_time = time.perf_counter() - t0
    ite = model.predict_tau(X).ravel()
    results[label] = {"ite": ite, "train_time": train_time}
    print(f"[{label}] done — ATE={ite.mean():.4f},  train time={train_time:.2f}s")
[TF DragonNet] fitting...
24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
[TF DragonNet] done — ATE=3.9952,  train time=4.42s
[JAX DragonNet] fitting...
[JAX DragonNet] done — ATE=3.9220,  train time=1.71s

Results

[5]:
# Build df_preds following the same convention as dragonnet_example.ipynb:
# model columns + reserved columns tau / w / y
preds_data = {label: v["ite"] for label, v in results.items()}
preds_data.update({"tau": tau, "w": treatment, "y": y})
df_preds = pd.DataFrame(preds_data)

rows = []
for label, v in results.items():
    ite = v["ite"]
    rows.append(
        {
            "ATE": float(ite.mean()),
            "Actual ATE": float(tau.mean()),
            "MAE": float(mean_absolute_error(tau, ite)),
            "Train time (s)": float(v["train_time"]),
        }
    )

df_result = pd.DataFrame(rows, index=list(results.keys()))
df_result["AUUC"] = auuc_score(
    df_preds, outcome_col="y", treatment_col="w", treatment_effect_col="tau"
)

df_result
[5]:
ATE Actual ATE MAE Train time (s) AUUC
TF DragonNet 3.995167 4.098887 1.199390 4.416656 0.552354
JAX DragonNet 3.921982 4.098887 1.212568 1.710895 0.552620
[6]:
plot_gain(df_preds, outcome_col="y", treatment_col="w", treatment_effect_col="tau")
plt.title("Cumulative Gain — JAX vs TF DragonNet (IHDP)")
plt.tight_layout()
plt.show()
../_images/examples_dragonnet_jax_vs_tf_10_0.png

Note on training time: JAX train time includes one-time JIT compilation on the first mini-batch. TF time includes Keras graph building overhead. Both are measured end-to-end from fit() call to return.