You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
raiseValueError("type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'")
418
+
ifself.rfx_model_specnotin [
419
+
"custom",
420
+
"intercept_only",
421
+
"intercept_plus_treatment",
422
+
]:
423
+
raiseValueError(
424
+
"type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'"
425
+
)
418
426
419
427
# Override keep_gfr if there are no MCMC samples
420
428
ifnum_mcmc==0:
@@ -2295,7 +2303,7 @@ def predict(
2295
2303
) ->Union[dict[str, np.array], np.array]:
2296
2304
"""Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation.
2297
2305
Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function.
2298
-
When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or
2306
+
When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or
2299
2307
partially included in mu_x and partially included in tau_x `rfx_model_spec == "intercept_plus_treatment"`.
2300
2308
2301
2309
Parameters
@@ -2508,9 +2516,13 @@ def predict(
2508
2516
ifrfx_basis.ndim==1:
2509
2517
rfx_basis=np.expand_dims(rfx_basis, 1)
2510
2518
ifrfx_basis.shape[0] !=X.shape[0]:
2511
-
raiseValueError("X and rfx_basis must have the same number of rows")
2519
+
raiseValueError(
2520
+
"X and rfx_basis must have the same number of rows"
2521
+
)
2512
2522
ifrfx_basis.shape[1] !=self.num_rfx_basis:
2513
-
raiseValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
2523
+
raiseValueError(
2524
+
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
2525
+
)
2514
2526
2515
2527
# Random effects predictions
2516
2528
ifpredict_rfxorpredict_rfx_intermediate:
@@ -2522,26 +2534,28 @@ def predict(
2522
2534
ifpredict_rfx_raw:
2523
2535
# Extract the raw RFX samples and scale by train set outcome standard deviation
0 commit comments