Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Changelog

# stochtree 0.2.1.9000

## Bug Fixes

* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))

# stochtree 0.2.1

## Bug Fixes

* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256))

## Other Changes

Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# stochtree 0.2.1.9000

## Bug Fixes

* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))

# stochtree 0.2.1

## Bug Fixes

* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256))

## Other Changes

Expand Down
18 changes: 8 additions & 10 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -2134,17 +2134,15 @@ predict.bartmodel <- function(
X <- preprocessPredictionData(X, train_set_metadata)

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
if (predict_rfx) {
if (!is.null(rfx_group_ids)) {
rfx_unique_group_ids <- object$rfx_unique_group_ids
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
if (sum(is.na(group_ids_factor)) > 0) {
stop(
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
)
}
rfx_group_ids <- as.integer(group_ids_factor)
if (!is.null(rfx_group_ids)) {
rfx_unique_group_ids <- object$rfx_unique_group_ids
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
if (sum(is.na(group_ids_factor)) > 0) {
stop(
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
)
}
rfx_group_ids <- as.integer(group_ids_factor)
}

# Handle RFX model specification
Expand Down
2 changes: 0 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -3040,7 +3040,6 @@ predict.bcfmodel <- function(
X <- preprocessPredictionData(X, train_set_metadata)

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- FALSE
if (!is.null(rfx_group_ids)) {
rfx_unique_group_ids <- object$rfx_unique_group_ids
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
Expand All @@ -3050,7 +3049,6 @@ predict.bcfmodel <- function(
)
}
rfx_group_ids <- as.integer(group_ids_factor)
has_rfx <- TRUE
}

# Handle RFX model specification
Expand Down
16 changes: 15 additions & 1 deletion src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,18 @@ class RandomEffectsLabelMapperCpp {
StochTree::LabelMapper* GetLabelMapper() {
return rfx_label_mapper_.get();
}
int MapGroupIdToArrayIndex(int original_label) {
return rfx_label_mapper_->CategoryNumber(original_label);
}
py::array_t<int> MapMultipleGroupIdsToArrayIndices(py::array_t<int> original_labels) {
int output_size = original_labels.size();
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({output_size}));
auto accessor = result.mutable_unchecked<1>();
for (int i = 0; i < output_size; i++) {
accessor(i) = rfx_label_mapper_->CategoryNumber(original_labels.at(i));
}
return result;
}

private:
std::unique_ptr<StochTree::LabelMapper> rfx_label_mapper_;
Expand Down Expand Up @@ -2410,7 +2422,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("DumpJsonString", &RandomEffectsLabelMapperCpp::DumpJsonString)
.def("LoadFromJsonString", &RandomEffectsLabelMapperCpp::LoadFromJsonString)
.def("LoadFromJson", &RandomEffectsLabelMapperCpp::LoadFromJson)
.def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper);
.def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper)
.def("MapGroupIdToArrayIndex", &RandomEffectsLabelMapperCpp::MapGroupIdToArrayIndex)
.def("MapMultipleGroupIdsToArrayIndices", &RandomEffectsLabelMapperCpp::MapMultipleGroupIdsToArrayIndices);

py::class_<RandomEffectsModelCpp>(m, "RandomEffectsModelCpp")
.def(py::init<int, int>())
Expand Down
8 changes: 7 additions & 1 deletion stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,12 @@ def predict(
"Random effects basis has a different dimension than the basis used to train this model"
)

# Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays
if rfx_group_ids is not None:
rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices(
rfx_group_ids
)

# Random effects predictions
if predict_rfx or predict_rfx_intermediate:
if rfx_basis is not None:
Expand Down Expand Up @@ -2017,7 +2023,7 @@ def predict(
)
for i in range(n_train):
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
rfx_group_ids[i], :
rfx_group_id_indices[i], :
]
rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :])

Expand Down
10 changes: 8 additions & 2 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,6 +3053,12 @@ def predict(
raise ValueError(
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
)

# Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays
if rfx_group_ids is not None:
rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices(
rfx_group_ids
)

# Random effects predictions
if predict_rfx or predict_rfx_intermediate:
Expand All @@ -3073,14 +3079,14 @@ def predict(
)
for i in range(X.shape[0]):
rfx_predictions_raw[i, :, :] = rfx_beta_draws[
:, rfx_group_ids[i], :
:, rfx_group_id_indices[i], :
]
elif rfx_beta_draws.ndim == 2:
rfx_predictions_raw = np.empty(
shape=(X.shape[0], 1, rfx_beta_draws.shape[1])
)
for i in range(X.shape[0]):
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :]
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_id_indices[i], :]
else:
raise ValueError(
"Unexpected number of dimensions in extracted random effects samples"
Expand Down
32 changes: 32 additions & 0 deletions stochtree/random_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,38 @@ def extract_parameter_samples(self) -> dict[str, np.ndarray]:
"sigma_samples": sigma_samples,
}
return output

def map_group_id_to_array_index(self, group_id: int) -> int:
"""
Map an integer-valued random effects group ID to its group's corresponding position in the arrays that store random effects parameter samples.

Parameters
----------
group_id : int
Group identifier to be converted to an array position.

Returns
-------
int
The position of `group_id` in the parameter sample arrays underlying the random effects container.
"""
return self.rfx_label_mapper_cpp.MapGroupIdToArrayIndex(group_id)

def map_group_ids_to_array_indices(self, group_ids: np.ndarray) -> np.ndarray:
"""
Map an array of integer-valued random effects group IDs to their groups' corresponding positions in the arrays that store random effects parameter samples.

Parameters
----------
group_ids : np.ndarray
Array of group identifiers (integer-valued) to be converted to an array position.

Returns
-------
np.ndarray
Numpy array of the position of `group_id` in the parameter sample arrays underlying the random effects container.
"""
return self.rfx_label_mapper_cpp.MapMultipleGroupIdsToArrayIndices(group_ids)


class RandomEffectsModel:
Expand Down
Loading
Loading