Skip to content

Commit b59947a

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents ea9bb54 + 0299d7d commit b59947a

File tree

14 files changed

+546
-43
lines changed

14 files changed

+546
-43
lines changed

HISTORY.md

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,32 @@ This is because the interface functions required have been shifted upstream to A
88

99
In particular, you now only need to define the following functions:
1010

11-
- AbstractMCMC.step(rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, ::MySampler; kwargs...) (and also a method with `state`, and the corresponding `step_warmup` methods if needed)
12-
- AbstractMCMC.getparams(::MySamplerState) -> Vector{<:Real}
13-
- AbstractMCMC.getstats(::MySamplerState) -> NamedTuple
14-
- AbstractMCMC.requires_unconstrained_space(::MySampler) -> Bool (default `true`)
11+
- `AbstractMCMC.step(rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, ::MySampler; kwargs...)` (and also a method with `state`, and the corresponding `step_warmup` methods if needed)
12+
- `AbstractMCMC.getparams(::MySamplerState)` -> Vector{<:Real}
13+
- `AbstractMCMC.getstats(::MySamplerState)` -> NamedTuple
14+
- `AbstractMCMC.requires_unconstrained_space(::MySampler)` -> Bool (default `true`)
1515

1616
This means that you only need to depend on AbstractMCMC.jl.
1717
As long as the above functions are defined correctly, Turing will be able to use your external sampler.
1818

1919
The `Turing.Inference.isgibbscomponent(::MySampler)` interface function still exists, but in this version the default has been changed to `true`, so you should not need to overload this.
2020

21+
# 0.41.4
22+
23+
Fixed a bug where the `check_model=false` keyword argument would not be respected when sampling with multiple threads or cores.
24+
25+
# 0.41.3
26+
27+
Fixed NUTS not correctly specifying the number of adaptation steps when calling `AdvancedHMC.initialize!` (this bug led to mass matrix adaptation not actually happening).
28+
29+
# 0.41.2
30+
31+
Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler.
32+
33+
In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details.
34+
35+
Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface.
36+
2137
# 0.41.1
2238

2339
The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.

README.md

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
<p align="center"><img src="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/refs/heads/main/assets/logo/turing-logo.svg" alt="Turing.jl logo" width="200" /></p>
2-
<h1 align="center">Turing.jl</h1>
3-
<p align="center"><i>Probabilistic programming and Bayesian inference in Julia</i></p>
1+
<p align="center">
2+
<picture>
3+
<source media="(prefers-color-scheme: dark)" srcset="https://turinglang.org/assets/logo/turing-logo-dark.svg">
4+
<img src="https://turinglang.org/assets/logo/turing-logo-light.svg" alt="Turing.jl logo" width="300">
5+
</picture>
6+
</p>
7+
<p align="center"><i>Bayesian inference with probabilistic programming</i></p>
48
<p align="center">
59
<a href="https://turinglang.org/"><img src="https://img.shields.io/badge/docs-tutorials-blue.svg" alt="Tutorials" /></a>
610
<a href="https://turinglang.org/Turing.jl/stable"><img src="https://img.shields.io/badge/docs-API-blue.svg" alt="API docs" /></a>
@@ -9,9 +13,9 @@
913
<a href="https://github.com/SciML/ColPrac"><img src="https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet" alt="ColPrac: Contributor's Guide on Collaborative Practices for Community Packages" /></a>
1014
</p>
1115

12-
## 🚀 Get started
16+
## Get started
1317

