CEVAE vs. Meta-Learners Benchmark with IHDP + Synthetic Datasets

[1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import torch

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import entropy
import warnings
import logging

from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.nn import CEVAE
from causalml.propensity import ElasticNetPropensityModel
from causalml.metrics import *
from causalml.dataset import simulate_hidden_confounder

%matplotlib inline

warnings.filterwarnings('ignore')
logger = logging.getLogger('causalml')
logger.setLevel(logging.DEBUG)

plt.style.use('fivethirtyeight')
sns.set_palette('Paired')
plt.rcParams['figure.figsize'] = (12,8)

IHDP semi-synthetic dataset

Hill introduced a semi-synthetic dataset constructed from the Infant Health and Development Program (IHDP). This dataset is based on a randomized experiment investigating the effect of home visits by specialists on future cognitive scores. The IHDP simulation is considered the de-facto standard benchmark for neural network treatment effect estimation methods.

[2]:
# load all ihadp data
df = pd.DataFrame()
for i in range(1, 10):
    data = pd.read_csv('./data/ihdp_npci_' + str(i) + '.csv', header=None)
    df = pd.concat([data, df])
cols =  ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [i for i in range(25)]
df.columns = cols
print(df.shape)

# replicate the data 100 times
replications = 100
df = pd.concat([df]*replications, ignore_index=True)
print(df.shape)
(6723, 30)
(672300, 30)
[3]:
# set which features are binary
binfeats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
# set which features are continuous
contfeats = [i for i in range(25) if i not in binfeats]

# reorder features with binary first and continuous after
perm = binfeats + contfeats
[4]:
df = df.reset_index(drop=True)
df.head()
[4]:
treatment y_factual y_cfactual mu0 mu1 0 1 2 3 4 ... 15 16 17 18 19 20 21 22 23 24
0 1 49.647921 34.950762 37.173291 50.383798 -0.528603 -0.343455 1.128554 0.161703 -0.316603 ... 1 1 1 1 0 0 0 0 0 0
1 0 16.073412 49.435313 16.087249 49.546234 -1.736945 -1.802002 0.383828 2.244320 -0.629189 ... 1 1 1 1 0 0 0 0 0 0
2 0 19.643007 48.598210 18.044855 49.661068 -0.807451 -0.202946 -0.360898 -0.879606 0.808706 ... 1 0 1 1 0 0 0 0 0 0
3 0 26.368322 49.715204 24.605964 49.971196 0.390083 0.596582 -1.850350 -0.879606 -0.004017 ... 1 0 1 1 0 0 0 0 0 0
4 0 20.258893 51.147418 20.612816 49.794120 -1.045229 -0.602710 0.011465 0.161703 0.683672 ... 1 1 1 1 0 0 0 0 0 0

5 rows × 30 columns

[5]:
X = df[perm].values
treatment = df['treatment'].values
y = df['y_factual'].values
y_cf = df['y_cfactual'].values
tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1
               else d['y_cfactual'] - d['y_factual'],
               axis=1)
mu_0 = df['mu0'].values
mu_1 = df['mu1'].values
[6]:
# seperate for train and test
itr, ite = train_test_split(np.arange(X.shape[0]), test_size=0.2, random_state=1)
X_train, treatment_train, y_train, y_cf_train, tau_train, mu_0_train, mu_1_train = X[itr], treatment[itr], y[itr], y_cf[itr], tau[itr], mu_0[itr], mu_1[itr]
X_val, treatment_val, y_val, y_cf_val, tau_val, mu_0_val, mu_1_val = X[ite], treatment[ite], y[ite], y_cf[ite], tau[ite], mu_0[ite], mu_1[ite]

CEVAE Model

[7]:
# cevae model settings
outcome_dist = "normal"
latent_dim = 20
hidden_dim = 200
num_epochs = 5
batch_size = 1000
learning_rate = 0.001
learning_rate_decay = 0.01
num_layers = 2
[8]:
cevae = CEVAE(outcome_dist=outcome_dist,
              latent_dim=latent_dim,
              hidden_dim=hidden_dim,
              num_epochs=num_epochs,
              batch_size=batch_size,
              learning_rate=learning_rate,
              learning_rate_decay=learning_rate_decay,
              num_layers=num_layers)
[9]:
# fit
losses = cevae.fit(X=torch.tensor(X_train, dtype=torch.float),
                   treatment=torch.tensor(treatment_train, dtype=torch.float),
                   y=torch.tensor(y_train, dtype=torch.float))
