{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression Based Data Generation Function for Uplift Classification Problem\n", "This Data Generation Function uses Logistic Regression as the underlying data generation model.\n", "This function enables better control of feature patterns: how feature is associated with outcome baseline and treatment effect. It enables 6 differernt patterns: Linear, Quadratic, Cubic, Relu, Sine, and Cosine. \n", "\n", "This notebook shows how to use this data generation function to generate data, with a visualization of the feature patterns.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import numpy as np\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Data Generation Function" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n" ] } ], "source": [ "from causalml.dataset import make_uplift_classification_logistic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Data" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "df, feature_name = make_uplift_classification_logistic( n_samples=100000,\n", " treatment_name=['control', 'treatment1', 'treatment2', 'treatment3'],\n", " y_name='conversion',\n", " n_classification_features=10,\n", " n_classification_informative=5,\n", " n_classification_redundant=0,\n", " n_classification_repeated=0,\n", " n_uplift_dict={'treatment1': 2, 'treatment2': 2, 'treatment3': 3},\n", " n_mix_informative_uplift_dict={'treatment1': 1, 'treatment2': 1, 'treatment3': 0},\n", " delta_uplift_dict={'treatment1': 0.05, 'treatment2': 0.02, 'treatment3': -0.05},\n", " feature_association_list = ['linear','quadratic','cubic','relu','sin','cos'],\n", " random_select_association = False,\n", " random_seed=20200416\n", " \n", " )" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | treatment_group_key | \n", "x1_informative | \n", "x1_informative_transformed | \n", "x2_informative | \n", "x2_informative_transformed | \n", "x3_informative | \n", "x3_informative_transformed | \n", "x4_informative | \n", "x4_informative_transformed | \n", "x5_informative | \n", "... | \n", "conversion_prob | \n", "control_conversion_prob | \n", "control_true_effect | \n", "treatment1_conversion_prob | \n", "treatment1_true_effect | \n", "treatment2_conversion_prob | \n", "treatment2_true_effect | \n", "treatment3_conversion_prob | \n", "treatment3_true_effect | \n", "conversion | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "treatment1 | \n", "-0.194205 | \n", "-0.192043 | \n", "1.791408 | \n", "1.572609 | \n", "0.678028 | \n", "0.080696 | \n", "-0.169306 | \n", "-0.683035 | \n", "-1.837155 | \n", "... | \n", "0.126770 | \n", "0.076138 | \n", "0.0 | \n", "0.126770 | \n", "0.050632 | \n", "0.087545 | \n", "0.011407 | \n", "0.029396 | \n", "-0.046742 | \n", "0 | \n", "
| 1 | \n", "treatment1 | \n", "-0.898070 | \n", "-0.894462 | \n", "0.252125 | \n", "-0.663393 | \n", "-0.842844 | \n", "-0.156004 | \n", "-0.047769 | \n", "-0.683035 | \n", "-0.251752 | \n", "... | \n", "0.064278 | \n", "0.070799 | \n", "0.0 | \n", "0.064278 | \n", "-0.006522 | \n", "0.101076 | \n", "0.030277 | \n", "0.050778 | \n", "-0.020021 | \n", "0 | \n", "
| 2 | \n", "treatment1 | \n", "0.701002 | \n", "0.701325 | \n", "0.239320 | \n", "-0.667867 | \n", "1.700766 | \n", "1.278676 | \n", "-0.734568 | \n", "-0.683035 | \n", "-1.130113 | \n", "... | \n", "0.018480 | \n", "0.014947 | \n", "0.0 | \n", "0.018480 | \n", "0.003534 | \n", "0.018055 | \n", "0.003109 | \n", "0.019327 | \n", "0.004380 | \n", "0 | \n", "
| 3 | \n", "control | \n", "-1.653684 | \n", "-1.648524 | \n", "-0.119123 | \n", "-0.698492 | \n", "-0.037645 | \n", "-0.000355 | \n", "0.687429 | \n", "0.495943 | \n", "-1.427400 | \n", "... | \n", "0.102799 | \n", "0.102799 | \n", "0.0 | \n", "0.101410 | \n", "-0.001390 | \n", "0.040230 | \n", "-0.062569 | \n", "0.030753 | \n", "-0.072046 | \n", "0 | \n", "
| 4 | \n", "treatment3 | \n", "1.057909 | \n", "1.057498 | \n", "-2.019523 | \n", "2.190564 | \n", "-0.950180 | \n", "-0.223370 | \n", "-1.505741 | \n", "-0.683035 | \n", "-0.399457 | \n", "... | \n", "0.012964 | \n", "0.106241 | \n", "0.0 | \n", "0.171309 | \n", "0.065068 | \n", "0.114526 | \n", "0.008285 | \n", "0.012964 | \n", "-0.093277 | \n", "0 | \n", "
5 rows × 47 columns
\n", "