14-
Install Julia (see [the official Julia website](https://julialang.org/install/); you will need at least Julia 1.10 for the latest version of Turing.jl).
18+
Install Julia (see [the official Julia website](https://julialang.org/install/); you will need at least Julia 1.10.8 for the latest version of Turing.jl).
1519
Then, launch a Julia REPL and run:
1620

1721
```julia
@@ -23,22 +27,29 @@ You can define models using the `@model` macro, and then perform Markov chain Mo
2327
```julia
2428
julia> using Turing
2529

26-
julia> @model function my_first_model(data)
27-
mean ~ Normal(0, 1)
28-
sd ~ truncated(Cauchy(0, 3); lower=0)
29-
data ~ Normal(mean, sd)
30+
julia> @model function linear_regression(x)
31+
# Priors
32+
α ~ Normal(0, 1)
33+
β ~ Normal(0, 1)
34+
σ² ~ truncated(Cauchy(0, 3); lower=0)
35+
36+
# Likelihood
37+
μ = α .+ β .* x
38+
y ~ MvNormal(μ, σ² * I)
3039
end
3140

32-
julia> model = my_first_model(randn())
41+
julia> x, y = rand(10), rand(10)
3342

34-
julia> chain = sample(model, NUTS(), 1000)
43+
julia> posterior = linear_regression(x) | (; y = y)
44+
45+
julia> chain = sample(posterior, NUTS(), 1000)
3546
```
3647

3748
You can find the main TuringLang documentation at [**https://turinglang.org**](https://turinglang.org), which contains general information about Turing.jl's features, as well as a variety of tutorials with examples of Turing.jl models.
3849

3950
API documentation for Turing.jl is specifically available at [**https://turinglang.org/Turing.jl/stable**](https://turinglang.org/Turing.jl/stable/).
4051

41-
## 🛠️ Contributing
52+
## Contributing
4253

4354
### Issues
4455

@@ -55,20 +66,20 @@ Breaking releases (minor version) should target the `breaking` branch.
5566

5667
If you have not received any feedback on an issue or PR for a while, please feel free to ping `@TuringLang/maintainers` in a comment.
5768

58-
## 💬 Other channels
69+
## Other channels
5970

6071
The Turing.jl userbase tends to be most active on the [`#turing` channel of Julia Slack](https://julialang.slack.com/archives/CCYDC34A0).
6172
If you do not have an invitation to Julia's Slack, you can get one from [the official Julia website](https://julialang.org/slack/).
6273

6374
There are also often threads on [Julia Discourse](https://discourse.julialang.org) (you can search using, e.g., [the `turing` tag](https://discourse.julialang.org/tag/turing)).
6475

65-
## 🔄 What's changed recently?
76+
## What's changed recently?
6677

6778
We publish a fortnightly newsletter summarising recent updates in the TuringLang ecosystem, which you can view on [our website](https://turinglang.org/news/), [GitHub](https://github.com/TuringLang/Turing.jl/issues/2498), or [Julia Slack](https://julialang.slack.com/archives/CCYDC34A0).
6879

6980
For Turing.jl specifically, you can see a full changelog in [`HISTORY.md`](https://github.com/TuringLang/Turing.jl/blob/main/HISTORY.md) or [our GitHub releases](https://github.com/TuringLang/Turing.jl/releases).
7081

71-
## 🧩 Where does Turing.jl sit in the TuringLang ecosystem?
82+
## Where does Turing.jl sit in the TuringLang ecosystem?
7283

7384
Turing.jl is the main entry point for users, and seeks to provide a unified, convenient interface to all of the functionality in the TuringLang (and broader Julia) ecosystem.
7485

@@ -125,5 +136,3 @@ month = feb,
125136
```
126137

127138
</details>
128-
129-
You can see the full list of publications that have cited Turing.jl on [Google Scholar](https://scholar.google.com/scholar?cites=11803241473159708991).

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
6363
| `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler |
6464
| `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling |
6565
| `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling |
66+
| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions |
6667
| `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo |
6768
| `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics |
6869
| `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo |

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export
102102
Emcee,
103103
ESS,
104104
Gibbs,
105+
GibbsConditional,
105106
HMC,
106107
SGLD,
107108
SGHMC,

src/mcmc/Inference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ export Hamiltonian,
5656
ESS,
5757
Emcee,
5858
Gibbs, # classic sampling
59+
GibbsConditional, # conditional sampling
5960
HMC,
6061
SGLD,
6162
PolynomialStepsize,
@@ -433,6 +434,7 @@ include("sghmc.jl")
433434
include("emcee.jl")
434435
include("prior.jl")
435436
include("gibbs.jl")
437+
include("gibbs_conditional.jl")
436438

437439
################
438440
# Typing tools #

src/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ function AbstractMCMC.sample(
131131
N,
132132
n_chains;
133133
chain_type,
134+
check_model=false, # no need to check again
134135
initial_params=map(_convert_initial_params, initial_params),
135136
kwargs...,
136137
)

src/mcmc/gibbs_conditional.jl

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
GibbsConditional(get_cond_dists)
3+
4+
A Gibbs component sampler that samples variables according to user-provided analytical
5+
conditional posterior distributions.
6+
7+
When using Gibbs sampling, sometimes one may know the analytical form of the posterior for
8+
a given variable, given the conditioned values of the other variables. In such cases one can
9+
use `GibbsConditional` as a component sampler to to sample from these known conditionals
10+
directly, avoiding any MCMC methods. One does so with
11+
12+
```julia
13+
sampler = Gibbs(
14+
(@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists),
15+
other samplers go here...
16+
)
17+
```
18+
19+
Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping
20+
the conditioned variables (anything other than `var1` and `var2`) to their values, and
21+
returns the conditional posterior distributions for `var1` and `var2`. You may, of course,
22+
have any number of variables being sampled as a block in this manner, we only use two as an
23+
example. The return value of `get_cond_dists` should be one of the following:
24+
- A single `Distribution`, if only one variable is being sampled.
25+
- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their
26+
conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`.
27+
- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used
28+
if all the variable names are single `Symbol`s, and may be more performant. E.g.
29+
`(; var1=dist1, var2=dist2)`.
30+
31+
# Examples
32+
33+
```julia
34+
# Define a model
35+
@model function inverse_gdemo(x)
36+
precision ~ Gamma(2, inv(3))
37+
std = sqrt(1 / precision)
38+
m ~ Normal(0, std)
39+
for i in eachindex(x)
40+
x[i] ~ Normal(m, std)
41+
end
42+
end
43+
44+
# Define analytical conditionals. See
45+
# https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution
46+
function cond_precision(c)
47+
a = 2.0
48+
b = 3.0
49+
# We use AbstractPPL.getvalue instead of indexing into `c` directly to guard against
50+
# issues where e.g. you try to get `c[@varname(x[1])]` but only `@varname(x)` is present
51+
# in `c`. `getvalue` handles that gracefully, `getindex` doesn't. In this case
52+
# `getindex` would suffice, but `getvalue` is good practice.
53+
m = AbstractPPL.getvalue(c, @varname(m))
54+
x = AbstractPPL.getvalue(c, @varname(x))
55+
n = length(x)
56+
a_new = a + (n + 1) / 2
57+
b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2
58+
return Gamma(a_new, 1 / b_new)
59+
end
60+
61+
function cond_m(c)
62+
precision = AbstractPPL.getvalue(c, @varname(precision))
63+
x = AbstractPPL.getvalue(c, @varname(x))
64+
n = length(x)
65+
m_mean = sum(x) / (n + 1)
66+
m_var = 1 / (precision * (n + 1))
67+
return Normal(m_mean, sqrt(m_var))
68+
end
69+
70+
# Sample using GibbsConditional
71+
model = inverse_gdemo([1.0, 2.0, 3.0])
72+
chain = sample(model, Gibbs(
73+
:precision => GibbsConditional(cond_precision),
74+
:m => GibbsConditional(cond_m)
75+
), 1000)
76+
```
77+
"""
78+
struct GibbsConditional{C} <: AbstractSampler
79+
get_cond_dists::C
80+
end
81+
82+
isgibbscomponent(::GibbsConditional) = true
83+
84+
"""
85+
build_variable_dict(model::DynamicPPL.Model)
86+
87+
Traverse the context stack of `model` and build a `Dict` of all the variable values that are
88+
set in GibbsContext, ConditionContext, or FixedContext.
89+
"""
90+
function build_variable_dict(model::DynamicPPL.Model)
91+
context = model.context
92+
cond_vals = DynamicPPL.conditioned(context)
93+
fixed_vals = DynamicPPL.fixed(context)
94+
# TODO(mhauru) Can we avoid invlinking all the time?
95+
global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model)
96+
# TODO(mhauru) This creates a lot of Dicts, which are then immediately merged into one.
97+
# Also, DynamicPPL.to_varname_dict is known to be inefficient. Make a more efficient
98+
# implementation.
99+
return merge(
100+
DynamicPPL.values_as(global_vi, Dict),
101+
DynamicPPL.to_varname_dict(cond_vals),
102+
DynamicPPL.to_varname_dict(fixed_vals),
103+
DynamicPPL.to_varname_dict(model.args),
104+
)
105+
end
106+
107+
function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext)
108+
return if context isa GibbsContext
109+
get_global_varinfo(context)
110+
elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
111+
get_gibbs_global_varinfo(DynamicPPL.childcontext(context))
112+
else
113+
msg = """No GibbsContext found in context stack. Are you trying to use \
114+
GibbsConditional outside of Gibbs?
115+
"""
116+
throw(ArgumentError(msg))
117+
end
118+
end
119+
120+
function initialstep(
121+
::Random.AbstractRNG,
122+
model::DynamicPPL.Model,
123+
::GibbsConditional,
124+
vi::DynamicPPL.AbstractVarInfo;
125+
kwargs...,
126+
)
127+
state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi
128+
# Since GibbsConditional is only used within Gibbs, it does not need to return a
129+
# transition.
130+
return nothing, state
131+
end
132+
133+
function AbstractMCMC.step(
134+
rng::Random.AbstractRNG,
135+
model::DynamicPPL.Model,
136+
sampler::GibbsConditional,
137+
state::DynamicPPL.AbstractVarInfo;
138+
kwargs...,
139+
)
140+
# Get all the conditioned variable values from the model context. This is assumed to
141+
# include a GibbsContext as part of the context stack.
142+
condvals = build_variable_dict(model)
143+
conddists = sampler.get_cond_dists(condvals)
144+
145+
# We support three different kinds of return values for `sample.get_cond_dists`, to make
146+
# life easier for the user.
147+
if conddists isa AbstractDict
148+
for (vn, dist) in conddists
149+
state = setindex!!(state, rand(rng, dist), vn)
150+
end
151+
elseif conddists isa NamedTuple
152+
for (vn_sym, dist) in pairs(conddists)
153+
vn = VarName{vn_sym}()
154+
state = setindex!!(state, rand(rng, dist), vn)
155+
end
156+
else
157+
# Single variable case
158+
vn = only(keys(state))
159+
state = setindex!!(state, rand(rng, conddists), vn)
160+
end
161+
162+
# Since GibbsConditional is only used within Gibbs, it does not need to return a
163+
# transition.
164+
return nothing, state
165+
end
166+
167+
function setparams_varinfo!!(
168+
::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo
169+
)
170+
return params
171+
end

src/mcmc/hmc.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ function Turing.Inference.initialstep(
223223
end
224224
# Generate a kernel and adaptor.
225225
kernel = make_ahmc_kernel(spl, ϵ)
226-
adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ)
226+
adaptor = AHMCAdaptor(spl, hamiltonian.metric, nadapts; ϵ=ϵ)
227227

228228
transition = Transition(model, vi, NamedTuple())
229229
state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor)
@@ -480,7 +480,9 @@ end
480480
#### Default HMC stepsize and mass matrix adaptor
481481
####
482482

483-
function AHMCAdaptor(alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric; ϵ=alg.ϵ)
483+
function AHMCAdaptor(
484+
alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric, nadapts::Int; ϵ=alg.ϵ
485+
)
484486
pc = AHMC.MassMatrixAdaptor(metric)
485487
da = AHMC.StepSizeAdaptor(alg.δ, ϵ)
486488

@@ -491,13 +493,13 @@ function AHMCAdaptor(alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric; ϵ=a
491493
adaptor = AHMC.NaiveHMCAdaptor(pc, da) # there is actually no adaptation for mass matrix
492494
else
493495
adaptor = AHMC.StanHMCAdaptor(pc, da)
494-
AHMC.initialize!(adaptor, alg.n_adapts)
496+
AHMC.initialize!(adaptor, nadapts)
495497
end
496498
end
497499

498500
return adaptor
499501
end
500502

501-
function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...)
503+
function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric, nadapts::Int; kwargs...)
502504
return AHMC.Adaptation.NoAdaptation()
503505
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2222
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2323
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
24+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2425
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2526
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2627
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"

0 commit comments

Comments
 (0)