INFO     Training with 538 minibatches per epoch
DEBUG    step     0 loss = 1021.35
DEBUG    step     1 loss = 421.484
DEBUG    step     2 loss = 338.296
DEBUG    step     3 loss = 319.514
DEBUG    step     4 loss = 217.484
DEBUG    step     5 loss = 237.474
DEBUG    step     6 loss = 242.367
DEBUG    step     7 loss = 236.713
DEBUG    step     8 loss = 200.399
DEBUG    step     9 loss = 201.788
DEBUG    step    10 loss = 220.049
DEBUG    step    11 loss = 213.79
DEBUG    step    12 loss = 190.921
DEBUG    step    13 loss = 196.359
DEBUG    step    14 loss = 189.747
DEBUG    step    15 loss = 167.321
DEBUG    step    16 loss = 159.207
DEBUG    step    17 loss = 154.599
DEBUG    step    18 loss = 150.961
DEBUG    step    19 loss = 149.938
DEBUG    step    20 loss = 134.768
DEBUG    step    21 loss = 140.833
DEBUG    step    22 loss = 146.769
DEBUG    step    23 loss = 132.524
DEBUG    step    24 loss = 134.194
DEBUG    step    25 loss = 130.618
DEBUG    step    26 loss = 136.787
DEBUG    step    27 loss = 126.727
DEBUG    step    28 loss = 120.942
DEBUG    step    29 loss = 118.619
DEBUG    step    30 loss = 120.946
DEBUG    step    31 loss = 110.782
DEBUG    step    32 loss = 120.907
DEBUG    step    33 loss = 106.87
DEBUG    step    34 loss = 95.3908
DEBUG    step    35 loss = 104.229
DEBUG    step    36 loss = 100.688
DEBUG    step    37 loss = 102.31
DEBUG    step    38 loss = 96.3181
DEBUG    step    39 loss = 92.0119
DEBUG    step    40 loss = 101.374
DEBUG    step    41 loss = 95.1874
DEBUG    step    42 loss = 91.693
DEBUG    step    43 loss = 83.7838
DEBUG    step    44 loss = 76.9446
DEBUG    step    45 loss = 77.8403
DEBUG    step    46 loss = 81.372
DEBUG    step    47 loss = 82.7198
DEBUG    step    48 loss = 72.8519
DEBUG    step    49 loss = 76.6569
DEBUG    step    50 loss = 75.7397
DEBUG    step    51 loss = 79.6319
DEBUG    step    52 loss = 79.2719
DEBUG    step    53 loss = 74.6354
DEBUG    step    54 loss = 68.5501
DEBUG    step    55 loss = 72.5121
DEBUG    step    56 loss = 65.3819
DEBUG    step    57 loss = 68.0494
DEBUG    step    58 loss = 69.0703
DEBUG    step    59 loss = 67.7917
DEBUG    step    60 loss = 66.9287
DEBUG    step    61 loss = 58.5794
DEBUG    step    62 loss = 59.4718
DEBUG    step    63 loss = 62.9541
DEBUG    step    64 loss = 60.0412
DEBUG    step    65 loss = 57.8926
DEBUG    step    66 loss = 57.5324
DEBUG    step    67 loss = 56.5494
DEBUG    step    68 loss = 52.2587
DEBUG    step    69 loss = 55.7073
DEBUG    step    70 loss = 54.979
DEBUG    step    71 loss = 55.4208
DEBUG    step    72 loss = 54.7927
DEBUG    step    73 loss = 49.0343
DEBUG    step    74 loss = 53.8712
DEBUG    step    75 loss = 50.4505
DEBUG    step    76 loss = 49.2015
DEBUG    step    77 loss = 49.1161
DEBUG    step    78 loss = 51.0351
DEBUG    step    79 loss = 47.8925
DEBUG    step    80 loss = 48.4682
DEBUG    step    81 loss = 47.0941
DEBUG    step    82 loss = 44.807
DEBUG    step    83 loss = 43.6143
DEBUG    step    84 loss = 48.9903
DEBUG    step    85 loss = 46.6454
DEBUG    step    86 loss = 46.2746
DEBUG    step    87 loss = 47.5599
DEBUG    step    88 loss = 45.7764
DEBUG    step    89 loss = 42.9916
DEBUG    step    90 loss = 43.2444
DEBUG    step    91 loss = 43.616
DEBUG    step    92 loss = 41.0364
DEBUG    step    93 loss = 40.7751
DEBUG    step    94 loss = 39.693
DEBUG    step    95 loss = 41.2092
DEBUG    step    96 loss = 41.3535
DEBUG    step    97 loss = 39.0969
DEBUG    step    98 loss = 39.176
DEBUG    step    99 loss = 41.4575
DEBUG    step   100 loss = 40.5371
DEBUG    step   101 loss = 39.4805
DEBUG    step   102 loss = 37.7776
DEBUG    step   103 loss = 36.5425
DEBUG    step   104 loss = 37.3177
DEBUG    step   105 loss = 37.9773
DEBUG    step   106 loss = 36.8961
DEBUG    step   107 loss = 36.6936
DEBUG    step   108 loss = 35.1503
DEBUG    step   109 loss = 37.8622
DEBUG    step   110 loss = 36.6135
DEBUG    step   111 loss = 34.6556
DEBUG    step   112 loss = 32.9034
DEBUG    step   113 loss = 35.928
DEBUG    step   114 loss = 35.6375
DEBUG    step   115 loss = 34.8875
DEBUG    step   116 loss = 32.4369
DEBUG    step   117 loss = 35.5889
DEBUG    step   118 loss = 33.3445
DEBUG    step   119 loss = 35.3891
DEBUG    step   120 loss = 32.7132
DEBUG    step   121 loss = 32.4759
DEBUG    step   122 loss = 33.143
DEBUG    step   123 loss = 31.3498
DEBUG    step   124 loss = 31.6331
DEBUG    step   125 loss = 33.2434
DEBUG    step   126 loss = 31.1028
DEBUG    step   127 loss = 32.8674
DEBUG    step   128 loss = 32.8578
DEBUG    step   129 loss = 32.625
DEBUG    step   130 loss = 31.8448
DEBUG    step   131 loss = 30.8554
DEBUG    step   132 loss = 31.9763
DEBUG    step   133 loss = 29.6616
DEBUG    step   134 loss = 30.0425
DEBUG    step   135 loss = 30.836
DEBUG    step   136 loss = 31.0736
DEBUG    step   137 loss = 30.8878
DEBUG    step   138 loss = 30.43
DEBUG    step   139 loss = 30.6093
DEBUG    step   140 loss = 30.7339
DEBUG    step   141 loss = 30.0207
DEBUG    step   142 loss = 29.3626
DEBUG    step   143 loss = 29.7463
DEBUG    step   144 loss = 29.4184
DEBUG    step   145 loss = 29.2421
DEBUG    step   146 loss = 29.7529
DEBUG    step   147 loss = 29.3111
DEBUG    step   148 loss = 28.7811
DEBUG    step   149 loss = 29.3185
DEBUG    step   150 loss = 28.3709
DEBUG    step   151 loss = 30.2563
DEBUG    step   152 loss = 29.5989
DEBUG    step   153 loss = 28.8563
DEBUG    step   154 loss = 27.3948
DEBUG    step   155 loss = 28.3484
DEBUG    step   156 loss = 29.0616
DEBUG    step   157 loss = 28.8883
DEBUG    step   158 loss = 27.0463
DEBUG    step   159 loss = 27.3796
DEBUG    step   160 loss = 29.0732
DEBUG    step   161 loss = 26.8263
DEBUG    step   162 loss = 27.2883
DEBUG    step   163 loss = 28.6272
DEBUG    step   164 loss = 26.7478
DEBUG    step   165 loss = 27.6244
DEBUG    step   166 loss = 26.3508
DEBUG    step   167 loss = 26.1734
DEBUG    step   168 loss = 26.4877
DEBUG    step   169 loss = 26.9542
DEBUG    step   170 loss = 27.5395
DEBUG    step   171 loss = 26.4924
DEBUG    step   172 loss = 26.2203
DEBUG    step   173 loss = 26.039
DEBUG    step   174 loss = 25.7883
DEBUG    step   175 loss = 25.7104
DEBUG    step   176 loss = 25.9135
DEBUG    step   177 loss = 25.8419
DEBUG    step   178 loss = 26.897
DEBUG    step   179 loss = 24.8235
DEBUG    step   180 loss = 25.8669
DEBUG    step   181 loss = 26.442
DEBUG    step   182 loss = 24.7512
DEBUG    step   183 loss = 25.4444
DEBUG    step   184 loss = 25.7225
DEBUG    step   185 loss = 24.9703
DEBUG    step   186 loss = 25.5197
DEBUG    step   187 loss = 25.3311
DEBUG    step   188 loss = 25.0711
DEBUG    step   189 loss = 25.5542
DEBUG    step   190 loss = 25.2289
DEBUG    step   191 loss = 24.9589
DEBUG    step   192 loss = 24.5436
DEBUG    step   193 loss = 24.4451
DEBUG    step   194 loss = 23.3428
DEBUG    step   195 loss = 24.6046
DEBUG    step   196 loss = 25.1871
DEBUG    step   197 loss = 24.1005
DEBUG    step   198 loss = 24.287
DEBUG    step   199 loss = 24.4165
DEBUG    step   200 loss = 24.5855
DEBUG    step   201 loss = 23.2874
DEBUG    step   202 loss = 23.8787
DEBUG    step   203 loss = 24.5806
DEBUG    step   204 loss = 24.0906
DEBUG    step   205 loss = 25.0818
DEBUG    step   206 loss = 23.9177
DEBUG    step   207 loss = 25.0566
DEBUG    step   208 loss = 23.0722
DEBUG    step   209 loss = 23.8822
DEBUG    step   210 loss = 24.3339
DEBUG    step   211 loss = 24.7321
DEBUG    step   212 loss = 22.9672
DEBUG    step   213 loss = 23.6966
DEBUG    step   214 loss = 23.0869
DEBUG    step   215 loss = 23.5599
DEBUG    step   216 loss = 23.6307
DEBUG    step   217 loss = 23.1928
DEBUG    step   218 loss = 23.9375
DEBUG    step   219 loss = 23.65
DEBUG    step   220 loss = 22.5324
DEBUG    step   221 loss = 23.7082
DEBUG    step   222 loss = 22.854
DEBUG    step   223 loss = 21.8886
DEBUG    step   224 loss = 23.4573
DEBUG    step   225 loss = 22.4752
DEBUG    step   226 loss = 22.2281
DEBUG    step   227 loss = 22.6597
DEBUG    step   228 loss = 22.8313
DEBUG    step   229 loss = 22.8756
DEBUG    step   230 loss = 22.1289
DEBUG    step   231 loss = 22.6235
DEBUG    step   232 loss = 22.0739
DEBUG    step   233 loss = 22.7643
DEBUG    step   234 loss = 21.5396
DEBUG    step   235 loss = 21.5537
DEBUG    step   236 loss = 21.8743
DEBUG    step   237 loss = 22.6117
DEBUG    step   238 loss = 22.8206
DEBUG    step   239 loss = 22.8641
DEBUG    step   240 loss = 22.5666
DEBUG    step   241 loss = 22.3578
DEBUG    step   242 loss = 23.3638
DEBUG    step   243 loss = 22.1094
DEBUG    step   244 loss = 22.1056
DEBUG    step   245 loss = 22.1651
DEBUG    step   246 loss = 21.4072
DEBUG    step   247 loss = 21.4627
DEBUG    step   248 loss = 21.2096
DEBUG    step   249 loss = 21.3499
DEBUG    step   250 loss = 21.4386
DEBUG    step   251 loss = 21.3385
DEBUG    step   252 loss = 21.3782
DEBUG    step   253 loss = 20.7455
DEBUG    step   254 loss = 22.3244
DEBUG    step   255 loss = 21.1068
DEBUG    step   256 loss = 21.5648
DEBUG    step   257 loss = 21.5746
DEBUG    step   258 loss = 21.6169
DEBUG    step   259 loss = 21.2303
DEBUG    step   260 loss = 21.8207
DEBUG    step   261 loss = 21.2217
DEBUG    step   262 loss = 22.4259
DEBUG    step   263 loss = 21.2911
DEBUG    step   264 loss = 21.9783
DEBUG    step   265 loss = 120.585
DEBUG    step   266 loss = 22.3958
DEBUG    step   267 loss = 21.1204
DEBUG    step   268 loss = 20.3405
DEBUG    step   269 loss = 19.9695
DEBUG    step   270 loss = 21.6718
DEBUG    step   271 loss = 20.8654
DEBUG    step   272 loss = 20.4101
DEBUG    step   273 loss = 20.769
DEBUG    step   274 loss = 20.5526
DEBUG    step   275 loss = 20.026
DEBUG    step   276 loss = 20.2413
DEBUG    step   277 loss = 20.0747
DEBUG    step   278 loss = 20.6848
DEBUG    step   279 loss = 20.0956
DEBUG    step   280 loss = 20.667
DEBUG    step   281 loss = 19.8283
DEBUG    step   282 loss = 19.8651
DEBUG    step   283 loss = 19.4686
DEBUG    step   284 loss = 19.7195
DEBUG    step   285 loss = 20.1469
DEBUG    step   286 loss = 19.8956
DEBUG    step   287 loss = 20.3657
DEBUG    step   288 loss = 20.1624
DEBUG    step   289 loss = 20.8871
DEBUG    step   290 loss = 19.7327
DEBUG    step   291 loss = 19.3476
DEBUG    step   292 loss = 19.841
DEBUG    step   293 loss = 20.0052
DEBUG    step   294 loss = 19.7133
DEBUG    step   295 loss = 19.7911
DEBUG    step   296 loss = 18.6917
DEBUG    step   297 loss = 19.795
DEBUG    step   298 loss = 19.1175
DEBUG    step   299 loss = 20.1492
DEBUG    step   300 loss = 19.7831
DEBUG    step   301 loss = 19.7247
DEBUG    step   302 loss = 19.5755
DEBUG    step   303 loss = 19.9661
DEBUG    step   304 loss = 18.2884
DEBUG    step   305 loss = 19.6565
DEBUG    step   306 loss = 19.6213
DEBUG    step   307 loss = 19.2026
DEBUG    step   308 loss = 19.8699
DEBUG    step   309 loss = 18.7806
DEBUG    step   310 loss = 18.8876
DEBUG    step   311 loss = 19.3982
DEBUG    step   312 loss = 19.1813
DEBUG    step   313 loss = 18.9337
DEBUG    step   314 loss = 18.2574
DEBUG    step   315 loss = 19.0662
DEBUG    step   316 loss = 19.1584
DEBUG    step   317 loss = 18.1926
DEBUG    step   318 loss = 18.7658
DEBUG    step   319 loss = 18.2249
DEBUG    step   320 loss = 19.003
DEBUG    step   321 loss = 19.0593
DEBUG    step   322 loss = 18.9254
DEBUG    step   323 loss = 19.0602
DEBUG    step   324 loss = 18.5273
DEBUG    step   325 loss = 18.2321
DEBUG    step   326 loss = 18.354
DEBUG    step   327 loss = 18.2741
DEBUG    step   328 loss = 18.544
DEBUG    step   329 loss = 18.3197
DEBUG    step   330 loss = 18.8422
DEBUG    step   331 loss = 18.4199
DEBUG    step   332 loss = 17.7382
DEBUG    step   333 loss = 18.1209
DEBUG    step   334 loss = 18.4557
DEBUG    step   335 loss = 18.5937
DEBUG    step   336 loss = 17.7678
DEBUG    step   337 loss = 19.1363
DEBUG    step   338 loss = 18.0725
DEBUG    step   339 loss = 18.3309
DEBUG    step   340 loss = 17.9822
DEBUG    step   341 loss = 17.7317
DEBUG    step   342 loss = 18.1821
DEBUG    step   343 loss = 18.1704
DEBUG    step   344 loss = 18.0436
DEBUG    step   345 loss = 17.3161
DEBUG    step   346 loss = 17.1744
DEBUG    step   347 loss = 18.4531
DEBUG    step   348 loss = 17.097
DEBUG    step   349 loss = 17.2031
DEBUG    step   350 loss = 17.7855
DEBUG    step   351 loss = 17.3887
DEBUG    step   352 loss = 18.1904
DEBUG    step   353 loss = 16.9673
DEBUG    step   354 loss = 17.6665
DEBUG    step   355 loss = 17.9181
DEBUG    step   356 loss = 17.3892
DEBUG    step   357 loss = 18.6147
DEBUG    step   358 loss = 17.0139
DEBUG    step   359 loss = 17.4958
DEBUG    step   360 loss = 16.8143
DEBUG    step   361 loss = 16.8076
DEBUG    step   362 loss = 17.2509
DEBUG    step   363 loss = 16.6091
DEBUG    step   364 loss = 16.5105
DEBUG    step   365 loss = 16.8734
DEBUG    step   366 loss = 16.7367
DEBUG    step   367 loss = 16.3754
DEBUG    step   368 loss = 16.7072
DEBUG    step   369 loss = 16.6687
DEBUG    step   370 loss = 16.4918
DEBUG    step   371 loss = 17.4622
DEBUG    step   372 loss = 16.5902
DEBUG    step   373 loss = 17.0211
DEBUG    step   374 loss = 16.1971
DEBUG    step   375 loss = 17.1127
DEBUG    step   376 loss = 17.0151
DEBUG    step   377 loss = 16.5271
DEBUG    step   378 loss = 15.7553
DEBUG    step   379 loss = 17.5206
DEBUG    step   380 loss = 16.1141
DEBUG    step   381 loss = 16.0002
DEBUG    step   382 loss = 16.7775
DEBUG    step   383 loss = 16.0455
DEBUG    step   384 loss = 16.4851
DEBUG    step   385 loss = 15.9572
DEBUG    step   386 loss = 16.045
DEBUG    step   387 loss = 16.3194
DEBUG    step   388 loss = 16.827
DEBUG    step   389 loss = 16.818
DEBUG    step   390 loss = 16.5154
DEBUG    step   391 loss = 16.4575
DEBUG    step   392 loss = 16.3866
DEBUG    step   393 loss = 16.7649
DEBUG    step   394 loss = 16.3661
DEBUG    step   395 loss = 16.0388
DEBUG    step   396 loss = 16.3603
DEBUG    step   397 loss = 15.9295
DEBUG    step   398 loss = 16.2829
DEBUG    step   399 loss = 15.7255
DEBUG    step   400 loss = 15.9625
DEBUG    step   401 loss = 16.2877
DEBUG    step   402 loss = 15.9384
DEBUG    step   403 loss = 15.7691
DEBUG    step   404 loss = 15.3813
DEBUG    step   405 loss = 16.3497
DEBUG    step   406 loss = 15.6471
DEBUG    step   407 loss = 15.7245
DEBUG    step   408 loss = 15.5237
DEBUG    step   409 loss = 15.4977
DEBUG    step   410 loss = 15.7544
DEBUG    step   411 loss = 16.4454
DEBUG    step   412 loss = 15.8385
DEBUG    step   413 loss = 15.8783
DEBUG    step   414 loss = 14.5518
DEBUG    step   415 loss = 15.248
DEBUG    step   416 loss = 15.4766
DEBUG    step   417 loss = 15.1702
DEBUG    step   418 loss = 15.0027
DEBUG    step   419 loss = 14.7798
DEBUG    step   420 loss = 14.2242
DEBUG    step   421 loss = 14.7344
DEBUG    step   422 loss = 15.3192
DEBUG    step   423 loss = 14.5862
DEBUG    step   424 loss = 14.8549
DEBUG    step   425 loss = 15.1208
DEBUG    step   426 loss = 15.6343
DEBUG    step   427 loss = 14.9648
DEBUG    step   428 loss = 15.8638
DEBUG    step   429 loss = 14.7795
DEBUG    step   430 loss = 15.1229
DEBUG    step   431 loss = 14.9709
DEBUG    step   432 loss = 15.3807
DEBUG    step   433 loss = 14.2497
DEBUG    step   434 loss = 15.0741
DEBUG    step   435 loss = 13.8058
DEBUG    step   436 loss = 15.0915
DEBUG    step   437 loss = 15.2831
DEBUG    step   438 loss = 15.0772
DEBUG    step   439 loss = 15.8433
DEBUG    step   440 loss = 15.3281
DEBUG    step   441 loss = 14.7288
DEBUG    step   442 loss = 15.1505
DEBUG    step   443 loss = 15.3472
DEBUG    step   444 loss = 13.545
DEBUG    step   445 loss = 14.6441
DEBUG    step   446 loss = 14.0351
DEBUG    step   447 loss = 14.0212
DEBUG    step   448 loss = 14.1237
DEBUG    step   449 loss = 14.4073
DEBUG    step   450 loss = 14.4118
DEBUG    step   451 loss = 13.9406
DEBUG    step   452 loss = 15.0758
DEBUG    step   453 loss = 14.9103
DEBUG    step   454 loss = 14.3315
DEBUG    step   455 loss = 13.8796
DEBUG    step   456 loss = 13.9354
DEBUG    step   457 loss = 13.8283
DEBUG    step   458 loss = 14.8273
DEBUG    step   459 loss = 14.4759
DEBUG    step   460 loss = 14.5714
DEBUG    step   461 loss = 14.0121
DEBUG    step   462 loss = 14.393
DEBUG    step   463 loss = 14.4324
DEBUG    step   464 loss = 14.0807
DEBUG    step   465 loss = 14.3522
DEBUG    step   466 loss = 14.4154
DEBUG    step   467 loss = 13.1898
DEBUG    step   468 loss = 14.06
DEBUG    step   469 loss = 20.7401
DEBUG    step   470 loss = 14.2803
DEBUG    step   471 loss = 14.287
DEBUG    step   472 loss = 14.0215
DEBUG    step   473 loss = 13.4496
DEBUG    step   474 loss = 14.033
DEBUG    step   475 loss = 14.4732
DEBUG    step   476 loss = 13.7291
DEBUG    step   477 loss = 13.0513
DEBUG    step   478 loss = 13.6051
DEBUG    step   479 loss = 13.5316
DEBUG    step   480 loss = 13.5474
DEBUG    step   481 loss = 13.7794
DEBUG    step   482 loss = 13.8363
DEBUG    step   483 loss = 13.2939
DEBUG    step   484 loss = 13.3987
DEBUG    step   485 loss = 13.4694
DEBUG    step   486 loss = 13.0736
DEBUG    step   487 loss = 12.9663
DEBUG    step   488 loss = 13.4017
DEBUG    step   489 loss = 13.1387
DEBUG    step   490 loss = 12.8554
DEBUG    step   491 loss = 13.7535
DEBUG    step   492 loss = 13.0516
DEBUG    step   493 loss = 12.9229
DEBUG    step   494 loss = 13.0794
DEBUG    step   495 loss = 12.6742
DEBUG    step   496 loss = 12.5159
DEBUG    step   497 loss = 13.8863
DEBUG    step   498 loss = 13.275
DEBUG    step   499 loss = 13.8195
DEBUG    step   500 loss = 14.2111
DEBUG    step   501 loss = 12.8113
DEBUG    step   502 loss = 13.5611
DEBUG    step   503 loss = 13.1597
DEBUG    step   504 loss = 12.7698
DEBUG    step   505 loss = 12.655
DEBUG    step   506 loss = 13.3424
DEBUG    step   507 loss = 13.0807
DEBUG    step   508 loss = 13.4257
DEBUG    step   509 loss = 12.769
DEBUG    step   510 loss = 13.2426
DEBUG    step   511 loss = 13.7624
DEBUG    step   512 loss = 13.4707
DEBUG    step   513 loss = 12.6719
DEBUG    step   514 loss = 12.7837
DEBUG    step   515 loss = 12.3574
DEBUG    step   516 loss = 12.4319
DEBUG    step   517 loss = 12.2339
DEBUG    step   518 loss = 12.5959
DEBUG    step   519 loss = 12.9824
DEBUG    step   520 loss = 12.7877
DEBUG    step   521 loss = 13.0799
DEBUG    step   522 loss = 12.6134
DEBUG    step   523 loss = 12.0151
DEBUG    step   524 loss = 13.6236
DEBUG    step   525 loss = 13.0926
DEBUG    step   526 loss = 12.7921
DEBUG    step   527 loss = 12.3066
DEBUG    step   528 loss = 12.657
DEBUG    step   529 loss = 12.1989
DEBUG    step   530 loss = 12.6969
DEBUG    step   531 loss = 12.205
DEBUG    step   532 loss = 12.7905
DEBUG    step   533 loss = 12.6645
DEBUG    step   534 loss = 11.9637
DEBUG    step   535 loss = 12.3953
DEBUG    step   536 loss = 12.326
DEBUG    step   537 loss = 12.3011
DEBUG    step   538 loss = 12.3628
DEBUG    step   539 loss = 13.1567
DEBUG    step   540 loss = 12.5927
DEBUG    step   541 loss = 12.5462
DEBUG    step   542 loss = 12.2117
DEBUG    step   543 loss = 11.9447
DEBUG    step   544 loss = 12.5186
DEBUG    step   545 loss = 11.6064
DEBUG    step   546 loss = 12.1038
DEBUG    step   547 loss = 12.4013
DEBUG    step   548 loss = 12.1646
DEBUG    step   549 loss = 11.6217
DEBUG    step   550 loss = 11.7608
DEBUG    step   551 loss = 12.044
DEBUG    step   552 loss = 11.5987
DEBUG    step   553 loss = 12.2336
DEBUG    step   554 loss = 11.6134
DEBUG    step   555 loss = 12.212
DEBUG    step   556 loss = 11.7942
DEBUG    step   557 loss = 11.8134
DEBUG    step   558 loss = 11.8879
DEBUG    step   559 loss = 11.5601
DEBUG    step   560 loss = 11.8819
DEBUG    step   561 loss = 11.2771
DEBUG    step   562 loss = 12.6852
DEBUG    step   563 loss = 11.8853
DEBUG    step   564 loss = 11.8232
DEBUG    step   565 loss = 12.2208
DEBUG    step   566 loss = 11.8434
DEBUG    step   567 loss = 10.8617
DEBUG    step   568 loss = 11.9089
DEBUG    step   569 loss = 12.8768
DEBUG    step   570 loss = 11.7326
DEBUG    step   571 loss = 11.6924
DEBUG    step   572 loss = 12.071
DEBUG    step   573 loss = 11.4507
DEBUG    step   574 loss = 11.9765
DEBUG    step   575 loss = 12.3481
DEBUG    step   576 loss = 10.7076
DEBUG    step   577 loss = 11.2173
DEBUG    step   578 loss = 11.6225
DEBUG    step   579 loss = 11.7975
DEBUG    step   580 loss = 11.4295
DEBUG    step   581 loss = 11.7824
DEBUG    step   582 loss = 12.1286
DEBUG    step   583 loss = 10.932
DEBUG    step   584 loss = 11.9352
DEBUG    step   585 loss = 11.4005
DEBUG    step   586 loss = 11.1264
DEBUG    step   587 loss = 10.3828
DEBUG    step   588 loss = 10.6477
DEBUG    step   589 loss = 11.2266
DEBUG    step   590 loss = 11.7988
DEBUG    step   591 loss = 11.1602
DEBUG    step   592 loss = 11.2809
DEBUG    step   593 loss = 11.0131
DEBUG    step   594 loss = 11.3859
DEBUG    step   595 loss = 11.1015
DEBUG    step   596 loss = 11.4198
DEBUG    step   597 loss = 10.501
DEBUG    step   598 loss = 11.206
DEBUG    step   599 loss = 11.2975
DEBUG    step   600 loss = 10.0333
DEBUG    step   601 loss = 9.98137
DEBUG    step   602 loss = 12.6949
DEBUG    step   603 loss = 11.1914
DEBUG    step   604 loss = 10.2179
DEBUG    step   605 loss = 10.8835
DEBUG    step   606 loss = 10.3426
DEBUG    step   607 loss = 10.9994
DEBUG    step   608 loss = 10.4913
DEBUG    step   609 loss = 10.5934
DEBUG    step   610 loss = 11.2756
DEBUG    step   611 loss = 10.6515
DEBUG    step   612 loss = 10.634
DEBUG    step   613 loss = 10.6894
DEBUG    step   614 loss = 10.4173
DEBUG    step   615 loss = 10.3444
DEBUG    step   616 loss = 16.9274
DEBUG    step   617 loss = 10.6686
DEBUG    step   618 loss = 10.6302
DEBUG    step   619 loss = 11.4147
DEBUG    step   620 loss = 10.4305
DEBUG    step   621 loss = 9.93963
DEBUG    step   622 loss = 10.2567
DEBUG    step   623 loss = 10.4703
DEBUG    step   624 loss = 10.5793
DEBUG    step   625 loss = 10.7117
DEBUG    step   626 loss = 10.6469
DEBUG    step   627 loss = 10.6067
DEBUG    step   628 loss = 10.2047
DEBUG    step   629 loss = 10.7753
DEBUG    step   630 loss = 9.84085
DEBUG    step   631 loss = 9.8512
DEBUG    step   632 loss = 9.90551
DEBUG    step   633 loss = 10.2306
DEBUG    step   634 loss = 10.4
DEBUG    step   635 loss = 9.96456
DEBUG    step   636 loss = 10.0543
DEBUG    step   637 loss = 10.4722
DEBUG    step   638 loss = 10.2624
DEBUG    step   639 loss = 9.8927
DEBUG    step   640 loss = 9.74269
DEBUG    step   641 loss = 10.0714
DEBUG    step   642 loss = 9.4886
DEBUG    step   643 loss = 11.2356
DEBUG    step   644 loss = 10.4613
DEBUG    step   645 loss = 9.92244
DEBUG    step   646 loss = 10.5003
DEBUG    step   647 loss = 9.28321
DEBUG    step   648 loss = 10.0217
DEBUG    step   649 loss = 9.95832
DEBUG    step   650 loss = 9.89816
DEBUG    step   651 loss = 9.97542
DEBUG    step   652 loss = 9.11257
DEBUG    step   653 loss = 9.9837
DEBUG    step   654 loss = 10.1827
DEBUG    step   655 loss = 10.101
DEBUG    step   656 loss = 9.23931
DEBUG    step   657 loss = 8.75782
DEBUG    step   658 loss = 9.40421
DEBUG    step   659 loss = 9.13174
DEBUG    step   660 loss = 9.68286
DEBUG    step   661 loss = 10.4162
DEBUG    step   662 loss = 8.75674
DEBUG    step   663 loss = 10.001
DEBUG    step   664 loss = 9.40904
DEBUG    step   665 loss = 10.1505
DEBUG    step   666 loss = 10.1748
DEBUG    step   667 loss = 10.2148
DEBUG    step   668 loss = 10.2481
DEBUG    step   669 loss = 9.96609
DEBUG    step   670 loss = 9.65714
DEBUG    step   671 loss = 9.60848
DEBUG    step   672 loss = 9.84922
DEBUG    step   673 loss = 10.0371
DEBUG    step   674 loss = 9.28353
DEBUG    step   675 loss = 9.06586
DEBUG    step   676 loss = 9.44504
DEBUG    step   677 loss = 9.66529
DEBUG    step   678 loss = 9.7542
DEBUG    step   679 loss = 9.10189
DEBUG    step   680 loss = 9.36793
DEBUG    step   681 loss = 9.47525
DEBUG    step   682 loss = 9.98902
DEBUG    step   683 loss = 9.58746
DEBUG    step   684 loss = 8.77309
DEBUG    step   685 loss = 9.58264
DEBUG    step   686 loss = 9.774
DEBUG    step   687 loss = 10.1397
DEBUG    step   688 loss = 10.2031
DEBUG    step   689 loss = 8.85642
DEBUG    step   690 loss = 8.65729
DEBUG    step   691 loss = 9.30864
DEBUG    step   692 loss = 9.08819
DEBUG    step   693 loss = 8.79863
DEBUG    step   694 loss = 9.54987
DEBUG    step   695 loss = 8.96493
DEBUG    step   696 loss = 8.57488
DEBUG    step   697 loss = 9.37986
DEBUG    step   698 loss = 9.12005
DEBUG    step   699 loss = 9.55977
DEBUG    step   700 loss = 9.71548
DEBUG    step   701 loss = 8.66767
DEBUG    step   702 loss = 9.24891
DEBUG    step   703 loss = 8.96681
DEBUG    step   704 loss = 8.50462
DEBUG    step   705 loss = 8.97093
DEBUG    step   706 loss = 8.42754
DEBUG    step   707 loss = 8.31459
DEBUG    step   708 loss = 8.92468
DEBUG    step   709 loss = 8.62381
DEBUG    step   710 loss = 8.99014
DEBUG    step   711 loss = 9.12061
DEBUG    step   712 loss = 9.1673
DEBUG    step   713 loss = 8.71748
DEBUG    step   714 loss = 9.10944
DEBUG    step   715 loss = 8.2948
DEBUG    step   716 loss = 9.03157
DEBUG    step   717 loss = 8.86918
DEBUG    step   718 loss = 8.4948
DEBUG    step   719 loss = 8.20143
DEBUG    step   720 loss = 9.02752
DEBUG    step   721 loss = 9.07482
DEBUG    step   722 loss = 8.47549
DEBUG    step   723 loss = 8.6139
DEBUG    step   724 loss = 8.71389
DEBUG    step   725 loss = 8.71019
DEBUG    step   726 loss = 9.34067
DEBUG    step   727 loss = 8.33531
DEBUG    step   728 loss = 8.50657
DEBUG    step   729 loss = 7.92335
DEBUG    step   730 loss = 8.73418
DEBUG    step   731 loss = 7.50367
DEBUG    step   732 loss = 8.30074
DEBUG    step   733 loss = 8.10457
DEBUG    step   734 loss = 8.57933
DEBUG    step   735 loss = 8.29648
DEBUG    step   736 loss = 9.08495
DEBUG    step   737 loss = 9.19558
DEBUG    step   738 loss = 7.57463
DEBUG    step   739 loss = 8.25734
DEBUG    step   740 loss = 8.1562
DEBUG    step   741 loss = 8.13552
DEBUG    step   742 loss = 8.61787
DEBUG    step   743 loss = 7.84507
DEBUG    step   744 loss = 8.50339
DEBUG    step   745 loss = 9.99432
DEBUG    step   746 loss = 8.67392
DEBUG    step   747 loss = 7.62062
DEBUG    step   748 loss = 8.47083
DEBUG    step   749 loss = 7.59856
DEBUG    step   750 loss = 8.73944
DEBUG    step   751 loss = 7.82123
DEBUG    step   752 loss = 8.3673
DEBUG    step   753 loss = 8.05969
DEBUG    step   754 loss = 7.67401
DEBUG    step   755 loss = 8.23807
DEBUG    step   756 loss = 7.85361
DEBUG    step   757 loss = 8.29006
DEBUG    step   758 loss = 7.93663
DEBUG    step   759 loss = 7.14638
DEBUG    step   760 loss = 7.75548
DEBUG    step   761 loss = 7.23605
DEBUG    step   762 loss = 8.39854
DEBUG    step   763 loss = 8.36651
DEBUG    step   764 loss = 8.08217
DEBUG    step   765 loss = 8.51663
DEBUG    step   766 loss = 17.1032
DEBUG    step   767 loss = 8.11124
DEBUG    step   768 loss = 8.07747
DEBUG    step   769 loss = 7.82815
DEBUG    step   770 loss = 9.03203
DEBUG    step   771 loss = 8.53237
DEBUG    step   772 loss = 7.96279
DEBUG    step   773 loss = 8.05574
DEBUG    step   774 loss = 7.76004
DEBUG    step   775 loss = 7.35636
DEBUG    step   776 loss = 8.11715
DEBUG    step   777 loss = 8.26839
DEBUG    step   778 loss = 8.3788
DEBUG    step   779 loss = 8.4216
DEBUG    step   780 loss = 8.70143
DEBUG    step   781 loss = 7.68424
DEBUG    step   782 loss = 7.71564
DEBUG    step   783 loss = 8.99345
DEBUG    step   784 loss = 7.84072
DEBUG    step   785 loss = 7.97106
DEBUG    step   786 loss = 8.17313
DEBUG    step   787 loss = 8.43836
DEBUG    step   788 loss = 8.48604
DEBUG    step   789 loss = 7.89398
DEBUG    step   790 loss = 7.66896
DEBUG    step   791 loss = 7.93176
DEBUG    step   792 loss = 7.50743
DEBUG    step   793 loss = 7.35892
DEBUG    step   794 loss = 8.19966
DEBUG    step   795 loss = 8.04621
DEBUG    step   796 loss = 7.20783
DEBUG    step   797 loss = 7.82553
DEBUG    step   798 loss = 7.99542
DEBUG    step   799 loss = 7.39769
DEBUG    step   800 loss = 7.53701
DEBUG    step   801 loss = 7.24536
DEBUG    step   802 loss = 7.33658
DEBUG    step   803 loss = 7.342
DEBUG    step   804 loss = 7.75321
DEBUG    step   805 loss = 6.91357
DEBUG    step   806 loss = 7.52435
DEBUG    step   807 loss = 7.56103
DEBUG    step   808 loss = 7.79389
DEBUG    step   809 loss = 8.33436
DEBUG    step   810 loss = 7.46276
DEBUG    step   811 loss = 7.03648
DEBUG    step   812 loss = 7.09304
DEBUG    step   813 loss = 7.55697
DEBUG    step   814 loss = 7.74993
DEBUG    step   815 loss = 7.77072
DEBUG    step   816 loss = 7.57071
DEBUG    step   817 loss = 7.87914
DEBUG    step   818 loss = 7.59507
DEBUG    step   819 loss = 7.95819
DEBUG    step   820 loss = 7.26536
DEBUG    step   821 loss = 7.76702
DEBUG    step   822 loss = 6.81672
DEBUG    step   823 loss = 7.69591
DEBUG    step   824 loss = 7.49277
DEBUG    step   825 loss = 7.71589
DEBUG    step   826 loss = 7.54939
DEBUG    step   827 loss = 7.14454
DEBUG    step   828 loss = 6.54073
DEBUG    step   829 loss = 7.31939
DEBUG    step   830 loss = 8.24107
DEBUG    step   831 loss = 7.75897
DEBUG    step   832 loss = 7.0123
DEBUG    step   833 loss = 6.6658
DEBUG    step   834 loss = 7.17121
DEBUG    step   835 loss = 7.8772
DEBUG    step   836 loss = 6.91091
DEBUG    step   837 loss = 7.24767
DEBUG    step   838 loss = 7.3708
DEBUG    step   839 loss = 6.72671
DEBUG    step   840 loss = 6.91319
DEBUG    step   841 loss = 7.38147
DEBUG    step   842 loss = 6.73919
DEBUG    step   843 loss = 7.1541
DEBUG    step   844 loss = 7.09714
DEBUG    step   845 loss = 7.6505
DEBUG    step   846 loss = 6.37122
DEBUG    step   847 loss = 7.15714
DEBUG    step   848 loss = 6.78871
DEBUG    step   849 loss = 6.43234
DEBUG    step   850 loss = 6.64114
DEBUG    step   851 loss = 6.98987
DEBUG    step   852 loss = 7.51277
DEBUG    step   853 loss = 7.34095
DEBUG    step   854 loss = 7.5216
DEBUG    step   855 loss = 6.37953
DEBUG    step   856 loss = 7.08232
DEBUG    step   857 loss = 6.96187
DEBUG    step   858 loss = 6.12791
DEBUG    step   859 loss = 6.71254
DEBUG    step   860 loss = 6.15329
DEBUG    step   861 loss = 6.74574
DEBUG    step   862 loss = 7.24058
DEBUG    step   863 loss = 6.16476
DEBUG    step   864 loss = 7.61778
DEBUG    step   865 loss = 6.35608
DEBUG    step   866 loss = 6.53307
DEBUG    step   867 loss = 6.36949
DEBUG    step   868 loss = 6.71838
DEBUG    step   869 loss = 7.3967
DEBUG    step   870 loss = 6.65597
DEBUG    step   871 loss = 6.77125
DEBUG    step   872 loss = 6.67395
DEBUG    step   873 loss = 6.40736
DEBUG    step   874 loss = 6.35543
DEBUG    step   875 loss = 6.74703
DEBUG    step   876 loss = 6.58434
DEBUG    step   877 loss = 6.62172
DEBUG    step   878 loss = 6.65244
DEBUG    step   879 loss = 6.97937
DEBUG    step   880 loss = 6.42221
DEBUG    step   881 loss = 6.84026
DEBUG    step   882 loss = 6.72631
DEBUG    step   883 loss = 6.90398
DEBUG    step   884 loss = 6.6266
DEBUG    step   885 loss = 6.51678
DEBUG    step   886 loss = 6.65169
DEBUG    step   887 loss = 6.63095
DEBUG    step   888 loss = 6.24306
DEBUG    step   889 loss = 7.46224
DEBUG    step   890 loss = 6.84275
DEBUG    step   891 loss = 6.19764
DEBUG    step   892 loss = 7.16809
DEBUG    step   893 loss = 6.57301
DEBUG    step   894 loss = 6.72905
DEBUG    step   895 loss = 7.3967
DEBUG    step   896 loss = 6.78504
DEBUG    step   897 loss = 6.52102
DEBUG    step   898 loss = 6.07938
DEBUG    step   899 loss = 5.95618
DEBUG    step   900 loss = 6.14126
DEBUG    step   901 loss = 5.67246
DEBUG    step   902 loss = 5.59678
DEBUG    step   903 loss = 6.5394
DEBUG    step   904 loss = 6.4651
DEBUG    step   905 loss = 6.64771
DEBUG    step   906 loss = 6.44477
DEBUG    step   907 loss = 5.17112
DEBUG    step   908 loss = 5.80493
DEBUG    step   909 loss = 6.36914
DEBUG    step   910 loss = 6.68615
DEBUG    step   911 loss = 5.53628
DEBUG    step   912 loss = 6.51742
DEBUG    step   913 loss = 6.95286
DEBUG    step   914 loss = 7.2883
DEBUG    step   915 loss = 6.09494
DEBUG    step   916 loss = 6.74383
DEBUG    step   917 loss = 6.3917
DEBUG    step   918 loss = 6.25799
DEBUG    step   919 loss = 6.55483
DEBUG    step   920 loss = 6.44743
DEBUG    step   921 loss = 5.77905
DEBUG    step   922 loss = 5.98885
DEBUG    step   923 loss = 5.83527
DEBUG    step   924 loss = 5.93447
DEBUG    step   925 loss = 5.9199
DEBUG    step   926 loss = 6.01515
DEBUG    step   927 loss = 6.14634
DEBUG    step   928 loss = 5.77208
DEBUG    step   929 loss = 6.78369
DEBUG    step   930 loss = 6.21236
DEBUG    step   931 loss = 5.98394
DEBUG    step   932 loss = 6.51115
DEBUG    step   933 loss = 6.44652
DEBUG    step   934 loss = 5.83554
DEBUG    step   935 loss = 6.30905
DEBUG    step   936 loss = 5.93238
DEBUG    step   937 loss = 6.50758
DEBUG    step   938 loss = 5.93256
DEBUG    step   939 loss = 6.06647
DEBUG    step   940 loss = 6.03391
DEBUG    step   941 loss = 5.51953
DEBUG    step   942 loss = 6.03728
DEBUG    step   943 loss = 6.18949
DEBUG    step   944 loss = 6.10855
DEBUG    step   945 loss = 5.92263
DEBUG    step   946 loss = 6.72183
DEBUG    step   947 loss = 6.11911
DEBUG    step   948 loss = 5.84314
DEBUG    step   949 loss = 6.02928
DEBUG    step   950 loss = 5.82459
DEBUG    step   951 loss = 5.98588
DEBUG    step   952 loss = 5.75092
DEBUG    step   953 loss = 6.19303
DEBUG    step   954 loss = 5.78729
DEBUG    step   955 loss = 5.9059
DEBUG    step   956 loss = 5.31694
DEBUG    step   957 loss = 5.71936
DEBUG    step   958 loss = 6.06149
DEBUG    step   959 loss = 4.93583
DEBUG    step   960 loss = 5.8746
DEBUG    step   961 loss = 5.81154
DEBUG    step   962 loss = 6.22302
DEBUG    step   963 loss = 4.62915
DEBUG    step   964 loss = 6.26837
DEBUG    step   965 loss = 6.9227
DEBUG    step   966 loss = 5.69589
DEBUG    step   967 loss = 4.89925
DEBUG    step   968 loss = 5.95339
DEBUG    step   969 loss = 5.41167
DEBUG    step   970 loss = 5.61495
DEBUG    step   971 loss = 6.08719
DEBUG    step   972 loss = 5.70671
DEBUG    step   973 loss = 6.29176
DEBUG    step   974 loss = 5.96967
DEBUG    step   975 loss = 5.64207
DEBUG    step   976 loss = 6.11389
DEBUG    step   977 loss = 5.4677
DEBUG    step   978 loss = 5.26326
DEBUG    step   979 loss = 5.63665
DEBUG    step   980 loss = 5.47218
DEBUG    step   981 loss = 5.76207
DEBUG    step   982 loss = 5.25431
DEBUG    step   983 loss = 5.11318
DEBUG    step   984 loss = 5.23281
DEBUG    step   985 loss = 4.9322
DEBUG    step   986 loss = 5.19766
DEBUG    step   987 loss = 5.32089
DEBUG    step   988 loss = 5.56581
DEBUG    step   989 loss = 5.68178
DEBUG    step   990 loss = 4.37302
DEBUG    step   991 loss = 5.50948
DEBUG    step   992 loss = 5.3806
DEBUG    step   993 loss = 6.08309
DEBUG    step   994 loss = 5.74113
DEBUG    step   995 loss = 5.29156
DEBUG    step   996 loss = 6.09862
DEBUG    step   997 loss = 4.34491
DEBUG    step   998 loss = 4.74828
DEBUG    step   999 loss = 5.1352
DEBUG    step  1000 loss = 5.90098
DEBUG    step  1001 loss = 5.65187
DEBUG    step  1002 loss = 4.99241
DEBUG    step  1003 loss = 4.93651
DEBUG    step  1004 loss = 5.71697
DEBUG    step  1005 loss = 5.12284
DEBUG    step  1006 loss = 6.20878
DEBUG    step  1007 loss = 5.12986
DEBUG    step  1008 loss = 4.9672
DEBUG    step  1009 loss = 5.65217
DEBUG    step  1010 loss = 5.48825
DEBUG    step  1011 loss = 5.54487
DEBUG    step  1012 loss = 5.84657
DEBUG    step  1013 loss = 5.74514
DEBUG    step  1014 loss = 5.23785
DEBUG    step  1015 loss = 4.71362
DEBUG    step  1016 loss = 4.36813
DEBUG    step  1017 loss = 5.45256
DEBUG    step  1018 loss = 5.15537
DEBUG    step  1019 loss = 5.42831
DEBUG    step  1020 loss = 5.17
DEBUG    step  1021 loss = 4.94556
DEBUG    step  1022 loss = 5.84439
DEBUG    step  1023 loss = 5.11129
DEBUG    step  1024 loss = 4.68024
DEBUG    step  1025 loss = 4.6169
DEBUG    step  1026 loss = 4.95606
DEBUG    step  1027 loss = 4.74444
DEBUG    step  1028 loss = 4.27131
DEBUG    step  1029 loss = 4.88013
DEBUG    step  1030 loss = 4.77623
DEBUG    step  1031 loss = 5.86898
DEBUG    step  1032 loss = 5.16058
DEBUG    step  1033 loss = 4.97931
DEBUG    step  1034 loss = 5.05067
DEBUG    step  1035 loss = 5.13984
DEBUG    step  1036 loss = 5.39295
DEBUG    step  1037 loss = 4.95942
DEBUG    step  1038 loss = 5.33035
DEBUG    step  1039 loss = 4.99434
DEBUG    step  1040 loss = 4.98677
DEBUG    step  1041 loss = 4.65488
DEBUG    step  1042 loss = 4.61823
DEBUG    step  1043 loss = 4.68538
DEBUG    step  1044 loss = 4.55243
DEBUG    step  1045 loss = 4.72619
DEBUG    step  1046 loss = 4.88855
DEBUG    step  1047 loss = 4.91348
DEBUG    step  1048 loss = 4.14682
DEBUG    step  1049 loss = 5.40462
DEBUG    step  1050 loss = 4.9091
DEBUG    step  1051 loss = 4.81781
DEBUG    step  1052 loss = 4.87586
DEBUG    step  1053 loss = 5.02846
DEBUG    step  1054 loss = 5.07139
DEBUG    step  1055 loss = 4.59791
DEBUG    step  1056 loss = 4.63243
DEBUG    step  1057 loss = 5.06353
DEBUG    step  1058 loss = 3.85668
DEBUG    step  1059 loss = 5.28508
DEBUG    step  1060 loss = 5.2355
DEBUG    step  1061 loss = 4.07526
DEBUG    step  1062 loss = 4.13481
DEBUG    step  1063 loss = 5.15536
DEBUG    step  1064 loss = 4.30691
DEBUG    step  1065 loss = 4.27459
DEBUG    step  1066 loss = 4.41401
DEBUG    step  1067 loss = 4.55242
DEBUG    step  1068 loss = 5.11923
DEBUG    step  1069 loss = 4.62136
DEBUG    step  1070 loss = 4.88281
DEBUG    step  1071 loss = 6.58954
DEBUG    step  1072 loss = 4.35964
DEBUG    step  1073 loss = 4.70629
DEBUG    step  1074 loss = 4.33995
DEBUG    step  1075 loss = 4.68683
DEBUG    step  1076 loss = 4.2739
DEBUG    step  1077 loss = 3.67668
DEBUG    step  1078 loss = 4.68557
DEBUG    step  1079 loss = 4.38688
DEBUG    step  1080 loss = 4.37331
DEBUG    step  1081 loss = 4.81933
DEBUG    step  1082 loss = 4.4695
DEBUG    step  1083 loss = 4.97354
DEBUG    step  1084 loss = 4.51781
DEBUG    step  1085 loss = 4.12469
DEBUG    step  1086 loss = 6.42285
DEBUG    step  1087 loss = 5.01891
DEBUG    step  1088 loss = 4.62022
DEBUG    step  1089 loss = 4.87794
DEBUG    step  1090 loss = 4.91586
DEBUG    step  1091 loss = 4.10107
DEBUG    step  1092 loss = 4.64939
DEBUG    step  1093 loss = 5.02957
DEBUG    step  1094 loss = 4.41712
DEBUG    step  1095 loss = 4.42776
DEBUG    step  1096 loss = 4.28038
DEBUG    step  1097 loss = 4.93038
DEBUG    step  1098 loss = 4.39647
DEBUG    step  1099 loss = 4.14815
DEBUG    step  1100 loss = 4.47418
DEBUG    step  1101 loss = 4.53913
DEBUG    step  1102 loss = 4.18599
DEBUG    step  1103 loss = 4.42585
DEBUG    step  1104 loss = 4.52254
DEBUG    step  1105 loss = 3.73001
DEBUG    step  1106 loss = 3.80091
DEBUG    step  1107 loss = 4.65234
DEBUG    step  1108 loss = 4.22851
DEBUG    step  1109 loss = 3.80812
DEBUG    step  1110 loss = 4.85446
DEBUG    step  1111 loss = 3.86523
DEBUG    step  1112 loss = 4.18319
DEBUG    step  1113 loss = 4.21953
DEBUG    step  1114 loss = 5.04039
DEBUG    step  1115 loss = 4.80243
DEBUG    step  1116 loss = 4.30441
DEBUG    step  1117 loss = 5.39042
DEBUG    step  1118 loss = 4.25597
DEBUG    step  1119 loss = 5.07854
DEBUG    step  1120 loss = 4.12041
DEBUG    step  1121 loss = 3.47527
DEBUG    step  1122 loss = 4.13058
DEBUG    step  1123 loss = 3.55016
DEBUG    step  1124 loss = 4.84087
DEBUG    step  1125 loss = 4.22556
DEBUG    step  1126 loss = 4.61652
DEBUG    step  1127 loss = 4.38913
DEBUG    step  1128 loss = 4.1752
DEBUG    step  1129 loss = 4.35237
DEBUG    step  1130 loss = 4.11809
DEBUG    step  1131 loss = 4.52757
DEBUG    step  1132 loss = 3.64453
DEBUG    step  1133 loss = 3.92684
DEBUG    step  1134 loss = 4.419
DEBUG    step  1135 loss = 4.53101
DEBUG    step  1136 loss = 4.20247
DEBUG    step  1137 loss = 4.4274
DEBUG    step  1138 loss = 4.00318
DEBUG    step  1139 loss = 6.42864
DEBUG    step  1140 loss = 4.00687
DEBUG    step  1141 loss = 4.74919
DEBUG    step  1142 loss = 3.83376
DEBUG    step  1143 loss = 4.00634
DEBUG    step  1144 loss = 3.43185
DEBUG    step  1145 loss = 3.91977
DEBUG    step  1146 loss = 3.8136
DEBUG    step  1147 loss = 4.02812
DEBUG    step  1148 loss = 4.1181
DEBUG    step  1149 loss = 3.40067
DEBUG    step  1150 loss = 3.87853
DEBUG    step  1151 loss = 4.30686
DEBUG    step  1152 loss = 4.22774
DEBUG    step  1153 loss = 4.38618
DEBUG    step  1154 loss = 4.56262
DEBUG    step  1155 loss = 4.45982
DEBUG    step  1156 loss = 4.59891
DEBUG    step  1157 loss = 4.44961
DEBUG    step  1158 loss = 4.0087
DEBUG    step  1159 loss = 4.88411
DEBUG    step  1160 loss = 3.81384
DEBUG    step  1161 loss = 3.60741
DEBUG    step  1162 loss = 4.1445
DEBUG    step  1163 loss = 4.40349
DEBUG    step  1164 loss = 3.83159
DEBUG    step  1165 loss = 3.76538
DEBUG    step  1166 loss = 4.21465
DEBUG    step  1167 loss = 3.94987
DEBUG    step  1168 loss = 4.0818
DEBUG    step  1169 loss = 4.06183
DEBUG    step  1170 loss = 3.47987
DEBUG    step  1171 loss = 3.67692
DEBUG    step  1172 loss = 4.20745
DEBUG    step  1173 loss = 3.84148
DEBUG    step  1174 loss = 3.49437
DEBUG    step  1175 loss = 3.67877
DEBUG    step  1176 loss = 3.95581
DEBUG    step  1177 loss = 4.26368
DEBUG    step  1178 loss = 3.89446
DEBUG    step  1179 loss = 3.66383
DEBUG    step  1180 loss = 4.65264
DEBUG    step  1181 loss = 3.91674
DEBUG    step  1182 loss = 3.80197
DEBUG    step  1183 loss = 3.24795
DEBUG    step  1184 loss = 4.25066
DEBUG    step  1185 loss = 3.59737
DEBUG    step  1186 loss = 4.23543
DEBUG    step  1187 loss = 4.40551
DEBUG    step  1188 loss = 3.06393
DEBUG    step  1189 loss = 3.78871
DEBUG    step  1190 loss = 4.47356
DEBUG    step  1191 loss = 3.01607
DEBUG    step  1192 loss = 3.5921
DEBUG    step  1193 loss = 4.14678
DEBUG    step  1194 loss = 4.06156
DEBUG    step  1195 loss = 3.63912
DEBUG    step  1196 loss = 3.80904
DEBUG    step  1197 loss = 3.94498
DEBUG    step  1198 loss = 4.46766
DEBUG    step  1199 loss = 3.94135
DEBUG    step  1200 loss = 3.16809
DEBUG    step  1201 loss = 4.44084
DEBUG    step  1202 loss = 4.10566
DEBUG    step  1203 loss = 3.80488
DEBUG    step  1204 loss = 3.19777
DEBUG    step  1205 loss = 2.95526
DEBUG    step  1206 loss = 4.49641
DEBUG    step  1207 loss = 4.23787
DEBUG    step  1208 loss = 3.70975
DEBUG    step  1209 loss = 3.79127
DEBUG    step  1210 loss = 3.59221
DEBUG    step  1211 loss = 3.88194
DEBUG    step  1212 loss = 3.40576
DEBUG    step  1213 loss = 3.87329
DEBUG    step  1214 loss = 3.49796
DEBUG    step  1215 loss = 3.24266
DEBUG    step  1216 loss = 3.73337
DEBUG    step  1217 loss = 3.64298
DEBUG    step  1218 loss = 3.20159
DEBUG    step  1219 loss = 2.85318
DEBUG    step  1220 loss = 3.73986
DEBUG    step  1221 loss = 3.01543
DEBUG    step  1222 loss = 3.32277
DEBUG    step  1223 loss = 2.74171
DEBUG    step  1224 loss = 3.70805
DEBUG    step  1225 loss = 3.61112
DEBUG    step  1226 loss = 2.88479
DEBUG    step  1227 loss = 3.65801
DEBUG    step  1228 loss = 4.02943
DEBUG    step  1229 loss = 2.83562
DEBUG    step  1230 loss = 3.24228
DEBUG    step  1231 loss = 3.2782
DEBUG    step  1232 loss = 3.59486
DEBUG    step  1233 loss = 3.65803
DEBUG    step  1234 loss = 2.6809
DEBUG    step  1235 loss = 3.3619
DEBUG    step  1236 loss = 3.39297
DEBUG    step  1237 loss = 3.81023
DEBUG    step  1238 loss = 3.22556
DEBUG    step  1239 loss = 3.19648
DEBUG    step  1240 loss = 4.0888
DEBUG    step  1241 loss = 3.74848
DEBUG    step  1242 loss = 2.87371
DEBUG    step  1243 loss = 2.63874
DEBUG    step  1244 loss = 3.5867
DEBUG    step  1245 loss = 2.79683
DEBUG    step  1246 loss = 2.68036
DEBUG    step  1247 loss = 3.90314
DEBUG    step  1248 loss = 2.79271
DEBUG    step  1249 loss = 3.35704
DEBUG    step  1250 loss = 3.22364
DEBUG    step  1251 loss = 4.49007
DEBUG    step  1252 loss = 3.48859
DEBUG    step  1253 loss = 3.53123
DEBUG    step  1254 loss = 3.95726
DEBUG    step  1255 loss = 3.76191
DEBUG    step  1256 loss = 3.16396
DEBUG    step  1257 loss = 3.27892
DEBUG    step  1258 loss = 3.61666
DEBUG    step  1259 loss = 2.60104
DEBUG    step  1260 loss = 3.61282
DEBUG    step  1261 loss = 3.39698
DEBUG    step  1262 loss = 3.25254
DEBUG    step  1263 loss = 3.60338
DEBUG    step  1264 loss = 3.24701
DEBUG    step  1265 loss = 2.68532
DEBUG    step  1266 loss = 3.48767
DEBUG    step  1267 loss = 3.38295
DEBUG    step  1268 loss = 3.05102
DEBUG    step  1269 loss = 2.66065
DEBUG    step  1270 loss = 4.91023
DEBUG    step  1271 loss = 3.58709
DEBUG    step  1272 loss = 2.62444
DEBUG    step  1273 loss = 3.1492
DEBUG    step  1274 loss = 2.40123
DEBUG    step  1275 loss = 3.45261
DEBUG    step  1276 loss = 3.09002
DEBUG    step  1277 loss = 3.43325
DEBUG    step  1278 loss = 3.65285
DEBUG    step  1279 loss = 5.20928
DEBUG    step  1280 loss = 3.18166
DEBUG    step  1281 loss = 2.98796
DEBUG    step  1282 loss = 3.51501
DEBUG    step  1283 loss = 3.69819
DEBUG    step  1284 loss = 2.9171
DEBUG    step  1285 loss = 3.58279
DEBUG    step  1286 loss = 3.22799
DEBUG    step  1287 loss = 2.95054
DEBUG    step  1288 loss = 2.73463
DEBUG    step  1289 loss = 2.94937
DEBUG    step  1290 loss = 3.66875
DEBUG    step  1291 loss = 5.37338
DEBUG    step  1292 loss = 3.4862
DEBUG    step  1293 loss = 3.53109
DEBUG    step  1294 loss = 3.13318
DEBUG    step  1295 loss = 3.44508
DEBUG    step  1296 loss = 3.03238
DEBUG    step  1297 loss = 3.20079
DEBUG    step  1298 loss = 2.97329
DEBUG    step  1299 loss = 2.847
DEBUG    step  1300 loss = 2.9055
DEBUG    step  1301 loss = 2.11617
DEBUG    step  1302 loss = 3.67571
DEBUG    step  1303 loss = 3.05302
DEBUG    step  1304 loss = 2.67335
DEBUG    step  1305 loss = 3.19011
DEBUG    step  1306 loss = 2.28169
DEBUG    step  1307 loss = 3.15299
DEBUG    step  1308 loss = 2.48567
DEBUG    step  1309 loss = 3.02921
DEBUG    step  1310 loss = 2.74102
DEBUG    step  1311 loss = 2.92383
DEBUG    step  1312 loss = 3.50952
DEBUG    step  1313 loss = 3.4817
DEBUG    step  1314 loss = 2.90958
DEBUG    step  1315 loss = 3.17264
DEBUG    step  1316 loss = 3.00095
DEBUG    step  1317 loss = 3.28235
DEBUG    step  1318 loss = 3.1123
DEBUG    step  1319 loss = 3.19697
DEBUG    step  1320 loss = 3.23534
DEBUG    step  1321 loss = 2.62485
DEBUG    step  1322 loss = 2.39473
DEBUG    step  1323 loss = 2.65671
DEBUG    step  1324 loss = 2.6517
DEBUG    step  1325 loss = 2.83837
DEBUG    step  1326 loss = 2.96297
DEBUG    step  1327 loss = 3.27864
DEBUG    step  1328 loss = 2.8699
DEBUG    step  1329 loss = 2.41302
DEBUG    step  1330 loss = 2.75787
DEBUG    step  1331 loss = 2.02633
DEBUG    step  1332 loss = 2.64443
DEBUG    step  1333 loss = 3.00131
DEBUG    step  1334 loss = 2.90105
DEBUG    step  1335 loss = 2.53407
DEBUG    step  1336 loss = 2.69649
DEBUG    step  1337 loss = 3.10092
DEBUG    step  1338 loss = 2.40056
DEBUG    step  1339 loss = 2.89754
DEBUG    step  1340 loss = 3.58338
DEBUG    step  1341 loss = 2.91623
DEBUG    step  1342 loss = 3.01027
DEBUG    step  1343 loss = 2.88131
DEBUG    step  1344 loss = 2.61064
DEBUG    step  1345 loss = 3.21264
DEBUG    step  1346 loss = 3.68778
DEBUG    step  1347 loss = 3.20522
DEBUG    step  1348 loss = 3.02826
DEBUG    step  1349 loss = 2.26471
DEBUG    step  1350 loss = 1.86408
DEBUG    step  1351 loss = 2.38076
DEBUG    step  1352 loss = 3.04889
DEBUG    step  1353 loss = 2.88127
DEBUG    step  1354 loss = 2.29979
DEBUG    step  1355 loss = 2.32288
DEBUG    step  1356 loss = 2.58144
DEBUG    step  1357 loss = 3.13952
DEBUG    step  1358 loss = 2.64957
DEBUG    step  1359 loss = 2.66308
DEBUG    step  1360 loss = 2.4935
DEBUG    step  1361 loss = 2.44679
DEBUG    step  1362 loss = 2.35046
DEBUG    step  1363 loss = 2.68055
DEBUG    step  1364 loss = 2.70021
DEBUG    step  1365 loss = 2.92847
DEBUG    step  1366 loss = 2.65287
DEBUG    step  1367 loss = 3.36018
DEBUG    step  1368 loss = 3.14083
DEBUG    step  1369 loss = 3.2839
DEBUG    step  1370 loss = 2.87706
DEBUG    step  1371 loss = 2.28323
DEBUG    step  1372 loss = 2.71482
DEBUG    step  1373 loss = 3.14818
DEBUG    step  1374 loss = 1.91019
DEBUG    step  1375 loss = 3.26189
DEBUG    step  1376 loss = 2.32266
DEBUG    step  1377 loss = 2.58565
DEBUG    step  1378 loss = 2.78616
DEBUG    step  1379 loss = 2.61887
DEBUG    step  1380 loss = 1.77536
DEBUG    step  1381 loss = 2.46593
DEBUG    step  1382 loss = 2.03291
DEBUG    step  1383 loss = 2.25107
DEBUG    step  1384 loss = 2.02538
DEBUG    step  1385 loss = 2.64462
DEBUG    step  1386 loss = 2.52711
DEBUG    step  1387 loss = 2.82251
DEBUG    step  1388 loss = 1.84549
DEBUG    step  1389 loss = 2.80308
DEBUG    step  1390 loss = 2.50824
DEBUG    step  1391 loss = 2.32621
DEBUG    step  1392 loss = 2.47522
DEBUG    step  1393 loss = 2.25115
DEBUG    step  1394 loss = 2.13335
DEBUG    step  1395 loss = 2.34713
DEBUG    step  1396 loss = 2.70859
DEBUG    step  1397 loss = 2.40365
DEBUG    step  1398 loss = 1.77973
DEBUG    step  1399 loss = 2.20398
DEBUG    step  1400 loss = 2.03752
DEBUG    step  1401 loss = 2.92017
DEBUG    step  1402 loss = 2.30887
DEBUG    step  1403 loss = 2.55533
DEBUG    step  1404 loss = 3.27081
DEBUG    step  1405 loss = 2.00323
DEBUG    step  1406 loss = 2.58616
DEBUG    step  1407 loss = 2.32837
DEBUG    step  1408 loss = 2.62355
DEBUG    step  1409 loss = 2.55319
DEBUG    step  1410 loss = 2.91456
DEBUG    step  1411 loss = 2.51186
DEBUG    step  1412 loss = 2.58023
DEBUG    step  1413 loss = 2.11317
DEBUG    step  1414 loss = 2.72763
DEBUG    step  1415 loss = 2.46438
DEBUG    step  1416 loss = 2.66077
DEBUG    step  1417 loss = 3.45261
DEBUG    step  1418 loss = 1.30968
DEBUG    step  1419 loss = 2.02033
DEBUG    step  1420 loss = 1.66572
DEBUG    step  1421 loss = 2.63344
DEBUG    step  1422 loss = 2.79048
DEBUG    step  1423 loss = 2.36907
DEBUG    step  1424 loss = 2.09989
DEBUG    step  1425 loss = 1.90149
DEBUG    step  1426 loss = 1.62709
DEBUG    step  1427 loss = 1.95195
DEBUG    step  1428 loss = 1.51384
DEBUG    step  1429 loss = 2.89507
DEBUG    step  1430 loss = 2.15085
DEBUG    step  1431 loss = 3.11155
DEBUG    step  1432 loss = 2.44331
DEBUG    step  1433 loss = 2.20407
DEBUG    step  1434 loss = 2.08581
DEBUG    step  1435 loss = 2.42461
DEBUG    step  1436 loss = 1.99394
DEBUG    step  1437 loss = 2.04695
DEBUG    step  1438 loss = 2.82294
DEBUG    step  1439 loss = 2.33058
DEBUG    step  1440 loss = 2.10667
DEBUG    step  1441 loss = 2.3715
DEBUG    step  1442 loss = 2.13589
DEBUG    step  1443 loss = 2.0997
DEBUG    step  1444 loss = 2.40378
DEBUG    step  1445 loss = 2.69322
DEBUG    step  1446 loss = 2.3217
DEBUG    step  1447 loss = 3.06968
DEBUG    step  1448 loss = 2.19487
DEBUG    step  1449 loss = 2.62741
DEBUG    step  1450 loss = 1.93388
DEBUG    step  1451 loss = 2.23005
DEBUG    step  1452 loss = 2.05846
DEBUG    step  1453 loss = 2.37242
DEBUG    step  1454 loss = 1.70136
DEBUG    step  1455 loss = 2.47376
DEBUG    step  1456 loss = 2.62243
DEBUG    step  1457 loss = 2.22
DEBUG    step  1458 loss = 2.60625
DEBUG    step  1459 loss = 1.61209
DEBUG    step  1460 loss = 2.40373
DEBUG    step  1461 loss = 3.32855
DEBUG    step  1462 loss = 2.61678
DEBUG    step  1463 loss = 3.63504
DEBUG    step  1464 loss = 2.30637
DEBUG    step  1465 loss = 2.62554
DEBUG    step  1466 loss = 2.52577
DEBUG    step  1467 loss = 2.04929
DEBUG    step  1468 loss = 2.80166
DEBUG    step  1469 loss = 2.27281
DEBUG    step  1470 loss = 2.53645
DEBUG    step  1471 loss = 2.23338
DEBUG    step  1472 loss = 2.09672
DEBUG    step  1473 loss = 2.42459
DEBUG    step  1474 loss = 2.39755
DEBUG    step  1475 loss = 2.70626
DEBUG    step  1476 loss = 2.14803
DEBUG    step  1477 loss = 2.12395
DEBUG    step  1478 loss = 2.0754
DEBUG    step  1479 loss = 2.52702
DEBUG    step  1480 loss = 2.14769
DEBUG    step  1481 loss = 1.52042
DEBUG    step  1482 loss = 2.93158
DEBUG    step  1483 loss = 2.05924
DEBUG    step  1484 loss = 2.20132
DEBUG    step  1485 loss = 2.50342
DEBUG    step  1486 loss = 2.16502
DEBUG    step  1487 loss = 2.30084
DEBUG    step  1488 loss = 1.63317
DEBUG    step  1489 loss = 1.89554
DEBUG    step  1490 loss = 1.68024
DEBUG    step  1491 loss = 1.84459
DEBUG    step  1492 loss = 1.63598
DEBUG    step  1493 loss = 1.38678
DEBUG    step  1494 loss = 1.71994
DEBUG    step  1495 loss = 1.81303
DEBUG    step  1496 loss = 2.59038
DEBUG    step  1497 loss = 1.6169
DEBUG    step  1498 loss = 1.90588
DEBUG    step  1499 loss = 2.14643
DEBUG    step  1500 loss = 2.01967
DEBUG    step  1501 loss = 1.91788
DEBUG    step  1502 loss = 1.75204
DEBUG    step  1503 loss = 2.31053
DEBUG    step  1504 loss = 2.12471
DEBUG    step  1505 loss = 2.22645
DEBUG    step  1506 loss = 2.04981
DEBUG    step  1507 loss = 1.88154
DEBUG    step  1508 loss = 1.58932
DEBUG    step  1509 loss = 1.74206
DEBUG    step  1510 loss = 2.37344
DEBUG    step  1511 loss = 1.17495
DEBUG    step  1512 loss = 1.82669
DEBUG    step  1513 loss = 1.3465
DEBUG    step  1514 loss = 1.10967
DEBUG    step  1515 loss = 1.68837
DEBUG    step  1516 loss = 2.49356
DEBUG    step  1517 loss = 1.35455
DEBUG    step  1518 loss = 1.27578
DEBUG    step  1519 loss = 1.65972
DEBUG    step  1520 loss = 1.66863
DEBUG    step  1521 loss = 1.89212
DEBUG    step  1522 loss = 1.54516
DEBUG    step  1523 loss = 1.393
DEBUG    step  1524 loss = 1.88502
DEBUG    step  1525 loss = 2.90167
DEBUG    step  1526 loss = 1.52293
DEBUG    step  1527 loss = 1.99959
DEBUG    step  1528 loss = 1.23991
DEBUG    step  1529 loss = 2.5743
DEBUG    step  1530 loss = 1.36191
DEBUG    step  1531 loss = 1.72816
DEBUG    step  1532 loss = 1.58642
DEBUG    step  1533 loss = 1.48767
DEBUG    step  1534 loss = 1.89661
DEBUG    step  1535 loss = 2.36828
DEBUG    step  1536 loss = 1.07969
DEBUG    step  1537 loss = 1.76135
DEBUG    step  1538 loss = 1.71266
DEBUG    step  1539 loss = 1.89935
DEBUG    step  1540 loss = 1.46401
DEBUG    step  1541 loss = 0.630489
DEBUG    step  1542 loss = 1.97178
DEBUG    step  1543 loss = 1.54882
DEBUG    step  1544 loss = 1.59709
DEBUG    step  1545 loss = 1.05165
DEBUG    step  1546 loss = 1.80869
DEBUG    step  1547 loss = 2.13186
DEBUG    step  1548 loss = 2.48523
DEBUG    step  1549 loss = 1.36797
DEBUG    step  1550 loss = 2.11571
DEBUG    step  1551 loss = 1.90579
DEBUG    step  1552 loss = 1.53151
DEBUG    step  1553 loss = 1.99713
DEBUG    step  1554 loss = 2.22942
DEBUG    step  1555 loss = 2.03508
DEBUG    step  1556 loss = 1.91097
DEBUG    step  1557 loss = 1.64553
DEBUG    step  1558 loss = 2.31868
DEBUG    step  1559 loss = 1.88206
DEBUG    step  1560 loss = 1.84929
DEBUG    step  1561 loss = 1.74253
DEBUG    step  1562 loss = 1.55262
DEBUG    step  1563 loss = 1.24187
DEBUG    step  1564 loss = 2.21666
DEBUG    step  1565 loss = 1.54179
DEBUG    step  1566 loss = 1.18126
DEBUG    step  1567 loss = 1.60436
DEBUG    step  1568 loss = 1.62646
DEBUG    step  1569 loss = 1.13235
DEBUG    step  1570 loss = 1.73874
DEBUG    step  1571 loss = 2.98272
DEBUG    step  1572 loss = 1.97496
DEBUG    step  1573 loss = 1.40697
DEBUG    step  1574 loss = 1.75862
DEBUG    step  1575 loss = 2.24646
DEBUG    step  1576 loss = 1.71452
DEBUG    step  1577 loss = 2.13269
DEBUG    step  1578 loss = 1.87098
DEBUG    step  1579 loss = 0.903461
DEBUG    step  1580 loss = 1.25201
DEBUG    step  1581 loss = 1.8638
DEBUG    step  1582 loss = 1.8996
DEBUG    step  1583 loss = 1.43805
DEBUG    step  1584 loss = 1.15156
DEBUG    step  1585 loss = 1.41428
DEBUG    step  1586 loss = 1.13043
DEBUG    step  1587 loss = 0.838783
DEBUG    step  1588 loss = 0.782387
DEBUG    step  1589 loss = 1.6801
DEBUG    step  1590 loss = 2.16813
DEBUG    step  1591 loss = 2.3584
DEBUG    step  1592 loss = 2.03198
DEBUG    step  1593 loss = 1.6852
DEBUG    step  1594 loss = 1.6894
DEBUG    step  1595 loss = 2.05611
DEBUG    step  1596 loss = 2.04665
DEBUG    step  1597 loss = 1.44473
DEBUG    step  1598 loss = 2.35641
DEBUG    step  1599 loss = 1.77884
DEBUG    step  1600 loss = 1.29297
DEBUG    step  1601 loss = 1.44123
DEBUG    step  1602 loss = 1.03164
DEBUG    step  1603 loss = 1.97062
DEBUG    step  1604 loss = 1.84778
DEBUG    step  1605 loss = 1.97628
DEBUG    step  1606 loss = 1.80254
DEBUG    step  1607 loss = 1.53044
DEBUG    step  1608 loss = 1.69098
DEBUG    step  1609 loss = 1.92866
DEBUG    step  1610 loss = 1.70258
DEBUG    step  1611 loss = 1.76521
DEBUG    step  1612 loss = 1.52449
DEBUG    step  1613 loss = 1.15307
DEBUG    step  1614 loss = 1.88707
DEBUG    step  1615 loss = 1.61141
DEBUG    step  1616 loss = 1.23801
DEBUG    step  1617 loss = 1.51574
DEBUG    step  1618 loss = 1.26473
DEBUG    step  1619 loss = 1.24652
DEBUG    step  1620 loss = 1.06793
DEBUG    step  1621 loss = 1.89787
DEBUG    step  1622 loss = 1.49286
DEBUG    step  1623 loss = 0.830939
DEBUG    step  1624 loss = 1.66349
DEBUG    step  1625 loss = 1.17004
DEBUG    step  1626 loss = 1.24293
DEBUG    step  1627 loss = 1.90752
DEBUG    step  1628 loss = 2.46158
DEBUG    step  1629 loss = 1.45676
DEBUG    step  1630 loss = 1.70154
DEBUG    step  1631 loss = 1.18527
DEBUG    step  1632 loss = 1.32646
DEBUG    step  1633 loss = 1.34788
DEBUG    step  1634 loss = 1.57518
DEBUG    step  1635 loss = 1.92275
DEBUG    step  1636 loss = 1.85572
DEBUG    step  1637 loss = 1.18637
DEBUG    step  1638 loss = 0.775541
DEBUG    step  1639 loss = 1.3429
DEBUG    step  1640 loss = 1.74344
DEBUG    step  1641 loss = 1.40233
DEBUG    step  1642 loss = 1.9051
DEBUG    step  1643 loss = 1.16771
DEBUG    step  1644 loss = 1.1377
DEBUG    step  1645 loss = 1.73862
DEBUG    step  1646 loss = 0.958234
DEBUG    step  1647 loss = 1.11713
DEBUG    step  1648 loss = 0.944722
DEBUG    step  1649 loss = 3.08687
DEBUG    step  1650 loss = 1.27105
DEBUG    step  1651 loss = 0.857286
DEBUG    step  1652 loss = 1.52856
DEBUG    step  1653 loss = 1.96828
DEBUG    step  1654 loss = 0.92382
DEBUG    step  1655 loss = 2.05783
DEBUG    step  1656 loss = 1.16256
DEBUG    step  1657 loss = 1.42272
DEBUG    step  1658 loss = 1.07507
DEBUG    step  1659 loss = 1.64777
DEBUG    step  1660 loss = 0.919807
DEBUG    step  1661 loss = 0.726715
DEBUG    step  1662 loss = 1.57691
DEBUG    step  1663 loss = 1.38782
DEBUG    step  1664 loss = 1.26784
DEBUG    step  1665 loss = 1.64389
DEBUG    step  1666 loss = 0.984072
DEBUG    step  1667 loss = 1.65232
DEBUG    step  1668 loss = 1.8319
DEBUG    step  1669 loss = 1.46141
DEBUG    step  1670 loss = 0.989564
DEBUG    step  1671 loss = 1.60373
DEBUG    step  1672 loss = 1.79838
DEBUG    step  1673 loss = 1.0971
DEBUG    step  1674 loss = 1.6531
DEBUG    step  1675 loss = 0.569279
DEBUG    step  1676 loss = 1.1229
DEBUG    step  1677 loss = 2.09242
DEBUG    step  1678 loss = 1.25957
DEBUG    step  1679 loss = 1.20155
DEBUG    step  1680 loss = 0.445877
DEBUG    step  1681 loss = 1.06367
DEBUG    step  1682 loss = 1.53222
DEBUG    step  1683 loss = 1.46691
DEBUG    step  1684 loss = 1.33858
DEBUG    step  1685 loss = 1.34251
DEBUG    step  1686 loss = 1.41284
DEBUG    step  1687 loss = 1.13937
DEBUG    step  1688 loss = 2.37319
DEBUG    step  1689 loss = 0.934886
DEBUG    step  1690 loss = 0.989814
DEBUG    step  1691 loss = 1.37887
DEBUG    step  1692 loss = 1.40474
DEBUG    step  1693 loss = 1.73022
DEBUG    step  1694 loss = 0.660628
DEBUG    step  1695 loss = 1.47228
DEBUG    step  1696 loss = 1.16098
DEBUG    step  1697 loss = 1.3503
DEBUG    step  1698 loss = 1.31396
DEBUG    step  1699 loss = 2.02182
DEBUG    step  1700 loss = 0.960196
DEBUG    step  1701 loss = 1.45575
DEBUG    step  1702 loss = 1.09297
DEBUG    step  1703 loss = 1.27731
DEBUG    step  1704 loss = 1.63084
DEBUG    step  1705 loss = 1.46701
DEBUG    step  1706 loss = 1.58075
DEBUG    step  1707 loss = 2.77646
DEBUG    step  1708 loss = 1.66917
DEBUG    step  1709 loss = 1.53974
DEBUG    step  1710 loss = 0.746076
DEBUG    step  1711 loss = 0.787667
DEBUG    step  1712 loss = 1.48705
DEBUG    step  1713 loss = 1.15223
DEBUG    step  1714 loss = 0.74432
DEBUG    step  1715 loss = 1.20326
DEBUG    step  1716 loss = 1.05584
DEBUG    step  1717 loss = 1.25595
DEBUG    step  1718 loss = 1.63639
DEBUG    step  1719 loss = 1.18738
DEBUG    step  1720 loss = 0.997565
DEBUG    step  1721 loss = 1.59334
DEBUG    step  1722 loss = 1.18497
DEBUG    step  1723 loss = 1.39869
DEBUG    step  1724 loss = 1.13685
DEBUG    step  1725 loss = 0.477479
DEBUG    step  1726 loss = 1.42541
DEBUG    step  1727 loss = 1.47176
DEBUG    step  1728 loss = 2.13344
DEBUG    step  1729 loss = 0.989916
DEBUG    step  1730 loss = 1.00084
DEBUG    step  1731 loss = 1.31844
DEBUG    step  1732 loss = 1.44907
DEBUG    step  1733 loss = 1.14411
DEBUG    step  1734 loss = 0.997098
DEBUG    step  1735 loss = 1.22144
DEBUG    step  1736 loss = 1.65521
DEBUG    step  1737 loss = 1.04064
DEBUG    step  1738 loss = 1.40232
DEBUG    step  1739 loss = 1.21052
DEBUG    step  1740 loss = 0.52208
DEBUG    step  1741 loss = 0.96464
DEBUG    step  1742 loss = 0.922535
DEBUG    step  1743 loss = 0.57069
DEBUG    step  1744 loss = 1.29497
DEBUG    step  1745 loss = 0.764636
DEBUG    step  1746 loss = 0.596204
DEBUG    step  1747 loss = 1.47739
DEBUG    step  1748 loss = 0.704551
DEBUG    step  1749 loss = 1.13051
DEBUG    step  1750 loss = 1.81735
DEBUG    step  1751 loss = 1.15569
DEBUG    step  1752 loss = 0.62525
DEBUG    step  1753 loss = -0.14409
DEBUG    step  1754 loss = 0.819491
DEBUG    step  1755 loss = 0.584971
DEBUG    step  1756 loss = 1.50396
DEBUG    step  1757 loss = 1.12784
DEBUG    step  1758 loss = 1.37416
DEBUG    step  1759 loss = 0.944302
DEBUG    step  1760 loss = 0.708327
DEBUG    step  1761 loss = 1.51183
DEBUG    step  1762 loss = 0.951956
DEBUG    step  1763 loss = 1.13992
DEBUG    step  1764 loss = -0.0584559
DEBUG    step  1765 loss = 0.941625
DEBUG    step  1766 loss = 1.46371
DEBUG    step  1767 loss = 1.36433
DEBUG    step  1768 loss = 0.560516
DEBUG    step  1769 loss = 1.35952
DEBUG    step  1770 loss = 1.01687
DEBUG    step  1771 loss = 1.21911
DEBUG    step  1772 loss = 1.8578
DEBUG    step  1773 loss = 0.774448
DEBUG    step  1774 loss = 1.37295
DEBUG    step  1775 loss = 1.18173
DEBUG    step  1776 loss = 1.66936
DEBUG    step  1777 loss = 0.860755
DEBUG    step  1778 loss = 1.32138
DEBUG    step  1779 loss = 0.898082
DEBUG    step  1780 loss = 1.12301
DEBUG    step  1781 loss = 0.960121
DEBUG    step  1782 loss = 1.20348
DEBUG    step  1783 loss = 0.758963
DEBUG    step  1784 loss = 0.862989
DEBUG    step  1785 loss = 1.21436
DEBUG    step  1786 loss = 0.458139
DEBUG    step  1787 loss = 1.46172
DEBUG    step  1788 loss = 0.843393
DEBUG    step  1789 loss = 0.533864
DEBUG    step  1790 loss = 0.960291
DEBUG    step  1791 loss = 0.630529
DEBUG    step  1792 loss = 1.45164
DEBUG    step  1793 loss = 0.664835
DEBUG    step  1794 loss = 0.710118
DEBUG    step  1795 loss = 0.719209
DEBUG    step  1796 loss = 0.810381
DEBUG    step  1797 loss = 0.138259
DEBUG    step  1798 loss = 1.22091
DEBUG    step  1799 loss = 0.446191
DEBUG    step  1800 loss = 1.12451
DEBUG    step  1801 loss = 0.847999
DEBUG    step  1802 loss = 1.09745
DEBUG    step  1803 loss = 1.45925
DEBUG    step  1804 loss = 0.713525
DEBUG    step  1805 loss = 0.953999
DEBUG    step  1806 loss = 1.14265
DEBUG    step  1807 loss = 0.244373
DEBUG    step  1808 loss = 1.06263
DEBUG    step  1809 loss = 0.771337
DEBUG    step  1810 loss = 1.0411
DEBUG    step  1811 loss = 1.37541
DEBUG    step  1812 loss = 1.5398
DEBUG    step  1813 loss = 1.04689
DEBUG    step  1814 loss = 1.50583
DEBUG    step  1815 loss = 0.278969
DEBUG    step  1816 loss = 0.303059
DEBUG    step  1817 loss = 0.843962
DEBUG    step  1818 loss = 0.360989
DEBUG    step  1819 loss = 1.42488
DEBUG    step  1820 loss = 0.334529
DEBUG    step  1821 loss = 1.15429
DEBUG    step  1822 loss = 0.942839
DEBUG    step  1823 loss = -0.0623802
DEBUG    step  1824 loss = 1.2242
DEBUG    step  1825 loss = 0.110633
DEBUG    step  1826 loss = 1.04671
DEBUG    step  1827 loss = 0.814721
DEBUG    step  1828 loss = 0.981389
DEBUG    step  1829 loss = 0.374465
DEBUG    step  1830 loss = 0.682603
DEBUG    step  1831 loss = 0.888044
DEBUG    step  1832 loss = 1.00653
DEBUG    step  1833 loss = -0.192628
DEBUG    step  1834 loss = 1.33105
DEBUG    step  1835 loss = -0.292317
DEBUG    step  1836 loss = 1.40156
DEBUG    step  1837 loss = 0.548849
DEBUG    step  1838 loss = 0.733393
DEBUG    step  1839 loss = 0.737875
DEBUG    step  1840 loss = 0.953065
DEBUG    step  1841 loss = 1.35565
DEBUG    step  1842 loss = 0.334132
DEBUG    step  1843 loss = 0.527886
DEBUG    step  1844 loss = 0.728576
DEBUG    step  1845 loss = 0.971659
DEBUG    step  1846 loss = 1.0362
DEBUG    step  1847 loss = 1.1995
DEBUG    step  1848 loss = 0.74542
DEBUG    step  1849 loss = 0.822038
DEBUG    step  1850 loss = 0.14102
DEBUG    step  1851 loss = 0.351881
DEBUG    step  1852 loss = 0.718691
DEBUG    step  1853 loss = 0.454031
DEBUG    step  1854 loss = 1.34327
DEBUG    step  1855 loss = 1.12586
DEBUG    step  1856 loss = 0.794541
DEBUG    step  1857 loss = 0.881259
DEBUG    step  1858 loss = 0.402362
DEBUG    step  1859 loss = 0.490797
DEBUG    step  1860 loss = 0.12956
DEBUG    step  1861 loss = 1.00601
DEBUG    step  1862 loss = 0.0126683
DEBUG    step  1863 loss = 0.367983
DEBUG    step  1864 loss = 0.519085
DEBUG    step  1865 loss = 1.5708
DEBUG    step  1866 loss = 1.47664
DEBUG    step  1867 loss = 0.891001
DEBUG    step  1868 loss = 1.33164
DEBUG    step  1869 loss = 1.43242
DEBUG    step  1870 loss = 1.57703
DEBUG    step  1871 loss = 0.409759
DEBUG    step  1872 loss = 0.481442
DEBUG    step  1873 loss = 0.433702
DEBUG    step  1874 loss = 0.102985
DEBUG    step  1875 loss = 1.07597
DEBUG    step  1876 loss = 0.628031
DEBUG    step  1877 loss = -0.0152627
DEBUG    step  1878 loss = 0.482545
DEBUG    step  1879 loss = 1.55648
DEBUG    step  1880 loss = 0.844998
DEBUG    step  1881 loss = 0.42592
DEBUG    step  1882 loss = -0.0152035
DEBUG    step  1883 loss = -0.0997669
DEBUG    step  1884 loss = 1.01354
DEBUG    step  1885 loss = 0.490207
DEBUG    step  1886 loss = 0.736687
DEBUG    step  1887 loss = 0.433603
DEBUG    step  1888 loss = 1.07525
DEBUG    step  1889 loss = 0.678383
DEBUG    step  1890 loss = 0.980835
DEBUG    step  1891 loss = 0.470526
DEBUG    step  1892 loss = 0.591348
DEBUG    step  1893 loss = 0.496179
DEBUG    step  1894 loss = 0.164359
DEBUG    step  1895 loss = 0.505431
DEBUG    step  1896 loss = 0.848054
DEBUG    step  1897 loss = 1.22015
DEBUG    step  1898 loss = 0.21223
DEBUG    step  1899 loss = 0.804585
DEBUG    step  1900 loss = 0.337482
DEBUG    step  1901 loss = 0.380753
DEBUG    step  1902 loss = 1.09557
DEBUG    step  1903 loss = 0.452767
DEBUG    step  1904 loss = 0.505589
DEBUG    step  1905 loss = 0.533463
DEBUG    step  1906 loss = 0.732611
DEBUG    step  1907 loss = 0.457369
DEBUG    step  1908 loss = 0.397615
DEBUG    step  1909 loss = 0.304795
DEBUG    step  1910 loss = 0.832857
DEBUG    step  1911 loss = 0.776005
DEBUG    step  1912 loss = 0.0557357
DEBUG    step  1913 loss = 1.06473
DEBUG    step  1914 loss = 0.621938
DEBUG    step  1915 loss = 3.8174
DEBUG    step  1916 loss = 0.834741
DEBUG    step  1917 loss = 0.432647
DEBUG    step  1918 loss = 1.0107
DEBUG    step  1919 loss = 0.887171
DEBUG    step  1920 loss = 0.214395
DEBUG    step  1921 loss = 0.27015
DEBUG    step  1922 loss = 0.723923
DEBUG    step  1923 loss = 0.0225524
DEBUG    step  1924 loss = 0.311126
DEBUG    step  1925 loss = 0.163129
DEBUG    step  1926 loss = 1.0852
DEBUG    step  1927 loss = 0.845341
DEBUG    step  1928 loss = 0.067302
DEBUG    step  1929 loss = 1.81058
DEBUG    step  1930 loss = 0.711902
DEBUG    step  1931 loss = 0.544337
DEBUG    step  1932 loss = 0.729942
DEBUG    step  1933 loss = 0.281568
DEBUG    step  1934 loss = 0.746916
DEBUG    step  1935 loss = 0.731851
DEBUG    step  1936 loss = 0.861581
DEBUG    step  1937 loss = 0.587285
DEBUG    step  1938 loss = 0.375893
DEBUG    step  1939 loss = 0.52338
DEBUG    step  1940 loss = 0.0507239
DEBUG    step  1941 loss = 0.544204
DEBUG    step  1942 loss = 0.139653
DEBUG    step  1943 loss = 0.603852
DEBUG    step  1944 loss = 0.591492
DEBUG    step  1945 loss = 0.211932
DEBUG    step  1946 loss = 0.632158
DEBUG    step  1947 loss = 0.613739
DEBUG    step  1948 loss = 1.12637
DEBUG    step  1949 loss = 0.655486
DEBUG    step  1950 loss = 0.687108
DEBUG    step  1951 loss = 0.224532
DEBUG    step  1952 loss = 0.675569
DEBUG    step  1953 loss = 1.16836
DEBUG    step  1954 loss = 0.575642
DEBUG    step  1955 loss = 0.314398
DEBUG    step  1956 loss = 0.949717
DEBUG    step  1957 loss = 1.06026
DEBUG    step  1958 loss = 0.894075
DEBUG    step  1959 loss = 0.268737
DEBUG    step  1960 loss = -0.0684191
DEBUG    step  1961 loss = 0.301358
DEBUG    step  1962 loss = 0.670349
DEBUG    step  1963 loss = 0.631736
DEBUG    step  1964 loss = 1.17734
DEBUG    step  1965 loss = -0.0977912
DEBUG    step  1966 loss = 0.872278
DEBUG    step  1967 loss = 0.0835433
DEBUG    step  1968 loss = -0.0705985
DEBUG    step  1969 loss = 0.193565
DEBUG    step  1970 loss = 0.817641
DEBUG    step  1971 loss = 1.54214
DEBUG    step  1972 loss = -0.0112863
DEBUG    step  1973 loss = 0.170732
DEBUG    step  1974 loss = 0.437139
DEBUG    step  1975 loss = -0.0416076
DEBUG    step  1976 loss = 0.201051
DEBUG    step  1977 loss = 0.663106
DEBUG    step  1978 loss = 0.647153
DEBUG    step  1979 loss = 0.138818
DEBUG    step  1980 loss = 0.0719861
DEBUG    step  1981 loss = 1.12457
DEBUG    step  1982 loss = 0.123392
DEBUG    step  1983 loss = 0.35576
DEBUG    step  1984 loss = 0.187577
DEBUG    step  1985 loss = 0.158135
DEBUG    step  1986 loss = 0.172388
DEBUG    step  1987 loss = 0.864039
DEBUG    step  1988 loss = 0.522948
DEBUG    step  1989 loss = 0.218993
DEBUG    step  1990 loss = 0.958601
DEBUG    step  1991 loss = 0.0281422
DEBUG    step  1992 loss = 0.15538
DEBUG    step  1993 loss = 0.298106
DEBUG    step  1994 loss = 0.192198
DEBUG    step  1995 loss = -0.437914
DEBUG    step  1996 loss = 0.17182
DEBUG    step  1997 loss = 0.625345
DEBUG    step  1998 loss = 0.443585
DEBUG    step  1999 loss = -0.0372677
DEBUG    step  2000 loss = 0.0965499
DEBUG    step  2001 loss = 0.684757
DEBUG    step  2002 loss = 0.0434506
DEBUG    step  2003 loss = 0.179006
DEBUG    step  2004 loss = 0.585443
DEBUG    step  2005 loss = 0.75187
DEBUG    step  2006 loss = -0.19287
DEBUG    step  2007 loss = 0.753149
DEBUG    step  2008 loss = 0.524784
DEBUG    step  2009 loss = 0.500014
DEBUG    step  2010 loss = 0.68905
DEBUG    step  2011 loss = 0.508104
DEBUG    step  2012 loss = 1.12944
DEBUG    step  2013 loss = 0.636447
DEBUG    step  2014 loss = 1.07191
DEBUG    step  2015 loss = 0.620359
DEBUG    step  2016 loss = -0.0672604
DEBUG    step  2017 loss = 0.12611
DEBUG    step  2018 loss = -0.160067
DEBUG    step  2019 loss = 0.560006
DEBUG    step  2020 loss = -0.0938559
DEBUG    step  2021 loss = 0.2633
DEBUG    step  2022 loss = -0.24172
DEBUG    step  2023 loss = 0.23306
DEBUG    step  2024 loss = -0.119578
DEBUG    step  2025 loss = 0.304582
DEBUG    step  2026 loss = 0.222591
DEBUG    step  2027 loss = 0.47586
DEBUG    step  2028 loss = 0.504828
DEBUG    step  2029 loss = 0.422783
DEBUG    step  2030 loss = 0.346542
DEBUG    step  2031 loss = 0.22548
DEBUG    step  2032 loss = 0.0345138
DEBUG    step  2033 loss = 0.727085
DEBUG    step  2034 loss = 0.438053
DEBUG    step  2035 loss = -0.163181
DEBUG    step  2036 loss = 0.816675
DEBUG    step  2037 loss = 0.0115353
DEBUG    step  2038 loss = 0.768062
DEBUG    step  2039 loss = 0.24584
DEBUG    step  2040 loss = 0.290391
DEBUG    step  2041 loss = 0.955838
DEBUG    step  2042 loss = 0.185171
DEBUG    step  2043 loss = -0.360956
DEBUG    step  2044 loss = 0.12458
DEBUG    step  2045 loss = 0.00191054
DEBUG    step  2046 loss = 0.0451765
DEBUG    step  2047 loss = 0.215519
DEBUG    step  2048 loss = 0.159755
DEBUG    step  2049 loss = 0.917712
DEBUG    step  2050 loss = -0.26462
DEBUG    step  2051 loss = 0.310773
DEBUG    step  2052 loss = -0.0363671
DEBUG    step  2053 loss = 0.0293219
DEBUG    step  2054 loss = -0.00587582
DEBUG    step  2055 loss = 0.471752
DEBUG    step  2056 loss = 0.238597
DEBUG    step  2057 loss = 0.0422264
DEBUG    step  2058 loss = -0.543846
DEBUG    step  2059 loss = 0.777388
DEBUG    step  2060 loss = -0.693749
DEBUG    step  2061 loss = 0.0994059
DEBUG    step  2062 loss = -0.286047
DEBUG    step  2063 loss = 0.766898
DEBUG    step  2064 loss = -0.142116
DEBUG    step  2065 loss = 0.883171
DEBUG    step  2066 loss = 0.180947
DEBUG    step  2067 loss = 0.210857
DEBUG    step  2068 loss = 0.118777
DEBUG    step  2069 loss = -0.141074
DEBUG    step  2070 loss = 0.363284
DEBUG    step  2071 loss = 0.39178
DEBUG    step  2072 loss = 0.305299
DEBUG    step  2073 loss = 0.545026
DEBUG    step  2074 loss = -0.226126
DEBUG    step  2075 loss = 0.169667
DEBUG    step  2076 loss = -0.336501
DEBUG    step  2077 loss = 0.965252
DEBUG    step  2078 loss = -0.170774
DEBUG    step  2079 loss = 0.0928747
DEBUG    step  2080 loss = 0.134985
DEBUG    step  2081 loss = 0.0768925
DEBUG    step  2082 loss = 0.207024
DEBUG    step  2083 loss = -0.157205
DEBUG    step  2084 loss = -0.13322
DEBUG    step  2085 loss = 0.262412
DEBUG    step  2086 loss = 0.327786
DEBUG    step  2087 loss = -0.0993449
DEBUG    step  2088 loss = 0.244769
DEBUG    step  2089 loss = -0.0589051
DEBUG    step  2090 loss = 0.332496
DEBUG    step  2091 loss = 0.925634
DEBUG    step  2092 loss = -0.257988
DEBUG    step  2093 loss = 0.518207
DEBUG    step  2094 loss = 0.286856
DEBUG    step  2095 loss = -0.300405
DEBUG    step  2096 loss = -0.0130847
DEBUG    step  2097 loss = 0.519027
DEBUG    step  2098 loss = 0.318041
DEBUG    step  2099 loss = -0.133822
DEBUG    step  2100 loss = -0.076749
DEBUG    step  2101 loss = 0.0152595
DEBUG    step  2102 loss = 0.678585
DEBUG    step  2103 loss = -0.164601
DEBUG    step  2104 loss = 0.384856
DEBUG    step  2105 loss = 0.0680997
DEBUG    step  2106 loss = -0.0351076
DEBUG    step  2107 loss = 0.231791
DEBUG    step  2108 loss = -0.117496
DEBUG    step  2109 loss = -0.0222189
DEBUG    step  2110 loss = -0.0573999
DEBUG    step  2111 loss = 0.524485
DEBUG    step  2112 loss = 0.0913248
DEBUG    step  2113 loss = 0.280226
DEBUG    step  2114 loss = 0.318695
DEBUG    step  2115 loss = 0.039408
DEBUG    step  2116 loss = 0.0231956
DEBUG    step  2117 loss = -0.144188
DEBUG    step  2118 loss = -0.249522
DEBUG    step  2119 loss = 0.182491
DEBUG    step  2120 loss = -0.137275
DEBUG    step  2121 loss = -0.116535
DEBUG    step  2122 loss = -0.502473
DEBUG    step  2123 loss = 0.106871
DEBUG    step  2124 loss = 0.219624
DEBUG    step  2125 loss = 0.236981
DEBUG    step  2126 loss = 0.308991
DEBUG    step  2127 loss = 0.361933
DEBUG    step  2128 loss = -0.0891354
DEBUG    step  2129 loss = 0.375717
DEBUG    step  2130 loss = 0.458
DEBUG    step  2131 loss = 0.804599
DEBUG    step  2132 loss = -0.850078
DEBUG    step  2133 loss = -0.565978
DEBUG    step  2134 loss = 0.395504
DEBUG    step  2135 loss = 0.0360778
DEBUG    step  2136 loss = 0.262763
DEBUG    step  2137 loss = 0.173679
DEBUG    step  2138 loss = 0.245434
DEBUG    step  2139 loss = -0.325045
DEBUG    step  2140 loss = 0.197687
DEBUG    step  2141 loss = 0.10554
DEBUG    step  2142 loss = 0.629076
DEBUG    step  2143 loss = -0.444622
DEBUG    step  2144 loss = 0.29245
DEBUG    step  2145 loss = -0.169153
DEBUG    step  2146 loss = -0.122091
DEBUG    step  2147 loss = -0.482058
DEBUG    step  2148 loss = -0.145807
DEBUG    step  2149 loss = -0.321955
DEBUG    step  2150 loss = -0.204977
DEBUG    step  2151 loss = 0.260222
DEBUG    step  2152 loss = -0.0221428
DEBUG    step  2153 loss = -0.299182
DEBUG    step  2154 loss = 0.492136
DEBUG    step  2155 loss = -0.512058
DEBUG    step  2156 loss = -0.701374
DEBUG    step  2157 loss = 0.616286
DEBUG    step  2158 loss = -0.580705
DEBUG    step  2159 loss = 0.543072
DEBUG    step  2160 loss = -0.271091
DEBUG    step  2161 loss = -0.152006
DEBUG    step  2162 loss = -0.0906625
DEBUG    step  2163 loss = -0.341321
DEBUG    step  2164 loss = -0.0973744
DEBUG    step  2165 loss = 0.335691
DEBUG    step  2166 loss = -0.513224
DEBUG    step  2167 loss = 0.441127
DEBUG    step  2168 loss = -0.195149
DEBUG    step  2169 loss = -0.155654
DEBUG    step  2170 loss = 0.146065
DEBUG    step  2171 loss = -0.157879
DEBUG    step  2172 loss = 0.427397
DEBUG    step  2173 loss = -0.264271
DEBUG    step  2174 loss = 0.255104
DEBUG    step  2175 loss = 0.143516
DEBUG    step  2176 loss = -0.144723
DEBUG    step  2177 loss = 0.362921
DEBUG    step  2178 loss = 0.085199
DEBUG    step  2179 loss = 0.166598
DEBUG    step  2180 loss = -0.529532
DEBUG    step  2181 loss = -0.318048
DEBUG    step  2182 loss = -0.0852365
DEBUG    step  2183 loss = -0.226952
DEBUG    step  2184 loss = 0.372169
DEBUG    step  2185 loss = 0.46677
DEBUG    step  2186 loss = -0.0550372
DEBUG    step  2187 loss = 0.123473
DEBUG    step  2188 loss = -0.709439
DEBUG    step  2189 loss = 0.627293
DEBUG    step  2190 loss = -0.932047
DEBUG    step  2191 loss = -0.0653693
DEBUG    step  2192 loss = 0.694153
DEBUG    step  2193 loss = -0.0535071
DEBUG    step  2194 loss = -0.691768
DEBUG    step  2195 loss = -0.0777673
DEBUG    step  2196 loss = -0.0291022
DEBUG    step  2197 loss = 0.0775634
DEBUG    step  2198 loss = -0.00225392
DEBUG    step  2199 loss = 0.467416
DEBUG    step  2200 loss = -0.0729818
DEBUG    step  2201 loss = -0.174586
DEBUG    step  2202 loss = -0.0735762
DEBUG    step  2203 loss = -0.291103
DEBUG    step  2204 loss = 0.206642
DEBUG    step  2205 loss = -0.35946
DEBUG    step  2206 loss = 0.0623758
DEBUG    step  2207 loss = -0.0335207
DEBUG    step  2208 loss = -0.322341
DEBUG    step  2209 loss = -0.164268
DEBUG    step  2210 loss = -0.298333
DEBUG    step  2211 loss = -0.542928
DEBUG    step  2212 loss = 0.818519
DEBUG    step  2213 loss = -0.175861
DEBUG    step  2214 loss = -1.18826
DEBUG    step  2215 loss = 0.020086
DEBUG    step  2216 loss = -1.07731
DEBUG    step  2217 loss = 0.861459
DEBUG    step  2218 loss = -0.30791
DEBUG    step  2219 loss = 12.8663
DEBUG    step  2220 loss = 0.110738
DEBUG    step  2221 loss = 0.415476
DEBUG    step  2222 loss = -0.0830224
DEBUG    step  2223 loss = 0.026601
DEBUG    step  2224 loss = -0.484626
DEBUG    step  2225 loss = -0.643493
DEBUG    step  2226 loss = -0.531596
DEBUG    step  2227 loss = -0.159798
DEBUG    step  2228 loss = 0.444723
DEBUG    step  2229 loss = -0.209576
DEBUG    step  2230 loss = -0.117957
DEBUG    step  2231 loss = 0.26718
DEBUG    step  2232 loss = -0.623983
DEBUG    step  2233 loss = -0.134441
DEBUG    step  2234 loss = -1.03047
DEBUG    step  2235 loss = 0.10526
DEBUG    step  2236 loss = -0.168391
DEBUG    step  2237 loss = -0.325326
DEBUG    step  2238 loss = -0.636917
DEBUG    step  2239 loss = -1.01447
DEBUG    step  2240 loss = -0.137275
DEBUG    step  2241 loss = -0.0928798
DEBUG    step  2242 loss = 0.521724
DEBUG    step  2243 loss = -0.726267
DEBUG    step  2244 loss = -0.151048
DEBUG    step  2245 loss = -0.0553814
DEBUG    step  2246 loss = -0.0806889
DEBUG    step  2247 loss = -0.265405
DEBUG    step  2248 loss = -0.605389
DEBUG    step  2249 loss = 0.609598
DEBUG    step  2250 loss = 0.201578
DEBUG    step  2251 loss = -0.301686
DEBUG    step  2252 loss = 0.254437
DEBUG    step  2253 loss = 0.53236
DEBUG    step  2254 loss = -0.405195
DEBUG    step  2255 loss = -0.0701203
DEBUG    step  2256 loss = -0.2183
DEBUG    step  2257 loss = -0.766243
DEBUG    step  2258 loss = -0.732259
DEBUG    step  2259 loss = -0.142207
DEBUG    step  2260 loss = -0.15166
DEBUG    step  2261 loss = -0.700015
DEBUG    step  2262 loss = 0.0802323
DEBUG    step  2263 loss = 0.313499
DEBUG    step  2264 loss = 0.283268
DEBUG    step  2265 loss = -0.458733
DEBUG    step  2266 loss = 0.169434
DEBUG    step  2267 loss = 0.0517936
DEBUG    step  2268 loss = -0.303608
DEBUG    step  2269 loss = 0.273257
DEBUG    step  2270 loss = -0.392904
DEBUG    step  2271 loss = 0.44848
DEBUG    step  2272 loss = -0.703877
DEBUG    step  2273 loss = -1.01002
DEBUG    step  2274 loss = 0.359133
DEBUG    step  2275 loss = 0.212775
DEBUG    step  2276 loss = -0.519192
DEBUG    step  2277 loss = -0.2437
DEBUG    step  2278 loss = -0.667431
DEBUG    step  2279 loss = -0.996026
DEBUG    step  2280 loss = 0.273185
DEBUG    step  2281 loss = -0.00770547
DEBUG    step  2282 loss = -0.162126
DEBUG    step  2283 loss = 0.175816
DEBUG    step  2284 loss = 0.0773304
DEBUG    step  2285 loss = -0.512412
DEBUG    step  2286 loss = -0.607146
DEBUG    step  2287 loss = 0.182539
DEBUG    step  2288 loss = -0.694855
DEBUG    step  2289 loss = 0.335107
DEBUG    step  2290 loss = 0.351011
DEBUG    step  2291 loss = -0.367074
DEBUG    step  2292 loss = 0.961813
DEBUG    step  2293 loss = 0.319814
DEBUG    step  2294 loss = -0.0375465
DEBUG    step  2295 loss = -0.685502
DEBUG    step  2296 loss = 0.702536
DEBUG    step  2297 loss = -0.0365256
DEBUG    step  2298 loss = 0.297325
DEBUG    step  2299 loss = -0.161133
DEBUG    step  2300 loss = -0.0621092
DEBUG    step  2301 loss = -0.524049
DEBUG    step  2302 loss = -0.428477
DEBUG    step  2303 loss = -0.481184
DEBUG    step  2304 loss = -0.582241
DEBUG    step  2305 loss = -0.22409
DEBUG    step  2306 loss = -0.0466428
DEBUG    step  2307 loss = -0.807201
DEBUG    step  2308 loss = -0.418819
DEBUG    step  2309 loss = -0.11762
DEBUG    step  2310 loss = -0.00959172
DEBUG    step  2311 loss = -0.00444585
DEBUG    step  2312 loss = 0.043913
DEBUG    step  2313 loss = 0.571166
DEBUG    step  2314 loss = -0.537292
DEBUG    step  2315 loss = 0.270969
DEBUG    step  2316 loss = -0.212546
DEBUG    step  2317 loss = 0.112569
DEBUG    step  2318 loss = -0.455186
DEBUG    step  2319 loss = -0.424695
DEBUG    step  2320 loss = -0.464438
DEBUG    step  2321 loss = -0.473156
DEBUG    step  2322 loss = -0.105536
DEBUG    step  2323 loss = -0.198469
DEBUG    step  2324 loss = 0.422803
DEBUG    step  2325 loss = 0.887627
DEBUG    step  2326 loss = -0.685745
DEBUG    step  2327 loss = -0.656979
DEBUG    step  2328 loss = -1.1468
DEBUG    step  2329 loss = -0.416101
DEBUG    step  2330 loss = -0.0506251
DEBUG    step  2331 loss = 0.38371
DEBUG    step  2332 loss = -0.410896
DEBUG    step  2333 loss = -0.490316
DEBUG    step  2334 loss = -0.148082
DEBUG    step  2335 loss = -1.2066
DEBUG    step  2336 loss = -0.480291
DEBUG    step  2337 loss = -0.564195
DEBUG    step  2338 loss = -0.051699
DEBUG    step  2339 loss = 0.554887
DEBUG    step  2340 loss = 0.464537
DEBUG    step  2341 loss = -0.586118
DEBUG    step  2342 loss = -0.224842
DEBUG    step  2343 loss = 0.140776
DEBUG    step  2344 loss = 0.0989285
DEBUG    step  2345 loss = -0.140234
DEBUG    step  2346 loss = -0.220834
DEBUG    step  2347 loss = 0.358295
DEBUG    step  2348 loss = -0.935413
DEBUG    step  2349 loss = -0.797103
DEBUG    step  2350 loss = -0.370552
DEBUG    step  2351 loss = -0.255635
DEBUG    step  2352 loss = 0.0331677
DEBUG    step  2353 loss = -0.0654061
DEBUG    step  2354 loss = -0.792516
DEBUG    step  2355 loss = 1.00517
DEBUG    step  2356 loss = -0.0650678
DEBUG    step  2357 loss = 0.100208
DEBUG    step  2358 loss = 0.315501
DEBUG    step  2359 loss = -0.196945
DEBUG    step  2360 loss = -0.706372
DEBUG    step  2361 loss = 0.134541
DEBUG    step  2362 loss = -0.114532
DEBUG    step  2363 loss = -0.661938
DEBUG    step  2364 loss = -0.826783
DEBUG    step  2365 loss = 0.561703
DEBUG    step  2366 loss = -0.380749
DEBUG    step  2367 loss = -0.599982
DEBUG    step  2368 loss = -0.552984
DEBUG    step  2369 loss = -0.809876
DEBUG    step  2370 loss = -0.41806
DEBUG    step  2371 loss = -0.293652
DEBUG    step  2372 loss = 0.019794
DEBUG    step  2373 loss = 0.366571
DEBUG    step  2374 loss = -0.330331
DEBUG    step  2375 loss = -0.108959
DEBUG    step  2376 loss = 0.0823981
DEBUG    step  2377 loss = -0.122074
DEBUG    step  2378 loss = 0.104684
DEBUG    step  2379 loss = -0.245806
DEBUG    step  2380 loss = -0.458836
DEBUG    step  2381 loss = -0.728625
DEBUG    step  2382 loss = 0.366162
DEBUG    step  2383 loss = -0.402356
DEBUG    step  2384 loss = -0.915713
DEBUG    step  2385 loss = 0.25255
DEBUG    step  2386 loss = -0.596414
DEBUG    step  2387 loss = 0.191845
DEBUG    step  2388 loss = 0.173331
DEBUG    step  2389 loss = -0.235943
DEBUG    step  2390 loss = -0.578616
DEBUG    step  2391 loss = -0.387393
DEBUG    step  2392 loss = -0.509603
DEBUG    step  2393 loss = -0.0789079
DEBUG    step  2394 loss = -0.146879
DEBUG    step  2395 loss = -0.162622
DEBUG    step  2396 loss = -0.580962
DEBUG    step  2397 loss = -0.704767
DEBUG    step  2398 loss = -0.471613
DEBUG    step  2399 loss = -0.18096
DEBUG    step  2400 loss = -0.162947
DEBUG    step  2401 loss = 0.0571842
DEBUG    step  2402 loss = -0.707115
DEBUG    step  2403 loss = -0.812926
DEBUG    step  2404 loss = 0.680889
DEBUG    step  2405 loss = 0.158955
DEBUG    step  2406 loss = -0.636955
DEBUG    step  2407 loss = -0.821936
DEBUG    step  2408 loss = 0.0161349
DEBUG    step  2409 loss = 2.05343
DEBUG    step  2410 loss = -0.449846
DEBUG    step  2411 loss = -0.112297
DEBUG    step  2412 loss = 0.23516
DEBUG    step  2413 loss = 0.598729
DEBUG    step  2414 loss = -0.637791
DEBUG    step  2415 loss = -0.0771543
DEBUG    step  2416 loss = -0.720933
DEBUG    step  2417 loss = -0.324247
DEBUG    step  2418 loss = -0.615081
DEBUG    step  2419 loss = -0.489061
DEBUG    step  2420 loss = -0.81913
DEBUG    step  2421 loss = -0.291852
DEBUG    step  2422 loss = -0.279411
DEBUG    step  2423 loss = -0.168712
DEBUG    step  2424 loss = -0.823371
DEBUG    step  2425 loss = -0.956634
DEBUG    step  2426 loss = 0.283457
DEBUG    step  2427 loss = 0.194569
DEBUG    step  2428 loss = -0.838871
DEBUG    step  2429 loss = -0.0047413
DEBUG    step  2430 loss = -0.559076
DEBUG    step  2431 loss = -0.689148
DEBUG    step  2432 loss = -0.299682
DEBUG    step  2433 loss = -0.884385
DEBUG    step  2434 loss = -0.595315
DEBUG    step  2435 loss = -1.11435
DEBUG    step  2436 loss = 0.0495753
DEBUG    step  2437 loss = -0.0852002
DEBUG    step  2438 loss = -0.15404
DEBUG    step  2439 loss = -0.266736
DEBUG    step  2440 loss = 0.195932
DEBUG    step  2441 loss = 0.185633
DEBUG    step  2442 loss = -0.863258
DEBUG    step  2443 loss = -0.382026
DEBUG    step  2444 loss = 0.252158
DEBUG    step  2445 loss = -0.448511
DEBUG    step  2446 loss = -0.179625
DEBUG    step  2447 loss = -0.114999
DEBUG    step  2448 loss = -1.00638
DEBUG    step  2449 loss = -0.0562548
DEBUG    step  2450 loss = 0.120608
DEBUG    step  2451 loss = -0.248703
DEBUG    step  2452 loss = 0.580167
DEBUG    step  2453 loss = -0.403365
DEBUG    step  2454 loss = -0.427596
DEBUG    step  2455 loss = -0.386274
DEBUG    step  2456 loss = -0.0709784
DEBUG    step  2457 loss = -0.478124
DEBUG    step  2458 loss = -0.427781
DEBUG    step  2459 loss = 0.213299
DEBUG    step  2460 loss = 0.185551
DEBUG    step  2461 loss = -1.15001
DEBUG    step  2462 loss = -0.908913
DEBUG    step  2463 loss = -0.296839
DEBUG    step  2464 loss = -0.213982
DEBUG    step  2465 loss = -0.139768
DEBUG    step  2466 loss = -0.554577
DEBUG    step  2467 loss = -1.29373
DEBUG    step  2468 loss = 0.168238
DEBUG    step  2469 loss = 0.134877
DEBUG    step  2470 loss = -0.255521
DEBUG    step  2471 loss = -0.750256
DEBUG    step  2472 loss = -0.0114451
DEBUG    step  2473 loss = -0.410735
DEBUG    step  2474 loss = 0.218873
DEBUG    step  2475 loss = -0.141217
DEBUG    step  2476 loss = -0.78113
DEBUG    step  2477 loss = -0.143108
DEBUG    step  2478 loss = -0.0878578
DEBUG    step  2479 loss = 0.498992
DEBUG    step  2480 loss = -0.385873
DEBUG    step  2481 loss = 0.697456
DEBUG    step  2482 loss = -0.330902
DEBUG    step  2483 loss = -0.416052
DEBUG    step  2484 loss = -0.0582824
DEBUG    step  2485 loss = -0.749726
DEBUG    step  2486 loss = -0.705093
DEBUG    step  2487 loss = -0.366732
DEBUG    step  2488 loss = 0.0636343
DEBUG    step  2489 loss = -0.428274
DEBUG    step  2490 loss = -0.97996
DEBUG    step  2491 loss = -0.721423
DEBUG    step  2492 loss = -0.901971
DEBUG    step  2493 loss = -0.821726
DEBUG    step  2494 loss = -0.48277
DEBUG    step  2495 loss = 0.159761
DEBUG    step  2496 loss = -0.802472
DEBUG    step  2497 loss = -0.687559
DEBUG    step  2498 loss = -0.256268
DEBUG    step  2499 loss = -0.571636
DEBUG    step  2500 loss = -0.184076
DEBUG    step  2501 loss = -0.0532485
DEBUG    step  2502 loss = 0.0489593
DEBUG    step  2503 loss = -0.699592
DEBUG    step  2504 loss = -0.964232
DEBUG    step  2505 loss = -0.33835
DEBUG    step  2506 loss = -0.425566
DEBUG    step  2507 loss = -0.0965802
DEBUG    step  2508 loss = -0.745661
DEBUG    step  2509 loss = -0.103916
DEBUG    step  2510 loss = -0.489986
DEBUG    step  2511 loss = -1.22721
DEBUG    step  2512 loss = -0.573065
DEBUG    step  2513 loss = -0.8967
DEBUG    step  2514 loss = -0.714046
DEBUG    step  2515 loss = -0.893781
DEBUG    step  2516 loss = 0.465743
DEBUG    step  2517 loss = -0.941392
DEBUG    step  2518 loss = -0.858442
DEBUG    step  2519 loss = -0.18183
DEBUG    step  2520 loss = -0.380441
DEBUG    step  2521 loss = -0.374258
DEBUG    step  2522 loss = -0.682367
DEBUG    step  2523 loss = -0.821137
DEBUG    step  2524 loss = -0.445525
DEBUG    step  2525 loss = -0.97567
DEBUG    step  2526 loss = -0.547556
DEBUG    step  2527 loss = -0.853315
DEBUG    step  2528 loss = 0.114161
DEBUG    step  2529 loss = -0.579036
DEBUG    step  2530 loss = 0.0135827
DEBUG    step  2531 loss = -0.0582753
DEBUG    step  2532 loss = -0.140801
DEBUG    step  2533 loss = -0.182517
DEBUG    step  2534 loss = -0.829945
DEBUG    step  2535 loss = -0.0669306
DEBUG    step  2536 loss = -0.467228
DEBUG    step  2537 loss = -0.584846
DEBUG    step  2538 loss = -0.273549
DEBUG    step  2539 loss = -0.00248221
DEBUG    step  2540 loss = -0.345479
DEBUG    step  2541 loss = -0.515946
DEBUG    step  2542 loss = -0.103854
DEBUG    step  2543 loss = 0.187452
DEBUG    step  2544 loss = -0.154338
DEBUG    step  2545 loss = -0.915668
DEBUG    step  2546 loss = -0.75074
DEBUG    step  2547 loss = -0.235062
DEBUG    step  2548 loss = -0.615748
DEBUG    step  2549 loss = 0.163511
DEBUG    step  2550 loss = -0.558204
DEBUG    step  2551 loss = -0.429658
DEBUG    step  2552 loss = -0.527625
DEBUG    step  2553 loss = -0.663658
DEBUG    step  2554 loss = -0.866039
DEBUG    step  2555 loss = -0.0667327
DEBUG    step  2556 loss = -1.14744
DEBUG    step  2557 loss = -0.599862
DEBUG    step  2558 loss = -0.628051
DEBUG    step  2559 loss = -1.02429
DEBUG    step  2560 loss = -0.812641
DEBUG    step  2561 loss = -0.207669
DEBUG    step  2562 loss = -0.346239
DEBUG    step  2563 loss = -0.42864
DEBUG    step  2564 loss = -0.769289
DEBUG    step  2565 loss = -0.442619
DEBUG    step  2566 loss = -0.551839
DEBUG    step  2567 loss = -0.434892
DEBUG    step  2568 loss = -0.822885
DEBUG    step  2569 loss = -0.0774252
DEBUG    step  2570 loss = -0.962704
DEBUG    step  2571 loss = 0.382489
DEBUG    step  2572 loss = -0.340682
DEBUG    step  2573 loss = -0.42353
DEBUG    step  2574 loss = -0.0114898
DEBUG    step  2575 loss = -0.210306
DEBUG    step  2576 loss = -0.625316
DEBUG    step  2577 loss = -0.61977
DEBUG    step  2578 loss = -0.641895
DEBUG    step  2579 loss = -0.158468
DEBUG    step  2580 loss = -0.376173
DEBUG    step  2581 loss = -0.562516
DEBUG    step  2582 loss = -0.606728
DEBUG    step  2583 loss = -0.486623
DEBUG    step  2584 loss = -0.253736
DEBUG    step  2585 loss = 0.342148
DEBUG    step  2586 loss = -0.165116
DEBUG    step  2587 loss = -0.173551
DEBUG    step  2588 loss = -1.46536
DEBUG    step  2589 loss = 0.0896398
DEBUG    step  2590 loss = -0.545322
DEBUG    step  2591 loss = -0.406094
DEBUG    step  2592 loss = -0.918525
DEBUG    step  2593 loss = -0.894497
DEBUG    step  2594 loss = -0.578103
DEBUG    step  2595 loss = -0.553256
DEBUG    step  2596 loss = -0.593555
DEBUG    step  2597 loss = 0.00266581
DEBUG    step  2598 loss = -0.0584986
DEBUG    step  2599 loss = -0.607323
DEBUG    step  2600 loss = 0.38463
DEBUG    step  2601 loss = -0.481794
DEBUG    step  2602 loss = -0.902755
DEBUG    step  2603 loss = -0.823573
DEBUG    step  2604 loss = -0.581352
DEBUG    step  2605 loss = -0.546566
DEBUG    step  2606 loss = -1.1963
DEBUG    step  2607 loss = -0.66562
DEBUG    step  2608 loss = -0.885256
DEBUG    step  2609 loss = -0.510776
DEBUG    step  2610 loss = -0.414367
DEBUG    step  2611 loss = -0.63994
DEBUG    step  2612 loss = -0.993912
DEBUG    step  2613 loss = -1.01504
DEBUG    step  2614 loss = 0.596202
DEBUG    step  2615 loss = 0.482037
DEBUG    step  2616 loss = -0.301577
DEBUG    step  2617 loss = -1.49396
DEBUG    step  2618 loss = -0.392669
DEBUG    step  2619 loss = -0.324627
DEBUG    step  2620 loss = 0.619205
DEBUG    step  2621 loss = -0.269684
DEBUG    step  2622 loss = -0.661252
DEBUG    step  2623 loss = -0.774471
DEBUG    step  2624 loss = -1.18561
DEBUG    step  2625 loss = -0.275053
DEBUG    step  2626 loss = -0.887767
DEBUG    step  2627 loss = 0.287073
DEBUG    step  2628 loss = -0.905378
DEBUG    step  2629 loss = 0.0570901
DEBUG    step  2630 loss = -0.351999
DEBUG    step  2631 loss = -0.118707
DEBUG    step  2632 loss = -0.671623
DEBUG    step  2633 loss = -0.681996
DEBUG    step  2634 loss = -0.521377
DEBUG    step  2635 loss = -0.617793
DEBUG    step  2636 loss = -0.603524
DEBUG    step  2637 loss = -0.821486
DEBUG    step  2638 loss = -0.356088
DEBUG    step  2639 loss = -0.536534
DEBUG    step  2640 loss = -0.747998
DEBUG    step  2641 loss = -0.439992
DEBUG    step  2642 loss = -0.0627628
DEBUG    step  2643 loss = 0.331022
DEBUG    step  2644 loss = -0.441603
DEBUG    step  2645 loss = -0.515788
DEBUG    step  2646 loss = -0.475961
DEBUG    step  2647 loss = -0.401744
DEBUG    step  2648 loss = 0.262217
DEBUG    step  2649 loss = -0.831643
DEBUG    step  2650 loss = -1.12754
DEBUG    step  2651 loss = -0.829439
DEBUG    step  2652 loss = 0.0111126
DEBUG    step  2653 loss = 0.0545446
DEBUG    step  2654 loss = -0.34779
DEBUG    step  2655 loss = -0.239686
DEBUG    step  2656 loss = -0.0659961
DEBUG    step  2657 loss = -0.0800167
DEBUG    step  2658 loss = -0.56742
DEBUG    step  2659 loss = -1.12966
DEBUG    step  2660 loss = -0.735846
DEBUG    step  2661 loss = -0.857747
DEBUG    step  2662 loss = -0.626603
DEBUG    step  2663 loss = 0.501296
DEBUG    step  2664 loss = -0.345909
DEBUG    step  2665 loss = -0.48826
DEBUG    step  2666 loss = -0.425832
DEBUG    step  2667 loss = -0.622227
DEBUG    step  2668 loss = 0.0905803
DEBUG    step  2669 loss = -0.934806
DEBUG    step  2670 loss = -0.55195
DEBUG    step  2671 loss = 0.285835
DEBUG    step  2672 loss = -0.62289
DEBUG    step  2673 loss = -0.438078
DEBUG    step  2674 loss = -0.351686
DEBUG    step  2675 loss = -0.476577
DEBUG    step  2676 loss = -0.894385
DEBUG    step  2677 loss = -0.258823
DEBUG    step  2678 loss = -0.413825
DEBUG    step  2679 loss = -0.737152
DEBUG    step  2680 loss = -0.756135
DEBUG    step  2681 loss = -0.475365
DEBUG    step  2682 loss = -0.271527
DEBUG    step  2683 loss = -0.628242
DEBUG    step  2684 loss = -1.36686
DEBUG    step  2685 loss = -0.608447
DEBUG    step  2686 loss = -0.685795
DEBUG    step  2687 loss = -0.240269
DEBUG    step  2688 loss = 0.146378
DEBUG    step  2689 loss = -1.10885
[10]:
# predict
ite_train = cevae.predict(X_train)
ite_val = cevae.predict(X_val)
INFO     Evaluating 538 minibatches
DEBUG    batch ate = 0.62191
DEBUG    batch ate = 0.613137
DEBUG    batch ate = 0.688279
DEBUG    batch ate = 0.530233
DEBUG    batch ate = 0.814089
DEBUG    batch ate = 0.623182
DEBUG    batch ate = 0.657884
DEBUG    batch ate = 0.594205
DEBUG    batch ate = 0.319953
DEBUG    batch ate = 0.557599
DEBUG    batch ate = 0.718177
DEBUG    batch ate = 0.441256
DEBUG    batch ate = 0.654653
DEBUG    batch ate = 0.70725
DEBUG    batch ate = 0.715862
DEBUG    batch ate = 0.193786
DEBUG    batch ate = 0.557451
DEBUG    batch ate = 0.788378
DEBUG    batch ate = 0.605489
DEBUG    batch ate = 0.669786
DEBUG    batch ate = 0.852794
DEBUG    batch ate = 0.755987
DEBUG    batch ate = 0.510262
DEBUG    batch ate = 0.502153
DEBUG    batch ate = 0.254691
DEBUG    batch ate = 0.369999
DEBUG    batch ate = 0.59401
DEBUG    batch ate = 0.608015
DEBUG    batch ate = 0.661765
DEBUG    batch ate = 0.25462
DEBUG    batch ate = 0.771231
DEBUG    batch ate = 0.530303
DEBUG    batch ate = 0.566246
DEBUG    batch ate = 0.683882
DEBUG    batch ate = 0.616635
DEBUG    batch ate = 0.324804
DEBUG    batch ate = 0.383451
DEBUG    batch ate = 0.690402
DEBUG    batch ate = 0.558513
DEBUG    batch ate = 0.618007
DEBUG    batch ate = 0.551096
DEBUG    batch ate = 0.462644
DEBUG    batch ate = 0.615761
DEBUG    batch ate = 0.543891
DEBUG    batch ate = 0.432806
DEBUG    batch ate = 0.562174
DEBUG    batch ate = 0.654926
DEBUG    batch ate = 0.421796
DEBUG    batch ate = 0.719893
DEBUG    batch ate = 0.454017
DEBUG    batch ate = 0.699385
DEBUG    batch ate = 0.54048
DEBUG    batch ate = 0.333772
DEBUG    batch ate = 0.737522
DEBUG    batch ate = 0.5696
DEBUG    batch ate = 0.467629
DEBUG    batch ate = 0.601579
DEBUG    batch ate = 0.509313
DEBUG    batch ate = 0.385523
DEBUG    batch ate = 0.510085
DEBUG    batch ate = 0.661952
DEBUG    batch ate = 0.600664
DEBUG    batch ate = 0.066584
DEBUG    batch ate = 0.552528
DEBUG    batch ate = 0.467475
DEBUG    batch ate = 0.539326
DEBUG    batch ate = 0.694311
DEBUG    batch ate = 0.198014
DEBUG    batch ate = 0.61709
DEBUG    batch ate = 0.408558
DEBUG    batch ate = 0.684187
DEBUG    batch ate = 0.447501
DEBUG    batch ate = 0.347885
DEBUG    batch ate = 0.561035
DEBUG    batch ate = 0.617192
DEBUG    batch ate = 0.81278
DEBUG    batch ate = 0.61961
DEBUG    batch ate = 1.01213
DEBUG    batch ate = 0.345585
DEBUG    batch ate = 0.51818
DEBUG    batch ate = 0.436719
DEBUG    batch ate = 0.604546
DEBUG    batch ate = 0.706353
DEBUG    batch ate = 0.661419
DEBUG    batch ate = 0.787418
DEBUG    batch ate = 0.61231
DEBUG    batch ate = 0.629355
DEBUG    batch ate = 0.550861
DEBUG    batch ate = 0.472948
DEBUG    batch ate = 0.594738
DEBUG    batch ate = 0.844747
DEBUG    batch ate = 0.682486
DEBUG    batch ate = 0.607738
DEBUG    batch ate = 0.49322
DEBUG    batch ate = 0.547857
DEBUG    batch ate = 0.255665
DEBUG    batch ate = 0.564768
DEBUG    batch ate = 0.34345
DEBUG    batch ate = 0.40075
DEBUG    batch ate = 0.72982
DEBUG    batch ate = 0.878728
DEBUG    batch ate = 0.860621
DEBUG    batch ate = 0.544359
DEBUG    batch ate = 0.777127
DEBUG    batch ate = 0.590297
DEBUG    batch ate = 0.880415
DEBUG    batch ate = 0.67375
DEBUG    batch ate = 0.784914
DEBUG    batch ate = 0.511374
DEBUG    batch ate = 0.327954
DEBUG    batch ate = 0.628989
DEBUG    batch ate = 0.529468
DEBUG    batch ate = 0.688235
DEBUG    batch ate = 0.872871
DEBUG    batch ate = 0.3485
DEBUG    batch ate = 0.572016
DEBUG    batch ate = 0.565154
DEBUG    batch ate = 0.588927
DEBUG    batch ate = 0.520636
DEBUG    batch ate = 0.345301
DEBUG    batch ate = 0.611386
DEBUG    batch ate = 0.702772
DEBUG    batch ate = 0.764302
DEBUG    batch ate = 0.638517
DEBUG    batch ate = 0.498749
DEBUG    batch ate = 0.922372
DEBUG    batch ate = 0.648347
DEBUG    batch ate = 0.930839
DEBUG    batch ate = 0.841956
DEBUG    batch ate = 0.687886
DEBUG    batch ate = 0.804776
DEBUG    batch ate = 0.550305
DEBUG    batch ate = 0.625526
DEBUG    batch ate = 0.856957
DEBUG    batch ate = 0.470616
DEBUG    batch ate = 0.507122
DEBUG    batch ate = 0.358198
DEBUG    batch ate = 0.6335
DEBUG    batch ate = 0.473881
DEBUG    batch ate = 0.415356
DEBUG    batch ate = 0.309733
DEBUG    batch ate = 0.290068
DEBUG    batch ate = 0.470317
DEBUG    batch ate = 0.668486
DEBUG    batch ate = 0.580281
DEBUG    batch ate = 0.772137
DEBUG    batch ate = 0.490976
DEBUG    batch ate = 0.511012
DEBUG    batch ate = 0.441551
DEBUG    batch ate = 0.575225
DEBUG    batch ate = 0.591247
DEBUG    batch ate = 0.368313
DEBUG    batch ate = 0.350138
DEBUG    batch ate = 0.603038
DEBUG    batch ate = 0.241947
DEBUG    batch ate = 0.599275
DEBUG    batch ate = 0.41003
DEBUG    batch ate = 0.447525
DEBUG    batch ate = 0.79099
DEBUG    batch ate = 0.506499
DEBUG    batch ate = 0.61826
DEBUG    batch ate = 0.651964
DEBUG    batch ate = 0.52761
DEBUG    batch ate = 0.888067
DEBUG    batch ate = 0.367077
DEBUG    batch ate = 0.524761
DEBUG    batch ate = 0.6165
DEBUG    batch ate = 0.72863
DEBUG    batch ate = 0.516559
DEBUG    batch ate = 0.385291
DEBUG    batch ate = 0.660073
DEBUG    batch ate = 0.465947
DEBUG    batch ate = 0.586065
DEBUG    batch ate = 0.533599
DEBUG    batch ate = 0.916433
DEBUG    batch ate = 0.658235
DEBUG    batch ate = 0.770213
DEBUG    batch ate = 0.634768
DEBUG    batch ate = 0.887955
DEBUG    batch ate = 0.374664
DEBUG    batch ate = 0.649699
DEBUG    batch ate = 0.550386
DEBUG    batch ate = 0.516355
DEBUG    batch ate = 0.425265
DEBUG    batch ate = 0.264789
DEBUG    batch ate = 0.775339
DEBUG    batch ate = 0.636203
DEBUG    batch ate = 0.507562
DEBUG    batch ate = 0.885973
DEBUG    batch ate = 0.951861
DEBUG    batch ate = 0.370282
DEBUG    batch ate = 0.69922
DEBUG    batch ate = 0.956577
DEBUG    batch ate = 0.789856
DEBUG    batch ate = 0.726278
DEBUG    batch ate = 0.165073
DEBUG    batch ate = 0.530907
DEBUG    batch ate = 0.602567
DEBUG    batch ate = 0.682041
DEBUG    batch ate = 0.54427
DEBUG    batch ate = 0.787318
DEBUG    batch ate = 0.491623
DEBUG    batch ate = 0.794449
DEBUG    batch ate = 0.928849
DEBUG    batch ate = 0.771662
DEBUG    batch ate = 0.722534
DEBUG    batch ate = 0.611424
DEBUG    batch ate = 0.754558
DEBUG    batch ate = 0.466829
DEBUG    batch ate = 0.623566
DEBUG    batch ate = 0.595247
DEBUG    batch ate = 0.790067
DEBUG    batch ate = 0.218814
DEBUG    batch ate = 0.551078
DEBUG    batch ate = 0.561368
DEBUG    batch ate = 0.823733
DEBUG    batch ate = 0.725582
DEBUG    batch ate = 0.685417
DEBUG    batch ate = 0.573616
DEBUG    batch ate = 0.408314
DEBUG    batch ate = 0.420605
DEBUG    batch ate = 0.699393
DEBUG    batch ate = 0.485361
DEBUG    batch ate = 0.470607
DEBUG    batch ate = 0.672379
DEBUG    batch ate = 0.515571
DEBUG    batch ate = 0.837184
DEBUG    batch ate = 0.383294
DEBUG    batch ate = 0.631237
DEBUG    batch ate = 0.660588
DEBUG    batch ate = 0.454409
DEBUG    batch ate = 0.277474
DEBUG    batch ate = 1.08705
DEBUG    batch ate = 0.542072
DEBUG    batch ate = 0.667987
DEBUG    batch ate = 0.474515
DEBUG    batch ate = 0.462981
DEBUG    batch ate = 0.581607
DEBUG    batch ate = 0.539565
DEBUG    batch ate = 0.740687
DEBUG    batch ate = 0.672987
DEBUG    batch ate = 0.725537
DEBUG    batch ate = 0.683099
DEBUG    batch ate = 0.695347
DEBUG    batch ate = 0.533302
DEBUG    batch ate = 0.625668
DEBUG    batch ate = 0.744886
DEBUG    batch ate = 0.686994
DEBUG    batch ate = 0.572683
DEBUG    batch ate = 0.431316
DEBUG    batch ate = 0.521101
DEBUG    batch ate = 0.651604
DEBUG    batch ate = 0.514384
DEBUG    batch ate = 0.471155
DEBUG    batch ate = 0.759972
DEBUG    batch ate = 0.633456
DEBUG    batch ate = 0.52144
DEBUG    batch ate = 0.675739
DEBUG    batch ate = 0.713319
DEBUG    batch ate = 0.749301
DEBUG    batch ate = 0.637229
DEBUG    batch ate = 0.690767
DEBUG    batch ate = 0.638464
DEBUG    batch ate = 0.804409
DEBUG    batch ate = 0.379763
DEBUG    batch ate = 0.939645
DEBUG    batch ate = 0.566416
DEBUG    batch ate = 0.722778
DEBUG    batch ate = 0.875249
DEBUG    batch ate = 0.585553
DEBUG    batch ate = 0.452997
DEBUG    batch ate = 0.660046
DEBUG    batch ate = 0.523958
DEBUG    batch ate = 0.743689
DEBUG    batch ate = 0.281901
DEBUG    batch ate = 0.79823
DEBUG    batch ate = 0.501476
DEBUG    batch ate = 0.27024
DEBUG    batch ate = 0.661638
DEBUG    batch ate = 0.530568
DEBUG    batch ate = 0.276738
DEBUG    batch ate = 0.734873
DEBUG    batch ate = 0.547245
DEBUG    batch ate = 0.642462
DEBUG    batch ate = 0.69965
DEBUG    batch ate = 0.544179
DEBUG    batch ate = 0.501292
DEBUG    batch ate = 0.782594
DEBUG    batch ate = 0.718873
DEBUG    batch ate = 0.53492
DEBUG    batch ate = 0.602767
DEBUG    batch ate = 0.642604
DEBUG    batch ate = 0.899802
DEBUG    batch ate = 0.345271
DEBUG    batch ate = 0.408736
DEBUG    batch ate = 0.503462
DEBUG    batch ate = 0.548023
DEBUG    batch ate = 0.869944
DEBUG    batch ate = 0.712165
DEBUG    batch ate = 0.840788
DEBUG    batch ate = 0.802797
DEBUG    batch ate = 0.448752
DEBUG    batch ate = 0.489339
DEBUG    batch ate = 0.760921
DEBUG    batch ate = 0.549896
DEBUG    batch ate = 0.337833
DEBUG    batch ate = 0.489319
DEBUG    batch ate = 0.349298
DEBUG    batch ate = 0.0851573
DEBUG    batch ate = 0.701312
DEBUG    batch ate = 0.426929
DEBUG    batch ate = 0.52591
DEBUG    batch ate = 0.45672
DEBUG    batch ate = 0.691007
DEBUG    batch ate = 0.681652
DEBUG    batch ate = 0.414373
DEBUG    batch ate = 0.43001
DEBUG    batch ate = 0.698964
DEBUG    batch ate = 0.569967
DEBUG    batch ate = 0.670148
DEBUG    batch ate = 0.612077
DEBUG    batch ate = 0.559155
DEBUG    batch ate = 0.839547
DEBUG    batch ate = 0.704653
DEBUG    batch ate = 0.44604
DEBUG    batch ate = 0.608618
DEBUG    batch ate = 0.744417
DEBUG    batch ate = 0.340019
DEBUG    batch ate = 0.469705
DEBUG    batch ate = 0.859227
DEBUG    batch ate = 0.732652
DEBUG    batch ate = 0.624253
DEBUG    batch ate = 0.767217
DEBUG    batch ate = 0.431167
DEBUG    batch ate = 0.712165
DEBUG    batch ate = 0.576947
DEBUG    batch ate = 0.546332
DEBUG    batch ate = 0.52999
DEBUG    batch ate = 0.349895
DEBUG    batch ate = 0.625377
DEBUG    batch ate = 0.564784
DEBUG    batch ate = 0.827983
DEBUG    batch ate = 0.402039
DEBUG    batch ate = 0.732634
DEBUG    batch ate = 0.828913
DEBUG    batch ate = 0.580144
DEBUG    batch ate = 0.568022
DEBUG    batch ate = 0.561761
DEBUG    batch ate = 0.294596
DEBUG    batch ate = 0.636919
DEBUG    batch ate = 0.655477
DEBUG    batch ate = 0.925995
DEBUG    batch ate = 0.729636
DEBUG    batch ate = 0.550091
DEBUG    batch ate = 0.558647
DEBUG    batch ate = 0.673149
DEBUG    batch ate = 0.657379
DEBUG    batch ate = 0.553136
DEBUG    batch ate = 0.784905
DEBUG    batch ate = 0.72343
DEBUG    batch ate = 0.872444
DEBUG    batch ate = 0.594647
DEBUG    batch ate = 0.815522
DEBUG    batch ate = 0.882869
DEBUG    batch ate = 0.505135
DEBUG    batch ate = 0.608259
DEBUG    batch ate = 0.438947
DEBUG    batch ate = 0.642148
DEBUG    batch ate = 0.42703
DEBUG    batch ate = 0.492255
DEBUG    batch ate = 1.01806
DEBUG    batch ate = 0.488789
DEBUG    batch ate = 0.353427
DEBUG    batch ate = 0.697426
DEBUG    batch ate = 0.454108
DEBUG    batch ate = 0.585995
DEBUG    batch ate = 0.898554
DEBUG    batch ate = 0.462355
DEBUG    batch ate = 0.847193
DEBUG    batch ate = 0.435861
DEBUG    batch ate = 0.350475
DEBUG    batch ate = 0.494122
DEBUG    batch ate = 0.641375
DEBUG    batch ate = 1.05287
DEBUG    batch ate = 0.560613
DEBUG    batch ate = 0.622122
DEBUG    batch ate = 0.617646
DEBUG    batch ate = 0.438831
DEBUG    batch ate = 0.413241
DEBUG    batch ate = 0.709999
DEBUG    batch ate = 0.393058
DEBUG    batch ate = 0.577082
DEBUG    batch ate = 0.449773
DEBUG    batch ate = 0.409307
DEBUG    batch ate = 0.717688
DEBUG    batch ate = 0.680811
DEBUG    batch ate = 0.636654
DEBUG    batch ate = 0.537257
DEBUG    batch ate = 0.485248
DEBUG    batch ate = 0.611201
DEBUG    batch ate = 0.66029
DEBUG    batch ate = 0.621785
DEBUG    batch ate = 0.656557
DEBUG    batch ate = 0.50069
DEBUG    batch ate = 0.531677
DEBUG    batch ate = 0.539529
DEBUG    batch ate = 0.7621
DEBUG    batch ate = 0.34175
DEBUG    batch ate = 0.573927
DEBUG    batch ate = 0.698847
DEBUG    batch ate = 0.687271
DEBUG    batch ate = 0.625974
DEBUG    batch ate = 0.623745
DEBUG    batch ate = 0.542737
DEBUG    batch ate = 0.203161
DEBUG    batch ate = 0.656258
DEBUG    batch ate = 0.20316
DEBUG    batch ate = 0.333921
DEBUG    batch ate = 0.503528
DEBUG    batch ate = 0.274319
DEBUG    batch ate = 0.435086
DEBUG    batch ate = 0.577274
DEBUG    batch ate = 0.404617
DEBUG    batch ate = 0.488066
DEBUG    batch ate = 0.804592
DEBUG    batch ate = 0.731865
DEBUG    batch ate = 0.751529
DEBUG    batch ate = 0.847831
DEBUG    batch ate = 0.737108
DEBUG    batch ate = 0.403549
DEBUG    batch ate = 0.659598
DEBUG    batch ate = 0.777456
DEBUG    batch ate = 0.655091
DEBUG    batch ate = 0.805262
DEBUG    batch ate = 0.578173
DEBUG    batch ate = 0.749979
DEBUG    batch ate = 0.645467
DEBUG    batch ate = 0.765642
DEBUG    batch ate = 0.221318
DEBUG    batch ate = 0.566684
DEBUG    batch ate = 0.885021
DEBUG    batch ate = 0.798495
DEBUG    batch ate = 0.749958
DEBUG    batch ate = 0.404101
DEBUG    batch ate = 0.597844
DEBUG    batch ate = 0.548862
DEBUG    batch ate = 0.633423
DEBUG    batch ate = 0.58442
DEBUG    batch ate = 0.406284
DEBUG    batch ate = 0.497425
DEBUG    batch ate = 0.64323
DEBUG    batch ate = 0.764823
DEBUG    batch ate = 0.719326
DEBUG    batch ate = 0.850669
DEBUG    batch ate = 0.567251
DEBUG    batch ate = 0.531746
DEBUG    batch ate = 0.422011
DEBUG    batch ate = 0.469137
DEBUG    batch ate = 0.568481
DEBUG    batch ate = 0.336506
DEBUG    batch ate = 0.785506
DEBUG    batch ate = 0.771601
DEBUG    batch ate = 0.790584
DEBUG    batch ate = 0.756722
DEBUG    batch ate = 0.558484
DEBUG    batch ate = 0.565823
DEBUG    batch ate = 0.85092
DEBUG    batch ate = 0.836311
DEBUG    batch ate = 0.36647
DEBUG    batch ate = 0.671067
DEBUG    batch ate = 0.678834
DEBUG    batch ate = 0.7427
DEBUG    batch ate = 0.380171
DEBUG    batch ate = 0.702751
DEBUG    batch ate = 0.821684
DEBUG    batch ate = 0.183044
DEBUG    batch ate = 0.71705
DEBUG    batch ate = 0.650429
DEBUG    batch ate = 0.647615
DEBUG    batch ate = 0.590948
DEBUG    batch ate = 0.32329
DEBUG    batch ate = 0.8901
DEBUG    batch ate = 0.56427
DEBUG    batch ate = 0.335077
DEBUG    batch ate = 0.777793
DEBUG    batch ate = 0.669449
DEBUG    batch ate = 0.794569
DEBUG    batch ate = 0.455826
DEBUG    batch ate = 0.237244
DEBUG    batch ate = 0.449816
DEBUG    batch ate = 0.544514
DEBUG    batch ate = 0.426984
DEBUG    batch ate = 0.440946
DEBUG    batch ate = 0.331075
DEBUG    batch ate = 0.486034
DEBUG    batch ate = 0.518074
DEBUG    batch ate = 0.508189
DEBUG    batch ate = 0.7412
DEBUG    batch ate = 0.744264
DEBUG    batch ate = 0.23702
DEBUG    batch ate = 0.724052
DEBUG    batch ate = 0.26753
DEBUG    batch ate = 0.45962
DEBUG    batch ate = 0.447174
DEBUG    batch ate = 0.615098
DEBUG    batch ate = 0.665408
DEBUG    batch ate = 0.227405
DEBUG    batch ate = 0.567846
DEBUG    batch ate = 0.642301
DEBUG    batch ate = 0.572763
DEBUG    batch ate = 0.492713
DEBUG    batch ate = 0.495091
DEBUG    batch ate = 0.387373
DEBUG    batch ate = 0.536913
DEBUG    batch ate = 0.70732
DEBUG    batch ate = 0.57493
DEBUG    batch ate = 0.575226
DEBUG    batch ate = 0.820646
DEBUG    batch ate = 0.299924
DEBUG    batch ate = 0.521718
DEBUG    batch ate = 0.201825
DEBUG    batch ate = 0.575455
DEBUG    batch ate = 0.34346
DEBUG    batch ate = 0.511799
DEBUG    batch ate = 0.577593
DEBUG    batch ate = 0.606313
DEBUG    batch ate = 0.479831
DEBUG    batch ate = 0.430969
DEBUG    batch ate = 0.68106
DEBUG    batch ate = 0.393857
DEBUG    batch ate = 0.592259
DEBUG    batch ate = 0.904887
DEBUG    batch ate = 1.1646
DEBUG    batch ate = 0.462751
DEBUG    batch ate = 0.849577
DEBUG    batch ate = 0.675505
DEBUG    batch ate = 0.655771
DEBUG    batch ate = 0.433719
INFO     Evaluating 135 minibatches
DEBUG    batch ate = 0.228577
DEBUG    batch ate = 0.602583
DEBUG    batch ate = 0.802412
DEBUG    batch ate = 0.445214
DEBUG    batch ate = 0.569569
DEBUG    batch ate = 0.816098
DEBUG    batch ate = 0.799774
DEBUG    batch ate = 0.580379
DEBUG    batch ate = 0.705277
DEBUG    batch ate = 0.472644
DEBUG    batch ate = 0.425481
DEBUG    batch ate = 0.529719
DEBUG    batch ate = 1.03265
DEBUG    batch ate = 0.702212
DEBUG    batch ate = 0.716867
DEBUG    batch ate = 0.732634
DEBUG    batch ate = 0.479447
DEBUG    batch ate = 0.751748
DEBUG    batch ate = 0.372753
DEBUG    batch ate = 0.743915
DEBUG    batch ate = 0.695771
DEBUG    batch ate = 0.486699
DEBUG    batch ate = 0.617069
DEBUG    batch ate = 0.924266
DEBUG    batch ate = 0.41445
DEBUG    batch ate = 0.51611
DEBUG    batch ate = 0.570871
DEBUG    batch ate = 0.52222
DEBUG    batch ate = 0.550225
DEBUG    batch ate = 0.827474
DEBUG    batch ate = 0.660622
DEBUG    batch ate = 0.435264
DEBUG    batch ate = 0.252852
DEBUG    batch ate = 0.521581
DEBUG    batch ate = 0.620552
DEBUG    batch ate = 0.46738
DEBUG    batch ate = 0.469133
DEBUG    batch ate = 0.769782
DEBUG    batch ate = 0.641767
DEBUG    batch ate = 0.61662
DEBUG    batch ate = 0.497127
DEBUG    batch ate = 0.541457
DEBUG    batch ate = 0.950244
DEBUG    batch ate = 0.475156
DEBUG    batch ate = 0.752711
DEBUG    batch ate = 0.301103
DEBUG    batch ate = 0.843295
DEBUG    batch ate = 0.374278
DEBUG    batch ate = 0.686422
DEBUG    batch ate = 0.558687
DEBUG    batch ate = 0.66816
DEBUG    batch ate = 0.756011
DEBUG    batch ate = 0.268842
DEBUG    batch ate = 0.467443
DEBUG    batch ate = 0.7511
DEBUG    batch ate = 0.644642
DEBUG    batch ate = 0.763036
DEBUG    batch ate = 0.590393
DEBUG    batch ate = 0.693136
DEBUG    batch ate = 0.486587
DEBUG    batch ate = 0.604928
DEBUG    batch ate = 0.711657
DEBUG    batch ate = 0.606803
DEBUG    batch ate = 0.514715
DEBUG    batch ate = 0.755621
DEBUG    batch ate = 0.563381
DEBUG    batch ate = 0.658584
DEBUG    batch ate = 0.309254
DEBUG    batch ate = 0.186426
DEBUG    batch ate = 0.642211
DEBUG    batch ate = 0.726449
DEBUG    batch ate = 0.609017
DEBUG    batch ate = 0.693574
DEBUG    batch ate = 0.619707
DEBUG    batch ate = 0.711907
DEBUG    batch ate = 0.763202
DEBUG    batch ate = 0.583925
DEBUG    batch ate = 0.732382
DEBUG    batch ate = 0.598957
DEBUG    batch ate = 0.61077
DEBUG    batch ate = 0.407628
DEBUG    batch ate = 0.813409
DEBUG    batch ate = 0.879196
DEBUG    batch ate = 0.59526
DEBUG    batch ate = 0.597031
DEBUG    batch ate = 0.404295
DEBUG    batch ate = 0.444806
DEBUG    batch ate = 0.976863
DEBUG    batch ate = 0.191305
DEBUG    batch ate = 0.55377
DEBUG    batch ate = 1.03828
DEBUG    batch ate = 0.478516
DEBUG    batch ate = 0.925168
DEBUG    batch ate = 0.605732
DEBUG    batch ate = 0.321156
DEBUG    batch ate = 0.47538
DEBUG    batch ate = 0.750148
DEBUG    batch ate = 0.468002
DEBUG    batch ate = 0.483354
DEBUG    batch ate = 0.727932
DEBUG    batch ate = 0.499526
DEBUG    batch ate = 0.505064
DEBUG    batch ate = 1.03597
DEBUG    batch ate = 0.528672
DEBUG    batch ate = 0.713761
DEBUG    batch ate = 0.657063
DEBUG    batch ate = 0.677198
DEBUG    batch ate = 0.761366
DEBUG    batch ate = 0.569046
DEBUG    batch ate = 0.806944
DEBUG    batch ate = 0.512402
DEBUG    batch ate = 0.638473
DEBUG    batch ate = 0.594415
DEBUG    batch ate = 0.662585
DEBUG    batch ate = 0.815776
DEBUG    batch ate = 0.547243
DEBUG    batch ate = 0.446772
DEBUG    batch ate = 0.609724
DEBUG    batch ate = 0.672535
DEBUG    batch ate = 0.294262
DEBUG    batch ate = 0.650225
DEBUG    batch ate = 0.437027
DEBUG    batch ate = 0.395884
DEBUG    batch ate = 0.457884
DEBUG    batch ate = 0.381654
DEBUG    batch ate = 0.474322
DEBUG    batch ate = 0.636114
DEBUG    batch ate = 0.433205
DEBUG    batch ate = 0.340026
DEBUG    batch ate = 0.631428
DEBUG    batch ate = 0.465448
DEBUG    batch ate = 0.438805
DEBUG    batch ate = 0.50323
DEBUG    batch ate = 0.522954
DEBUG    batch ate = 0.58916
[11]:
ate_train = ite_train.mean()
ate_val = ite_val.mean()
print(ate_train, ate_val)
0.58953923 0.5956359

