Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 22 additions & 50 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant, TensorVariable

from pymc.distributions.custom import CustomDist
from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import TensorLike, icdf
from pymc.pytensorf import normalize_rng_param
Expand Down Expand Up @@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs):
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.distributions.transforms import _default_transform
from pymc.math import invlogit, logdiffexp, logit
from pymc.math import invlogit, logdiffexp

__all__ = [
"AsymmetricLaplace",
Expand Down Expand Up @@ -2531,6 +2532,9 @@ def logcdf(value, alpha, beta):
msg="alpha > 0, beta > 0",
)

def icdf(value, alpha, beta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't visualize the math on top of my head. I presume this is what test_inverse_gamma_icdf below checks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return icdf(1 / Gamma.dist(alpha, beta), value)


class ChiSquared:
r"""
Expand Down Expand Up @@ -3603,28 +3607,7 @@ def icdf(value, mu, s):
)


class LogitNormalRV(SymbolicRandomVariable):
name = "logit_normal"
extended_signature = "[rng],[size],(),()->[rng],()"
_print_name = ("LogitNormal", "\\operatorname{LogitNormal}")

@classmethod
def rv_op(cls, mu, sigma, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.expit(normal_draws)

return cls(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
)(rng, size, mu, sigma)


class LogitNormal(UnitContinuous):
class LogitNormal:
r"""
Logit-Normal distribution.
Expand Down Expand Up @@ -3672,37 +3655,26 @@ class LogitNormal(UnitContinuous):
Defaults to 1.
"""

rv_type = LogitNormalRV
rv_op = LogitNormalRV.rv_op
@staticmethod
def logitnormal_dist(mu, sigma, size):
return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size))

@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return super().dist([mu, sigma], **kwargs)

def support_point(rv, size, mu, sigma):
median, _ = pt.broadcast_arrays(invlogit(mu), sigma)
if not rv_size_is_none(size):
median = pt.full(size, median)
return median

def logp(value, mu, sigma):
tau, _ = get_tau_sigma(sigma=sigma)

res = pt.switch(
pt.or_(pt.le(value, 0), pt.ge(value, 1)),
-np.inf,
(
-0.5 * tau * (logit(value) - mu) ** 2
+ 0.5 * pt.log(tau / (2.0 * np.pi))
- pt.log(value * (1 - value))
),
return CustomDist(
name,
mu,
sigma,
dist=cls.logitnormal_dist,
class_name="LogitNormal",
**kwargs,
)

return check_parameters(
res,
tau > 0,
msg="tau > 0",
@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return CustomDist.dist(
mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs
)


Expand Down
12 changes: 11 additions & 1 deletion pymc/distributions/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def rv_op(
size=None,
signature: str,
class_name: str,
rng=None,
):
size = normalize_size_param(size)
# If it's NoneConst, just use that as the dummy
Expand All @@ -270,7 +271,8 @@ def rv_op(
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param, *dummy_dist_params]
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
# We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op
# We retrieve them here. This will also raise if the user forgot to specify some update in an InnerGraphOp (e.g., Scan)
# If the user passed an explicit rng we will respect that later when we instantiate the final rv_op
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

rv_type = type(
Expand Down Expand Up @@ -357,6 +359,14 @@ def change_custom_dist_size(op, rv, new_size, expand):
outputs=outputs,
extended_signature=extended_signature,
)
if rng is not None:
# User passed an RNG, use that if the graph only required one, raise otherwise
if len(rngs) != 1:
raise ValueError(
f"CustomDist received an explicit rng but it actually requires {len(rngs)} rngs."
" Please modify your dist function to only use one rng, or don't pass an explicitly rng."
)
rngs = (rng,)
return rv_op(size, *dist_params, *rngs)

@staticmethod
Expand Down
6 changes: 0 additions & 6 deletions pymc/distributions/moments/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
HalfFlatRV,
HalfStudentTRV,
KumaraswamyRV,
LogitNormalRV,
MoyalRV,
PolyaGammaRV,
RiceRV,
Expand Down Expand Up @@ -290,11 +289,6 @@ def logistic_mean(op, rv, rng, size, mu, s):
return maybe_resize(pt.broadcast_arrays(mu, s)[0], size)


@_mean.register(LogitNormalRV)
def logitnormal_mean(op, rv, rng, size, mu, sigma):
raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined")


@_mean.register(LogNormalRV)
def lognormal_mean(op, rv, rng, size, mu, sigma):
return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size)
Expand Down
16 changes: 14 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise):
Erf,
Erfc,
Erfcx,
Sigmoid,
)

# Cannot use `transform` as name because it would clash with the property added by
Expand Down Expand Up @@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)


MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)


Expand Down Expand Up @@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
value = pt.switch(pt.lt(scale, 0), 1 - value, value)
elif isinstance(op.scalar_op, Pow):
if op.transform_elemwise.power < 0:
raise NotImplementedError
# Note: Negative even powers will be rejected below when inverting the transform
# For the remaining negative powers the function is decreasing with a jump around 0
# We adjust the value with the mass below zero.
# For non-negative RVs with cdf(0)=0, it simplifies to 1 - value
cdf_zero = pt.exp(_logcdf_helper(measurable_input, 0))
# Use nan to not mask invalid values accidentally
value = pt.switch((value >= 0) & (value <= 1), value, np.nan)
value = pt.switch(
(cdf_zero > 0) & (value < cdf_zero),
cdf_zero - value,
1 + cdf_zero - value,
)
else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/moments/test_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,5 @@ def test_mean_equal_expected(dist, dist_params, expected):
],
)
def test_no_mean(dist, dist_params):
with pytest.raises(UndefinedMomentException):
with pytest.raises((UndefinedMomentException, NotImplementedError)):
mean(dist.dist(**dist_params))
13 changes: 13 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,13 @@ def test_inverse_gamma_logcdf(self):
lambda value, alpha, beta: st.invgamma.logcdf(value, alpha, scale=beta),
)

def test_inverse_gamma_icdf(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some distributions have all functionality tests (logp, logcdf, icdf) under a single method (test_half_cauchy), whereas other distributions have one method per functionality. Is there any particular reason for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No grand plan. I prefer separate but if a function was already testing multiple for that distribution I will just append to it

check_icdf(
pm.InverseGamma,
{"alpha": Rplusbig, "beta": Rplusbig},
lambda q, alpha, beta: st.invgamma.ppf(q, alpha, scale=beta),
)

@pytest.mark.skipif(
condition=(pytensor.config.floatX == "float32"),
reason="Fails on float32 due to scaling issues",
Expand Down Expand Up @@ -872,6 +879,12 @@ def test_logitnormal(self):
),
decimal=select_by_precision(float64=6, float32=1),
)
check_icdf(
pm.LogitNormal,
{"mu": R, "sigma": Rplus},
lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)),
decimal=select_by_precision(float64=12, float32=5),
)

@pytest.mark.skipif(
condition=(pytensor.config.floatX == "float32"),
Expand Down
44 changes: 44 additions & 0 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,47 @@ def normal_shifted(mu, size):
observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}),
expected_logp,
)

def test_explicit_rng(self):
def custom_dist(mu, size):
return Normal.dist(mu, size=size)

x = CustomDist.dist(0, dist=custom_dist)
assert len(x.owner.op.rng_params(x.owner)) == 1 # Rng created by default

explicit_rng = pt.random.type.random_generator_type("rng")
x_explicit = CustomDist.dist(0, dist=custom_dist, rng=explicit_rng)
[used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner)
assert used_rng is explicit_rng

# API for passing multiple explicit RNGs not supported
def custom_dist_multi_rng(mu, size):
return Normal.dist(mu, size=size) + Normal.dist(0, size=size)

x = CustomDist.dist(0, dist=custom_dist_multi_rng)
assert len(x.owner.op.rng_params(x.owner)) == 2

with pytest.raises(
ValueError,
match="CustomDist received an explicit rng but it actually requires 2 rngs",
):
CustomDist.dist(
0,
dist=custom_dist_multi_rng,
rng=explicit_rng,
)

# But it can be done if the custom_dist uses only one RNG internally
def custom_dist_multi_rng_fixed(mu, size):
next_rng, x = Normal.dist(mu, size=size).owner.outputs
return x + Normal.dist(0, size=size, rng=next_rng)

x = CustomDist.dist(0, dist=custom_dist_multi_rng_fixed)
assert len(x.owner.op.rng_params(x.owner)) == 1
x_explicit = CustomDist.dist(
0,
dist=custom_dist_multi_rng_fixed,
rng=explicit_rng,
)
[used_rng] = x_explicit.owner.op.rng_params(x_explicit.owner)
assert used_rng is explicit_rng
14 changes: 9 additions & 5 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,7 @@ def test_reciprocal_rv_transform(self, numerator):
x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))

with pytest.raises(NotImplementedError):
icdf(x_rv, x_vv)
x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv))

x_test_val = np.r_[-0.5, 1.5]
np.testing.assert_allclose(
Expand All @@ -392,6 +390,10 @@ def test_reciprocal_rv_transform(self, numerator):
x_logcdf_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val),
)
np.testing.assert_allclose(
x_icdf_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).ppf(x_test_val),
)

def test_reciprocal_real_rv_transform(self):
# 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2))
Expand All @@ -406,8 +408,10 @@ def test_reciprocal_real_rv_transform(self):
logcdf(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value),
)
with pytest.raises(NotImplementedError):
icdf(test_rv, test_value)
np.testing.assert_allclose(
icdf(test_rv, test_value).eval(),
sp.stats.cauchy(1 / 5, 2 / 5).ppf(test_value),
)

def test_sqr_transform(self):
# The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2
Expand Down