|
1 | 1 | import numpy as np |
| 2 | +import pytest |
2 | 3 | from sklearn.model_selection import train_test_split |
3 | 4 |
|
4 | 5 | from stochtree import BARTModel |
@@ -1009,3 +1010,157 @@ def conditional_stddev(X): |
1009 | 1010 | np.testing.assert_allclose( |
1010 | 1011 | rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 |
1011 | 1012 | ) |
| 1013 | + |
| 1014 | + def test_bart_rfx_parameters(self): |
| 1015 | + # RNG |
| 1016 | + random_seed = 101 |
| 1017 | + rng = np.random.default_rng(random_seed) |
| 1018 | + |
| 1019 | + # Generate covariates and basis |
| 1020 | + n = 100 |
| 1021 | + p_X = 10 |
| 1022 | + p_W = 1 |
| 1023 | + X = rng.uniform(0, 1, (n, p_X)) |
| 1024 | + W = rng.uniform(0, 1, (n, p_W)) |
| 1025 | + |
| 1026 | + # Generate RFX group labels and basis term |
| 1027 | + num_rfx_basis = 2 |
| 1028 | + num_rfx_groups = 4 |
| 1029 | + group_labels = rng.choice(num_rfx_groups, size=n) |
| 1030 | + rfx_basis = np.empty((n, num_rfx_basis)) |
| 1031 | + rfx_basis[:, 0] = 1.0 |
| 1032 | + if num_rfx_basis > 1: |
| 1033 | + rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) |
| 1034 | + |
| 1035 | + # Define the outcome mean function |
| 1036 | + def outcome_mean(X, W): |
| 1037 | + return np.where( |
| 1038 | + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), |
| 1039 | + -7.5 * W[:, 0], |
| 1040 | + np.where( |
| 1041 | + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), |
| 1042 | + -2.5 * W[:, 0], |
| 1043 | + np.where( |
| 1044 | + (X[:, 0] >= 0.5) & (X[:, 0] < 0.75), |
| 1045 | + 2.5 * W[:, 0], |
| 1046 | + 7.5 * W[:, 0], |
| 1047 | + ), |
| 1048 | + ), |
| 1049 | + ) |
| 1050 | + |
| 1051 | + # Define the group rfx function |
| 1052 | + def rfx_term(group_labels, basis): |
| 1053 | + return np.where( |
| 1054 | + group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1] |
| 1055 | + ) |
| 1056 | + |
| 1057 | + # Define the conditional standard deviation function |
| 1058 | + def conditional_stddev(X): |
| 1059 | + return np.where( |
| 1060 | + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), |
| 1061 | + 0.25, |
| 1062 | + np.where( |
| 1063 | + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), |
| 1064 | + 0.5, |
| 1065 | + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2), |
| 1066 | + ), |
| 1067 | + ) |
| 1068 | + |
| 1069 | + # Generate outcome |
| 1070 | + epsilon = rng.normal(0, 1, n) |
| 1071 | + y = ( |
| 1072 | + outcome_mean(X, W) |
| 1073 | + + rfx_term(group_labels, rfx_basis) |
| 1074 | + + epsilon * conditional_stddev(X) |
| 1075 | + ) |
| 1076 | + |
| 1077 | + # Test-train split |
| 1078 | + sample_inds = np.arange(n) |
| 1079 | + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) |
| 1080 | + X_train = X[train_inds, :] |
| 1081 | + X_test = X[test_inds, :] |
| 1082 | + basis_train = W[train_inds, :] |
| 1083 | + basis_test = W[test_inds, :] |
| 1084 | + group_labels_train = group_labels[train_inds] |
| 1085 | + group_labels_test = group_labels[test_inds] |
| 1086 | + rfx_basis_train = rfx_basis[train_inds, :] |
| 1087 | + rfx_basis_test = rfx_basis[test_inds, :] |
| 1088 | + y_train = y[train_inds] |
| 1089 | + n_train = X_train.shape[0] |
| 1090 | + n_test = X_test.shape[0] |
| 1091 | + |
| 1092 | + # BART settings |
| 1093 | + num_gfr = 10 |
| 1094 | + num_burnin = 0 |
| 1095 | + num_mcmc = 10 |
| 1096 | + |
| 1097 | + # Specify no rfx parameters |
| 1098 | + general_params = {} |
| 1099 | + bart_model = BARTModel() |
| 1100 | + bart_model.sample( |
| 1101 | + X_train=X_train, |
| 1102 | + y_train=y_train, |
| 1103 | + leaf_basis_train=basis_train, |
| 1104 | + rfx_group_ids_train=group_labels_train, |
| 1105 | + rfx_basis_train=rfx_basis_train, |
| 1106 | + X_test=X_test, |
| 1107 | + leaf_basis_test=basis_test, |
| 1108 | + rfx_group_ids_test=group_labels_test, |
| 1109 | + rfx_basis_test=rfx_basis_test, |
| 1110 | + num_gfr=num_gfr, |
| 1111 | + num_burnin=num_burnin, |
| 1112 | + num_mcmc=num_mcmc, |
| 1113 | + general_params=general_params, |
| 1114 | + ) |
| 1115 | + |
| 1116 | + # Specify scalar rfx parameters |
| 1117 | + general_params = { |
| 1118 | + "rfx_working_parameter_prior_mean": 1., |
| 1119 | + "rfx_group_parameter_prior_mean": 1., |
| 1120 | + "rfx_working_parameter_prior_cov": 1., |
| 1121 | + "rfx_group_parameter_prior_cov": 1., |
| 1122 | + "rfx_variance_prior_shape": 1, |
| 1123 | + "rfx_variance_prior_scale": 1 |
| 1124 | + } |
| 1125 | + bart_model_2 = BARTModel() |
| 1126 | + bart_model_2.sample( |
| 1127 | + X_train=X_train, |
| 1128 | + y_train=y_train, |
| 1129 | + leaf_basis_train=basis_train, |
| 1130 | + rfx_group_ids_train=group_labels_train, |
| 1131 | + rfx_basis_train=rfx_basis_train, |
| 1132 | + X_test=X_test, |
| 1133 | + leaf_basis_test=basis_test, |
| 1134 | + rfx_group_ids_test=group_labels_test, |
| 1135 | + rfx_basis_test=rfx_basis_test, |
| 1136 | + num_gfr=num_gfr, |
| 1137 | + num_burnin=num_burnin, |
| 1138 | + num_mcmc=num_mcmc, |
| 1139 | + general_params=general_params, |
| 1140 | + ) |
| 1141 | + |
| 1142 | + # Specify all relevant rfx parameters as vectors |
| 1143 | + general_params = { |
| 1144 | + "rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis), |
| 1145 | + "rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis), |
| 1146 | + "rfx_working_parameter_prior_cov": np.identity(num_rfx_basis), |
| 1147 | + "rfx_group_parameter_prior_cov": np.identity(num_rfx_basis), |
| 1148 | + "rfx_variance_prior_shape": 1, |
| 1149 | + "rfx_variance_prior_scale": 1 |
| 1150 | + } |
| 1151 | + bart_model_2 = BARTModel() |
| 1152 | + bart_model_2.sample( |
| 1153 | + X_train=X_train, |
| 1154 | + y_train=y_train, |
| 1155 | + leaf_basis_train=basis_train, |
| 1156 | + rfx_group_ids_train=group_labels_train, |
| 1157 | + rfx_basis_train=rfx_basis_train, |
| 1158 | + X_test=X_test, |
| 1159 | + leaf_basis_test=basis_test, |
| 1160 | + rfx_group_ids_test=group_labels_test, |
| 1161 | + rfx_basis_test=rfx_basis_test, |
| 1162 | + num_gfr=num_gfr, |
| 1163 | + num_burnin=num_burnin, |
| 1164 | + num_mcmc=num_mcmc, |
| 1165 | + general_params=general_params, |
| 1166 | + ) |
0 commit comments