import argparse
import logging
import sys
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
logger = logging.getLogger("causalml")
[docs]def smd(feature, treatment):
"""Calculate the standard mean difference (SMD) of a feature between the
treatment and control groups.
The definition is available at
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3144483/#s11title
Args:
feature (pandas.Series): a column of a feature to calculate SMD for
treatment (pandas.Series): a column that indicate whether a row is in
the treatment group or not
Returns:
(float): The SMD of the feature
"""
t = feature[treatment == 1]
c = feature[treatment == 0]
return (t.mean() - c.mean()) / np.sqrt(0.5 * (t.var() + c.var()))
[docs]def create_table_one(data, treatment_col, features, with_std=True, with_counts=True):
"""Report balance in input features between the treatment and control groups.
References:
R's tableone at CRAN: https://github.com/kaz-yos/tableone
Python's tableone at PyPi: https://github.com/tompollard/tableone
Args:
data (pandas.DataFrame): total or matched sample data
treatment_col (str): the column name for the treatment
features (list of str): the column names of features
with_std (bool): whether to output std together with mean values as in <mean> (<std>) format
with_counts (bool): whether to include a row counting the total number of samples
Returns:
(pandas.DataFrame): A table with the means and standard deviations in
the treatment and control groups, and the SMD between two groups
for the features.
"""
t1 = pd.pivot_table(
data[features + [treatment_col]],
columns=treatment_col,
aggfunc=[
lambda x: (
"{:.2f} ({:.2f})".format(x.mean(), x.std())
if with_std
else "{:.2f}".format(x.mean())
)
],
)
t1.columns = t1.columns.droplevel(level=0)
t1["SMD"] = data[features].apply(lambda x: smd(x, data[treatment_col])).round(4)
if with_counts:
n_row = pd.pivot_table(
data[[features[0], treatment_col]], columns=treatment_col, aggfunc=["count"]
)
n_row.columns = n_row.columns.droplevel(level=0)
n_row["SMD"] = ""
n_row.index = ["n"]
t1 = pd.concat([n_row, t1], axis=0)
t1.columns.name = ""
t1.columns = ["Control", "Treatment", "SMD"]
t1.index.name = "Variable"
return t1
[docs]class NearestNeighborMatch:
"""
Propensity score matching based on the nearest neighbor algorithm.
Attributes:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
ratio (int): ratio of control / treatment to be matched. used only if
replace=True.
shuffle (bool): whether to shuffle the treatment group data before
matching
random_state (numpy.random.RandomState or int): RandomState or an int
seed
n_jobs (int): The number of parallel jobs to run for neighbors search.
None means 1 unless in a joblib.parallel_backend context. -1 means using all processors
"""
def __init__(
self,
caliper=0.2,
replace=False,
ratio=1,
shuffle=True,
random_state=None,
n_jobs=-1,
):
"""Initialize a propensity score matching model.
Args:
caliper (float): threshold to be considered as a match.
replace (bool): whether to match with replacement or not
shuffle (bool): whether to shuffle the treatment group data before
matching or not
random_state (numpy.random.RandomState or int): RandomState or an
int seed
n_jobs (int): The number of parallel jobs to run for neighbors search.
None means 1 unless in a joblib.parallel_backend context. -1 means using all processors
"""
self.caliper = caliper
self.replace = replace
self.ratio = ratio
self.shuffle = shuffle
self.random_state = check_random_state(random_state)
self.n_jobs = n_jobs
[docs] def match(self, data, treatment_col, score_cols):
"""Find matches from the control group by matching on specified columns
(propensity preferred).
Args:
data (pandas.DataFrame): total input data
treatment_col (str): the column name for the treatment
score_cols (list): list of column names for matching (propensity
column should be included)
Returns:
(pandas.DataFrame): The subset of data consisting of matched
treatment and control group data.
"""
assert isinstance(score_cols, list), "score_cols must be a list"
treatment = data.loc[data[treatment_col] == 1, score_cols]
control = data.loc[data[treatment_col] == 0, score_cols]
sdcal = self.caliper * np.std(data[score_cols].values)
if self.replace:
scaler = StandardScaler()
scaler.fit(data[score_cols])
treatment_scaled = pd.DataFrame(
scaler.transform(treatment), index=treatment.index
)
control_scaled = pd.DataFrame(
scaler.transform(control), index=control.index
)
# SD is the same as caliper because we use a StandardScaler above
sdcal = self.caliper
matching_model = NearestNeighbors(
n_neighbors=self.ratio, n_jobs=self.n_jobs
)
matching_model.fit(control_scaled)
distances, indices = matching_model.kneighbors(treatment_scaled)
# distances and indices are (n_obs, self.ratio) matrices.
# To index easily, reshape distances, indices and treatment into
# the (n_obs * self.ratio, 1) matrices and data frame.
distances = distances.T.flatten()
indices = indices.T.flatten()
treatment_scaled = pd.concat([treatment_scaled] * self.ratio, axis=0)
cond = (distances / np.sqrt(len(score_cols))) < sdcal
# Deduplicate the indices of the treatment group
t_idx_matched = np.unique(treatment_scaled.loc[cond].index)
# XXX: Should we deduplicate the indices of the control group too?
c_idx_matched = np.array(control_scaled.iloc[indices[cond]].index)
else:
assert len(score_cols) == 1, (
"Matching on multiple columns is only supported using the "
"replacement method (if matching on multiple columns, set "
"replace=True)."
)
# unpack score_cols for the single-variable matching case
score_col = score_cols[0]
if self.shuffle:
t_indices = self.random_state.permutation(treatment.index)
else:
t_indices = treatment.index
t_idx_matched = []
c_idx_matched = []
control["unmatched"] = True
for t_idx in t_indices:
dist = np.abs(
control.loc[control.unmatched, score_col]
- treatment.loc[t_idx, score_col]
)
c_idx_min = dist.idxmin()
if dist[c_idx_min] <= sdcal:
t_idx_matched.append(t_idx)
c_idx_matched.append(c_idx_min)
control.loc[c_idx_min, "unmatched"] = False
return data.loc[
np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)])
]
[docs] def match_by_group(self, data, treatment_col, score_cols, groupby_col):
"""Find matches from the control group stratified by groupby_col, by
matching on specified columns (propensity preferred).
Args:
data (pandas.DataFrame): total sample data
treatment_col (str): the column name for the treatment
score_cols (list): list of column names for matching (propensity
column should be included)
groupby_col (str): the column name to be used for stratification
Returns:
(pandas.DataFrame): The subset of data consisting of matched
treatment and control group data.
"""
matched = data.groupby(groupby_col).apply(
lambda x: self.match(
data=x, treatment_col=treatment_col, score_cols=score_cols
)
)
return matched.reset_index(level=0, drop=True)
[docs]class MatchOptimizer:
def __init__(
self,
treatment_col="is_treatment",
ps_col="pihat",
user_col=None,
matching_covariates=["pihat"],
max_smd=0.1,
max_deviation=0.1,
caliper_range=(0.01, 0.5),
max_pihat_range=(0.95, 0.999),
max_iter_per_param=5,
min_users_per_group=1000,
smd_cols=["pihat"],
dev_cols_transformations={"pihat": np.mean},
dev_factor=1.0,
verbose=True,
):
"""Finds the set of parameters that gives the best matching result.
Score = (number of features with SMD > max_smd)
+ (sum of deviations for important variables
* deviation factor)
The logic behind the scoring is that we are most concerned with
minimizing the number of features where SMD is lower than a certain
threshold (max_smd). However, we would also like the matched dataset
not deviate too much from the original dataset, in terms of key
variable(s), so that we still retain a similar userbase.
Args:
- treatment_col (str): name of the treatment column
- ps_col (str): name of the propensity score column
- max_smd (float): maximum acceptable SMD
- max_deviation (float): maximum acceptable deviation for
important variables
- caliper_range (tuple): low and high bounds for caliper search
range
- max_pihat_range (tuple): low and high bounds for max pihat
search range
- max_iter_per_param (int): maximum number of search values per
parameters
- min_users_per_group (int): minimum number of users per group in
matched set
- smd_cols (list): score is more sensitive to these features
exceeding max_smd
- dev_factor (float): importance weight factor for dev_cols
(e.g. dev_factor=1 means a 10% deviation leads to penalty of 1
in score)
- dev_cols_transformations (dict): dict of transformations to be
made on dev_cols
- verbose (bool): boolean flag for printing statements
Returns:
The best matched dataset (pd.DataFrame)
"""
self.treatment_col = treatment_col
self.ps_col = ps_col
self.user_col = user_col
self.matching_covariates = matching_covariates
self.max_smd = max_smd
self.max_deviation = max_deviation
self.caliper_range = np.linspace(*caliper_range, num=max_iter_per_param)
self.max_pihat_range = np.linspace(*max_pihat_range, num=max_iter_per_param)
self.max_iter_per_param = max_iter_per_param
self.min_users_per_group = min_users_per_group
self.smd_cols = smd_cols
self.dev_factor = dev_factor
self.dev_cols_transformations = dev_cols_transformations
self.best_params = {}
self.best_score = 1e7 # ideal score is 0
self.verbose = verbose
self.pass_all = False
[docs] def single_match(self, score_cols, pihat_threshold, caliper):
matcher = NearestNeighborMatch(caliper=caliper, replace=True)
df_matched = matcher.match(
data=self.df[self.df[self.ps_col] < pihat_threshold],
treatment_col=self.treatment_col,
score_cols=score_cols,
)
return df_matched
[docs] def check_table_one(self, tableone, matched, score_cols, pihat_threshold, caliper):
# check if better than past runs
smd_values = np.abs(tableone[tableone.index != "n"]["SMD"].astype(float))
num_cols_over_smd = (smd_values >= self.max_smd).sum()
self.cols_to_fix = (
smd_values[smd_values >= self.max_smd]
.sort_values(ascending=False)
.index.values
)
if self.user_col is None:
num_users_per_group = (
matched.reset_index().groupby(self.treatment_col)["index"].count().min()
)
else:
num_users_per_group = (
matched.groupby(self.treatment_col)[self.user_col].count().min()
)
deviations = [
np.abs(
self.original_stats[col]
/ matched[matched[self.treatment_col] == 1][col].mean()
- 1
)
for col in self.dev_cols_transformations.keys()
]
score = num_cols_over_smd
score += len(
[col for col in self.smd_cols if smd_values.loc[col] >= self.max_smd]
)
score += np.sum([dev * 10 * self.dev_factor for dev in deviations])
# check if can be considered as best score
if score < self.best_score and num_users_per_group > self.min_users_per_group:
self.best_score = score
self.best_params = {
"score_cols": score_cols.copy(),
"pihat": pihat_threshold,
"caliper": caliper,
}
self.best_matched = matched.copy()
if self.verbose:
logger.info(
"\tScore: {:.03f} (Best Score: {:.03f})\n".format(
score, self.best_score
)
)
# check if passes all criteria
self.pass_all = (
(num_users_per_group > self.min_users_per_group)
and (num_cols_over_smd == 0)
and all(dev < self.max_deviation for dev in deviations)
)
[docs] def match_and_check(self, score_cols, pihat_threshold, caliper):
if self.verbose:
logger.info(
"Preparing match for: caliper={:.03f}, "
"pihat_threshold={:.03f}, "
"score_cols={}".format(caliper, pihat_threshold, score_cols)
)
df_matched = self.single_match(
score_cols=score_cols, pihat_threshold=pihat_threshold, caliper=caliper
)
tableone = create_table_one(
df_matched, self.treatment_col, self.matching_covariates
)
self.check_table_one(tableone, df_matched, score_cols, pihat_threshold, caliper)
[docs] def search_best_match(self, df):
self.df = df
self.original_stats = {}
for col, trans in self.dev_cols_transformations.items():
self.original_stats[col] = trans(
self.df[self.df[self.treatment_col] == 1][col]
)
# search best max pihat
if self.verbose:
logger.info("SEARCHING FOR BEST PIHAT")
score_cols = [self.ps_col]
caliper = self.caliper_range[-1]
for pihat_threshold in self.max_pihat_range:
self.match_and_check(score_cols, pihat_threshold, caliper)
# search best score_cols
if self.verbose:
logger.info("SEARCHING FOR BEST SCORE_COLS")
pihat_threshold = self.best_params["pihat"]
caliper = self.caliper_range[int(self.caliper_range.shape[0] / 2)]
score_cols = [self.ps_col]
while not self.pass_all:
if len(self.cols_to_fix) == 0:
break
elif np.intersect1d(self.cols_to_fix, score_cols).shape[0] > 0:
break
else:
score_cols.append(self.cols_to_fix[0])
self.match_and_check(score_cols, pihat_threshold, caliper)
# search best caliper
if self.verbose:
logger.info("SEARCHING FOR BEST CALIPER")
score_cols = self.best_params["score_cols"]
pihat_threshold = self.best_params["pihat"]
for caliper in self.caliper_range:
self.match_and_check(score_cols, pihat_threshold, caliper)
# summarize
if self.verbose:
logger.info("\n-----\nBest params are:\n{}".format(self.best_params))
return self.best_matched
if __name__ == "__main__":
from .features import load_data
from .propensity import ElasticNetPropensityModel
TREATMENT_COL = "treatment"
SCORE_COL = "score"
GROUPBY_COL = "group"
parser = argparse.ArgumentParser()
parser.add_argument("--input-file", required=True, dest="input_file")
parser.add_argument("--output-file", required=True, dest="output_file")
parser.add_argument("--treatment-col", default=TREATMENT_COL, dest="treatment_col")
parser.add_argument("--groupby-col", default=GROUPBY_COL, dest="groupby_col")
parser.add_argument("--score-col", default=SCORE_COL, dest="score_col")
parser.add_argument("--feature-cols", nargs="+", required=True, dest="feature_cols")
parser.add_argument(
"--matching-cols", nargs="+", required=True, dest="matching_cols"
)
parser.add_argument("--caliper", type=float, default=0.2)
parser.add_argument("--replace", default=False, action="store_true")
parser.add_argument("--ratio", type=int, default=1)
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger.info("Loading data from {}".format(args.input_file))
df = pd.read_csv(args.input_file)
df[args.treatment_col] = df[args.treatment_col].astype(int)
logger.info("shape: {}\n{}".format(df.shape, df.head()))
pm = ElasticNetPropensityModel(random_state=42)
w = df[args.treatment_col].values
X = load_data(
data=df,
features=args.feature_cols,
)
logger.info("Scoring with a propensity model: {}".format(pm))
df[args.score_col] = pm.fit_predict(X, w)
logger.info(
"Balance before matching:\n{}".format(
create_table_one(
data=df, treatment_col=args.treatment_col, features=args.matching_cols
)
)
)
logger.info(
"Matching based on the propensity score with the nearest neighbor model"
)
psm = NearestNeighborMatch(replace=args.replace, ratio=args.ratio, random_state=42)
matched = psm.match_by_group(
data=df,
treatment_col=args.treatment_col,
score_cols=[args.score_col],
groupby_col=args.groupby_col,
)
logger.info("shape: {}\n{}".format(matched.shape, matched.head()))
logger.info(
"Balance after matching:\n{}".format(
create_table_one(
data=matched,
treatment_col=args.treatment_col,
features=args.matching_cols,
)
)
)
matched.to_csv(args.output_file, index=False)
logger.info("Matched data saved as {}".format(args.output_file))