Skip to content

Commit a53e58d

Browse files
committed
Updated predict and demo script
1 parent 691eed7 commit a53e58d

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

R/bcf.R

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2808,15 +2808,16 @@ predict.bcfmodel <- function(
28082808
)
28092809
}
28102810
if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) {
2811-
stop("Random effects basis (rfx_basis) must be provided for this model")
2811+
if (object$model_params$rfx_model_spec == "custom") {
2812+
stop("Random effects basis (rfx_basis) must be provided for this model")
2813+
}
28122814
}
2813-
if (
2814-
(object$model_params$num_rfx_basis > 0) &&
2815-
(ncol(rfx_basis) != object$model_params$num_rfx_basis)
2816-
) {
2817-
stop(
2818-
"Random effects basis has a different dimension than the basis used to train this model"
2819-
)
2815+
if ((object$model_params$num_rfx_basis > 0) && (!is.null(rfx_basis))) {
2816+
if (ncol(rfx_basis) != object$model_params$num_rfx_basis) {
2817+
stop(
2818+
"Random effects basis has a different dimension than the basis used to train this model"
2819+
)
2820+
}
28202821
}
28212822

28222823
# Preprocess covariates

tools/debug/bcf_cate_debug.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,13 @@ tau_hat_posterior_test <- compute_contrast_bcf_model(
229229
scale = "linear"
230230
)
231231

232-
# Compute the same quantity via predict
232+
# Compute the same quantity directly via predict
233233
tau_hat_posterior_test_comparison <- predict(
234234
bcf_model,
235235
X = X_test,
236236
Z = Z_test,
237237
propensity = pi_test,
238238
rfx_group_ids = group_ids_test,
239-
rfx_basis = rfx_basis_test,
240239
type = "posterior",
241240
terms = "cate",
242241
scale = "linear"

0 commit comments

Comments
 (0)