Skip to content

Commit 53156f1

Browse files
authored
Merge pull request #250 from StochasticTree/release-prep-0.2.1
Prepare for CRAN / PyPI 0.2.1 patch release
2 parents 10d8215 + fe5d1a3 commit 53156f1

File tree

18 files changed

+1849
-247
lines changed

18 files changed

+1849
-247
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
name: Unit Tests and Slow Running API Integration Tests for R and Python
2+
3+
on:
4+
workflow_dispatch:
5+
6+
jobs:
7+
testing:
8+
name: test-slow-api-combinations
9+
runs-on: ${{ matrix.os }}
10+
11+
strategy:
12+
fail-fast: false
13+
matrix:
14+
os: [ubuntu-latest, windows-latest, macos-latest]
15+
16+
steps:
17+
- name: Prevent conversion of line endings on Windows
18+
if: startsWith(matrix.os, 'windows')
19+
shell: pwsh
20+
run: git config --global core.autocrlf false
21+
22+
- name: Checkout repository
23+
uses: actions/checkout@v4
24+
with:
25+
submodules: 'recursive'
26+
27+
- name: Setup Python 3.10
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: "3.10"
31+
cache: "pip"
32+
33+
- name: Set up openmp (macos)
34+
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
35+
if: matrix.os == 'macos-latest'
36+
run: |
37+
brew install libomp
38+
39+
- name: Install Package with Relevant Dependencies
40+
run: |
41+
pip install --upgrade pip
42+
pip install -r requirements.txt
43+
pip install .
44+
45+
- name: Run Pytest with Slow Running API Tests Enabled
46+
run: |
47+
pytest --runslow test/python
48+
49+
- name: Setup Pandoc for R
50+
uses: r-lib/actions/setup-pandoc@v2
51+
52+
- name: Setup R
53+
uses: r-lib/actions/setup-r@v2
54+
with:
55+
use-public-rspm: true
56+
57+
- name: Setup R Package Dependencies
58+
uses: r-lib/actions/setup-r-dependencies@v2
59+
with:
60+
extra-packages: any::testthat, any::decor, any::rcmdcheck
61+
needs: check
62+
63+
- name: Create a CRAN-ready version of the R package
64+
run: |
65+
Rscript cran-bootstrap.R 0 0 1
66+
67+
- name: Run CRAN Checks with Slow Running API Tests Enabled
68+
uses: r-lib/actions/check-r-package@v2
69+
env:
70+
RUN_SLOW_TESTS: true
71+
with:
72+
working-directory: 'stochtree_cran'

CHANGELOG.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
# Changelog
22

3-
# stochtree (development version)
4-
5-
## New Features
6-
7-
## Computational Improvements
3+
# stochtree 0.2.1
84

95
## Bug Fixes
106

11-
* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248))
12-
13-
## Documentation Improvements
7+
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
148

159
## Other Changes
1610

11+
* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250))
12+
1713
# stochtree 0.2.0
1814

1915
## New Features

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: stochtree
22
Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
3-
Version: 0.2.0.9000
3+
Version: 0.2.1
44
Authors@R:
55
c(
66
person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),

Doxyfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree"
4848
# could be handy for archiving the generated documentation or if some version
4949
# control system is used.
5050

51-
PROJECT_NUMBER = 0.2.0.9000
51+
PROJECT_NUMBER = 0.2.1
5252

5353
# Using the PROJECT_BRIEF tag one can provide an optional one line description
5454
# for a project that appears at the top of each page and should give viewer a

NEWS.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
# stochtree (development version)
2-
3-
## New Features
4-
5-
## Computational Improvements
1+
# stochtree 0.2.1
62

73
## Bug Fixes
84

9-
* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248))
10-
11-
## Documentation Improvements
5+
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
126

137
## Other Changes
148

9+
* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250))
10+
1511
# stochtree 0.2.0
1612

1713
## New Features

R/bart.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,16 @@ bart <- function(
835835
}
836836
}
837837

