Skip to content

Commit a35d0e2

Browse files
committed
Implement DynamicPPL.rand_with_logpdf
1 parent 6266f64 commit a35d0e2

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.5
4+
5+
Introduces a new function, `DynamicPPL.rand_with_logpdf([rng,] ldf[, strategy])`, which allows generating new trial parameter values from a `LogDensityFunction` (previously this would have been accomplished using the `ldf.varinfo` object, but since v0.39 there is no longer a `VarInfo` inside a `LogDensityFunction`, so this function is a direct replacement).
6+
37
## 0.39.4
48

59
Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.4"
3+
version = "0.39.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Internally, this is accomplished using [`init!!`](@ref) on:
8080
OnlyAccsVarInfo
8181
```
8282

83+
When given a `LogDensityFunction` (and only a `LogDensityFunction`!) it is often useful to be able to sample new parameters from the prior of the model, for example, when searching for initial points for MCMC sampling.
84+
This can be done with:
85+
86+
```@docs
87+
rand_with_logpdf
88+
```
89+
8390
## Condition and decondition
8491

8592
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ export AbstractVarInfo,
100100
# LogDensityFunction
101101
LogDensityFunction,
102102
OnlyAccsVarInfo,
103+
rand_with_logpdf,
103104
# Leaf contexts
104105
AbstractContext,
105106
contextualize,

src/logdensityfunction.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,166 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
388388
end
389389
return all_iden_ranges, all_ranges, offset
390390
end
391+
392+
####################################################
393+
# Generate new parameters for a LogDensityFunction #
394+
####################################################
395+
# Previously, when LogDensityFunction contained a full VarInfo, it was easy to generate
396+
# new 'trial' parameters for a LogDensityFunction by doing
397+
#
398+
# new_vi = last(DynamicPPL.init!!(rng, ldf.model, ldf.varinfo, strategy))
399+
#
400+
# This is useful e.g. when initialising MCMC sampling.
401+
#
402+
# However, now that LogDensityFunction only contains ranges and link status, we need to
403+
# provide some functionality to generate new parameter vectors (and also return their
404+
# logp).
405+
406+
struct LDFValuesAccumulator{T<:Real,N<:NamedTuple} <: AbstractAccumulator
407+
# These are copied over from the LogDensityFunction
408+
iden_varname_ranges::N
409+
varname_ranges::Dict{VarName,RangeAndLinked}
410+
# These are the actual outputs
411+
values::Dict{UnitRange{Int},Vector{T}}
412+
# This is the forward log-Jacobian term
413+
fwd_logjac::T
414+
end
415+
function LDFValuesAccumulator(ldf::LogDensityFunction)
416+
nt = ldf._iden_varname_ranges
417+
T = eltype(_get_input_vector_type(ldf))
418+
return LDFValuesAccumulator{T,typeof(nt)}(
419+
nt, ldf._varname_ranges, Dict{UnitRange{Int},Vector{T}}(), zero(T)
420+
)
421+
end
422+
423+
const _LDFValuesAccName = :LDFValues
424+
accumulator_name(::Type{<:LDFValuesAccumulator}) = _LDFValuesAccName
425+
accumulate_observe!!(acc::LDFValuesAccumulator, dist, val, vn) = acc
426+
function accumulate_assume!!(acc::LDFValuesAccumulator, val, logjac, vn::VarName, dist)
427+
ral = if DynamicPPL.getoptic(vn) === identity
428+
acc.iden_varname_ranges[DynamicPPL.getsym(vn)]
429+
else
430+
acc.varname_ranges[vn]
431+
end
432+
range = ral.range
433+
# Since `val` is always unlinked, we need to regenerate the vectorised value. This is a
434+
# bit wasteful if `tilde_assume!!` also did the invlinking, but in general, this is not
435+
# guaranteed: indeed, `tilde_assume!!` may never have seen a linked vector at all (e.g.
436+
# if it was called with `InitContext{rng,<:Union{InitFromPrior,InitFromUniform}}`; which
437+
# is the most likely situation where this accumulator will be used).
438+
y, fwd_logjac = if ral.is_linked
439+
with_logabsdet_jacobian(DynamicPPL.to_linked_vec_transform(dist), val)
440+
else
441+
with_logabsdet_jacobian(DynamicPPL.to_vec_transform(dist), val)
442+
end
443+
acc.values[range] = y
444+
return LDFValuesAccumulator(
445+
acc.iden_varname_ranges, acc.varname_ranges, acc.values, acc.fwd_logjac + fwd_logjac
446+
)
447+
end
448+
function reset(acc::LDFValuesAccumulator{T}) where {T}
449+
return LDFValuesAccumulator(
450+
acc.iden_varname_ranges,
451+
acc.varname_ranges,
452+
Dict{UnitRange{Int},Vector{T}}(),
453+
zero(T),
454+
)
455+
end
456+
function Base.copy(acc::LDFValuesAccumulator)
457+
return LDFValuesAccumulator(
458+
acc.iden_varname_ranges, copy(acc.varname_ranges), copy(acc.values), acc.fwd_logjac
459+
)
460+
end
461+
function split(acc::LDFValuesAccumulator{T}) where {T}
462+
return LDFValuesAccumulator(
463+
acc.iden_varname_ranges,
464+
acc.varname_ranges,
465+
Dict{UnitRange{Int},Vector{T}}(),
466+
zero(T),
467+
)
468+
end
469+
function combine(acc::LDFValuesAccumulator{T}, acc2::LDFValuesAccumulator{T}) where {T}
470+
if acc.iden_varname_ranges != acc2.iden_varname_ranges ||
471+
acc.varname_ranges != acc2.varname_ranges
472+
throw(
473+
ArgumentError(
474+
"cannot combine LDFValuesAccumulators with different varname ranges"
475+
),
476+
)
477+
end
478+
combined_values = merge(acc.values, acc2.values)
479+
combined_logjac = acc.fwd_logjac + acc2.fwd_logjac
480+
return LDFValuesAccumulator(
481+
acc.iden_varname_ranges, acc.varname_ranges, combined_values, combined_logjac
482+
)
483+
end
484+
485+
"""
486+
DynamicPPL.rand_with_logpdf(
487+
[rng::Random.AbstractRNG,]
488+
ldf::LogDensityFunction,
489+
strategy::AbstractInitStrategy=InitFromPrior(),
490+
)
491+
492+
Given a LogDensityFunction, generate a new parameter vector by sampling from the model using
493+
the given strategy. Returns a tuple of (new parameters, logpdf).
494+
495+
If `ldf` was generated using the call `LogDensityFunction(model, getlogdensity, vi)`, then
496+
the outputs of
497+
498+
```julia
499+
new_params, new_logp = rand_with_logpdf(rng, ldf, strategy)
500+
```
501+
502+
and
503+
504+
```julia
505+
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy)
506+
```
507+
508+
are guaranteed to be related in that
509+
510+
```julia
511+
new_params ≈ new_vi[:]
512+
new_logp = getlogdensity(new_vi)
513+
```
514+
515+
Furthermore, it is guaranteed that
516+
517+
```julia
518+
LogDensityProblems.logdensity(ldf, new_params) ≈ new_logp
519+
```
520+
521+
(but this function is more efficient, as it only performs one model evaluation to generate
522+
both parameters and log-density).
523+
524+
This function therefore provides an interface to sample from the model, even though the
525+
LogDensityFunction no longer carries a full VarInfo with it which would ordinarily be
526+
required for such sampling.
527+
"""
528+
function rand_with_logpdf(
529+
rng::Random.AbstractRNG,
530+
ldf::LogDensityFunction,
531+
strategy::AbstractInitStrategy=InitFromPrior(),
532+
)
533+
# Create a VarInfo with the necessary accumulators to generate both parameters and logp
534+
accs = (ldf_accs(ldf._getlogdensity)..., LDFValuesAccumulator(ldf))
535+
vi = OnlyAccsVarInfo(accs)
536+
# Initialise the model with the given strategy
537+
_, new_vi = DynamicPPL.init!!(rng, ldf.model, vi, strategy)
538+
# Extract the new parameters into a vector
539+
x = Vector{eltype(_get_input_vector_type(ldf))}(
540+
undef, LogDensityProblems.dimension(ldf)
541+
)
542+
values_acc = DynamicPPL.getacc(new_vi, Val(_LDFValuesAccName))
543+
for (range, val) in values_acc.values
544+
x[range] = val
545+
end
546+
lp = ldf._getlogdensity(new_vi) - values_acc.fwd_logjac
547+
return x, lp
548+
end
549+
function rand_with_logpdf(
550+
ldf::LogDensityFunction, strategy::AbstractInitStrategy=InitFromPrior()
551+
)
552+
return rand_with_logpdf(Random.default_rng(), ldf, strategy)
553+
end

test/logdensityfunction.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DistributionsAD: filldist
88
using ADTypes
99
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
1010
using LinearAlgebra: I
11+
using Random: Xoshiro
1112
using Test
1213
using LogDensityProblems: LogDensityProblems
1314

@@ -205,6 +206,28 @@ end
205206
end
206207
end
207208

209+
@testset "rand_with_logpdf" begin
210+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
211+
@testset for linked in (false, true)
212+
vi = if linked
213+
DynamicPPL.link!!(VarInfo(m), m)
214+
else
215+
VarInfo(m)
216+
end
217+
ldf = LogDensityFunction(m, getlogjoint_internal, vi)
218+
@testset for strategy in (InitFromPrior(), InitFromUniform())
219+
new_params, new_logp = DynamicPPL.rand_with_logpdf(
220+
Xoshiro(468), ldf, strategy
221+
)
222+
_, new_vi = DynamicPPL.init!!(Xoshiro(468), m, vi, strategy)
223+
@test new_params new_vi[:]
224+
@test new_logp getlogjoint_internal(new_vi)
225+
@test LogDensityProblems.logdensity(ldf, new_params) new_logp
226+
end
227+
end
228+
end
229+
end
230+
208231
# Test that various different ways of specifying array types as arguments work with all
209232
# ADTypes.
210233
@testset "Array argument types" begin

0 commit comments

Comments
 (0)