Meta Learners

[12]:
# fit propensity model
p_model = ElasticNetPropensityModel()
p_train = p_model.fit_predict(X_train, treatment_train)
p_val = p_model.fit_predict(X_val, treatment_val)
[13]:
s_learner = BaseSRegressor(LGBMRegressor())
s_ate = s_learner.estimate_ate(X_train, treatment_train, y_train)[0]
s_ite_train = s_learner.fit_predict(X_train, treatment_train, y_train)
s_ite_val = s_learner.predict(X_val)

t_learner = BaseTRegressor(LGBMRegressor())
t_ate = t_learner.estimate_ate(X_train, treatment_train, y_train)[0][0]
t_ite_train = t_learner.fit_predict(X_train, treatment_train, y_train)
t_ite_val = t_learner.predict(X_val, treatment_val, y_val)

x_learner = BaseXRegressor(LGBMRegressor())
x_ate = x_learner.estimate_ate(X_train, treatment_train, y_train, p_train)[0][0]
x_ite_train = x_learner.fit_predict(X_train, treatment_train, y_train, p_train)
x_ite_val = x_learner.predict(X_val, treatment_val, y_val, p_val)

r_learner = BaseRRegressor(LGBMRegressor())
r_ate = r_learner.estimate_ate(X_train, treatment_train, y_train, p_train)[0][0]
r_ite_train = r_learner.fit_predict(X_train, treatment_train, y_train, p_train)
r_ite_val = r_learner.predict(X_val)

