Skip to content

Commit 9713f43

Browse files
committed
Added ability to update variance weights
1 parent f6aed65 commit 9713f43

File tree

7 files changed

+65
-1
lines changed

7 files changed

+65
-1
lines changed

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ forest_dataset_update_basis_cpp <- function(dataset_ptr, basis) {
3636
invisible(.Call(`_stochtree_forest_dataset_update_basis_cpp`, dataset_ptr, basis))
3737
}
3838

39+
forest_dataset_update_var_weights_cpp <- function(dataset_ptr, weights) {
40+
invisible(.Call(`_stochtree_forest_dataset_update_var_weights_cpp`, dataset_ptr, weights))
41+
}
42+
3943
forest_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
4044
invisible(.Call(`_stochtree_forest_dataset_add_weights_cpp`, dataset_ptr, weights))
4145
}

man/ForestModel.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/bcf.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/R_data.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ void forest_dataset_update_basis_cpp(cpp11::external_pointer<StochTree::ForestDa
8484
UNPROTECT(1);
8585
}
8686

87+
[[cpp11::register]]
88+
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights) {
89+
// Add weights
90+
StochTree::data_size_t n = weights.size();
91+
double* weight_data_ptr = REAL(PROTECT(weights));
92+
dataset_ptr->AddVarianceWeights(weight_data_ptr, n);
93+
94+
// Unprotect pointers to R data
95+
UNPROTECT(1);
96+
}
97+
8798
[[cpp11::register]]
8899
void forest_dataset_add_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights) {
89100
// Add weights

src/cpp11.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ extern "C" SEXP _stochtree_forest_dataset_update_basis_cpp(SEXP dataset_ptr, SEX
7272
END_CPP11
7373
}
7474
// R_data.cpp
75+
void forest_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights);
76+
extern "C" SEXP _stochtree_forest_dataset_update_var_weights_cpp(SEXP dataset_ptr, SEXP weights) {
77+
BEGIN_CPP11
78+
forest_dataset_update_var_weights_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestDataset>>>(dataset_ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(weights));
79+
return R_NilValue;
80+
END_CPP11
81+
}
82+
// R_data.cpp
7583
void forest_dataset_add_weights_cpp(cpp11::external_pointer<StochTree::ForestDataset> dataset_ptr, cpp11::doubles weights);
7684
extern "C" SEXP _stochtree_forest_dataset_add_weights_cpp(SEXP dataset_ptr, SEXP weights) {
7785
BEGIN_CPP11
@@ -1541,6 +1549,7 @@ static const R_CallMethodDef CallEntries[] = {
15411549
{"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2},
15421550
{"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2},
15431551
{"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2},
1552+
{"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 2},
15441553
{"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2},
15451554
{"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2},
15461555
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},

src/py_stochtree.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ class ForestDatasetCpp {
6767
dataset_->AddVarianceWeights(data_ptr, num_row);
6868
}
6969

70+
void UpdateVarianceWeights(py::array_t<double> weight_vector, data_size_t num_row) {
71+
// Extract pointer to contiguous block of memory
72+
double* data_ptr = static_cast<double*>(weight_vector.mutable_data());
73+
74+
// Load covariates
75+
dataset_->AddVarianceWeights(data_ptr, num_row);
76+
}
77+
7078
data_size_t NumRows() {
7179
return dataset_->NumObservations();
7280
}

stochtree/data.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ def update_basis(self, basis: np.array):
6161
basis_ = np.expand_dims(basis, 1) if np.ndim(basis) == 1 else basis
6262
n, p = basis_.shape
6363
basis_rowmajor = np.ascontiguousarray(basis_)
64+
if not self.has_basis():
65+
raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.")
66+
if not isinstance(basis, np.ndarray):
67+
raise ValueError("basis must be a numpy array.")
68+
if basis.ndim != 2:
69+
raise ValueError("basis must be a 2-dimensional numpy array.")
70+
if self.num_basis() != p:
71+
raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).")
72+
if self.num_observations() != n:
73+
raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).")
6474
self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True)
6575

6676
def add_variance_weights(self, variance_weights: np.array):
@@ -74,6 +84,27 @@ def add_variance_weights(self, variance_weights: np.array):
7484
"""
7585
n = variance_weights.size
7686
self.dataset_cpp.AddVarianceWeights(variance_weights, n)
87+
88+
def update_variance_weights(self, variance_weights: np.array):
89+
"""
90+
Update variance weights in a dataset. Allows users to build an ensemble that depends on
91+
variance weights that are updated throughout the sampler.
92+
93+
Parameters
94+
----------
95+
variance_weights : np.array
96+
Univariate numpy array of variance weights.
97+
"""
98+
n = variance_weights.size
99+
if not self.has_variance_weights():
100+
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
101+
if not isinstance(variance_weights, np.ndarray):
102+
raise ValueError("variance_weights must be a numpy array.")
103+
if variance_weights.ndim != 1:
104+
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
105+
if self.num_observations() != n:
106+
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
107+
self.dataset_cpp.AddVarianceWeights(variance_weights, n)
77108

78109
def num_observations(self) -> int:
79110
"""

0 commit comments

Comments
 (0)