{ "cells": [ { "cell_type": "markdown", "id": "51363966", "metadata": {}, "source": [ "# DragonNet: JAX vs TF Benchmark\n", "\n", "Compares the JAX (`flax.nnx`) and TF (Keras) DragonNet implementations on the\n", "IHDP semi-synthetic dataset using ATE, MAE, AUUC, and wall-clock training time.\n", "\n", "Reference: Shi, Blei & Veitch (2019) — https://arxiv.org/pdf/1906.02120.pdf" ] }, { "cell_type": "markdown", "id": "2613c834", "metadata": {}, "source": [ "## Setup\n", "\n", "**TF backend:**\n", "```\n", "pip install tensorflow\n", "```\n", "\n", "**JAX backend:**\n", "```\n", "pip install \"causalml[jax]\"\n", "# or: uv pip install jax flax optax orbax-checkpoint\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "id": "d21524db", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "4de019de", "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import warnings\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "from matplotlib import pyplot as plt\n", "from sklearn.metrics import mean_absolute_error\n", "\n", "from causalml.metrics import auuc_score, plot_gain\n", "\n", "# fmt: off\n", "%matplotlib inline\n", "warnings.filterwarnings(\"ignore\")\n", "plt.style.use(\"fivethirtyeight\")\n", "sns.set_palette(\"Paired\")\n", "plt.rcParams[\"figure.figsize\"] = (12, 8)\n", "# fmt: on\n", "\n", "try:\n", " from causalml.inference.tf import DragonNet as TFDragonNet\n", "\n", " HAS_TF = True\n", "except ImportError:\n", " HAS_TF = False\n", " print(\"[INFO] tensorflow not installed — TF DragonNet will be skipped.\")\n", "\n", "try:\n", " from causalml.inference.jax import DragonNet as JAXDragonNet\n", "\n", " HAS_JAX = True\n", "except ImportError:\n", " HAS_JAX = False\n", " print(\"[INFO] jax/flax not installed — JAX DragonNet will be skipped.\")" ] }, { "cell_type": "markdown", "id": "62cd229a", "metadata": {}, "source": [ "## IHDP Dataset\n", "\n", "Semi-synthetic dataset from Hill (2011), used in the original DragonNet paper.\n", "747 observations from the Infant Health and Development Program (IHDP) study.\n", "We use one realisation (`ihdp_npci_3.csv`) to reproduce the notebook benchmark." ] }, { "cell_type": "code", "execution_count": 3, "id": "9633a572", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "n=747, actual ATE = 4.0989\n" ] }, { "data": { "text/html": [ "
| \n", " | treatment | \n", "y_factual | \n", "y_cfactual | \n", "mu0 | \n", "mu1 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "... | \n", "x16 | \n", "x17 | \n", "x18 | \n", "x19 | \n", "x20 | \n", "x21 | \n", "x22 | \n", "x23 | \n", "x24 | \n", "x25 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1 | \n", "5.931652 | \n", "3.500591 | \n", "2.253801 | \n", "7.136441 | \n", "-0.528603 | \n", "-0.343455 | \n", "1.128554 | \n", "0.161703 | \n", "-0.316603 | \n", "... | \n", "1 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
| 1 | \n", "0 | \n", "2.175966 | \n", "5.952101 | \n", "1.257592 | \n", "6.553022 | \n", "-1.736945 | \n", "-1.802002 | \n", "0.383828 | \n", "2.244320 | \n", "-0.629189 | \n", "... | \n", "1 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
| 2 | \n", "0 | \n", "2.180294 | \n", "7.175734 | \n", "2.384100 | \n", "7.192645 | \n", "-0.807451 | \n", "-0.202946 | \n", "-0.360898 | \n", "-0.879606 | \n", "0.808706 | \n", "... | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
| 3 | \n", "0 | \n", "3.587662 | \n", "7.787537 | \n", "4.009365 | \n", "7.712456 | \n", "0.390083 | \n", "0.596582 | \n", "-1.850350 | \n", "-0.879606 | \n", "-0.004017 | \n", "... | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
| 4 | \n", "0 | \n", "2.372618 | \n", "5.461871 | \n", "2.481631 | \n", "7.232739 | \n", "-1.045229 | \n", "-0.602710 | \n", "0.011465 | \n", "0.161703 | \n", "0.683672 | \n", "... | \n", "1 | \n", "1 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 rows × 30 columns
\n", "| \n", " | ATE | \n", "Actual ATE | \n", "MAE | \n", "Train time (s) | \n", "AUUC | \n", "
|---|---|---|---|---|---|
| TF DragonNet | \n", "3.995167 | \n", "4.098887 | \n", "1.199390 | \n", "4.416656 | \n", "0.552354 | \n", "
| JAX DragonNet | \n", "3.921982 | \n", "4.098887 | \n", "1.212568 | \n", "1.710895 | \n", "0.552620 | \n", "