Model Results Comparsion

Training

[14]:
df_preds_train = pd.DataFrame([s_ite_train.ravel(),
                               t_ite_train.ravel(),
                               x_ite_train.ravel(),
                               r_ite_train.ravel(),
                               ite_train.ravel(),
                               tau_train.ravel(),
                               treatment_train.ravel(),
                               y_train.ravel()],
                               index=['S','T','X','R','CEVAE','tau','w','y']).T

df_cumgain_train = get_cumgain(df_preds_train)
[15]:
df_result_train = pd.DataFrame([s_ate, t_ate, x_ate, r_ate, ate_train, tau_train.mean()],
                               index=['S','T','X','R','CEVAE','actual'], columns=['ATE'])
df_result_train['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite_train, t_ite_train, x_ite_train, r_ite_train, ite_train],
                                                                  [tau_train.values.reshape(-1,1)]*5 )
                          ] + [None]
df_result_train['AUUC'] = auuc_score(df_preds_train)
[16]:
df_result_train
[16]:
ATE MAE AUUC
S 4.690540 4.581416 0.684130
T 4.708557 4.715296 0.684878
X 4.555315 4.549527 0.671956
R 0.714936 5.991034 0.586835
CEVAE 0.589539 6.238858 0.566627
actual 4.755900 NaN NaN
[17]:
plot_gain(df_preds_train)
../_images/examples_cevae_example_22_0.png

