@@ -15,6 +15,9 @@ TEST(LeafConstantModel, FullEnumeration) {
1515 test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis ();
1616 std::vector<StochTree::FeatureType> feature_types (test_dataset.x_cols , StochTree::FeatureType::kNumeric );
1717 std::vector<double > variable_weights (test_dataset.x_cols , 1 ./test_dataset.x_cols );
18+ std::vector<bool > feature_subset (test_dataset.x_cols , true );
19+ std::random_device rd;
20+ std::mt19937 gen (rd ());
1821
1922 // Construct datasets
2023 using data_size_t = StochTree::data_size_t ;
@@ -51,8 +54,8 @@ TEST(LeafConstantModel, FullEnumeration) {
5154
5255 // Evaluate all possible cutpoints
5356 StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
54- dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
55- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
57+ dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
58+ cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types, feature_subset
5659 );
5760
5861 // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -74,6 +77,9 @@ TEST(LeafConstantModel, CutpointThinning) {
7477 test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis ();
7578 std::vector<StochTree::FeatureType> feature_types (test_dataset.x_cols , StochTree::FeatureType::kNumeric );
7679 std::vector<double > variable_weights (test_dataset.x_cols , 1 ./test_dataset.x_cols );
80+ std::vector<bool > feature_subset (test_dataset.x_cols , true );
81+ std::random_device rd;
82+ std::mt19937 gen (rd ());
7783
7884 // Construct datasets
7985 using data_size_t = StochTree::data_size_t ;
@@ -110,8 +116,8 @@ TEST(LeafConstantModel, CutpointThinning) {
110116
111117 // Evaluate all possible cutpoints
112118 StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113- dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
114- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
119+ dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
120+ cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types, feature_subset
115121 );
116122
117123 // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -132,6 +138,9 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
132138 test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis ();
133139 std::vector<StochTree::FeatureType> feature_types (test_dataset.x_cols , StochTree::FeatureType::kNumeric );
134140 std::vector<double > variable_weights (test_dataset.x_cols , 1 ./test_dataset.x_cols );
141+ std::vector<bool > feature_subset (test_dataset.x_cols , true );
142+ std::random_device rd;
143+ std::mt19937 gen (rd ());
135144
136145 // Construct datasets
137146 using data_size_t = StochTree::data_size_t ;
@@ -169,8 +178,8 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
169178
170179 // Evaluate all possible cutpoints
171180 StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
172- dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
173- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
181+ dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
182+ cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types, feature_subset
174183 );
175184
176185 // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -192,6 +201,9 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
192201 test_dataset = StochTree::TestUtils::LoadSmallDatasetUnivariateBasis ();
193202 std::vector<StochTree::FeatureType> feature_types (test_dataset.x_cols , StochTree::FeatureType::kNumeric );
194203 std::vector<double > variable_weights (test_dataset.x_cols , 1 ./test_dataset.x_cols );
204+ std::vector<bool > feature_subset (test_dataset.x_cols , true );
205+ std::random_device rd;
206+ std::mt19937 gen (rd ());
195207
196208 // Construct datasets
197209 using data_size_t = StochTree::data_size_t ;
@@ -229,8 +241,8 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
229241
230242 // Evaluate all possible cutpoints
231243 StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
232- dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
233- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
244+ dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
245+ cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types, feature_subset
234246 );
235247
236248
0 commit comments