@@ -716,18 +716,28 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
716716 // Subsample features (if requested)
717717 std::vector<bool > feature_subset (p, true );
718718 if (num_features_subsample < p) {
719- std::vector<int > feature_indices (p);
720- std::iota (feature_indices.begin (), feature_indices.end (), 0 );
721- std::vector<int > features_selected (num_features_subsample);
722- sample_without_replacement<int , double >(
723- features_selected.data (), variable_weights.data (), feature_indices.data (),
724- p, num_features_subsample, gen
725- );
726- for (int i = 0 ; i < p; i++) {
727- feature_subset.at (i) = false ;
719+ // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
720+ int number_nonzero_weights = 0 ;
721+ for (int j = 0 ; j < p; j++) {
722+ if (std::abs (variable_weights.at (j)) > kEpsilon ) {
723+ number_nonzero_weights++;
724+ }
728725 }
729- for (const auto & feat : features_selected) {
730- feature_subset.at (feat) = true ;
726+ if (number_nonzero_weights > num_features_subsample) {
727+ // Sample with replacement according to variable_weights
728+ std::vector<int > feature_indices (p);
729+ std::iota (feature_indices.begin (), feature_indices.end (), 0 );
730+ std::vector<int > features_selected (num_features_subsample);
731+ sample_without_replacement<int , double >(
732+ features_selected.data (), variable_weights.data (), feature_indices.data (),
733+ p, num_features_subsample, gen
734+ );
735+ for (int i = 0 ; i < p; i++) {
736+ feature_subset.at (i) = false ;
737+ }
738+ for (const auto & feat : features_selected) {
739+ feature_subset.at (feat) = true ;
740+ }
731741 }
732742 }
733743
@@ -782,6 +792,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
782792 * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon).
783793 * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via
784794 * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
795+ * \param num_features_subsample How many features to subsample when running the GFR algorithm.
785796 * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object.
786797 */
787798template <typename LeafModel, typename LeafSuffStat, typename ... LeafSuffStatConstructorArgs>
0 commit comments