Validation

[18]:
df_preds_val = pd.DataFrame([s_ite_val.ravel(),
                             t_ite_val.ravel(),
                             x_ite_val.ravel(),
                             r_ite_val.ravel(),
                             ite_val.ravel(),
                             tau_val.ravel(),
                             treatment_val.ravel(),
                             y_val.ravel()],
                             index=['S','T','X','R','CEVAE','tau','w','y']).T

df_cumgain_val = get_cumgain(df_preds_val)
[19]:
df_result_val = pd.DataFrame([s_ite_val.mean(), t_ite_val.mean(), x_ite_val.mean(), r_ite_val.mean(), ate_val, tau_val.mean()],
                              index=['S','T','X','R','CEVAE','actual'], columns=['ATE'])
df_result_val['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite_val, t_ite_val, x_ite_val, r_ite_val, ite_val],
                                                                  [tau_val.values.reshape(-1,1)]*5 )
                          ] + [None]
df_result_val['AUUC'] = auuc_score(df_preds_val)
[20]:
df_result_val
[20]:
ATE MAE AUUC
S 4.690676 4.582191 0.683782
T 4.709923 4.717909 0.684032
X 4.560680 4.544644 0.671907
R 0.761550 5.997526 0.586110
CEVAE 0.595636 6.241192 0.566356
actual 4.774991 NaN NaN
[21]:
plot_gain(df_preds_val)
../_images/examples_cevae_example_27_0.png

