Skip to content

Commit 4d1baa9

Browse files
committed
Added update methods to both ForestDataset and RandomEffectsDataset in R
1 parent 50bcc41 commit 4d1baa9

File tree

8 files changed

+271
-10
lines changed

8 files changed

+271
-10
lines changed

R/cpp11.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
3636
invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis))
3737
}
3838

39-
forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights) {
40-
invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights))
39+
forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
40+
invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
4141
}
4242

4343
forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
@@ -72,6 +72,18 @@ create_rfx_dataset_cpp <- function() {
7272
.Call(`_stochtree_create_rfx_dataset_cpp`)
7373
}
7474

75+
rfx_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
76+
invisible(.Call(`_stochtree_rfx_dataset_update_basis_cpp`, dataset_ptr, basis))
77+
}
78+
79+
rfx_dataset_update_var_weights_cpp <- function(dataset_ptr, weights, exponentiate) {
80+
invisible(.Call(`_stochtree_rfx_dataset_update_var_weights_cpp`, dataset_ptr, weights, exponentiate))
81+
}
82+
83+
rfx_dataset_num_basis_cpp <- function(dataset) {
84+
.Call(`_stochtree_rfx_dataset_num_basis_cpp`, dataset)
85+
}
86+
7587
rfx_dataset_num_rows_cpp <- function(dataset) {
7688
.Call(`_stochtree_rfx_dataset_num_rows_cpp`, dataset)
7789
}

R/data.R

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@ ForestDataset <- R6::R6Class(
3636
update_basis = function(basis) {
3737
stopifnot(self$has_basis())
3838
forest_dataset_update_basis_cpp(self$data_ptr, basis)
39-
},
39+
},
40+
41+
#' @description
42+
#' Update variance_weights in a dataset
43+
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
44+
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.
45+
update_variance_weights = function(variance_weights, exponentiate = F) {
46+
stopifnot(self$has_variance_weights())
47+
forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
48+
},
4049

4150
#' @description
4251
#' Return number of observations in a `ForestDataset` object
@@ -190,12 +199,36 @@ RandomEffectsDataset <- R6::R6Class(
190199
}
191200
},
192201

202+
#' @description
203+
#' Update basis matrix in a dataset
204+
#' @param basis Updated matrix of bases used to define random slopes / intercepts
205+
update_basis = function(basis) {
206+
stopifnot(self$has_basis())
207+
rfx_dataset_update_basis_cpp(self$data_ptr, basis)
208+
},
209+
210+
#' @description
211+
#' Update variance_weights in a dataset
212+
#' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
213+
#' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.
214+
update_variance_weights = function(variance_weights, exponentiate = F) {
215+
stopifnot(self$has_variance_weights())
216+
rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate)
217+
},
218+
193219
#' @description
194220
#' Return number of observations in a `RandomEffectsDataset` object
195221
#' @return Observation count
196222
num_observations = function() {
197223
return(rfx_dataset_num_rows_cpp(self$data_ptr))
198-
},
224+
},
225+
226+
#' @description
227+
#' Return dimension of the basis matrix in a `RandomEffectsDataset` object
228+
#' @return Basis vector count
229+
num_basis = function() {
230+
return(rfx_dataset_num_basis_cpp(self$data_ptr))
231+
},
199232

200233
#' @description
201234
#' Whether or not a dataset has group label indices

