Skip to content

Commit 6c95c65

Browse files
authored
Merge pull request #246 from StochasticTree/doc-vignette-hotfix
Update python demos and vignettes
2 parents 1b622e4 + 5bee405 commit 6c95c65

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

demo/debug/causal_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363

6464
# Run BCF
6565
bcf_model = BCFModel()
66-
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=1000)
66+
bcf_model.sample(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, num_gfr=10, num_mcmc=1000)
6767

6868
# Inspect the MCMC (BART) samples
6969
forest_preds_y_mcmc = bcf_model.y_hat_test

demo/debug/multivariate_treatment_causal_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@
4343

4444
# Run BCF
4545
bcf_model = BCFModel()
46-
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
46+
bcf_model.sample(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, num_gfr=10, num_mcmc=100)

demo/notebooks/multi_chain.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@
488488
"bcf_model.sample(\n",
489489
" X_train = X_train,\n",
490490
" Z_train = Z_train,\n",
491-
" pi_train = propensity_train,\n",
491+
" propensity_train = propensity_train,\n",
492492
" y_train = y_train,\n",
493493
" num_gfr = num_gfr,\n",
494494
" num_burnin = num_burnin,\n",
@@ -621,7 +621,7 @@
621621
"xbcf_model.sample(\n",
622622
" X_train = X_train,\n",
623623
" Z_train = Z_train,\n",
624-
" pi_train = propensity_train,\n",
624+
" propensity_train = propensity_train,\n",
625625
" y_train = y_train,\n",
626626
" num_gfr = num_gfr,\n",
627627
" num_burnin = 0,\n",
@@ -649,7 +649,7 @@
649649
"bcf_model.sample(\n",
650650
" X_train = X_train,\n",
651651
" Z_train = Z_train,\n",
652-
" pi_train = propensity_train,\n",
652+
" propensity_train = propensity_train,\n",
653653
" y_train = y_train,\n",
654654
" num_gfr = 0,\n",
655655
" num_burnin = num_burnin,\n",

0 commit comments

Comments
 (0)