Skip to content

Commit 8cf4bed

Browse files
committed
Updated python and R data interfaces
1 parent 4d1baa9 commit 8cf4bed

File tree

6 files changed

+124
-6
lines changed

6 files changed

+124
-6
lines changed

include/stochtree/data.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,21 @@ class RandomEffectsDataset {
553553
var_weights_.SetElement(i, temp_value);
554554
}
555555
}
556+
/*!
557+
* \brief Update a RandomEffectsDataset's group indices
558+
*
559+
* \param data_ptr Pointer to first element of a contiguous array of data storing a weight vector
560+
* \param num_row Number of rows in the weight vector
561+
* \param exponentiate Whether or not inputs should be exponentiated before being saved to var weight vector
562+
*/
563+
void UpdateGroupLabels(std::vector<int32_t>& group_labels, data_size_t num_row) {
564+
CHECK(has_group_labels_);
565+
CHECK_EQ(this->NumObservations(), num_row)
566+
// Copy data from R / Python process memory to internal vector
567+
for (data_size_t i = 0; i < num_row; ++i) {
568+
group_labels_[i] = group_labels[i];
569+
}
570+
}
556571
/*!
557572
* \brief Copy / load group indices for random effects
558573
*

src/R_data.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ void rfx_dataset_update_var_weights_cpp(cpp11::external_pointer<StochTree::Rando
217217
UNPROTECT(1);
218218
}
219219

220+
[[cpp11::register]]
221+
void rfx_dataset_update_group_labels_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset_ptr, cpp11::integers group_labels) {
222+
// Update group labels
223+
int n = group_labels.size();
224+
std::vector<int32_t> group_labels_vec(group_labels.begin(), group_labels.end());
225+
dataset_ptr->UpdateGroupLabels(group_labels_vec, n);
226+
}
227+
220228
[[cpp11::register]]
221229
int rfx_dataset_num_basis_cpp(cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset) {
222230
return dataset->NumBases();

src/py_stochtree.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class ForestDatasetCpp {
7272
double* data_ptr = static_cast<double*>(weight_vector.mutable_data());
7373

7474
// Load covariates
75-
dataset_->AddVarianceWeights(data_ptr, num_row);
75+
dataset_->UpdateVarWeights(data_ptr, num_row);
7676
}
7777

7878
data_size_t NumRows() {
@@ -1297,6 +1297,22 @@ class RandomEffectsDatasetCpp {
12971297
double* weight_data_ptr = static_cast<double*>(weights.mutable_data());
12981298
rfx_dataset_->AddVarianceWeights(weight_data_ptr, num_row);
12991299
}
1300+
void UpdateBasis(py::array_t<double> basis, data_size_t num_row, int num_col, bool row_major) {
1301+
double* basis_data_ptr = static_cast<double*>(basis.mutable_data());
1302+
rfx_dataset_->UpdateBasis(basis_data_ptr, num_row, num_col, row_major);
1303+
}
1304+
void UpdateVarianceWeights(py::array_t<double> weights, data_size_t num_row, bool exponentiate) {
1305+
double* weight_data_ptr = static_cast<double*>(weights.mutable_data());
1306+
rfx_dataset_->UpdateVarWeights(weight_data_ptr, num_row, exponentiate);
1307+
}
1308+
void UpdateGroupLabels(py::array_t<int> group_labels, data_size_t num_row) {
1309+
std::vector<int> group_labels_vec(num_row);
1310+
auto accessor = group_labels.mutable_unchecked<1>();
1311+
for (py::ssize_t i = 0; i < num_row; i++) {
1312+
group_labels_vec[i] = accessor(i);
1313+
}
1314+
rfx_dataset_->UpdateGroupLabels(group_labels_vec, num_row);
1315+
}
13001316
bool HasGroupLabels() {return rfx_dataset_->HasGroupLabels();}
13011317
bool HasBasis() {return rfx_dataset_->HasBasis();}
13021318
bool HasVarianceWeights() {return rfx_dataset_->HasVarWeights();}

stochtree/data.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,14 @@ def add_variance_weights(self, variance_weights: np.array):
8585
variance_weights : np.array
8686
Univariate numpy array of variance weights.
8787
"""
88-
n = variance_weights.size
89-
self.dataset_cpp.AddVarianceWeights(variance_weights, n)
88+
if not isinstance(variance_weights, np.ndarray):
89+
raise ValueError("variance_weights must be a numpy array.")
90+
variance_weights_ = np.squeeze(variance_weights)
91+
n = variance_weights_.size
92+
if variance_weights_.ndim != 1:
93+
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
94+
95+
self.dataset_cpp.AddVarianceWeights(variance_weights_, n)
9096

9197
def update_variance_weights(self, variance_weights: np.array):
9298
"""
@@ -98,16 +104,17 @@ def update_variance_weights(self, variance_weights: np.array):
98104
variance_weights : np.array
99105
Univariate numpy array of variance weights.
100106
"""
101-
n = variance_weights.size
102107
if not self.has_variance_weights():
103108
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
104109
if not isinstance(variance_weights, np.ndarray):
105110
raise ValueError("variance_weights must be a numpy array.")
106-
if variance_weights.ndim != 1:
111+
variance_weights_ = np.squeeze(variance_weights)
112+
n = variance_weights_.size
113+
if variance_weights_.ndim != 1:
107114
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
108115
if self.num_observations() != n:
109116
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
110-
self.dataset_cpp.AddVarianceWeights(variance_weights, n)
117+
self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n)
111118

112119
def num_observations(self) -> int:
113120
"""

stochtree/random_effects.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@ def add_group_labels(self, group_labels: np.array):
4040
n = group_labels_.shape[0]
4141
self.rfx_dataset_cpp.AddGroupLabels(group_labels_, n)
4242

43+
def update_group_labels(self, group_labels: np.array):
44+
"""
45+
Update group labels in a dataset
46+
47+
Parameters
48+
----------
49+
group_labels : np.array
50+
One-dimensional numpy array of group labels.
51+
"""
52+
group_labels_ = np.squeeze(group_labels)
53+
if group_labels_.ndim > 1:
54+
raise ValueError(
55+
"group_labels must be a one-dimensional numpy array of group indices"
56+
)
57+
n = group_labels_.shape[0]
58+
self.rfx_dataset_cpp.UpdateGroupLabels(group_labels_, n)
59+
4360
def add_basis(self, basis: np.array):
4461
"""
4562
Add basis matrix to a dataset
@@ -93,6 +110,30 @@ def add_variance_weights(self, variance_weights: np.array):
93110
)
94111
n = variance_weights_.shape[0]
95112
self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n)
113+
114+
def update_variance_weights(self, variance_weights: np.array):
115+
"""
116+
Update variance weights in a dataset. Allows users to build an ensemble that depends on
117+
variance weights that are updated throughout the sampler.
118+
119+
Parameters
120+
----------
121+
variance_weights : np.array
122+
Univariate numpy array of variance weights.
123+
"""
124+
if not self.has_variance_weights():
125+
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
126+
if not isinstance(variance_weights, np.ndarray):
127+
raise ValueError("variance_weights must be a numpy array.")
128+
variance_weights_ = np.squeeze(variance_weights)
129+
if variance_weights_.ndim > 1:
130+
raise ValueError(
131+
"variance_weights must be a one-dimensional numpy array of group indices"
132+
)
133+
n = variance_weights_.shape[0]
134+
if self.num_observations() != n:
135+
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
136+
self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n)
96137

97138
def num_observations(self) -> int:
98139
"""

test/python/test_data.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
3+
from stochtree import Dataset
4+
5+
class TestDataset:
6+
def test_dataset_update(self):
7+
# Generate data
8+
n = 20
9+
num_covariates = 10
10+
num_basis = 5
11+
rng = np.random.default_rng()
12+
covariates = rng.uniform(0, 1, size=(n, num_covariates))
13+
basis = rng.uniform(0, 1, size=(n, num_basis))
14+
variance_weights = rng.uniform(0, 1, size=n)
15+
16+
# Construct dataset
17+
forest_dataset = Dataset()
18+
forest_dataset.add_covariates(covariates)
19+
forest_dataset.add_basis(basis)
20+
forest_dataset.add_variance_weights(variance_weights)
21+
assert forest_dataset.num_observations() == n
22+
assert forest_dataset.num_covariates() == num_covariates
23+
assert forest_dataset.num_basis() == num_basis
24+
assert forest_dataset.has_variance_weights()
25+
26+
# Update dataset
27+
new_basis = rng.uniform(0, 1, size=(n, num_basis))
28+
new_variance_weights = rng.uniform(0, 1, size=n)
29+
with np.testing.assert_no_warnings():
30+
forest_dataset.update_basis(new_basis)
31+
forest_dataset.update_variance_weights(new_variance_weights)

0 commit comments

Comments
 (0)