Skip to content

Commit e96268f

Browse files
committed
Deprecate predict_tau and predict_variance and expand predict method
1 parent 7dc71ce commit e96268f

File tree

1 file changed

+15
-217
lines changed

1 file changed

+15
-217
lines changed

stochtree/bcf.py

Lines changed: 15 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,222 +2232,6 @@ def sample(
22322232
sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std
22332233
)
22342234

2235-
def predict_tau(
2236-
self, X: np.array, Z: np.array, propensity: np.array = None
2237-
) -> np.array:
2238-
"""Predict CATE function for every provided observation.
2239-
2240-
Parameters
2241-
----------
2242-
X : np.array or pd.DataFrame
2243-
Test set covariates.
2244-
Z : np.array
2245-
Test set treatment indicators.
2246-
propensity : np.array, optional
2247-
Optional test set propensities. Must be provided if propensities were provided when the model was sampled.
2248-
2249-
Returns
2250-
-------
2251-
np.array
2252-
Array with as many rows as in `X` and as many columns as retained samples of the algorithm.
2253-
"""
2254-
if not self.is_sampled():
2255-
msg = (
2256-
"This BCFModel instance is not fitted yet. Call 'fit' with "
2257-
"appropriate arguments before using this model."
2258-
)
2259-
raise NotSampledError(msg)
2260-
2261-
# Convert everything to standard shape (2-dimensional)
2262-
if X.ndim == 1:
2263-
X = np.expand_dims(X, 1)
2264-
if Z.ndim == 1:
2265-
Z = np.expand_dims(Z, 1)
2266-
else:
2267-
if Z.ndim != 2:
2268-
raise ValueError("treatment must have 1 or 2 dimensions")
2269-
if propensity is not None:
2270-
if propensity.ndim == 1:
2271-
propensity = np.expand_dims(propensity, 1)
2272-
2273-
# Data checks
2274-
if Z.shape[0] != X.shape[0]:
2275-
raise ValueError("X and Z must have the same number of rows")
2276-
2277-
# Covariate preprocessing
2278-
if not self._covariate_preprocessor._check_is_fitted():
2279-
if not isinstance(X, np.ndarray):
2280-
raise ValueError(
2281-
"Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2282-
)
2283-
else:
2284-
warnings.warn(
2285-
"This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
2286-
RuntimeWarning,
2287-
)
2288-
if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype(
2289-
X.dtype, np.integer
2290-
):
2291-
raise ValueError(
2292-
"Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2293-
)
2294-
covariates_processed = X
2295-
else:
2296-
covariates_processed = self._covariate_preprocessor.transform(X)
2297-
2298-
# Handle propensities
2299-
if propensity is not None:
2300-
if propensity.shape[0] != X.shape[0]:
2301-
raise ValueError("X and propensity must have the same number of rows")
2302-
else:
2303-
if self.propensity_covariate != "none":
2304-
if not self.internal_propensity_model:
2305-
raise ValueError(
2306-
"Propensity scores not provided, but no propensity model was trained during sampling"
2307-
)
2308-
else:
2309-
internal_propensity_preds = self.bart_propensity_model.predict(
2310-
covariates_processed
2311-
)
2312-
propensity = np.mean(
2313-
internal_propensity_preds["y_hat"], axis=1, keepdims=True
2314-
)
2315-
2316-
# Update covariates to include propensities if requested
2317-
if self.propensity_covariate == "none":
2318-
X_combined = covariates_processed
2319-
else:
2320-
X_combined = np.c_[covariates_processed, propensity]
2321-
2322-
# Forest dataset
2323-
forest_dataset_test = Dataset()
2324-
forest_dataset_test.add_covariates(X_combined)
2325-
forest_dataset_test.add_basis(Z)
2326-
2327-
# Estimate treatment effect
2328-
tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw(
2329-
forest_dataset_test.dataset_cpp
2330-
)
2331-
tau_raw = tau_raw
2332-
if self.adaptive_coding:
2333-
adaptive_coding_weights = np.expand_dims(
2334-
self.b1_samples - self.b0_samples, axis=(0, 2)
2335-
)
2336-
tau_raw = tau_raw * adaptive_coding_weights
2337-
tau_x = np.squeeze(tau_raw * self.y_std)
2338-
2339-
# Return result matrix
2340-
return tau_x
2341-
2342-
def predict_variance(
2343-
self, covariates: np.array, propensity: np.array = None
2344-
) -> np.array:
2345-
"""Predict expected conditional variance from a BART model.
2346-
2347-
Parameters
2348-
----------
2349-
covariates : np.array
2350-
Test set covariates.
2351-
propensity : np.array, optional
2352-
Test set propensity scores. Optional (not currently used in variance forests).
2353-
2354-
Returns
2355-
-------
2356-
np.array
2357-
Array of predictions corresponding to the variance forest. Each array will contain as many rows as in `covariates` and as many columns as retained samples of the algorithm.
2358-
"""
2359-
if not self.is_sampled():
2360-
msg = (
2361-
"This BARTModel instance is not fitted yet. Call 'fit' with "
2362-
"appropriate arguments before using this model."
2363-
)
2364-
raise NotSampledError(msg)
2365-
2366-
if not self.include_variance_forest:
2367-
msg = (
2368-
"This BARTModel instance was not sampled with a variance forest. "
2369-
"Call 'fit' with appropriate arguments before using this model."
2370-
)
2371-
raise NotSampledError(msg)
2372-
2373-
# Convert everything to standard shape (2-dimensional)
2374-
if covariates.ndim == 1:
2375-
covariates = np.expand_dims(covariates, 1)
2376-
if propensity is not None:
2377-
if propensity.ndim == 1:
2378-
propensity = np.expand_dims(propensity, 1)
2379-
2380-
# Covariate preprocessing
2381-
if not self._covariate_preprocessor._check_is_fitted():
2382-
if not isinstance(covariates, np.ndarray):
2383-
raise ValueError(
2384-
"Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2385-
)
2386-
else:
2387-
warnings.warn(
2388-
"This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
2389-
RuntimeWarning,
2390-
)
2391-
if not np.issubdtype(
2392-
covariates.dtype, np.floating
2393-
) and not np.issubdtype(covariates.dtype, np.integer):
2394-
raise ValueError(
2395-
"Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2396-
)
2397-
covariates_processed = covariates
2398-
else:
2399-
covariates_processed = self._covariate_preprocessor.transform(covariates)
2400-
2401-
# Handle propensities
2402-
if propensity is not None:
2403-
if propensity.shape[0] != covariates.shape[0]:
2404-
raise ValueError("X and propensity must have the same number of rows")
2405-
else:
2406-
if self.propensity_covariate != "none":
2407-
if not self.internal_propensity_model:
2408-
raise ValueError(
2409-
"Propensity scores not provided, but no propensity model was trained during sampling"
2410-
)
2411-
else:
2412-
internal_propensity_preds = self.bart_propensity_model.predict(
2413-
covariates_processed
2414-
)
2415-
propensity = np.mean(
2416-
internal_propensity_preds["y_hat"], axis=1, keepdims=True
2417-
)
2418-
2419-
# Update covariates to include propensities if requested
2420-
if self.propensity_covariate == "none":
2421-
X_combined = covariates_processed
2422-
else:
2423-
if propensity is not None:
2424-
X_combined = np.c_[covariates_processed, propensity]
2425-
else:
2426-
# Dummy propensities if not provided but also not needed
2427-
propensity = np.ones(covariates_processed.shape[0])
2428-
propensity = np.expand_dims(propensity, 1)
2429-
X_combined = np.c_[covariates_processed, propensity]
2430-
2431-
# Forest dataset
2432-
pred_dataset = Dataset()
2433-
pred_dataset.add_covariates(X_combined)
2434-
2435-
variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(
2436-
pred_dataset.dataset_cpp
2437-
)
2438-
if self.sample_sigma2_global:
2439-
variance_pred = np.empty_like(variance_pred_raw)
2440-
for i in range(self.num_samples):
2441-
variance_pred[:, i] = (
2442-
variance_pred_raw[:, i] * self.global_var_samples[i]
2443-
)
2444-
else:
2445-
variance_pred = (
2446-
variance_pred_raw * self.sigma2_init * self.y_std * self.y_std
2447-
)
2448-
2449-
return variance_pred
2450-
24512235
def predict(
24522236
self,
24532237
X: np.array,
@@ -2457,6 +2241,7 @@ def predict(
24572241
rfx_basis: np.array = None,
24582242
type: str = "posterior",
24592243
terms: Union[list[str], str] = "all",
2244+
scale: str = "linear"
24602245
) -> dict:
24612246
"""Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation.
24622247
Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function.
@@ -2473,16 +2258,29 @@ def predict(
24732258
Optional group labels used for an additive random effects model.
24742259
rfx_basis : np.array, optional
24752260
Optional basis for "random-slope" regression in an additive random effects model.
2476-
24772261
type : str, optional
24782262
Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
24792263
terms : str, optional
24802264
Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
2265+
scale : str, optional
2266+
Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with `probit_outcome_model=True`.
24812267
24822268
Returns
24832269
-------
24842270
Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested.
24852271
"""
2272+
# Handle mean function scale
2273+
if not isinstance(scale, str):
2274+
raise ValueError("scale must be a string")
2275+
if scale not in ["linear", "probability"]:
2276+
raise ValueError("scale must either be 'linear' or 'probability'")
2277+
is_probit = self.probit_outcome_model
2278+
if (scale == "probability") and (not is_probit):
2279+
raise ValueError(
2280+
"scale cannot be 'probability' for models not fit with a probit outcome model"
2281+
)
2282+
probability_scale = scale == "probability"
2283+
24862284
# Handle prediction type
24872285
if not isinstance(type, str):
24882286
raise ValueError("type must be a string")

0 commit comments

Comments
 (0)