include/stochtree/data.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ class RandomEffectsDataset {
497497
*/
498498
void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
499499
basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major);
500+
num_basis_ = num_col;
500501
has_basis_ = true;
501502
}
502503
/*!
@@ -509,6 +510,49 @@ class RandomEffectsDataset {
509510
var_weights_ = ColumnVector(data_ptr, num_row);
510511
has_var_weights_ = true;
511512
}
513+
/*!
514+
* \brief Update the data in the internal basis matrix to new values stored in a raw double array
515+
*
516+
* \param data_ptr Pointer to first element of a contiguous array of data storing a basis matrix
517+
* \param num_row Number of rows in the basis matrix
518+
* \param num_col Number of columns in the basis matrix
519+
* \param is_row_major Whether or not the data in `data_ptr` are organized in a row-major or column-major fashion
520+
*/
521+
void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) {
522+
CHECK(has_basis_);
523+
CHECK_EQ(num_col, num_basis_);
524+
// Copy data from R / Python process memory to Eigen matrix
525+
double temp_value;
526+
for (data_size_t i = 0; i < num_row; ++i) {
527+
for (int j = 0; j < num_col; ++j) {
528+
if (is_row_major){
529+
// Numpy 2-d arrays are stored in "row major" order
530+
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_col) * i + j));
531+
} else {
532+
// R matrices are stored in "column major" order
533+
temp_value = static_cast<double>(*(data_ptr + static_cast<data_size_t>(num_row) * j + i));
534+
}
535+
basis_.SetElement(i, j, temp_value);
536+
}
537+
}
538+
}
539+
/*!
540+
* \brief Update the data in the internal variance weight vector to new values stored in a raw double array
541+
*
542+
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
543+
* \param num_row Number of rows in the weight vector
544+
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
545+
*/
546+
void UpdateVarWeights(double* data_ptr, data_size_t num_row, bool exponentiate = true) {
547+
CHECK(has_var_weights_);
548+
// Copy data from R / Python process memory to Eigen vector
549+
double temp_value;
550+
for (data_size_t i = 0; i < num_row; ++i) {
551+
if (exponentiate) temp_value = std::exp(static_cast<double>(*(data_ptr + i)));
552+
else temp_value = static_cast<double>(*(data_ptr + i));
553+
var_weights_.SetElement(i, temp_value);
554+
}
555+
}
512556
/*!
513557
* \brief Copy / load group indices for random effects
514558
*
@@ -570,6 +614,7 @@ class RandomEffectsDataset {
570614
ColumnMatrix basis_;
571615
ColumnVector var_weights_;
572616
std::vector<int32_t> group_labels_;
617+
int num_basis_{0};
573618
bool has_basis_{false};
574619
bool has_var_weights_{false};
575620
bool has_group_labels_{false};

man/ForestDataset.Rd

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/RandomEffectsDataset.Rd

Lines changed: 55 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/R_data.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ void forest_dataset_update_basis_cpp(cpp11::external_pointer<StochTree::ForestDa
8585
}
8686

8787
[[cpp11::register]]
88-
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights) {
88+
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights, bool exponentiate) {
8989
// Add weights
9090
StochTree::data_size_t n = weights.size();
9191
double* weight_data_ptr = REAL(PROTECT(weights));
92-
dataset_ptr->AddVarianceWeights(weight_data_ptr, n);
92+
dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate);
9393

9494
// Unprotect pointers to R data
9595
UNPROTECT(1);
@@ -191,6 +191,37 @@ cpp11::external_pointer<StochTree::RandomEffectsDataset> create_rfx_dataset_cpp(
191191
return cpp11::external_pointer<StochTree::RandomEffectsDataset>(dataset_ptr_.release());
192192
}
193193

194+
[[cpp11::register]]
195+
void rfx_dataset_update_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::doubles_matrix<> basis) {
196+
// TODO: add handling code on the R side to ensure matrices are column-major
197+
bool row_major{false};
198+
199+
// Add basis
200+
StochTree::data_size_t n = basis.nrow();
201+
int num_basis = basis.ncol();
202+
double* basis_data_ptr = REAL(PROTECT(basis));
203+
dataset_ptr->UpdateBasis(basis_data_ptr, n, num_basis, row_major);
204+
205+
// Unprotect pointers to R data
206+
UNPROTECT(1);
207+
}
208+
209+
[[cpp11::register]]
210+
void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::doubles weights, bool exponentiate) {
211+
// Add weights
212+
StochTree::data_size_t n = weights.size();
213+
double* weight_data_ptr = REAL(PROTECT(weights));
214+
dataset_ptr->UpdateVarWeights(weight_data_ptr, n, exponentiate);
215+
216+
// Unprotect pointers to R data
217+
UNPROTECT(1);
218+
}
219+
220+
[[cpp11::register]]
221+
int rfx_dataset_num_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset) {
222+
return dataset->NumBases();
223+
}
224+
194225
[[cpp11::register]]
195226
int rfx_dataset_num_rows_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset) {
196227
return dataset->NumObservations();

src/cpp11.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ extern "C" SEXP _stochtree_forest_dataset_update_basis_cpp(SEXP dataset_ptr, SEX
7272
END_CPP11
7373
}
7474
// R_data.cpp
75-
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights);
76-
extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights) {
75+
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights, bool exponentiate);
76+
extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) {
7777
BEGIN_CPP11
78-
forest_dataset_update_var_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(weights));
78+
forest_dataset_update_var_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(weights), cpp11::as_cpp<cpp11::decay_t<bool>>(exponentiate));
7979
return R_NilValue;
8080
END_CPP11
8181
}
@@ -141,6 +141,29 @@ extern "C" SEXP _stochtree_create_rfx_dataset_cpp() {
141141
END_CPP11
142142
}
143143
// R_data.cpp
144+
void rfx_dataset_update_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::doubles_matrix<> basis);
145+
extern "C" SEXP _stochtree_rfx_dataset_update_basis_cpp(SEXP dataset_ptr, SEXP basis) {
146+
BEGIN_CPP11
147+
rfx_dataset_update_basis_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles_matrix<>>>(basis));
148+
return R_NilValue;
149+
END_CPP11
150+
}
151+
// R_data.cpp
152+
void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::doubles weights, bool exponentiate);
153+
extern "C" SEXP _stochtree_rfx_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights, SEXP exponentiate) {
154+
BEGIN_CPP11
155+
rfx_dataset_update_var_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(weights), cpp11::as_cpp<cpp11::decay_t<bool>>(exponentiate));
156+
return R_NilValue;
157+
END_CPP11
158+
}
159+
// R_data.cpp
160+
int rfx_dataset_num_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset);
161+
extern "C" SEXP _stochtree_rfx_dataset_num_basis_cpp(SEXP dataset) {
162+
BEGIN_CPP11
163+
return cpp11::as_sexp(rfx_dataset_num_basis_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(dataset)));
164+
END_CPP11
165+
}
166+
// R_data.cpp
144167
int rfx_dataset_num_rows_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset);
145168
extern "C" SEXP _stochtree_rfx_dataset_num_rows_cpp(SEXP dataset) {
146169
BEGIN_CPP11
@@ -1549,7 +1572,7 @@ static const R_CallMethodDef CallEntries[] = {
15491572
{"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2},
15501573
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
15511574
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
1552-
{"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 2},
1575+
{"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3},
15531576
{"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2},
15541577
{"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2},
15551578
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
@@ -1671,7 +1694,10 @@ static const R_CallMethodDef CallEntries[] = {
16711694
{"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1},
16721695
{"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1},
16731696
{"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1},
1697+
{"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1},
16741698
{"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1},
1699+
{"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2},
1700+
{"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3},
16751701
{"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2},
16761702
{"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2},
16771703
{"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1},

0 commit comments

Comments
 (0)