Synthetic Data

[23]:
y, X, w, tau, b, e = simulate_hidden_confounder(n=100000, p=5, sigma=1.0, adj=0.)

X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
    train_test_split(X, y, w, tau, b, e, test_size=0.2, random_state=123, shuffle=True)

preds_dict_train = {}
preds_dict_valid = {}

preds_dict_train['Actuals'] = tau_train
preds_dict_valid['Actuals'] = tau_val

preds_dict_train['generated_data'] = {
    'y': y_train,
    'X': X_train,
    'w': w_train,
    'tau': tau_train,
    'b': b_train,
    'e': e_train}
preds_dict_valid['generated_data'] = {
    'y': y_val,
    'X': X_val,
    'w': w_val,
    'tau': tau_val,
    'b': b_val,
    'e': e_val}

# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)

for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
                                 ['S', 'T', 'X', 'R']):
    for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
        # RLearner will need to fit on the p_hat
        if label_l != 'R':
            learner = base_learner(model())
            # fit the model on training data only
            learner.fit(X=X_train, treatment=w_train, y=y_train)
            try:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
            except TypeError:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
        else:
            learner = base_learner(model())
            learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
            preds_dict_train['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_train).flatten()
            preds_dict_valid['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_val).flatten()

# cevae model settings
outcome_dist = "normal"
latent_dim = 20
hidden_dim = 200
num_epochs = 5
batch_size = 1000
learning_rate = 1e-3
learning_rate_decay = 0.1
num_layers = 3
num_samples = 10

cevae = CEVAE(outcome_dist=outcome_dist,
              latent_dim=latent_dim,
              hidden_dim=hidden_dim,
              num_epochs=num_epochs,
              batch_size=batch_size,
              learning_rate=learning_rate,
              learning_rate_decay=learning_rate_decay,
              num_layers=num_layers,
              num_samples=num_samples)

# fit
losses = cevae.fit(X=torch.tensor(X_train, dtype=torch.float),
                   treatment=torch.tensor(w_train, dtype=torch.float),
                   y=torch.tensor(y_train, dtype=torch.float))

preds_dict_train['CEVAE'] = cevae.predict(X_train).flatten()
preds_dict_valid['CEVAE'] = cevae.predict(X_val).flatten()
INFO     Training with 80 minibatches per epoch
DEBUG    step     0 loss = 14.0534
DEBUG    step     1 loss = 13.2864
DEBUG    step     2 loss = 13.0712
DEBUG    step     3 loss = 12.4646
DEBUG    step     4 loss = 12.0247
DEBUG    step     5 loss = 11.5239
DEBUG    step     6 loss = 11.2934
DEBUG    step     7 loss = 11.3141
DEBUG    step     8 loss = 10.8347
DEBUG    step     9 loss = 10.7364
DEBUG    step    10 loss = 10.5978
DEBUG    step    11 loss = 10.2533
DEBUG    step    12 loss = 10.131
DEBUG    step    13 loss = 10.0307
DEBUG    step    14 loss = 9.57977
DEBUG    step    15 loss = 9.79295
DEBUG    step    16 loss = 9.46927
DEBUG    step    17 loss = 9.57581
DEBUG    step    18 loss = 9.24119
DEBUG    step    19 loss = 9.34084
DEBUG    step    20 loss = 9.32529
DEBUG    step    21 loss = 9.40313
DEBUG    step    22 loss = 9.27057
DEBUG    step    23 loss = 9.05239
DEBUG    step    24 loss = 9.17952
DEBUG    step    25 loss = 8.93083
DEBUG    step    26 loss = 8.88059
DEBUG    step    27 loss = 9.06328
DEBUG    step    28 loss = 8.97881
DEBUG    step    29 loss = 8.7639
DEBUG    step    30 loss = 8.80499
DEBUG    step    31 loss = 8.87173
DEBUG    step    32 loss = 8.56747
DEBUG    step    33 loss = 8.61066
DEBUG    step    34 loss = 8.79932
DEBUG    step    35 loss = 8.62871
DEBUG    step    36 loss = 8.54852
DEBUG    step    37 loss = 8.38022
DEBUG    step    38 loss = 8.31573
DEBUG    step    39 loss = 8.53857
DEBUG    step    40 loss = 8.57149
DEBUG    step    41 loss = 8.25793
DEBUG    step    42 loss = 8.54684
DEBUG    step    43 loss = 8.47699
DEBUG    step    44 loss = 8.3233
DEBUG    step    45 loss = 8.40228
DEBUG    step    46 loss = 8.14949
DEBUG    step    47 loss = 8.2015
DEBUG    step    48 loss = 8.07472
DEBUG    step    49 loss = 8.16795
DEBUG    step    50 loss = 8.34108
DEBUG    step    51 loss = 8.57682
DEBUG    step    52 loss = 8.24426
DEBUG    step    53 loss = 8.33251
DEBUG    step    54 loss = 8.10115
DEBUG    step    55 loss = 8.67902
DEBUG    step    56 loss = 8.14677
DEBUG    step    57 loss = 8.1041
DEBUG    step    58 loss = 8.15102
DEBUG    step    59 loss = 8.00679
DEBUG    step    60 loss = 8.0271
DEBUG    step    61 loss = 7.96041
DEBUG    step    62 loss = 7.82294
DEBUG    step    63 loss = 8.13456
DEBUG    step    64 loss = 8.23367
DEBUG    step    65 loss = 8.1886
DEBUG    step    66 loss = 8.11654
DEBUG    step    67 loss = 8.22645
DEBUG    step    68 loss = 8.29743
DEBUG    step    69 loss = 8.24127
DEBUG    step    70 loss = 7.86166
DEBUG    step    71 loss = 8.22115
DEBUG    step    72 loss = 7.8913
DEBUG    step    73 loss = 7.96265
DEBUG    step    74 loss = 7.96243
DEBUG    step    75 loss = 7.99336
DEBUG    step    76 loss = 7.97742
DEBUG    step    77 loss = 7.90728
DEBUG    step    78 loss = 7.79539
DEBUG    step    79 loss = 8.1732
DEBUG    step    80 loss = 8.05217
DEBUG    step    81 loss = 8.34642
DEBUG    step    82 loss = 8.03199
DEBUG    step    83 loss = 7.64226
DEBUG    step    84 loss = 7.60438
DEBUG    step    85 loss = 7.5962
DEBUG    step    86 loss = 7.85927
DEBUG    step    87 loss = 7.98567
DEBUG    step    88 loss = 7.82793
DEBUG    step    89 loss = 7.90716
DEBUG    step    90 loss = 7.71277
DEBUG    step    91 loss = 7.97724
DEBUG    step    92 loss = 7.84886
DEBUG    step    93 loss = 7.88323
DEBUG    step    94 loss = 7.58179
DEBUG    step    95 loss = 7.89912
DEBUG    step    96 loss = 7.67735
DEBUG    step    97 loss = 7.84808
DEBUG    step    98 loss = 7.66705
DEBUG    step    99 loss = 7.65615
DEBUG    step   100 loss = 7.73811
DEBUG    step   101 loss = 7.64997
DEBUG    step   102 loss = 8.36613
DEBUG    step   103 loss = 7.72687
DEBUG    step   104 loss = 7.68498
DEBUG    step   105 loss = 7.50849
DEBUG    step   106 loss = 7.63987
DEBUG    step   107 loss = 7.75501
DEBUG    step   108 loss = 7.62423
DEBUG    step   109 loss = 7.66921
DEBUG    step   110 loss = 7.50166
DEBUG    step   111 loss = 7.62314
DEBUG    step   112 loss = 7.80907
DEBUG    step   113 loss = 7.65659
DEBUG    step   114 loss = 7.55159
DEBUG    step   115 loss = 7.60577
DEBUG    step   116 loss = 7.36759
DEBUG    step   117 loss = 7.43037
DEBUG    step   118 loss = 7.41372
DEBUG    step   119 loss = 7.58245
DEBUG    step   120 loss = 7.75382
DEBUG    step   121 loss = 7.75345
DEBUG    step   122 loss = 7.71091
DEBUG    step   123 loss = 7.61762
DEBUG    step   124 loss = 7.5415
DEBUG    step   125 loss = 7.70995
DEBUG    step   126 loss = 7.43083
DEBUG    step   127 loss = 7.62284
DEBUG    step   128 loss = 7.57494
DEBUG    step   129 loss = 7.43229
DEBUG    step   130 loss = 7.417
DEBUG    step   131 loss = 7.36716
DEBUG    step   132 loss = 7.58527
DEBUG    step   133 loss = 7.61684
DEBUG    step   134 loss = 7.55247
DEBUG    step   135 loss = 7.54181
DEBUG    step   136 loss = 7.47493
DEBUG    step   137 loss = 7.65583
DEBUG    step   138 loss = 7.33769
DEBUG    step   139 loss = 7.36649
DEBUG    step   140 loss = 7.3634
DEBUG    step   141 loss = 7.50731
DEBUG    step   142 loss = 7.60657
DEBUG    step   143 loss = 7.38694
DEBUG    step   144 loss = 7.3596
DEBUG    step   145 loss = 7.42744
DEBUG    step   146 loss = 7.46609
DEBUG    step   147 loss = 7.44444
DEBUG    step   148 loss = 7.44656
DEBUG    step   149 loss = 7.32834
DEBUG    step   150 loss = 7.63049
DEBUG    step   151 loss = 7.43903
DEBUG    step   152 loss = 7.28372
DEBUG    step   153 loss = 7.28897
DEBUG    step   154 loss = 7.3515
DEBUG    step   155 loss = 7.29871
DEBUG    step   156 loss = 7.47948
DEBUG    step   157 loss = 7.56888
DEBUG    step   158 loss = 7.50302
DEBUG    step   159 loss = 7.14918
DEBUG    step   160 loss = 7.34611
DEBUG    step   161 loss = 7.04855
DEBUG    step   162 loss = 7.38615
DEBUG    step   163 loss = 7.39172
DEBUG    step   164 loss = 7.35778
DEBUG    step   165 loss = 7.39445
DEBUG    step   166 loss = 7.41489
DEBUG    step   167 loss = 7.36096
DEBUG    step   168 loss = 7.49107
DEBUG    step   169 loss = 7.31799
DEBUG    step   170 loss = 7.34851
DEBUG    step   171 loss = 7.17355
DEBUG    step   172 loss = 7.38851
DEBUG    step   173 loss = 7.35425
DEBUG    step   174 loss = 7.39068
DEBUG    step   175 loss = 7.08015
DEBUG    step   176 loss = 7.05245
DEBUG    step   177 loss = 7.43696
DEBUG    step   178 loss = 7.32325
DEBUG    step   179 loss = 7.31021
DEBUG    step   180 loss = 7.32132
DEBUG    step   181 loss = 7.34862
DEBUG    step   182 loss = 7.2863
DEBUG    step   183 loss = 7.04851
DEBUG    step   184 loss = 7.09608
DEBUG    step   185 loss = 7.30419
DEBUG    step   186 loss = 7.57377
DEBUG    step   187 loss = 7.17361
DEBUG    step   188 loss = 7.14099
DEBUG    step   189 loss = 7.0449
DEBUG    step   190 loss = 7.33529
DEBUG    step   191 loss = 8.26479
DEBUG    step   192 loss = 7.07407
DEBUG    step   193 loss = 7.17149
DEBUG    step   194 loss = 7.18364
DEBUG    step   195 loss = 7.27539
DEBUG    step   196 loss = 7.32838
DEBUG    step   197 loss = 7.26303
DEBUG    step   198 loss = 7.17846
DEBUG    step   199 loss = 7.43274
DEBUG    step   200 loss = 7.05834
DEBUG    step   201 loss = 7.06987
DEBUG    step   202 loss = 7.23815
DEBUG    step   203 loss = 7.2454
DEBUG    step   204 loss = 7.29509
DEBUG    step   205 loss = 7.13663
DEBUG    step   206 loss = 6.96725
DEBUG    step   207 loss = 7.11374
DEBUG    step   208 loss = 6.93604
DEBUG    step   209 loss = 7.14596
DEBUG    step   210 loss = 7.12832
DEBUG    step   211 loss = 7.16911
DEBUG    step   212 loss = 6.9426
DEBUG    step   213 loss = 7.18095
DEBUG    step   214 loss = 7.06178
DEBUG    step   215 loss = 7.10941
DEBUG    step   216 loss = 7.11186
DEBUG    step   217 loss = 7.20186
DEBUG    step   218 loss = 7.27586
DEBUG    step   219 loss = 7.1021
DEBUG    step   220 loss = 6.94478
DEBUG    step   221 loss = 7.09795
DEBUG    step   222 loss = 6.88571
DEBUG    step   223 loss = 7.03089
DEBUG    step   224 loss = 7.23866
DEBUG    step   225 loss = 7.10442
DEBUG    step   226 loss = 6.95982
DEBUG    step   227 loss = 8.71509
DEBUG    step   228 loss = 6.93005
DEBUG    step   229 loss = 7.2101
DEBUG    step   230 loss = 7.23326
DEBUG    step   231 loss = 6.94798
DEBUG    step   232 loss = 6.83511
DEBUG    step   233 loss = 6.99621
DEBUG    step   234 loss = 6.79696
DEBUG    step   235 loss = 7.21458
DEBUG    step   236 loss = 6.97841
DEBUG    step   237 loss = 7.12467
DEBUG    step   238 loss = 6.98927
DEBUG    step   239 loss = 7.13294
DEBUG    step   240 loss = 7.17033
DEBUG    step   241 loss = 7.09788
DEBUG    step   242 loss = 6.98868
DEBUG    step   243 loss = 7.0711
DEBUG    step   244 loss = 7.10628
DEBUG    step   245 loss = 7.12893
DEBUG    step   246 loss = 6.94537
DEBUG    step   247 loss = 6.98222
DEBUG    step   248 loss = 7.12801
DEBUG    step   249 loss = 6.94684
DEBUG    step   250 loss = 7.01901
DEBUG    step   251 loss = 7.03228
DEBUG    step   252 loss = 7.14612
DEBUG    step   253 loss = 7.04241
DEBUG    step   254 loss = 6.92232
DEBUG    step   255 loss = 7.02093
DEBUG    step   256 loss = 6.98689
DEBUG    step   257 loss = 6.97682
DEBUG    step   258 loss = 6.99232
DEBUG    step   259 loss = 7.01528
DEBUG    step   260 loss = 6.86835
DEBUG    step   261 loss = 7.00633
DEBUG    step   262 loss = 7.06246
DEBUG    step   263 loss = 6.90189
DEBUG    step   264 loss = 7.07629
DEBUG    step   265 loss = 6.88559
DEBUG    step   266 loss = 6.92606
DEBUG    step   267 loss = 6.8929
DEBUG    step   268 loss = 6.83142
DEBUG    step   269 loss = 6.73955
DEBUG    step   270 loss = 6.81085
DEBUG    step   271 loss = 6.87084
DEBUG    step   272 loss = 6.88125
DEBUG    step   273 loss = 6.94562
DEBUG    step   274 loss = 6.9711
DEBUG    step   275 loss = 7.01001
DEBUG    step   276 loss = 6.91986
DEBUG    step   277 loss = 6.92239
DEBUG    step   278 loss = 6.70706
DEBUG    step   279 loss = 6.84017
DEBUG    step   280 loss = 7.09178
DEBUG    step   281 loss = 6.7313
DEBUG    step   282 loss = 6.79816
DEBUG    step   283 loss = 6.86953
DEBUG    step   284 loss = 6.92598
DEBUG    step   285 loss = 7.0731
DEBUG    step   286 loss = 6.91421
DEBUG    step   287 loss = 6.76945
DEBUG    step   288 loss = 6.74834
DEBUG    step   289 loss = 6.84824
DEBUG    step   290 loss = 6.88344
DEBUG    step   291 loss = 6.85244
DEBUG    step   292 loss = 6.922
DEBUG    step   293 loss = 9.57555
DEBUG    step   294 loss = 6.83098
DEBUG    step   295 loss = 7.43121
DEBUG    step   296 loss = 6.95061
DEBUG    step   297 loss = 6.79967
DEBUG    step   298 loss = 6.7929
DEBUG    step   299 loss = 6.7355
DEBUG    step   300 loss = 7.01345
DEBUG    step   301 loss = 6.83328
DEBUG    step   302 loss = 6.62454
DEBUG    step   303 loss = 6.84473
DEBUG    step   304 loss = 9.05065
DEBUG    step   305 loss = 7.038
DEBUG    step   306 loss = 6.60419
DEBUG    step   307 loss = 6.80575
DEBUG    step   308 loss = 6.73912
DEBUG    step   309 loss = 6.47463
DEBUG    step   310 loss = 6.84484
DEBUG    step   311 loss = 6.73429
DEBUG    step   312 loss = 6.89219
DEBUG    step   313 loss = 7.05905
DEBUG    step   314 loss = 6.82365
DEBUG    step   315 loss = 6.72354
DEBUG    step   316 loss = 6.54532
DEBUG    step   317 loss = 6.95339
DEBUG    step   318 loss = 7.0503
DEBUG    step   319 loss = 6.78209
DEBUG    step   320 loss = 6.59514
DEBUG    step   321 loss = 6.89779
DEBUG    step   322 loss = 6.72151
DEBUG    step   323 loss = 6.90015
DEBUG    step   324 loss = 7.00599
DEBUG    step   325 loss = 6.85437
DEBUG    step   326 loss = 6.89033
DEBUG    step   327 loss = 6.7871
DEBUG    step   328 loss = 6.8493
DEBUG    step   329 loss = 6.80922
DEBUG    step   330 loss = 6.96322
DEBUG    step   331 loss = 6.84506
DEBUG    step   332 loss = 6.87015
DEBUG    step   333 loss = 6.88979
DEBUG    step   334 loss = 6.64982
DEBUG    step   335 loss = 6.86292
DEBUG    step   336 loss = 6.92489
DEBUG    step   337 loss = 6.62396
DEBUG    step   338 loss = 6.84564
DEBUG    step   339 loss = 6.62305
DEBUG    step   340 loss = 7.36375
DEBUG    step   341 loss = 6.73599
DEBUG    step   342 loss = 6.80353
DEBUG    step   343 loss = 6.96371
DEBUG    step   344 loss = 6.89915
DEBUG    step   345 loss = 6.64238
DEBUG    step   346 loss = 6.51934
DEBUG    step   347 loss = 6.78445
DEBUG    step   348 loss = 6.94965
DEBUG    step   349 loss = 6.78796
DEBUG    step   350 loss = 6.77106
DEBUG    step   351 loss = 6.7466
DEBUG    step   352 loss = 6.77313
DEBUG    step   353 loss = 6.70463
DEBUG    step   354 loss = 6.96683
DEBUG    step   355 loss = 6.73415
DEBUG    step   356 loss = 6.73694
DEBUG    step   357 loss = 6.60738
DEBUG    step   358 loss = 9.84151
DEBUG    step   359 loss = 6.84548
DEBUG    step   360 loss = 6.57425
DEBUG    step   361 loss = 6.78442
DEBUG    step   362 loss = 6.68523
DEBUG    step   363 loss = 6.93113
DEBUG    step   364 loss = 9.26669
DEBUG    step   365 loss = 6.71749
DEBUG    step   366 loss = 6.60656
DEBUG    step   367 loss = 6.7795
DEBUG    step   368 loss = 6.55477
DEBUG    step   369 loss = 6.73777
DEBUG    step   370 loss = 6.80791
DEBUG    step   371 loss = 6.75802
DEBUG    step   372 loss = 6.80779
DEBUG    step   373 loss = 6.82983
DEBUG    step   374 loss = 6.5821
DEBUG    step   375 loss = 6.81309
DEBUG    step   376 loss = 6.58409
DEBUG    step   377 loss = 6.59094
DEBUG    step   378 loss = 6.59232
DEBUG    step   379 loss = 7.0035
DEBUG    step   380 loss = 6.65775
DEBUG    step   381 loss = 6.61621
DEBUG    step   382 loss = 6.6329
DEBUG    step   383 loss = 6.63025
DEBUG    step   384 loss = 6.61858
DEBUG    step   385 loss = 6.63814
DEBUG    step   386 loss = 6.50298
DEBUG    step   387 loss = 6.62591
DEBUG    step   388 loss = 6.56514
DEBUG    step   389 loss = 6.67944
DEBUG    step   390 loss = 6.80612
DEBUG    step   391 loss = 6.61369
DEBUG    step   392 loss = 6.85104
DEBUG    step   393 loss = 6.61612
DEBUG    step   394 loss = 6.55337
DEBUG    step   395 loss = 6.76919
DEBUG    step   396 loss = 6.66491
DEBUG    step   397 loss = 6.57224
DEBUG    step   398 loss = 6.54065
DEBUG    step   399 loss = 6.73794
INFO     Evaluating 80 minibatches
DEBUG    batch ate = 0.823513
DEBUG    batch ate = 0.824189
DEBUG    batch ate = 0.820978
DEBUG    batch ate = 0.822631
DEBUG    batch ate = 0.823555
DEBUG    batch ate = 0.822441
DEBUG    batch ate = 0.823683
DEBUG    batch ate = 0.822339
DEBUG    batch ate = 0.823964
DEBUG    batch ate = 0.823921
DEBUG    batch ate = 0.825266
DEBUG    batch ate = 0.822931
DEBUG    batch ate = 0.823049
DEBUG    batch ate = 0.824161
DEBUG    batch ate = 0.821918
DEBUG    batch ate = 0.824303
DEBUG    batch ate = 0.823845
DEBUG    batch ate = 0.822578
DEBUG    batch ate = 0.825122
DEBUG    batch ate = 0.823321
DEBUG    batch ate = 0.823198
DEBUG    batch ate = 0.823159
DEBUG    batch ate = 0.823571
DEBUG    batch ate = 0.822972
DEBUG    batch ate = 0.82311
DEBUG    batch ate = 0.821233
DEBUG    batch ate = 0.824326
DEBUG    batch ate = 0.823645
DEBUG    batch ate = 0.8233
DEBUG    batch ate = 0.821567
DEBUG    batch ate = 0.820404
DEBUG    batch ate = 0.821521
DEBUG    batch ate = 0.82027
DEBUG    batch ate = 0.824084
DEBUG    batch ate = 0.824593
DEBUG    batch ate = 0.823614
DEBUG    batch ate = 0.820698
DEBUG    batch ate = 0.824454
DEBUG    batch ate = 0.819246
DEBUG    batch ate = 0.823614
DEBUG    batch ate = 0.822471
DEBUG    batch ate = 0.822809
DEBUG    batch ate = 0.82155
DEBUG    batch ate = 0.822985
DEBUG    batch ate = 0.821966
DEBUG    batch ate = 0.822152
DEBUG    batch ate = 0.824818
DEBUG    batch ate = 0.821926
DEBUG    batch ate = 0.821183
DEBUG    batch ate = 0.821644
DEBUG    batch ate = 0.823652
DEBUG    batch ate = 0.822925
DEBUG    batch ate = 0.822612
DEBUG    batch ate = 0.824216
DEBUG    batch ate = 0.824456
DEBUG    batch ate = 0.822995
DEBUG    batch ate = 0.823972
DEBUG    batch ate = 0.821021
DEBUG    batch ate = 0.822201
DEBUG    batch ate = 0.821493
DEBUG    batch ate = 0.823859
DEBUG    batch ate = 0.819778
DEBUG    batch ate = 0.822789
DEBUG    batch ate = 0.825457
DEBUG    batch ate = 0.824181
DEBUG    batch ate = 0.821647
DEBUG    batch ate = 0.82509
DEBUG    batch ate = 0.821287
DEBUG    batch ate = 0.824007
DEBUG    batch ate = 0.821076
DEBUG    batch ate = 0.823777
DEBUG    batch ate = 0.822884
DEBUG    batch ate = 0.824057
DEBUG    batch ate = 0.820844
DEBUG    batch ate = 0.821426
DEBUG    batch ate = 0.82413
DEBUG    batch ate = 0.822516
DEBUG    batch ate = 0.823242
DEBUG    batch ate = 0.820823
DEBUG    batch ate = 0.822049
INFO     Evaluating 20 minibatches
DEBUG    batch ate = 0.823355
DEBUG    batch ate = 0.826493
DEBUG    batch ate = 0.825423
DEBUG    batch ate = 0.825241
DEBUG    batch ate = 0.823623
DEBUG    batch ate = 0.823627
DEBUG    batch ate = 0.821589
DEBUG    batch ate = 0.824463
DEBUG    batch ate = 0.821071
DEBUG    batch ate = 0.820596
DEBUG    batch ate = 0.823198
DEBUG    batch ate = 0.820816
DEBUG    batch ate = 0.823484
DEBUG    batch ate = 0.823282
DEBUG    batch ate = 0.825439
DEBUG    batch ate = 0.822407
DEBUG    batch ate = 0.822365
DEBUG    batch ate = 0.825534
DEBUG    batch ate = 0.822151
DEBUG    batch ate = 0.823306
[24]:
actuals_train = preds_dict_train['Actuals']
actuals_validation = preds_dict_valid['Actuals']

synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
                                        in preds_dict_train.items() if 'generated' not in label.lower()},
                                       index=['ATE', 'MSE']).T
