Skip to content

Commit ee95a2f

Browse files
committed
Updated C++ unit tests and standalone program
1 parent e995300 commit ee95a2f

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

debug/api_debug.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
668668

669669
// Prepare the samplers
670670
LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest);
671+
int num_features_subsample = x_cols;
671672

672673
// Initialize vector of sweep update indices
673674
std::vector<int> sweep_indices(num_trees);
@@ -687,13 +688,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
687688

688689
// Sample tree ensemble
689690
if (model_type == ModelType::kConstantLeafGaussian) {
690-
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
691+
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample);
691692
} else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) {
692-
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
693+
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample);
693694
} else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) {
694-
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols);
695+
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, num_features_subsample, omega_cols);
695696
} else if (model_type == ModelType::kLogLinearVariance) {
696-
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false);
697+
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false, num_features_subsample);
697698
}
698699

699700
if (rfx_included) {

test/cpp/test_model.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ TEST(LeafConstantModel, FullEnumeration) {
1515
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
1616
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
1717
std::vector<double> variable_weights(test_dataset.x_cols, 1./test_dataset.x_cols);
18+
std::vector<bool> feature_subset(test_dataset.x_cols, true);
19+
std::random_device rd;
20+
std::mt19937 gen(rd());
1821

1922
// Construct datasets
2023
using data_size_t = StochTree::data_size_t;
@@ -51,8 +54,8 @@ TEST(LeafConstantModel, FullEnumeration) {
5154

5255
// Evaluate all possible cutpoints
5356
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
54-
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
55-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
57+
dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
58+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types, feature_subset
5659
);
5760

5861
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -74,6 +77,9 @@ TEST(LeafConstantModel, CutpointThinning) {
7477
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
7578
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
7679
std::vector<double> variable_weights(test_dataset.x_cols, 1./test_dataset.x_cols);
80+
std::vector<bool> feature_subset(test_dataset.x_cols, true);
81+
std::random_device rd;
82+
std::mt19937 gen(rd());
7783

7884
// Construct datasets
7985
using data_size_t = StochTree::data_size_t;
@@ -110,8 +116,8 @@ TEST(LeafConstantModel, CutpointThinning) {
110116

111117
// Evaluate all possible cutpoints
112118
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113-
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
114-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
119+
dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
120+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types, feature_subset
115121
);
116122

117123
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -132,6 +138,9 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
132138
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
133139
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
134140
std::vector<double> variable_weights(test_dataset.x_cols, 1./test_dataset.x_cols);
141+
std::vector<bool> feature_subset(test_dataset.x_cols, true);
142+
std::random_device rd;
143+
std::mt19937 gen(rd());
135144

136145
// Construct datasets
137146
using data_size_t = StochTree::data_size_t;
@@ -169,8 +178,8 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
169178

170179
// Evaluate all possible cutpoints
171180
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
172-
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
173-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
181+
dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
182+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types, feature_subset
174183
);
175184

176185
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -192,6 +201,9 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
192201
test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis();
193202
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);
194203
std::vector<double> variable_weights(test_dataset.x_cols, 1./test_dataset.x_cols);
204+
std::vector<bool> feature_subset(test_dataset.x_cols, true);
205+
std::random_device rd;
206+
std::mt19937 gen(rd());
195207

196208
// Construct datasets
197209
using data_size_t = StochTree::data_size_t;
@@ -229,8 +241,8 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
229241

230242
// Evaluate all possible cutpoints
231243
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
232-
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
233-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
244+
dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
245+
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types, feature_subset
234246
);
235247

236248

0 commit comments

Comments
 (0)