Skip to content

Commit cb164de

Browse files
committed
Ran cpp_register() and document() and added profiling script
1 parent 574fb90 commit cb164de

File tree

4 files changed

+83
-0
lines changed

4 files changed

+83
-0
lines changed

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
640640
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
641641
}
642642

643+
get_cached_forest_predictions_cpp <- function(tracker_ptr) {
644+
.Call(`_stochtree_get_cached_forest_predictions_cpp`, tracker_ptr)
645+
}
646+
643647
sample_without_replacement_integer_cpp <- function(population_vector, sampling_probs, sample_size) {
644648
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
645649
}

man/ForestModel.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cpp11.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,13 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX
11871187
END_CPP11
11881188
}
11891189
// sampler.cpp
1190+
cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker_ptr);
1191+
extern "C" SEXP _stochtree_get_cached_forest_predictions_cpp(SEXP tracker_ptr) {
1192+
BEGIN_CPP11
1193+
return cpp11::as_sexp(get_cached_forest_predictions_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker_ptr)));
1194+
END_CPP11
1195+
}
1196+
// sampler.cpp
11901197
cpp11::writable::integers sample_without_replacement_integer_cpp(cpp11::integers population_vector, cpp11::doubles sampling_probs, int sample_size);
11911198
extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP population_vector, SEXP sampling_probs, SEXP sample_size) {
11921199
BEGIN_CPP11
@@ -1539,6 +1546,7 @@ static const R_CallMethodDef CallEntries[] = {
15391546
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
15401547
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
15411548
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},
1549+
{"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1},
15421550
{"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3},
15431551
{"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2},
15441552
{"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2},

tools/perf/bart_profiling_script.R

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Capture command line arguments
5+
args <- commandArgs(trailingOnly = T)
6+
if (length(args) > 0){
7+
n <- as.integer(args[1])
8+
p <- as.integer(args[2])
9+
num_gfr <- as.integer(args[3])
10+
num_mcmc <- as.integer(args[4])
11+
snr <- as.numeric(args[5])
12+
} else{
13+
# Default arguments
14+
n <- 1000
15+
p <- 5
16+
num_gfr <- 10
17+
num_mcmc <- 100
18+
snr <- 3.0
19+
}
20+
cat("n = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr,
21+
"\nnum_mcmc = ", num_mcmc, "\nsnr = ", snr, "\n", sep = "")
22+
23+
# Generate data needed to train BART model
24+
X <- matrix(runif(n*p), ncol = p)
25+
plm_term <- (
26+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
27+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
28+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
29+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
30+
)
31+
trig_term <- (
32+
2*sin(X[,3]*2*pi) -
33+
1.5*cos(X[,4]*2*pi)
34+
)
35+
f_XW <- plm_term + trig_term
36+
noise_sd <- sd(f_XW)/snr
37+
y <- f_XW + rnorm(n, 0, noise_sd)
38+
39+
# Split into train and test sets
40+
test_set_pct <- 0.2
41+
n_test <- round(test_set_pct*n)
42+
n_train <- n - n_test
43+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
44+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
45+
X_test <- X[test_inds,]
46+
X_train <- X[train_inds,]
47+
y_test <- y[test_inds]
48+
y_train <- y[train_inds]
49+
50+
system.time({
51+
# Sample BART model
52+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
53+
num_gfr = num_gfr, num_mcmc = num_mcmc)
54+
55+
# Predict on the test set
56+
test_preds <- predict(bart_model, X = X_test)
57+
})

0 commit comments

Comments
 (0)