@@ -1093,14 +1093,6 @@ def sample(
10931093 else :
10941094 variable_subset_variance = [i for i in range (X_train .shape [1 ])]
10951095
1096- # Set num_features_subsample to default, ncol(X_train), if not already set
1097- if num_features_subsample_mu is None :
1098- num_features_subsample_mu = X_train .shape [1 ]
1099- if num_features_subsample_tau is None :
1100- num_features_subsample_tau = X_train .shape [1 ]
1101- if num_features_subsample_variance is None :
1102- num_features_subsample_variance = X_train .shape [1 ]
1103-
11041096 # Determine whether a test set is provided
11051097 self .has_test = X_test is not None
11061098
@@ -1498,6 +1490,14 @@ def sample(
14981490 # Store propensity score requirements of the BCF forests
14991491 self .propensity_covariate = propensity_covariate
15001492
1493+ # Set num_features_subsample to default, ncol(X_train), if not already set
1494+ if num_features_subsample_mu is None :
1495+ num_features_subsample_mu = X_train_processed .shape [1 ]
1496+ if num_features_subsample_tau is None :
1497+ num_features_subsample_tau = X_train_processed .shape [1 ]
1498+ if num_features_subsample_variance is None :
1499+ num_features_subsample_variance = X_train_processed .shape [1 ]
1500+
15011501 # Container of variance parameter samples
15021502 self .num_gfr = num_gfr
15031503 self .num_burnin = num_burnin
0 commit comments