synthetic_summary_train['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_train['ATE']/synthetic_summary_train.loc['Actuals', 'ATE']) - 1)

synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
                                             for label, preds in preds_dict_valid.items()
                                             if 'generated' not in label.lower()},
                                            index=['ATE', 'MSE']).T
synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_validation['ATE']/synthetic_summary_validation.loc['Actuals', 'ATE']) - 1)

# calculate kl divergence for training
for label in synthetic_summary_train.index:
    stacked_values = np.hstack((preds_dict_train[label], actuals_train))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_train[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_train, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_train.loc[label, 'KL Divergence'] = kl

# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
    stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_validation, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_validation.loc[label, 'KL Divergence'] = kl
[25]:
df_preds_train = pd.DataFrame([preds_dict_train['S Learner (LR)'].ravel(),
                               preds_dict_train['S Learner (XGB)'].ravel(),
                               preds_dict_train['T Learner (LR)'].ravel(),
                               preds_dict_train['T Learner (XGB)'].ravel(),
                               preds_dict_train['X Learner (LR)'].ravel(),
                               preds_dict_train['X Learner (XGB)'].ravel(),
                               preds_dict_train['R Learner (LR)'].ravel(),
                               preds_dict_train['R Learner (XGB)'].ravel(),
                               preds_dict_train['CEVAE'].ravel(),
                               preds_dict_train['generated_data']['tau'].ravel(),
                               preds_dict_train['generated_data']['w'].ravel(),
                               preds_dict_train['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'CEVAE','tau','w','y']).T

synthetic_summary_train['AUUC'] = auuc_score(df_preds_train).iloc[:-1]
[26]:
df_preds_validation = pd.DataFrame([preds_dict_valid['S Learner (LR)'].ravel(),
                               preds_dict_valid['S Learner (XGB)'].ravel(),
                               preds_dict_valid['T Learner (LR)'].ravel(),
                               preds_dict_valid['T Learner (XGB)'].ravel(),
                               preds_dict_valid['X Learner (LR)'].ravel(),
                               preds_dict_valid['X Learner (XGB)'].ravel(),
                               preds_dict_valid['R Learner (LR)'].ravel(),
                               preds_dict_valid['R Learner (XGB)'].ravel(),
                               preds_dict_valid['CEVAE'].ravel(),
                               preds_dict_valid['generated_data']['tau'].ravel(),
                               preds_dict_valid['generated_data']['w'].ravel(),
                               preds_dict_valid['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'CEVAE','tau','w','y']).T

synthetic_summary_validation['AUUC'] = auuc_score(df_preds_validation).iloc[:-1]
[27]:
synthetic_summary_train
[27]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.726115 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.832336 0.062462 0.146287 6.278413 0.499991
S Learner (XGB) 0.807743 0.039735 0.112417 2.551297 0.554885
T Learner (LR) 0.833364 0.059665 0.147703 3.312696 0.523272
T Learner (XGB) 0.803592 0.040524 0.106701 2.565715 0.553197
X Learner (LR) 0.833364 0.059665 0.147703 3.312696 0.523272
X Learner (XGB) 0.803349 0.038580 0.106367 2.500947 0.555391
R Learner (LR) 0.833845 0.060239 0.148365 3.511157 0.523214
R Learner (XGB) 0.735442 0.046848 0.012845 2.836128 0.539213
CEVAE 0.822853 0.058177 0.133227 3.157059 0.519150
[28]:
synthetic_summary_validation
[28]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.728371 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.832336 0.061983 0.142736 6.278413 0.499967
S Learner (XGB) 0.808844 0.040638 0.110483 2.548714 0.553011
T Learner (LR) 0.833805 0.059305 0.144753 3.316884 0.522972
T Learner (XGB) 0.803766 0.042424 0.103512 2.561688 0.549279
X Learner (LR) 0.833805 0.059305 0.144753 3.316884 0.522972
X Learner (XGB) 0.803530 0.039699 0.103187 2.489822 0.553039
R Learner (LR) 0.834179 0.059851 0.145266 3.512746 0.522887
R Learner (XGB) 0.736147 0.046685 0.010675 2.747596 0.536579
CEVAE 0.823373 0.057690 0.130430 3.152161 0.519573
[29]:
plot_gain(df_preds_train)
../_images/examples_cevae_example_35_0.png
[30]:
plot_gain(df_preds_validation)
../_images/examples_cevae_example_36_0.png