838+
# Runtime checks for variance forest
839+
if (include_variance_forest) {
840+
if (sample_sigma2_global) {
841+
warning(
842+
"Global error variance will not be sampled with a heteroskedasticity forest"
843+
)
844+
sample_sigma2_global <- F
845+
}
846+
}
847+
838848
# Handle standardization, prior calibration, and initialization of forest
839849
# differently for binary and continuous outcomes
840850
if (probit_outcome_model) {
@@ -2124,7 +2134,6 @@ predict.bartmodel <- function(
21242134
X <- preprocessPredictionData(X, train_set_metadata)
21252135

21262136
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
2127-
has_rfx <- FALSE
21282137
if (predict_rfx) {
21292138
if (!is.null(rfx_group_ids)) {
21302139
rfx_unique_group_ids <- object$rfx_unique_group_ids
@@ -2135,7 +2144,6 @@ predict.bartmodel <- function(
21352144
)
21362145
}
21372146
rfx_group_ids <- as.integer(group_ids_factor)
2138-
has_rfx <- TRUE
21392147
}
21402148
}
21412149

R/bcf.R

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -897,14 +897,7 @@ bcf <- function(
897897
# Handle multivariate treatment
898898
has_multivariate_treatment <- ncol(Z_train) > 1
899899
if (has_multivariate_treatment) {
900-
# Disable adaptive coding, internal propensity model, and
901-
# leaf scale sampling if treatment is multivariate
902-
if (adaptive_coding) {
903-
warning(
904-
"Adaptive coding is incompatible with multivariate treatment and will be ignored"
905-
)
906-
adaptive_coding <- FALSE
907-
}
900+
# Disable internal propensity model and leaf scale sampling if treatment is multivariate
908901
if (is.null(propensity_train)) {
909902
if (propensity_covariate != "none") {
910903
warning(
@@ -949,21 +942,31 @@ bcf <- function(
949942
}
950943
has_basis_rfx <- TRUE
951944
num_basis_rfx <- ncol(rfx_basis_train)
952-
} else if (rfx_model_spec == "intercept_only") {
953-
rfx_basis_train <- matrix(
954-
rep(1, nrow(X_train)),
955-
nrow = nrow(X_train),
956-
ncol = 1
957-
)
958-
has_basis_rfx <- TRUE
959-
num_basis_rfx <- 1
960945
} else if (rfx_model_spec == "intercept_plus_treatment") {
961-
rfx_basis_train <- cbind(
962-
rep(1, nrow(X_train)),
963-
Z_train
964-
)
965-
has_basis_rfx <- TRUE
966-
num_basis_rfx <- 1 + ncol(Z_train)
946+
if (has_multivariate_treatment) {
947+
warning(
948+
"Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables."
949+
)
950+
rfx_model_spec <- "intercept_only"
951+
}
952+
}
953+
if (is.null(rfx_basis_train)) {
954+
if (rfx_model_spec == "intercept_only") {
955+
rfx_basis_train <- matrix(
956+
rep(1, nrow(X_train)),
957+
nrow = nrow(X_train),
958+
ncol = 1
959+
)
960+
has_basis_rfx <- TRUE
961+
num_basis_rfx <- 1
962+
} else {
963+
rfx_basis_train <- cbind(
964+
rep(1, nrow(X_train)),
965+
Z_train
966+
)
967+
has_basis_rfx <- TRUE
968+
num_basis_rfx <- 1 + ncol(Z_train)
969+
}
967970
}
968971
num_rfx_groups <- length(unique(rfx_group_ids_train))
969972
num_rfx_components <- ncol(rfx_basis_train)
@@ -1021,15 +1024,21 @@ bcf <- function(
10211024
y_train <- as.matrix(y_train)
10221025
}
10231026

1024-
# Check whether treatment is binary (specifically 0-1 binary)
1025-
binary_treatment <- length(unique(Z_train)) == 2
1026-
if (binary_treatment) {
1027-
unique_treatments <- sort(unique(Z_train))
1028-
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
1027+
# Check whether treatment is binary and univariate (specifically 0-1 binary)
1028+
binary_treatment <- FALSE
1029+
if (!has_multivariate_treatment) {
1030+
binary_treatment <- length(unique(Z_train)) == 2
1031+
if (binary_treatment) {
1032+
unique_treatments <- sort(unique(Z_train))
1033+
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
1034+
}
10291035
}
10301036

10311037
# Adaptive coding will be ignored for continuous / ordered categorical treatments
10321038
if ((!binary_treatment) && (adaptive_coding)) {
1039+
warning(
1040+
"Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model"
1041+
)
10331042
adaptive_coding <- FALSE
10341043
}
10351044

R/posterior_transformation.R

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,31 @@ compute_contrast_bart_model <- function(
409409
"rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
410410
)
411411
}
412-
if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
413-
stop(
414-
"rfx_basis_0 and rfx_basis_1 must be provided for this model"
415-
)
416-
}
417-
if (
418-
(object$model_params$num_rfx_basis > 0) &&
419-
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
420-
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
421-
) {
422-
stop(
423-
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
424-
)
412+
if (has_rfx) {
413+
if (object$model_params$rfx_model_spec == "custom") {
414+
if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
415+
stop(
416+
"A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
417+
)
418+
}
419+
if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) {
420+
stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices")
421+
}
422+
if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) {
423+
stop(
424+
"'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'"
425+
)
426+
}
427+
if (
428+
(object$model_params$num_rfx_basis > 0) &&
429+
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
430+
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
431+
) {
432+
stop(
433+
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
434+
)
435+
}
436+
}
425437
}
426438

427439
# Predict for the control arm
@@ -574,16 +586,22 @@ sample_bcf_posterior_predictive <- function(
574586
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
575587
)
576588
}
577-
if (is.null(rfx_basis)) {
578-
stop(
579-
"'rfx_basis' must be provided in order to compute the requested intervals"
580-
)
581-
}
582-
if (!is.matrix(rfx_basis)) {
583-
stop("'rfx_basis' must be a matrix")
589+
590+
if (model_object$model_params$rfx_model_spec == "custom") {
591+
if (is.null(rfx_basis)) {
592+
stop(
593+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
594+
)
595+
}
584596
}
585-
if (nrow(rfx_basis) != nrow(X)) {
586-
stop("'rfx_basis' must have the same number of rows as 'X'")
597+
598+
if (!is.null(rfx_basis)) {
599+
if (!is.matrix(rfx_basis)) {
600+
stop("'rfx_basis' must be a matrix")
601+
}
602+
if (nrow(rfx_basis) != nrow(X)) {
603+
stop("'rfx_basis' must have the same number of rows as 'X'")
604+
}
587605
}
588606
}
589607

@@ -735,16 +753,18 @@ sample_bart_posterior_predictive <- function(
735753
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
736754
)
737755
}
738-
if (is.null(rfx_basis)) {
739-
stop(
740-
"'rfx_basis' must be provided in order to compute the requested intervals"
741-
)
742-
}
743-
if (!is.matrix(rfx_basis)) {
744-
stop("'rfx_basis' must be a matrix")
745-
}
746-
if (nrow(rfx_basis) != nrow(X)) {
747-
stop("'rfx_basis' must have the same number of rows as 'X'")
756+
if (model_object$model_params$rfx_model_spec == "custom") {
757+
if (is.null(rfx_basis)) {
758+
stop(
759+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
760+
)
761+
}
762+
if (!is.matrix(rfx_basis)) {
763+
stop("'rfx_basis' must be a matrix")
764+
}
765+
if (nrow(rfx_basis) != nrow(X)) {
766+
stop("'rfx_basis' must have the same number of rows as 'X'")
767+
}
748768
}
749769
}
750770

@@ -1172,16 +1192,18 @@ compute_bart_posterior_interval <- function(
11721192
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
11731193
)
11741194
}
1175-
if (is.null(rfx_basis)) {
1176-
stop(
1177-
"'rfx_basis' must be provided in order to compute the requested intervals"
1178-
)
1179-
}
1180-
if (!is.matrix(rfx_basis)) {
1181-
stop("'rfx_basis' must be a matrix")
1182-
}
1183-
if (nrow(rfx_basis) != nrow(X)) {
1184-
stop("'rfx_basis' must have the same number of rows as 'X'")
1195+
if (model_object$model_params$rfx_model_spec == "custom") {
1196+
if (is.null(rfx_basis)) {
1197+
stop(
1198+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
1199+
)
1200+
}
1201+
if (!is.matrix(rfx_basis)) {
1202+
stop("'rfx_basis' must be a matrix")
1203+
}
1204+
if (nrow(rfx_basis) != nrow(X)) {
1205+
stop("'rfx_basis' must have the same number of rows as 'X'")
1206+
}
11851207
}
11861208
}
11871209

0 commit comments

Comments
 (0)