@@ -669,6 +669,10 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
669669 // Prepare the samplers
670670 LeafModelVariant leaf_model = leafModelFactory (model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest);
671671
672+ // Initialize vector of sweep update indices
673+ std::vector<int > sweep_indices (num_trees);
674+ std::iota (sweep_indices.begin (), sweep_indices.end (), 0 );
675+
672676 // Run the GFR sampler
673677 if (num_gfr > 0 ) {
674678 for (int i = 0 ; i < num_gfr; i++) {
@@ -683,13 +687,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
683687
684688 // Sample tree ensemble
685689 if (model_type == ModelType::kConstantLeafGaussian ) {
686- GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true , true , true );
690+ GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true , true , true );
687691 } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian ) {
688- GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true , true , true );
692+ GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true , true , true );
689693 } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian ) {
690- GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int >(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true , true , true , omega_cols);
694+ GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int >(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true , true , true , omega_cols);
691695 } else if (model_type == ModelType::kLogLinearVariance ) {
692- GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true , true , false );
696+ GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true , true , false );
693697 }
694698
695699 if (rfx_included) {
@@ -720,13 +724,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
720724
721725 // Sample tree ensemble
722726 if (model_type == ModelType::kConstantLeafGaussian ) {
723- MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true , true , true );
727+ MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true , true , true );
724728 } else if (model_type == ModelType::kUnivariateRegressionLeafGaussian ) {
725- MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true , true , true );
729+ MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true , true , true );
726730 } else if (model_type == ModelType::kMultivariateRegressionLeafGaussian ) {
727- MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int >(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true , true , true , omega_cols);
731+ MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int >(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true , true , true , omega_cols);
728732 } else if (model_type == ModelType::kLogLinearVariance ) {
729- MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true , true , false );
733+ MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true , true , false );
730734 }
731735
732736 if (rfx_included) {
0 commit comments