Skip to content
11 changes: 6 additions & 5 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
using BangBang: setindex!!
using MCMCChains: MCMCChains

function getindex_varname(
Expand Down Expand Up @@ -95,11 +96,11 @@ function AbstractMCMC.to_samples(
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
# Get parameters
params_matrix = map(idxs) do (sample_idx, chain_idx)
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
vnt = DynamicPPL.VarNamedTuple()
for vn in get_varnames(chain)
d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx)
vnt = setindex!!(vnt, getindex_varname(chain, sample_idx, vn, chain_idx), vn)
end
d
vnt
end
# Statistics
stats_matrix = if :internals in MCMCChains.sections(chain)
Expand Down Expand Up @@ -164,8 +165,8 @@ end
fallback=nothing,
)

Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`,
returning an matrix of `(retval, updated_at)` tuples.
Re-evaluate `model` for each sample in `chain` using the accumulators provided in `accs`,
returning a matrix of `(retval, updated_at)` tuples.

This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
initialisation strategy when re-evaluating the model. For many usecases the fallback should
Expand Down
2 changes: 1 addition & 1 deletion src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A struct which contains parameter values extracted from a `VarInfo`, along with
statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are
optional.
"""
struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple}
struct ParamsWithStats{P<:Union{OrderedDict{<:VarName,<:Any},VarNamedTuple},S<:NamedTuple}
params::P
stats::S
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,14 @@ end

function generate_assign(left, right)
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
# ValuesAsInModel then in addition we push!! the pair of `x` and `y` to the accumulator.
@gensym acc right_val vn
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
$acc -> push!!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
end
$left = $right_val
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ InitFromParams(params) = InitFromParams(params, InitFromPrior())

function init(
rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P}
) where {P<:Union{AbstractDict{<:VarName},NamedTuple}}
) where {P<:Union{AbstractDict{<:VarName},NamedTuple,VarNamedTuple}}
# TODO(penelopeysm): It would be nice to do a check to make sure that all
# of the parameters in `p.params` were actually used, and either warn or
# error if they aren't. This is actually quite non-trivial though because
Expand Down
32 changes: 30 additions & 2 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
#
# Some additionally contain an implementation of `rand_prior_true`.

"""
varnames(model::Model)

Return the VarNames defined in `model`, as a Vector.
"""
function varnames end

# TODO(mhauru) The fact that the below function exists is a sign that we are inconsistent in
# how we handle IndexLenses. This should hopefully be resolved once we consistently use
# VarNamedTuple rather than dictionaries everywhere.
"""
varnames_split(model::Model)

Return the VarNames in `model`, with any ranges or colons split into individual indices.

The default implementation is to just return `varnames(model)`. If something else is needed,
this should be defined separately.
"""
varnames_split(model::Model) = varnames(model)

"""
demo_dynamic_constraint()

Expand Down Expand Up @@ -77,6 +97,9 @@ end
function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)})
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])]
end
function varnames_split(model::Model{typeof(demo_one_variable_multiple_constraints)})
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4]), @varname(x[5])]
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_one_variable_multiple_constraints)}, x
)
Expand Down Expand Up @@ -624,8 +647,13 @@ function varnames(::Model{typeof(demo_nested_colons)})
AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))),
]
),
# @varname(s.params[1].subparams[1,1,1]),
# @varname(s.params[1].subparams[1,1,2]),
@varname(m),
]
end
function varnames_split(::Model{typeof(demo_nested_colons)})
return [
@varname(s.params[1].subparams[1, 1, 1]),
@varname(s.params[1].subparams[1, 1, 2]),
@varname(m),
]
end
Expand Down
20 changes: 14 additions & 6 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelAccumulator <: AbstractAccumulator
struct ValuesAsInModelAccumulator{VNT<:VarNamedTuple} <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict{<:VarName}
values::VNT
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
end
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
return ValuesAsInModelAccumulator(VarNamedTuple(), include_colon_eq)
end

function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
Expand All @@ -30,6 +30,9 @@ end

accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

# TODO(mhauru) We could start using reset!!, which could call empty!! on the VarNamedTuple.
# This would create VarNamedTuples that share memory with the original one, saving
# allocations but also making them not capable of taking in any arbitrary VarName.
function _zero(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
Expand All @@ -45,8 +48,11 @@ function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumula
)
end

function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
setindex!(acc.values, deepcopy(val), vn)
function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
# TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model
# body can go mutating the object without that reactively affecting the value in the
# accumulator, which should be as it was at `~` time. Could there be a way around this?
Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn)
return acc
end

Expand All @@ -56,7 +62,7 @@ function is_extracting_values(vi::AbstractVarInfo)
end

function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
return push!(acc, vn, val)
return push!!(acc, vn, val)
end

accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc
Expand All @@ -75,6 +81,8 @@ working in unconstrained space.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.

Returns a `VarNamedTuple`.

# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
Expand Down
Loading
Loading