@@ -303,8 +303,8 @@ class InstrumentalVariableRegression(ModelBuilder):
303303 ... "mus": [[-2,4], [0.5, 3]],
304304 ... "sigmas": [1, 1],
305305 ... "eta": 2,
306- ... "lkj_sd": 2 ,
307- ... })
306+ ... "lkj_sd": 1 ,
307+ ... }, None )
308308 Inference data...
309309 """
310310
@@ -340,7 +340,7 @@ def build_model(self, X, Z, y, t, coords, priors):
340340 sigma = priors ["sigmas" ][1 ],
341341 dims = "covariates" ,
342342 )
343- sd_dist = pm .HalfCauchy .dist (beta = priors ["lkj_sd" ], shape = 2 )
343+ sd_dist = pm .Exponential .dist (priors ["lkj_sd" ], shape = 2 )
344344 chol , corr , sigmas = pm .LKJCholeskyCov (
345345 name = "chol_cov" ,
346346 eta = priors ["eta" ],
@@ -366,24 +366,52 @@ def build_model(self, X, Z, y, t, coords, priors):
366366 shape = (X .shape [0 ], 2 ),
367367 )
368368
369- def fit (self , X , Z , y , t , coords , priors ):
370- """Draw samples from posterior, prior predictive, and posterior predictive
371- distributions.
369+ def sample_predictive_distribution (self , ppc_sampler = "jax" ):
370+ """Function to sample the Multivariate Normal posterior predictive
371+ Likelihood term in the IV class. This can be slow without
372+ using the JAX sampler compilation method. If using the
373+ JAX sampler it will sample only the posterior predictive distribution.
374+ If using the PYMC sampler if will sample both the prior
375+ and posterior predictive distributions."""
376+ random_seed = self .sample_kwargs .get ("random_seed" , None )
377+
378+ if ppc_sampler == "jax" :
379+ with self :
380+ self .idata .extend (
381+ pm .sample_posterior_predictive (
382+ self .idata ,
383+ random_seed = random_seed ,
384+ compile_kwargs = {"mode" : "JAX" },
385+ )
386+ )
387+ elif ppc_sampler == "pymc" :
388+ with self :
389+ self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
390+ self .idata .extend (
391+ pm .sample_posterior_predictive (
392+ self .idata ,
393+ random_seed = random_seed ,
394+ )
395+ )
396+
397+ def fit (self , X , Z , y , t , coords , priors , ppc_sampler = None ):
398+ """Draw samples from posterior distribution and potentially
399+ from the prior and posterior predictive distributions. The
400+ fit call can take values for the
401+ ppc_sampler = ['jax', 'pymc', None]
402+ We default to None, so the user can determine if they wish
403+ to spend time sampling the posterior predictive distribution
404+ independently.
372405 """
373406
374407 # Ensure random_seed is used in sample_prior_predictive() and
375408 # sample_posterior_predictive() if provided in sample_kwargs.
376- random_seed = self . sample_kwargs . get ( "random_seed" , None )
409+ # Use JAX for ppc sampling of multivariate likelihood
377410
378411 self .build_model (X , Z , y , t , coords , priors )
379412 with self :
380413 self .idata = pm .sample (** self .sample_kwargs )
381- self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
382- self .idata .extend (
383- pm .sample_posterior_predictive (
384- self .idata , progressbar = False , random_seed = random_seed
385- )
386- )
414+ self .sample_predictive_distribution (ppc_sampler = ppc_sampler )
387415 return self .idata
388416
389417
0 commit comments