From 502949d810fc75731561ecd06d890c477ae33e4c Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Thu, 16 Oct 2025 14:28:01 +0100 Subject: [PATCH 01/25] Generalise Hamiltonian derivative calls for RHMC case --- src/AdvancedHMC.jl | 4 + src/hamiltonian.jl | 4 +- src/riemannian/hamiltonian.jl | 205 ++-------------------------------- src/riemannian/integrator.jl | 3 + src/riemannian/metric.jl | 66 +++++++++++ src/trajectory.jl | 8 +- 6 files changed, 89 insertions(+), 201 deletions(-) create mode 100644 src/riemannian/metric.jl diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5f..dce7b497a 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -47,8 +47,12 @@ export Hamiltonian include("integrator.jl") export Leapfrog, JitteredLeapfrog, TemperedLeapfrog + +include("riemannian/metric.jl") +export AbstractRiemannianMetric, DenseRiemannianMetric, IdentityMap, SoftAbsMap include("riemannian/integrator.jl") export GeneralizedLeapfrog +include("riemannian/hamiltonian.jl") include("trajectory.jl") export Trajectory, diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index c782e1a24..24f744ca6 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -101,7 +101,7 @@ function Base.similar(z::PhasePoint{<:AbstractVecOrMat{T}}) where {T<:AbstractFl end function phasepoint( - h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)) + h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)) ) where {T<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end @@ -115,7 +115,7 @@ function phasepoint( _r::T2; r=safe_rsimilar(θ, _r), ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), ) where {T1<:AbstractVecOrMat,T2<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 6f051ffbc..56fc01cc9 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,118 +1,8 @@ -using Random +import AdvancedHMC: refresh, phasepoint, neg_energy, ∂H∂θ, ∂H∂r +using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, DualValue, PhasePoint +using LinearAlgebra: logabsdet, tr, diagm -### integrator.jl - -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size - -""" -$(TYPEDEF) - -Generalized leapfrog integrator with fixed step size `ϵ`. - -# Fields - -$(TYPEDFIELDS) -""" -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} - "Step size." - ϵ::T - n::Int -end -function Base.show(io::IO, l::GeneralizedLeapfrog) - return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") -end - -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv -end - -# TODO Make sure vectorization works -# TODO Check if tempering is valid -function step( - lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, - z::P, - n_steps::Int=1; - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 - full_trajectory::Val{FullTraj}=Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} - n_steps = abs(n_steps) # to support `n_steps < 0` cases - - ϵ = fwd ? step_size(lf) : -step_size(lf) - ϵ = ϵ' - - res = if FullTraj - Vector{P}(undef, n_steps) - else - z - end - - for i in 1:n_steps - θ_init, r_init = z.θ, z.r - # Tempering - #r = temper(lf, r, (i=i, is_half=true), n_steps) - #! Eq (16) of Girolami & Calderhead (2011) - r_half = copy(r_init) - local cache - for j in 1:(lf.n) - # Reuse cache for the first iteration - if j == 1 - (; value, gradient) = z.ℓπ - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) - (; value, gradient) = retval - else # reuse cache - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) - end - r_half = r_init - ϵ / 2 * gradient - # println("r_half: ", r_half) - end - #! Eq (17) of Girolami & Calderhead (2011) - θ_full = copy(θ_init) - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop - for j in 1:(lf.n) - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) - # println("θ_full :", θ_full) - end - #! Eq (18) of Girolami & Calderhead (2011) - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) - r_full = r_half - ϵ / 2 * gradient - # println("r_full: ", r_full) - # Tempering - #r = temper(lf, r, (i=i, is_half=false), n_steps) - # Create a new phase point by caching the logdensity and gradient - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - # Update result - if FullTraj - res[i] = z - else - res = z - end - if !isfinite(z) - # Remove undef - if FullTraj - res = res[isassigned.(Ref(res), 1:n_steps)] - end - break - end - # @assert false - end - return res -end - -# TODO Make the order of θ and r consistent with neg_energy -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) - -### hamiltonian.jl - -import AdvancedHMC: refresh, phasepoint -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric - -# To change L180 of hamiltonian.jl +# Specialized phasepoint for Riemannian metrics that need θ for momentum gradient function phasepoint( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, θ::AbstractVecOrMat{T}, @@ -145,87 +35,11 @@ function refresh( ) end -### metric.jl - -import AdvancedHMC: _rand -using AdvancedHMC: AbstractMetric -using LinearAlgebra: eigen, cholesky, Symmetric - -abstract type AbstractRiemannianMetric <: AbstractMetric end - -abstract type AbstractHessianMap end - -struct IdentityMap <: AbstractHessianMap end - -(::IdentityMap)(x) = x +### +### DenseRiemannianMetric-specific Hamiltonian methods +### -struct SoftAbsMap{T} <: AbstractHessianMap - α::T -end - -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` - Q = hcat(F.vectors) - λ = F.values - softabsλ = λ .* coth.(α * λ) - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ -end - -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] - -struct DenseRiemannianMetric{ - T, - TM<:AbstractHessianMap, - A<:Union{Tuple{Int},Tuple{Int,Int}}, - AV<:AbstractVecOrMat{T}, - TG, - T∂G∂θ, -} <: AbstractRiemannianMetric - size::A - G::TG # TODO store G⁻¹ here instead - ∂G∂θ::T∂G∂θ - map::TM - _temp::AV -end - -# TODO Make dense mass matrix support matrix-mode parallel -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} - _temp = Vector{Float64}(undef, size[1]) - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) -end -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) - -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) - -Base.size(e::DenseRiemannianMetric) = e.size -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") - -function rand_momentum( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - metric::DenseRiemannianMetric{T}, - kinetic, - θ::AbstractVecOrMat, -) where {T} - r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) - return r -end - -### hamiltonian.jl - -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r -using LinearAlgebra: logabsdet, tr - -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric +# Specialized phasepoint for DenseRiemannianMetric that passes θ to ∂H∂r function phasepoint( h::Hamiltonian{<:DenseRiemannianMetric}, θ::T, @@ -249,7 +63,7 @@ function neg_energy( return -logZ - dot(r, h.metric._temp) / 2 end -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) +# Position gradient with Riemannian correction terms function ∂H∂θ( h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, θ::AbstractVecOrMat{T}, @@ -299,6 +113,7 @@ function ∂H∂θ( ) where {T} return ∂H∂θ_cache(h, θ, r) end + function ∂H∂θ_cache( h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, θ::AbstractVecOrMat{T}, diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 6ce594768..3d818e9f1 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -1,3 +1,6 @@ +import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step +using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size + """ $(TYPEDEF) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl new file mode 100644 index 000000000..451a08ec7 --- /dev/null +++ b/src/riemannian/metric.jl @@ -0,0 +1,66 @@ +using AdvancedHMC: AbstractMetric +using LinearAlgebra: eigen, cholesky, Symmetric + +# _randn is defined in utilities.jl which is included before this file + +abstract type AbstractRiemannianMetric <: AbstractMetric end + +abstract type AbstractHessianMap end + +struct IdentityMap <: AbstractHessianMap end + +(::IdentityMap)(x) = x + +struct SoftAbsMap{T} <: AbstractHessianMap + α::T +end + +# TODO Register softabs with ReverseDiff +#! The definition of SoftAbs from Page 3 of Betancourt (2012) +function softabs(X, α=20.0) + F = eigen(X) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q * diagm(softabsλ) * Q', Q, λ, softabsλ +end + +(map::SoftAbsMap)(x) = softabs(x, map.α)[1] + +struct DenseRiemannianMetric{ + T, + TM<:AbstractHessianMap, + A<:Union{Tuple{Int},Tuple{Int,Int}}, + AV<:AbstractVecOrMat{T}, + TG, + T∂G∂θ, +} <: AbstractRiemannianMetric + size::A + G::TG # TODO store G⁻¹ here instead + ∂G∂θ::T∂G∂θ + map::TM + _temp::AV +end + +# TODO Make dense mass matrix support matrix-mode parallel +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} + _temp = Vector{Float64}(undef, size[1]) + return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) +end + +Base.size(e::DenseRiemannianMetric) = e.size +Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] +Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + G⁻¹ = inv(metric.map(metric.G(θ))) + chol = cholesky(Symmetric(G⁻¹)) + ldiv!(chol.U, r) + return r +end diff --git a/src/trajectory.jl b/src/trajectory.jl index 66246e74a..7020d4f84 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -552,7 +552,7 @@ function isterminated(::ClassicNoUTurn, h::Hamiltonian, t::BinaryTree) # z0 is starting point and z1 is ending point z0, z1 = t.zleft, t.zright Δθ = z1.θ - z0.θ - s = (dot(Δθ, ∂H∂r(h, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.r)) >= 0) + s = (dot(Δθ, ∂H∂r(h, z0.θ, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.θ, z1.r)) >= 0) return Termination(s, false) end @@ -565,7 +565,7 @@ Ref: https://arxiv.org/abs/1701.02434 """ function isterminated(::GeneralisedNoUTurn, h::Hamiltonian, t::BinaryTree) rho = t.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, t.zright.θ, t.zright.r)) return Termination(s, false) end @@ -595,7 +595,7 @@ phase point of `tright`, the right subtree. """ function check_left_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree} rho = tleft.ts.rho + tright.zleft.r - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, tright.zleft.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, tright.zleft.θ, tright.zleft.r)) return Termination(s, false) end @@ -608,7 +608,7 @@ function check_right_subtree( h::Hamiltonian, t::T, tleft::T, tright::T ) where {T<:BinaryTree} rho = tleft.zright.r + tright.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.θ, tleft.zright.r), ∂H∂r(h, t.zright.θ, t.zright.r)) return Termination(s, false) end From 312b2d930536a70c7a41d210237209cb0cb3bb1c Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 6 Nov 2025 12:19:13 +0000 Subject: [PATCH 02/25] Update riemannian metric to remove unnecessary matrix inversion --- src/riemannian/hamiltonian.jl | 2 +- src/riemannian/metric.jl | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 56fc01cc9..5d83566c4 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,6 +1,6 @@ import AdvancedHMC: refresh, phasepoint, neg_energy, ∂H∂θ, ∂H∂r using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, DualValue, PhasePoint -using LinearAlgebra: logabsdet, tr, diagm +using LinearAlgebra: logabsdet, tr, diagm, logdet # Specialized phasepoint for Riemannian metrics that need θ for momentum gradient function phasepoint( diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 451a08ec7..70f74228b 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -59,8 +59,7 @@ function rand_momentum( θ::AbstractVecOrMat, ) where {T} r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) + chol = cholesky(Symmetric(metric.map(metric.G(θ)))) + r = chol.L * r return r end From bd583e9b5114e60860608e33ceaa6e399160f95a Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Mon, 10 Nov 2025 15:25:09 +0000 Subject: [PATCH 03/25] Fix bug that prevents eltype working for DenseRiemannianMetric --- src/riemannian/metric.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 70f74228b..71ce757f0 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -1,5 +1,6 @@ using AdvancedHMC: AbstractMetric using LinearAlgebra: eigen, cholesky, Symmetric +import Base: eltype # _randn is defined in utilities.jl which is included before this file @@ -52,6 +53,10 @@ Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") +function eltype(m::DenseRiemannianMetric) + return eltype(m._temp) +end + function rand_momentum( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, metric::DenseRiemannianMetric{T}, From 2020f2db7639f29cddf9f3b68dbf522a84d91a89 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Mon, 10 Nov 2025 15:26:24 +0000 Subject: [PATCH 04/25] Divergence protection for generalized leapfrog integrator. --- src/integrator.jl | 4 +-- src/riemannian/integrator.jl | 53 +++++++++++++++++++++++++++++++++++- src/trajectory.jl | 26 +++++++++++++----- 3 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 004028e1e..24344a502 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -260,8 +260,8 @@ function step( end end return if FullTraj - res + res, false else - z + z, false end end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 3d818e9f1..4c6325aec 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -1,4 +1,5 @@ import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step +import LinearAlgebra: norm using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size """ @@ -56,6 +57,7 @@ function step( z end + diverged = false for i in 1:n_steps θ_init, r_init = z.θ, z.r # Tempering @@ -75,12 +77,55 @@ function step( end r_half = r_init - ϵ / 2 * gradient end + + retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) + (; value, gradient) = retval + + r_diff = r_half - (r_init - ϵ / 2 * gradient) + if (norm(r_diff) > 1e-1) || any(isnan, r_half) || !isfinite(sum(r_half)) + diverged = true + + # Reset to last valid values + r_half = r_init + θ_full = θ_init + + # Recompute valid logprob/gradient for consistency + retval, cache = ∂H∂θ_cache(h, θ_init, r_init; return_cache=true) + (; value, gradient) = retval + + z = phasepoint(h, θ_init, r_init; ℓπ=DualValue(value, gradient)) + + # If full trajectory, truncate so no undef elements remain + if FullTraj + res = res[1:(i-1)] + push!(res, z) + else + res = z + end + + break + end # eq (17) of Girolami & Calderhead (2011) θ_full = θ_init term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop for j in 1:(lf.n) θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) end + θ_diff = norm(θ_full - (θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)))) + if !isfinite(sum(θ_full)) || θ_diff > 1e-1 || any(isnan, θ_full) + diverged = true + θ_full = θ_init + r_full = r_init + + z = phasepoint(h, θ_init, r_init; ℓπ=DualValue(value, gradient)) + if FullTraj + res = res[1:(i-1)] + push!(res, z) + else + res = z + end + break + end # eq (18) of Girolami & Calderhead (2011) (; value, gradient) = ∂H∂θ(h, θ_full, r_half) r_full = r_half - ϵ / 2 * gradient @@ -94,6 +139,12 @@ function step( else res = z end + + if any(!isfinite, z.θ) || any(!isfinite, z.r) + diverged = true + z = phasepoint(h, θ_init, r_init; ℓπ=z.ℓπ) + end + if !isfinite(z) # Remove undef if FullTraj @@ -102,5 +153,5 @@ function step( break end end - return res + return res, diverged end diff --git a/src/trajectory.jl b/src/trajectory.jl index 7020d4f84..ca7a1a19e 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -334,9 +334,12 @@ end ### Use end-point from the trajectory as a proposal and apply MH correction function sample_phasepoint(rng, τ::Trajectory{EndPointTS}, h, z) - z′ = step(τ.integrator, h, z, nsteps(τ)) + z′, diverged = step(τ.integrator, h, z, nsteps(τ)) is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′)) - return z′, is_accept, α + if diverged + α = zero(α) + end + return z′, (is_accept && !diverged), α end ### Multinomial sampling from trajectory @@ -371,9 +374,9 @@ function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) # TODO: Deal with vectorized-mode generically. # Currently the direction of multiple chains are always coupled n_steps_fwd = rand_coupled(rng, 0:n_steps) - zs_fwd = step(τ.integrator, h, z, n_steps_fwd; fwd=true, full_trajectory=Val(true)) + zs_fwd, diverged_fwd = step(τ.integrator, h, z, n_steps_fwd; fwd=true, full_trajectory=Val(true)) n_steps_bwd = n_steps - n_steps_fwd - zs_bwd = step(τ.integrator, h, z, n_steps_bwd; fwd=false, full_trajectory=Val(true)) + zs_bwd, diverged_bwd = step(τ.integrator, h, z, n_steps_bwd; fwd=false, full_trajectory=Val(true)) zs = vcat(reverse(zs_bwd)..., z, zs_fwd...) ℓweights = -energy.(zs) if eltype(ℓweights) <: AbstractVector @@ -386,7 +389,10 @@ function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) ΔH = Hs .- energy(z) α = exp.(min.(0, -ΔH)) # this is a matrix for vectorized mode and a vector otherwise α = typeof(α) <: AbstractVector ? mean(α) : vec(mean(α; dims=2)) - return z′, true, α + if (diverged_bwd || diverged_fwd) + α = zero(α) + end + return z′, !(diverged_bwd || diverged_fwd), α end ### @@ -637,10 +643,13 @@ function build_tree( } if j == 0 # Base case - take one leapfrog step in the direction v. - z′ = step(nt.integrator, h, z, v) + z′, diverged = step(nt.integrator, h, z, v) H′ = energy(z′) ΔH = H′ - H0 α′ = exp(min(0, -ΔH)) + if diverged + α′ = zero(α′) + end sampler′ = TS(sampler, H0, z′) return BinaryTree(z′, z′, TurnStatistic(nt.termination_criterion, z′), α′, 1, ΔH), sampler′, @@ -751,8 +760,11 @@ A single Hamiltonian integration step. NOTE: this function is intended to be used in `find_good_stepsize` only. """ function A(h, z, ϵ) - z′ = step(Leapfrog(ϵ), h, z) + z′, diverged = step(Leapfrog(ϵ), h, z) H′ = energy(z′) + if diverged + H′ *= 100000.0 # penalize diverged proposals + end return z′, H′ end From 4fe83b9d20c5d51827dbbd8241ddb9420f97e188 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 13 Nov 2025 15:20:24 +0000 Subject: [PATCH 05/25] Revert "Divergence protection for generalized leapfrog integrator." This reverts commit 2020f2db7639f29cddf9f3b68dbf522a84d91a89. --- src/integrator.jl | 4 +-- src/riemannian/integrator.jl | 53 +----------------------------------- src/trajectory.jl | 26 +++++------------- 3 files changed, 10 insertions(+), 73 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 24344a502..004028e1e 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -260,8 +260,8 @@ function step( end end return if FullTraj - res, false + res else - z, false + z end end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 4c6325aec..3d818e9f1 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -1,5 +1,4 @@ import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step -import LinearAlgebra: norm using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size """ @@ -57,7 +56,6 @@ function step( z end - diverged = false for i in 1:n_steps θ_init, r_init = z.θ, z.r # Tempering @@ -77,55 +75,12 @@ function step( end r_half = r_init - ϵ / 2 * gradient end - - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) - (; value, gradient) = retval - - r_diff = r_half - (r_init - ϵ / 2 * gradient) - if (norm(r_diff) > 1e-1) || any(isnan, r_half) || !isfinite(sum(r_half)) - diverged = true - - # Reset to last valid values - r_half = r_init - θ_full = θ_init - - # Recompute valid logprob/gradient for consistency - retval, cache = ∂H∂θ_cache(h, θ_init, r_init; return_cache=true) - (; value, gradient) = retval - - z = phasepoint(h, θ_init, r_init; ℓπ=DualValue(value, gradient)) - - # If full trajectory, truncate so no undef elements remain - if FullTraj - res = res[1:(i-1)] - push!(res, z) - else - res = z - end - - break - end # eq (17) of Girolami & Calderhead (2011) θ_full = θ_init term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop for j in 1:(lf.n) θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) end - θ_diff = norm(θ_full - (θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)))) - if !isfinite(sum(θ_full)) || θ_diff > 1e-1 || any(isnan, θ_full) - diverged = true - θ_full = θ_init - r_full = r_init - - z = phasepoint(h, θ_init, r_init; ℓπ=DualValue(value, gradient)) - if FullTraj - res = res[1:(i-1)] - push!(res, z) - else - res = z - end - break - end # eq (18) of Girolami & Calderhead (2011) (; value, gradient) = ∂H∂θ(h, θ_full, r_half) r_full = r_half - ϵ / 2 * gradient @@ -139,12 +94,6 @@ function step( else res = z end - - if any(!isfinite, z.θ) || any(!isfinite, z.r) - diverged = true - z = phasepoint(h, θ_init, r_init; ℓπ=z.ℓπ) - end - if !isfinite(z) # Remove undef if FullTraj @@ -153,5 +102,5 @@ function step( break end end - return res, diverged + return res end diff --git a/src/trajectory.jl b/src/trajectory.jl index ca7a1a19e..7020d4f84 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -334,12 +334,9 @@ end ### Use end-point from the trajectory as a proposal and apply MH correction function sample_phasepoint(rng, τ::Trajectory{EndPointTS}, h, z) - z′, diverged = step(τ.integrator, h, z, nsteps(τ)) + z′ = step(τ.integrator, h, z, nsteps(τ)) is_accept, α = mh_accept_ratio(rng, energy(z), energy(z′)) - if diverged - α = zero(α) - end - return z′, (is_accept && !diverged), α + return z′, is_accept, α end ### Multinomial sampling from trajectory @@ -374,9 +371,9 @@ function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) # TODO: Deal with vectorized-mode generically. # Currently the direction of multiple chains are always coupled n_steps_fwd = rand_coupled(rng, 0:n_steps) - zs_fwd, diverged_fwd = step(τ.integrator, h, z, n_steps_fwd; fwd=true, full_trajectory=Val(true)) + zs_fwd = step(τ.integrator, h, z, n_steps_fwd; fwd=true, full_trajectory=Val(true)) n_steps_bwd = n_steps - n_steps_fwd - zs_bwd, diverged_bwd = step(τ.integrator, h, z, n_steps_bwd; fwd=false, full_trajectory=Val(true)) + zs_bwd = step(τ.integrator, h, z, n_steps_bwd; fwd=false, full_trajectory=Val(true)) zs = vcat(reverse(zs_bwd)..., z, zs_fwd...) ℓweights = -energy.(zs) if eltype(ℓweights) <: AbstractVector @@ -389,10 +386,7 @@ function sample_phasepoint(rng, τ::Trajectory{MultinomialTS}, h, z) ΔH = Hs .- energy(z) α = exp.(min.(0, -ΔH)) # this is a matrix for vectorized mode and a vector otherwise α = typeof(α) <: AbstractVector ? mean(α) : vec(mean(α; dims=2)) - if (diverged_bwd || diverged_fwd) - α = zero(α) - end - return z′, !(diverged_bwd || diverged_fwd), α + return z′, true, α end ### @@ -643,13 +637,10 @@ function build_tree( } if j == 0 # Base case - take one leapfrog step in the direction v. - z′, diverged = step(nt.integrator, h, z, v) + z′ = step(nt.integrator, h, z, v) H′ = energy(z′) ΔH = H′ - H0 α′ = exp(min(0, -ΔH)) - if diverged - α′ = zero(α′) - end sampler′ = TS(sampler, H0, z′) return BinaryTree(z′, z′, TurnStatistic(nt.termination_criterion, z′), α′, 1, ΔH), sampler′, @@ -760,11 +751,8 @@ A single Hamiltonian integration step. NOTE: this function is intended to be used in `find_good_stepsize` only. """ function A(h, z, ϵ) - z′, diverged = step(Leapfrog(ϵ), h, z) + z′ = step(Leapfrog(ϵ), h, z) H′ = energy(z′) - if diverged - H′ *= 100000.0 # penalize diverged proposals - end return z′, H′ end From dfcbb5353bdf20e381ee4aee9d6d470f1703ab38 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 13 Nov 2025 15:22:59 +0000 Subject: [PATCH 06/25] Change eltype for DenseRiemannianMetric --- src/riemannian/metric.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 71ce757f0..86662b79b 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -53,9 +53,11 @@ Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") -function eltype(m::DenseRiemannianMetric) - return eltype(m._temp) -end +#function eltype(m::DenseRiemannianMetric) +# return eltype(m._temp) +#end + +eltype(::DenseRiemannianMetric{T}) where {T} = T function rand_momentum( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, From ce20e4b8509e1ada06674f5ca5557c1784b77c34 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 20 Nov 2025 17:45:25 +0000 Subject: [PATCH 07/25] Optimise riemannian hmc utility functions. --- research/src/riemannian_hmc_utility.jl | 55 ++++++++++++++++++-------- src/AdvancedHMC.jl | 3 ++ 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/research/src/riemannian_hmc_utility.jl b/research/src/riemannian_hmc_utility.jl index 8ceab303c..f19c225bb 100644 --- a/research/src/riemannian_hmc_utility.jl +++ b/research/src/riemannian_hmc_utility.jl @@ -2,47 +2,70 @@ using Random, LinearAlgebra, ReverseDiff, ForwardDiff, MCMCLogDensityProblems # Fisher information metric function gen_∂G∂θ_rev(Vfunc, x; f=identity) - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, ReverseDiff.track.(x)) - Hfunc = x -> _Hfunc(x)[3] + Hfunc = gen_hess_fwd(Vfunc, ReverseDiff.track.(x)) + # QUES What's the best output format of this function? return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...] end # TODO Refactor this using https://juliadiff.org/ForwardDiff.jl/stable/user/api/#Preallocating/Configuring-Work-Buffers +function gen_hess_fwd_precompute_cfg(func, x::AbstractVector) + cfg = ForwardDiff.HessianConfig(func, x) + H = Matrix{eltype(x)}(undef, length(x), length(x)) + + function hess(x::AbstractVector) + ForwardDiff.hessian!(H, func, x, cfg) + return H + end + return hess +end + function gen_hess_fwd(func, x::AbstractVector) + cfg = nothing + H = nothing + function hess(x::AbstractVector) - return nothing, nothing, ForwardDiff.hessian(func, x) + if cfg === nothing + cfg = ForwardDiff.HessianConfig(func, x) + H = Matrix{eltype(x)}(undef, length(x), length(x)) + end + ForwardDiff.hessian!(H, func, x, cfg) + return H end return hess end function gen_∂G∂θ_fwd(Vfunc, x; f=identity) - _Hfunc = gen_hess_fwd(Vfunc, x) - Hfunc = x -> _Hfunc(x)[3] - # QUES What's the best output format of this function? + Hfunc = gen_hess_fwd(Vfunc, x) + cfg = ForwardDiff.JacobianConfig(Hfunc, x) d = length(x) out = zeros(eltype(x), d^2, d) - return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) - return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] + + function ∂G∂θ_fwd(y) + ForwardDiff.jacobian!(out, Hfunc, y, cfg) + return out + end + return ∂G∂θ_fwd end -# 1.764 ms -# fwd -> 5.338 μs -# cfg -> 3.651 μs function reshape_∂G∂θ(H) d = size(H, 2) - return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) + return reshape(H, d, d, :) end function prepare_sample_target(hps, θ₀, ℓπ) Vfunc = x -> -ℓπ(x) # potential energy is the negative log-probability - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian) - Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug + Hfunc = gen_hess_fwd_precompute_cfg(Vfunc, θ₀) # x -> (value, gradient, hessian) - fstabilize = H -> H + hps.λ * I + fstabilize = H -> begin + @inbounds for i in 1:size(H,1) + H[i,i] += hps.λ + end + H + end Gfunc = x -> begin - H = fstabilize(Hfunc(x)[3]) + H = fstabilize(Hfunc(x)) all(isfinite, H) ? H : diagm(ones(length(x))) end _∂G∂θfunc = gen_∂G∂θ_fwd(Vfunc, θ₀; f=fstabilize) # size==(4, 2) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index dce7b497a..23d38bac2 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -67,6 +67,9 @@ export Trajectory, MultinomialTS, find_good_stepsize +include("../research/src/riemannian_hmc_utility.jl") +export prepare_sample_target + # Useful defaults @deprecate find_good_eps find_good_stepsize From a5373d9bafab790dc241158430523d22053856b1 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Fri, 28 Nov 2025 13:32:51 +0000 Subject: [PATCH 08/25] Fix type stability for jacobian of hessian. --- research/src/riemannian_hmc_utility.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/research/src/riemannian_hmc_utility.jl b/research/src/riemannian_hmc_utility.jl index f19c225bb..92937864e 100644 --- a/research/src/riemannian_hmc_utility.jl +++ b/research/src/riemannian_hmc_utility.jl @@ -36,16 +36,20 @@ function gen_hess_fwd(func, x::AbstractVector) end function gen_∂G∂θ_fwd(Vfunc, x; f=identity) - Hfunc = gen_hess_fwd(Vfunc, x) + chunk = ForwardDiff.Chunk(x) + tag = ForwardDiff.Tag(Vfunc, eltype(x)) + jac_cfg = ForwardDiff.JacobianConfig(Vfunc, x, chunk, tag) + hess_cfg = ForwardDiff.HessianConfig(Vfunc, jac_cfg.duals, chunk, tag) - cfg = ForwardDiff.JacobianConfig(Hfunc, x) d = length(x) out = zeros(eltype(x), d^2, d) function ∂G∂θ_fwd(y) - ForwardDiff.jacobian!(out, Hfunc, y, cfg) + hess = z -> ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}()) + ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}()) return out end + return ∂G∂θ_fwd end From d96c29c120bed65850771e4d585e8dfd3e95bf27 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Fri, 28 Nov 2025 15:04:09 +0000 Subject: [PATCH 09/25] Fix DenseRiemannianMetric type instability. --- src/riemannian/metric.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 86662b79b..5df4dd11d 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -18,8 +18,9 @@ end # TODO Register softabs with ReverseDiff #! The definition of SoftAbs from Page 3 of Betancourt (2012) -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` +function softabs(X::AbstractMatrix{T}, α=20.0) where {T<:Real} + # Enforce symmetry for type stability + F = eigen(Symmetric(X)) # ReverseDiff cannot diff through `eigen` Q = hcat(F.vectors) λ = F.values softabsλ = λ .* coth.(α * λ) From 7d4a86ff88e2bc3210c622be447d0e16a1e4c8d7 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Fri, 28 Nov 2025 15:21:26 +0000 Subject: [PATCH 10/25] Small optimisations for hamiltonian.jl --- src/riemannian/hamiltonian.jl | 67 +++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 5d83566c4..fbbc791f8 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -129,7 +129,7 @@ function ∂H∂θ_cache( G, Q, λ, softabsλ = softabs(H, h.metric.map.α) - R = diagm(1 ./ softabsλ) + R = Diagonal(1 ./ softabsλ) # softabsΛ = diagm(softabsλ) # M = inv(softabsΛ) * Q' * r @@ -137,30 +137,67 @@ function ∂H∂θ_cache( J = make_J(λ, h.metric.map.α) + tmp1 = similar(H) + tmp2 = similar(H) + tmp3 = similar(H) + tmp4 = similar(softabsλ) + #! Based on the two equations from the right column of Page 3 of Betancourt (2012) - term_1_cached = Q * (R .* J) * Q' + tmp1 = R .* J + # tmp2 = Q * tmp1 + mul!(tmp2, Q, tmp1) + + # tmp1 = tmp2 * Q' + mul!(tmp1, tmp2, Q') + + term_1_cached = tmp1 + + # Cache first part of the equation + term_1_prod = similar(∂ℓπ∂θ) + @inbounds for i in 1:length(∂ℓπ∂θ) + ∂H∂θᵢ = ∂H∂θ[:, :, i] + term_1_prod[i] = ∂ℓπ∂θ[i] - 1/2 * tr(term_1_cached * ∂H∂θᵢ) + end + else - ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache + ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4 = cache end d = length(∂ℓπ∂θ) - D = diagm((Q' * r) ./ softabsλ) - term_2_cached = Q * D * J * D * Q' - g = - -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly - end + mul!(tmp4, Q', r) + D = Diagonal(tmp4 ./ softabsλ) + + # tmp1 = D * J + mul!(tmp1, D, J) + # tmp2 = tmp1 * D + mul!(tmp2, tmp1, D) + # tmp1 = Q * tmp2 + mul!(tmp1, Q, tmp2) + # tmp2 = tmp1 * Q' + mul!(tmp2, tmp1, Q') + term_2_cached = tmp2 + + # g = + # -mapreduce(vcat, 1:d) do i + # ∂H∂θᵢ = ∂H∂θ[:, :, i] + # # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) + # # NOTE Some further optimization can be done here: cache the 1st product all together + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + # end + g = similar(∂ℓπ∂θ) + @inbounds for i in 1:d + ∂H∂θᵢ = ∂H∂θ[:, :, i] + g[i] = term_1_prod[i] + 1/2 * tr(term_2_cached * ∂H∂θᵢ) + end + g .*= -1 dv = DualValue(ℓπ, g) - return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv + return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_prod, tmp1, tmp2, tmp3, tmp4)) : dv end #! Eq (14) of Girolami & Calderhead (2011) function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat -) + h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T} +) where {T} H = h.metric.G(θ) # if !all(isfinite, H) # println("θ: ", θ) From 8f1ebc581c1614c5efcfceca84bb5f5c3e7d5243 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Sat, 29 Nov 2025 20:01:10 +0000 Subject: [PATCH 11/25] Use hessian symmetry for jacobian of hessian --- research/src/riemannian_hmc_utility.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/src/riemannian_hmc_utility.jl b/research/src/riemannian_hmc_utility.jl index 92937864e..a0738d188 100644 --- a/research/src/riemannian_hmc_utility.jl +++ b/research/src/riemannian_hmc_utility.jl @@ -45,7 +45,7 @@ function gen_∂G∂θ_fwd(Vfunc, x; f=identity) out = zeros(eltype(x), d^2, d) function ∂G∂θ_fwd(y) - hess = z -> ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}()) + hess = z -> Symmetric(ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}())) ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}()) return out end From 981932ee6544d4612764c27504805dec451c9128 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Sat, 29 Nov 2025 20:01:23 +0000 Subject: [PATCH 12/25] Implement Implicit Midpoint integrator --- src/AdvancedHMC.jl | 2 +- src/riemannian/integrator.jl | 104 +++++++++++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 23d38bac2..90edea783 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -51,7 +51,7 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog include("riemannian/metric.jl") export AbstractRiemannianMetric, DenseRiemannianMetric, IdentityMap, SoftAbsMap include("riemannian/integrator.jl") -export GeneralizedLeapfrog +export GeneralizedLeapfrog, ImplicitMidpoint include("riemannian/hamiltonian.jl") include("trajectory.jl") diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 3d818e9f1..e7d5b4a31 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -24,11 +24,47 @@ function Base.show(io::IO, l::GeneralizedLeapfrog) return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") end -# fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv +abstract type AbstractImplicitMidpoint{T} <: AbstractIntegrator end + +step_size(lf::AbstractImplicitMidpoint) = lf.ϵ +jitter(::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, lf::AbstractImplicitMidpoint) = lf +function temper( + lf::AbstractImplicitMidpoint, r, ::NamedTuple{(:i, :is_half),<:Tuple{Integer,Bool}}, ::Int +) + return r end +stat(lf::AbstractImplicitMidpoint) = (step_size=step_size(lf), nom_step_size=nom_step_size(lf)) +update_nom_step_size(lf::AbstractImplicitMidpoint, ϵ) = @set lf.ϵ = ϵ + +""" +$(TYPEDEF) + +Implicit midpoint integrator with fixed step size `ϵ`. + +# Fields + +$(TYPEDFIELDS) + + +## References + +1. James A. Brofos, Roy R. Lederman. "Evaluating the Implicit Midpoint +Integrator for Riemannian Manifold Hamiltonian Monte Carlo" +""" +struct ImplicitMidpoint{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} + "Step size." + ϵ::T + n::Int +end +function Base.show(io::IO, l::ImplicitMidpoint) + return print(io, "ImplicitMidpoint(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") +end + +# fallback to ignore return_cache & cache kwargs for other ∂H∂θ +# function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) +# dv = ∂H∂θ(h, θ, r) +# return return_cache ? (dv, nothing) : dv +# end # TODO(Kai) make sure vectorization works # TODO(Kai) check if tempering is valid @@ -104,3 +140,63 @@ function step( end return res end + +function step( + lf::ImplicitMidpoint{T}, + h::Hamiltonian, + z::P, + n_steps::Int=1; + fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 + full_trajectory::Val{FullTraj}=Val(false), +) where {T<:AbstractScalarOrVec{<:AbstractFloat},TP,P<:PhasePoint{TP},FullTraj} + n_steps = abs(n_steps) # to support `n_steps < 0` cases + + ϵ = fwd ? step_size(lf) : -step_size(lf) + ϵ = ϵ' + + if !(T <: AbstractFloat) || !(TP <: AbstractVector) + @warn "Vectorization is not tested for ImplicitMidpoint." + end + + res = if FullTraj + Vector{P}(undef, n_steps) + else + z + end + + for i in 1:n_steps + θ_init, r_init = z.θ, z.r + + + θ_full = θ_init + r_full = r_init + for j in 1:(lf.n) + θ_bar = (θ_full + θ_init) / 2 + r_bar = (r_full + r_init) / 2 + + dHdr = ∂H∂r(h, θ_bar, r_bar) + (; value, gradient) = ∂H∂θ(h, θ_bar, r_bar) + + θ_full = θ_init + ϵ * dHdr + r_full = r_init - ϵ * gradient + end + + (; value, gradient) = ∂H∂θ(h, θ_full, r_full) + z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) + + if FullTraj + res[i] = z + else + res = z + end + if !isfinite(z) + # Remove undef + if FullTraj + res = res[isassigned.(Ref(res), 1:n_steps)] + end + break + end + end + + return res +end \ No newline at end of file From d96d9395065c5af60e9ea780cd8c5c388e8ca5fd Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 11:12:11 +0000 Subject: [PATCH 13/25] Remove riemannian utils from build --- src/AdvancedHMC.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 26e50b3ca..30350a483 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -84,9 +84,6 @@ export Trajectory, MultinomialTS, find_good_stepsize -include("../research/src/riemannian_hmc_utility.jl") -export prepare_sample_target - # Useful defaults @deprecate find_good_eps find_good_stepsize From ca4d2bd25c75f95a1ae6fdcac7db4c49f17935fd Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 11:31:25 +0000 Subject: [PATCH 14/25] Remove unbound type parameter from DenseRiemannianMetric --- src/riemannian/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 5df4dd11d..d81383806 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -45,7 +45,7 @@ struct DenseRiemannianMetric{ end # TODO Make dense mass matrix support matrix-mode parallel -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) _temp = Vector{Float64}(undef, size[1]) return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) end From 346a1299805f7e5d78b1d6baa7448b26c4d7b798 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 11:46:47 +0000 Subject: [PATCH 15/25] Fix method overwriting --- src/riemannian/hamiltonian.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 5860a9693..62526a57e 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -7,7 +7,7 @@ using LinearAlgebra: logabsdet, tr, diagm, logdet function phasepoint( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, θ::AbstractVecOrMat{T}, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, ) where {T<:Real} return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) end @@ -16,7 +16,7 @@ end function refresh( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, ::FullMomentumRefreshment, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, z::PhasePoint, ) return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) @@ -26,7 +26,7 @@ end function refresh( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, ref::PartialMomentumRefreshment, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, z::PhasePoint, ) return phasepoint( From ae057b56d9629dbc7cc0634ac20948569a987f76 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 11:51:42 +0000 Subject: [PATCH 16/25] Remove duplicate include --- src/AdvancedHMC.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 30350a483..97349a501 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -66,9 +66,6 @@ include("riemannian/integrator.jl") export GeneralizedLeapfrog, ImplicitMidpoint include("riemannian/hamiltonian.jl") -include("riemannian/metric.jl") -export IdentityMap, SoftAbsMap, DenseRiemannianMetric - include("riemannian/hamiltonian.jl") include("trajectory.jl") From 703ed09f78e1815f4a807849da76b814b6d1f783 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 11:51:42 +0000 Subject: [PATCH 17/25] Remove duplicate include --- src/AdvancedHMC.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 30350a483..e8b10bf3b 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -66,11 +66,6 @@ include("riemannian/integrator.jl") export GeneralizedLeapfrog, ImplicitMidpoint include("riemannian/hamiltonian.jl") -include("riemannian/metric.jl") -export IdentityMap, SoftAbsMap, DenseRiemannianMetric - -include("riemannian/hamiltonian.jl") - include("trajectory.jl") export Trajectory, HMCKernel, From 187aa58b67f9c5e076cf7800df1ae5486415d5f7 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 12:32:28 +0000 Subject: [PATCH 18/25] Fixes for tests --- src/integrator.jl | 8 ++++---- src/riemannian/hamiltonian.jl | 4 ++-- src/riemannian/integrator.jl | 16 ++++++++-------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 004028e1e..f41c33c56 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -226,7 +226,7 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - if FullTraj + if FullTraj === true res = Vector{P}(undef, n_steps) end @@ -248,18 +248,18 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj + if FullTraj === true res[i] = z end if !isfinite(z) # Remove undef - if FullTraj + if FullTraj === true resize!(res, i) end break end end - return if FullTraj + return if FullTraj === true res else z diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index e82fb31df..aaa9174a3 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -175,7 +175,7 @@ function ∂H∂θ_cache( mul!(tmp1, Q, tmp2) # tmp2 = tmp1 * Q' mul!(tmp2, tmp1, Q') - term_2_cached = tmp2 + # term_2_cached = tmp2 # g = # -mapreduce(vcat, 1:d) do i @@ -187,7 +187,7 @@ function ∂H∂θ_cache( g = similar(∂ℓπ∂θ) @inbounds for i in 1:d ∂H∂θᵢ = ∂H∂θ[:, :, i] - g[i] = term_1_prod[i] + 1/2 * tr(term_2_cached * ∂H∂θᵢ) + g[i] = term_1_prod[i] + 1/2 * tr(tmp2 * ∂H∂θᵢ) end g .*= -1 diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 4dc3c0d6d..d5cb103e2 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -80,7 +80,7 @@ end # TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` function step( lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, z::P, n_steps::Int=1; fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 @@ -95,7 +95,7 @@ function step( @warn "Vectorization is not tested for GeneralizedLeapfrog." end - res = if FullTraj + res = if FullTraj === true Vector{P}(undef, n_steps) else z @@ -134,14 +134,14 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj + if FullTraj === true res[i] = z else res = z end if !isfinite(z) # Remove undef - if FullTraj + if FullTraj === true res = res[isassigned.(Ref(res), 1:n_steps)] end break @@ -152,7 +152,7 @@ end function step( lf::ImplicitMidpoint{T}, - h::Hamiltonian, + h::Hamiltonian{<:DenseRiemannianMetric}, z::P, n_steps::Int=1; fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 @@ -167,7 +167,7 @@ function step( @warn "Vectorization is not tested for ImplicitMidpoint." end - res = if FullTraj + res = if FullTraj === true Vector{P}(undef, n_steps) else z @@ -192,14 +192,14 @@ function step( (; value, gradient) = ∂H∂θ(h, θ_full, r_full) z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - if FullTraj + if FullTraj === true res[i] = z else res = z end if !isfinite(z) # Remove undef - if FullTraj + if FullTraj === true res = res[isassigned.(Ref(res), 1:n_steps)] end break From 4dbe9feabf89d5c6eb146ba9b88b63de172efcef Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 12:43:44 +0000 Subject: [PATCH 19/25] Test fixes --- src/integrator.jl | 1 + src/riemannian/integrator.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/integrator.jl b/src/integrator.jl index f41c33c56..4b615486f 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -226,6 +226,7 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' + res = nothing if FullTraj === true res = Vector{P}(undef, n_steps) end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index d5cb103e2..d6c91a101 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -107,7 +107,7 @@ function step( #r = temper(lf, r, (i=i, is_half=true), n_steps) # eq (16) of Girolami & Calderhead (2011) r_half = r_init - local cache + local cache = nothing for j in 1:(lf.n) # Reuse cache for the first iteration if j == 1 From 3de6002161e32a5c7186a301c7edb63a67116e8f Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 12:52:43 +0000 Subject: [PATCH 20/25] Fixes for tests --- src/integrator.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 4b615486f..c012f7fa6 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -226,9 +226,10 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - res = nothing - if FullTraj === true - res = Vector{P}(undef, n_steps) + res = if FullTraj === true + Vector{P}(undef, n_steps) + else + z end (; θ, r) = z @@ -251,6 +252,8 @@ function step( # Update result if FullTraj === true res[i] = z + else + res = z end if !isfinite(z) # Remove undef @@ -260,9 +263,5 @@ function step( break end end - return if FullTraj === true - res - else - z - end + return res end From 069c6a8170fca3e1e2bd72d40f2b34d07f6f6703 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 13:18:24 +0000 Subject: [PATCH 21/25] Fixes for tests --- src/integrator.jl | 6 +++--- src/riemannian/integrator.jl | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index c012f7fa6..5fbfe7fd6 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -226,7 +226,7 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - res = if FullTraj === true + res = if full_trajectory Vector{P}(undef, n_steps) else z @@ -250,14 +250,14 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj === true + if full_trajectory res[i] = z else res = z end if !isfinite(z) # Remove undef - if FullTraj === true + if full_trajectory resize!(res, i) end break diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index d6c91a101..a3a5d7b72 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -95,7 +95,7 @@ function step( @warn "Vectorization is not tested for GeneralizedLeapfrog." end - res = if FullTraj === true + res = if full_trajectory Vector{P}(undef, n_steps) else z @@ -134,14 +134,14 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj === true + if full_trajectory res[i] = z else res = z end if !isfinite(z) # Remove undef - if FullTraj === true + if full_trajectory res = res[isassigned.(Ref(res), 1:n_steps)] end break @@ -167,7 +167,7 @@ function step( @warn "Vectorization is not tested for ImplicitMidpoint." end - res = if FullTraj === true + res = if full_trajectory Vector{P}(undef, n_steps) else z @@ -192,14 +192,14 @@ function step( (; value, gradient) = ∂H∂θ(h, θ_full, r_full) z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - if FullTraj === true + if full_trajectory res[i] = z else res = z end if !isfinite(z) # Remove undef - if FullTraj === true + if full_trajectory res = res[isassigned.(Ref(res), 1:n_steps)] end break From b61ea01f20950a499179c29f8884256130e6ac7b Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 13:25:20 +0000 Subject: [PATCH 22/25] Fixes for tests --- src/integrator.jl | 15 ++++++--------- src/riemannian/integrator.jl | 12 ++++++------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 5fbfe7fd6..bd79eec47 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -226,10 +226,9 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - res = if full_trajectory - Vector{P}(undef, n_steps) - else - z + res = nothing + if FullTraj + res = Vector{P}(undef, n_steps) end (; θ, r) = z @@ -250,18 +249,16 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) # Update result - if full_trajectory + if FullTraj res[i] = z - else - res = z end if !isfinite(z) # Remove undef - if full_trajectory + if FullTraj resize!(res, i) end break end end - return res + return FullTraj === true ? res : z end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index a3a5d7b72..3a096fa9d 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -95,7 +95,7 @@ function step( @warn "Vectorization is not tested for GeneralizedLeapfrog." end - res = if full_trajectory + res = if FullTraj Vector{P}(undef, n_steps) else z @@ -134,14 +134,14 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) # Update result - if full_trajectory + if FullTraj res[i] = z else res = z end if !isfinite(z) # Remove undef - if full_trajectory + if FullTraj res = res[isassigned.(Ref(res), 1:n_steps)] end break @@ -167,7 +167,7 @@ function step( @warn "Vectorization is not tested for ImplicitMidpoint." end - res = if full_trajectory + res = if FullTraj Vector{P}(undef, n_steps) else z @@ -192,14 +192,14 @@ function step( (; value, gradient) = ∂H∂θ(h, θ_full, r_full) z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - if full_trajectory + if FullTraj res[i] = z else res = z end if !isfinite(z) # Remove undef - if full_trajectory + if FullTraj res = res[isassigned.(Ref(res), 1:n_steps)] end break From 4d847e2f4bacd1ffc0cfd03e1dafc29a7f0fb656 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 4 Dec 2025 13:55:50 +0000 Subject: [PATCH 23/25] Efficient rand_momentum for softabs --- src/riemannian/metric.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index d81383806..e3beb0441 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -27,6 +27,15 @@ function softabs(X::AbstractMatrix{T}, α=20.0) where {T<:Real} return Q * diagm(softabsλ) * Q', Q, λ, softabsλ end +function softabs_decomp(X::AbstractMatrix{T}, α=20.0) where {T<:Real} + # Enforce symmetry for type stability + F = eigen(Symmetric(X)) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q, softabsλ +end + (map::SoftAbsMap)(x) = softabs(x, map.α)[1] struct DenseRiemannianMetric{ @@ -71,3 +80,15 @@ function rand_momentum( r = chol.L * r return r end + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T,<:SoftAbsMap}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + Q, softabsλ = softabs_decomp(metric.G(θ), metric.map.α) + r = Q * Diagonal(sqrt.(softabsλ)) * r + return r +end From 473cc934c4afd9eaba31fa3be88b59eb9eca4b5c Mon Sep 17 00:00:00 2001 From: nsiccha Date: Mon, 1 Dec 2025 15:37:36 +0100 Subject: [PATCH 24/25] Implements a simple Nutpie style adaptation (using both positions and gradients, but not changing the schedule). (#473) * initial changes to get a working demo * fix tests, add new ones, and add documentation * fix type * fix some stray tests * delete tmp folder with demo * address review comments * reference interface refactor issue * refactor and fix test rng handling * improve docstring for NutpieVar * remove superfluous white space * fix JET tests * add entry to history, bump version * fix NutpieVar docstring * increase number of tests for mass matrix adaptation --------- Co-authored-by: Markus Hauru --- HISTORY.md | 7 +++ Project.toml | 2 +- docs/src/api.md | 12 ++-- src/AdvancedHMC.jl | 3 +- src/abstractmcmc.jl | 2 +- src/adaptation/Adaptation.jl | 23 ++++---- src/adaptation/massmatrix.jl | 105 +++++++++++++++++++++++++++++---- src/adaptation/stan_adaptor.jl | 8 +-- src/adaptation/stepsize.jl | 4 +- src/sampler.jl | 25 ++++---- test/adaptation.jl | 88 ++++++++++++++++++++++++--- 11 files changed, 224 insertions(+), 55 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 038968ef1..a9daf473d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,12 @@ # AdvancedHMC Changelog +## 0.8.4 + + - Introduces an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)), + currently to be initialized for a `metric` of type `DiagEuclideanMetric` + via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))` + until a new interface is introduced in an upcoming breaking release to specify the method of adaptation. + ## 0.8.0 - To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`). diff --git a/Project.toml b/Project.toml index 8c32e813a..f38dda02a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.8.3" +version = "0.8.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/api.md b/docs/src/api.md index a1c488fb8..e7caf2d0c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -32,11 +32,15 @@ where `ϵ` is the step size of leapfrog integration. ### Adaptor (`adaptor`) - Adapt the mass matrix `metric` of the Hamiltonian dynamics: `mma = MassMatrixAdaptor(metric)` - + + This is lowered to `UnitMassMatrix`, `WelfordVar` or `WelfordCov` based on the type of the mass matrix `metric` + + There is an experimental way to improve the *diagonal* mass matrix adaptation using gradient information (similar to [nutpie](https://github.com/pymc-devs/nutpie)), + currently to be initialized for a `metric` of type `DiagEuclideanMetric` + via `mma = AdvancedHMC.NutpieVar(size(metric); var=copy(metric.M⁻¹))` + until a new interface is introduced in an upcoming breaking release to specify the method of adaptation. - Adapt the step size of the leapfrog integrator `integrator`: `ssa = StepSizeAdaptor(δ, integrator)` - + + It uses Nesterov's dual averaging with `δ` as the target acceptance rate. - Combine the two above *naively*: `NaiveHMCAdaptor(mma, ssa)` - Combine the first two using Stan's windowed adaptation: `StanHMCAdaptor(mma, ssa)` @@ -61,12 +65,12 @@ sample( Draw `n_samples` samples using the kernel `κ` under the Hamiltonian system `h` - The randomness is controlled by `rng`. - + + If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used. - The initial point is given by `θ`. - The adaptor is set by `adaptor`, for which the default is no adaptation. - + + It will perform `n_adapts` steps of adaptation, for which the default is `1_000` or 10% of `n_samples`, whichever is lower. - `drop_warmup` specifies whether to drop samples. - `verbose` controls the verbosity. diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index e8b10bf3b..699196885 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -88,7 +88,7 @@ export find_good_eps include("adaptation/Adaptation.jl") using .Adaptation import .Adaptation: - StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation + StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint # Helpers for initializing adaptors via AHMC structs @@ -130,6 +130,7 @@ export StepSizeAdaptor, MassMatrixAdaptor, UnitMassMatrix, WelfordVar, + NutpieVar, WelfordCov, NaiveHMCAdaptor, StanHMCAdaptor, diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 413e9de6f..f6ce40009 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -196,7 +196,7 @@ function AbstractMCMC.step( # Adapt h and spl. tstat = stat(t) - h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) tstat = merge(tstat, (is_adapt=isadapted,)) # Compute next transition and state. diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 4f2fde83c..10a9a9805 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -4,13 +4,13 @@ export Adaptation using LinearAlgebra: LinearAlgebra using Statistics: Statistics -using ..AdvancedHMC: AbstractScalarOrVec +using ..AdvancedHMC: AbstractScalarOrVec, PhasePoint using DocStringExtensions """ $(TYPEDEF) -Abstract type for HMC adaptors. +Abstract type for HMC adaptors. """ abstract type AbstractAdaptor end function getM⁻¹ end @@ -21,12 +21,17 @@ function initialize! end function finalize! end export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ +get_position(x::PhasePoint) = x.θ +get_position(x::AbstractVecOrMat{<:AbstractFloat}) = x +const PositionOrPhasePoint = Union{AbstractVecOrMat{<:AbstractFloat}, PhasePoint} + struct NoAdaptation <: AbstractAdaptor end export NoAdaptation include("stepsize.jl") export StepSizeAdaptor, NesterovDualAveraging + include("massmatrix.jl") -export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov +export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, NutpieVar, WelfordCov ## ## Composite adaptors @@ -47,18 +52,14 @@ getϵ(ca::NaiveHMCAdaptor) = getϵ(ca.ssa) # TODO: implement consensus adaptor function adapt!( nca::NaiveHMCAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) - adapt!(nca.ssa, θ, α) - adapt!(nca.pc, θ, α) - return nothing -end -function reset!(aca::NaiveHMCAdaptor) - reset!(aca.ssa) - reset!(aca.pc) + adapt!(nca.ssa, z_or_theta, α) + adapt!(nca.pc, z_or_theta, α) return nothing end + initialize!(adaptor::NaiveHMCAdaptor, n_adapts::Int) = nothing finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa) diff --git a/src/adaptation/massmatrix.jl b/src/adaptation/massmatrix.jl index 105d3baeb..13f360e32 100644 --- a/src/adaptation/massmatrix.jl +++ b/src/adaptation/massmatrix.jl @@ -9,16 +9,18 @@ finalize!(::MassMatrixAdaptor) = nothing function adapt!( adaptor::MassMatrixAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, - α::AbstractScalarOrVec{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, + ::AbstractScalarOrVec{<:AbstractFloat}, is_update::Bool=true, ) - resize_adaptor!(adaptor, size(θ)) - push!(adaptor, θ) + resize_adaptor!(adaptor, size(get_position(z_or_theta))) + push!(adaptor, z_or_theta) is_update && update!(adaptor) return nothing end +Base.push!(a::MassMatrixAdaptor, z_or_theta::PositionOrPhasePoint) = push!(a, get_position(z_or_theta)) + ## Unit mass matrix adaptor struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end @@ -39,7 +41,7 @@ getM⁻¹(::UnitMassMatrix{T}) where {T} = LinearAlgebra.UniformScaling{T}(one(T function adapt!( ::UnitMassMatrix, - ::AbstractVecOrMat{<:AbstractFloat}, + ::PositionOrPhasePoint, ::AbstractScalarOrVec{<:AbstractFloat}, is_update::Bool=true, ) @@ -47,7 +49,6 @@ function adapt!( end ## Diagonal mass matrix adaptor - abstract type DiagMatrixEstimator{T} <: MassMatrixAdaptor end getM⁻¹(ve::DiagMatrixEstimator) = ve.var @@ -70,7 +71,7 @@ NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matri NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz) -Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s) +Base.push!(nv::NaiveVar, s::AbstractVecOrMat{<:AbstractFloat}) = push!(nv.S, s) reset!(nv::NaiveVar) = resize!(nv.S, 0) @@ -135,7 +136,7 @@ function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat} return nothing end -function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T} +function Base.push!(wv::WelfordVar, s::AbstractVecOrMat{T}) where {T<:AbstractFloat} wv.n += 1 (; δ, μ, M, n) = wv n = T(n) @@ -153,6 +154,90 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat} return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5)) end +""" + NutpieVar + +Nutpie-style diagonal mass matrix estimator (using positions and gradients). + +Expected to converge faster and to a better mass matrix than [`WelfordVar`](@ref), for which it is a drop-in replacement. + +Can be initialized via `NutpieVar(sz)` where `sz` is either a `Tuple{Int}` or a `Tuple{Int,Int}`. + +# Fields + +$(FIELDS) +""" +mutable struct NutpieVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} + "Online variance estimator of the posterior positions." + position_estimator::WelfordVar{T,E,V} + "Online variance estimator of the posterior gradients." + gradient_estimator::WelfordVar{T,E,V} + "The number of observations collected so far." + n::Int + "The minimal number of observations after which the estimate of the variances can be updated." + n_min::Int + "The estimated variances - initialized to ones, updated after calling [`update!`](@ref) if `n > n_min`." + var::V + function NutpieVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V} + return new{eltype(E),E,V}( + WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), + WelfordVar(n, n_min, copy(μ), copy(M), copy(δ), copy(var)), + n, n_min, var + ) + end +end + +function Base.show(io::IO, ::NutpieVar{T}) where {T} + return print(io, "NutpieVar{", T, "} adaptor") +end + +function NutpieVar{T}( + sz::Union{Tuple{Int},Tuple{Int,Int}}=(2,); n_min::Int=10, var=ones(T, sz) +) where {T<:AbstractFloat} + return NutpieVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) +end + +function NutpieVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) + return NutpieVar{Float64}(sz; kwargs...) +end + +function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat} + if size_θ != size(nv.var) + @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." + resize_adaptor!(nv.position_estimator, size_θ) + resize_adaptor!(nv.gradient_estimator, size_θ) + nv.var = ones(T, size_θ) + end +end + +function resize_adaptor!(nv::NutpieVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat} + length_θ = first(size_θ) + if length_θ != size(nv.var, 1) + @assert nv.n == 0 "Cannot resize a var estimator when it contains samples." + resize_adaptor!(nv.position_estimator, size_θ) + resize_adaptor!(nv.gradient_estimator, size_θ) + fill!(resize!(nv.var, length_θ), T(1)) + end +end + +function reset!(nv::NutpieVar) + nv.n = 0 + reset!(nv.position_estimator) + reset!(nv.gradient_estimator) +end + +Base.push!(::NutpieVar, x::AbstractVecOrMat{<:AbstractFloat}) = error("`NutpieVar` adaptation requires position and gradient information!") + +function Base.push!(nv::NutpieVar, z::PhasePoint) + nv.n += 1 + push!(nv.position_estimator, z.θ) + push!(nv.gradient_estimator, z.ℓπ.gradient) + return nothing +end + +# Ref: https://github.com/pymc-devs/nutpie +get_estimation(nv::NutpieVar) = sqrt.(get_estimation(nv.position_estimator) ./ get_estimation(nv.gradient_estimator)) + ## Dense mass matrix adaptor abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end @@ -175,7 +260,7 @@ end NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}()) -Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s) +Base.push!(nc::NaiveCov, s::AbstractVector{<:AbstractFloat}) = push!(nc.S, s) reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0) @@ -225,7 +310,7 @@ function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat} return nothing end -function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T} +function Base.push!(wc::WelfordCov, s::AbstractVector{T}) where {T<:AbstractFloat} wc.n += 1 (; δ, μ, n, M) = wc n = T(n) diff --git a/src/adaptation/stan_adaptor.jl b/src/adaptation/stan_adaptor.jl index b36a22597..931e741a0 100644 --- a/src/adaptation/stan_adaptor.jl +++ b/src/adaptation/stan_adaptor.jl @@ -136,20 +136,20 @@ is_window_end(a::StanHMCAdaptor) = a.state.i in a.state.window_splits function adapt!( tp::StanHMCAdaptor, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) tp.state.i += 1 - adapt!(tp.ssa, θ, α) + adapt!(tp.ssa, z_or_theta, α) - resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary. + resize_adaptor!(tp.pc, size(get_position(z_or_theta))) # Resize pre-conditioner if necessary. # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp if is_in_window(tp) # We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window. is_update_M⁻¹ = is_window_end(tp) - adapt!(tp.pc, θ, α, is_update_M⁻¹) + adapt!(tp.pc, z_or_theta, α, is_update_M⁻¹) end if is_window_end(tp) diff --git a/src/adaptation/stepsize.jl b/src/adaptation/stepsize.jl index 2afbb651e..cacb463db 100644 --- a/src/adaptation/stepsize.jl +++ b/src/adaptation/stepsize.jl @@ -174,7 +174,7 @@ end # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/stepsize_adaptation.hpp # Note: This function is not merged with `adapt!` to empahsize the fact that # step size adaptation is not dependent on `θ`. -# Note 2: `da.state` and `α` support vectorised HMC but should do so together. +# Note 2: `da.state` and `α` support vectorised HMC but should do so together. function adapt_stepsize!( da::NesterovDualAveraging{T}, α::AbstractScalarOrVec{T} ) where {T<:AbstractFloat} @@ -211,7 +211,7 @@ end function adapt!( da::NesterovDualAveraging, - θ::AbstractVecOrMat{<:AbstractFloat}, + ::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) adapt_stepsize!(da, α) diff --git a/src/sampler.jl b/src/sampler.jl index 3e477ba3a..1b282383b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -60,11 +60,11 @@ end function Adaptation.adapt!( h::Hamiltonian, κ::AbstractMCMCKernel, - adaptor::Adaptation.NoAdaptation, - i::Int, - n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, - α::AbstractScalarOrVec{<:AbstractFloat}, + ::Adaptation.NoAdaptation, + ::Int, + ::Int, + ::PositionOrPhasePoint, + ::AbstractScalarOrVec{<:AbstractFloat}, ) return h, κ, false end @@ -75,19 +75,18 @@ function Adaptation.adapt!( adaptor::AbstractAdaptor, i::Int, n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, + z_or_theta::PositionOrPhasePoint, α::AbstractScalarOrVec{<:AbstractFloat}, ) - isadapted = false - if i <= n_adapts + adapt = i <= n_adapts + if adapt i == 1 && Adaptation.initialize!(adaptor, n_adapts) - adapt!(adaptor, θ, α) + adapt!(adaptor, z_or_theta, α) i == n_adapts && finalize!(adaptor) h = update(h, adaptor) κ = update(κ, adaptor) - isadapted = true end - return h, κ, isadapted + return h, κ, adapt end """ @@ -148,7 +147,7 @@ end progress::Bool=false ) Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`. -- The randomness is controlled by `rng`. +- The randomness is controlled by `rng`. - If `rng` is not provided, the default random number generator (`Random.default_rng()`) will be used. - The initial point is given by `θ`. - The adaptor is set by `adaptor`, for which the default is no adaptation. @@ -185,7 +184,7 @@ function sample( t = transition(rng, h, κ, t.z) # Adapt h and κ; what mutable is the adaptor tstat = stat(t) - h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) if isadapted num_divergent_transitions_during_adaption += tstat.numerical_error else diff --git a/test/adaptation.jl b/test/adaptation.jl index 346423eaa..df72c159e 100644 --- a/test/adaptation.jl +++ b/test/adaptation.jl @@ -1,6 +1,8 @@ using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff +using AdvancedHMC: + PhasePoint, DualValue using AdvancedHMC.Adaptation: - WelfordVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset! + DiagMatrixEstimator, WelfordVar, NutpieVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset! function runnuts(ℓπ, metric; n_samples=10_000) D = size(metric, 1) @@ -18,7 +20,37 @@ function runnuts(ℓπ, metric; n_samples=10_000) return (samples=samples, stats=stats, adaptor=adaptor) end +# Temporary function until we've settled on a different interface +function runnuts_nutpie(ℓπ, metric::DiagEuclideanMetric; n_samples=10_000) + D = size(metric, 1) + n_adapts = 5_000 + θ_init = rand(D) + rng = MersenneTwister(0) + + nuts = NUTS(0.8) + h = Hamiltonian(metric, ℓπ, ForwardDiff) + step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init) + integrator = AdvancedHMC.make_integrator(nuts, step_size) + κ = AdvancedHMC.make_kernel(nuts, integrator) + # Constructing like this until we've settled on a different interface + adaptor = AdvancedHMC.StanHMCAdaptor( + AdvancedHMC.Adaptation.NutpieVar(size(metric); var=copy(metric.M⁻¹)), + AdvancedHMC.StepSizeAdaptor(nuts.δ, integrator) + ) + samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose=false) + return (samples=samples, stats=stats, adaptor=adaptor) +end +""" +Computes the condition number of a covariance matrix `cov::AbstractMatrix` after preconditioning with the (diagonal) mass matrix estimated in `a::DiagMatrixEstimator`. + +This is a simple but serviceable proxy for eventual sampling efficiency, but see also https://arxiv.org/abs/1905.09813 for a more involved estimate. + +(A lower number generally means that the estimated mass matrix is better). +""" +preconditioned_cond(a::DiagMatrixEstimator, cov::AbstractMatrix) = cond(sqrt(Diagonal(a.var)) \ cov / sqrt(Diagonal(a.var))) + @testset "Adaptation" begin + Random.seed!(1) # Check that the estimated variance is approximately correct. @testset "Online v.s. naive v.s. true var/cov estimation" begin D = 10 @@ -60,15 +92,32 @@ end @testset "MassMatrixAdaptor constructors" begin θ = [0.0, 0.0, 0.0, 0.0] + z = PhasePoint( + θ, θ, DualValue(0., θ), DualValue(0., θ) + ) pc1 = MassMatrixAdaptor(UnitEuclideanMetric) # default dim = 2 pc2 = MassMatrixAdaptor(DiagEuclideanMetric) + # Constructing like this until we've settled on a different interface + pc2_nutpie = NutpieVar{Float64}((2, )) pc3 = MassMatrixAdaptor(DenseEuclideanMetric) - # Var adaptor dimention should be increased to length(θ) from 2 + # Var adaptor dimension should be increased to length(θ) from 2 AdvancedHMC.adapt!(pc1, θ, 1.0) AdvancedHMC.adapt!(pc2, θ, 1.0) + AdvancedHMC.adapt!(pc2_nutpie, z, 1.0) AdvancedHMC.adapt!(pc3, θ, 1.0) @test AdvancedHMC.Adaptation.getM⁻¹(pc2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2_nutpie) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc3) == + LinearAlgebra.diagm(0 => ones(length(θ))) + + # Making sure "all" MassMatrixAdaptors support getting a PhasePoint instead of a Vector + AdvancedHMC.adapt!(pc1, z, 1.0) + AdvancedHMC.adapt!(pc2, z, 1.0) + AdvancedHMC.adapt!(pc2_nutpie, z, 1.0) + AdvancedHMC.adapt!(pc3, z, 1.0) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(pc2_nutpie) == ones(length(θ)) @test AdvancedHMC.Adaptation.getM⁻¹(pc3) == LinearAlgebra.diagm(0 => ones(length(θ))) end @@ -82,10 +131,14 @@ end adaptor2 = StanHMCAdaptor( MassMatrixAdaptor(DiagEuclideanMetric), NesterovDualAveraging(0.8, 0.5) ) + # Constructing like this until we've settled on a different interface + adaptor2_nutpie = StanHMCAdaptor( + NutpieVar{Float64}((2, )), NesterovDualAveraging(0.8, 0.5) + ) adaptor3 = StanHMCAdaptor( MassMatrixAdaptor(DenseEuclideanMetric), NesterovDualAveraging(0.8, 0.5) ) - for a in [adaptor1, adaptor2, adaptor3] + for a in [adaptor1, adaptor2, adaptor2_nutpie, adaptor3] AdvancedHMC.initialize!(a, 1_000) @test a.state.window_start == 76 @test a.state.window_end == 950 @@ -93,6 +146,7 @@ end AdvancedHMC.adapt!(a, θ, 1.0) end @test AdvancedHMC.Adaptation.getM⁻¹(adaptor2) == ones(length(θ)) + @test AdvancedHMC.Adaptation.getM⁻¹(adaptor2_nutpie) == ones(length(θ)) @test AdvancedHMC.Adaptation.getM⁻¹(adaptor3) == LinearAlgebra.diagm(0 => ones(length(θ))) @@ -112,26 +166,32 @@ end @testset "Adapted mass v.s. true variance" begin D = 10 - n_tests = 5 - @testset "DiagEuclideanMetric" begin + n_tests = 10 + @testset "'Diagonal' MvNormal target" begin for _ in 1:n_tests - Random.seed!(1) # Random variance σ² = 1 .+ abs.(randn(D)) + Σ = Diagonal(σ²) # Diagonal Gaussian - ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²))) + ℓπ = LogDensityDistribution(MvNormal(Σ)) res = runnuts(ℓπ, DiagEuclideanMetric(D)) @test res.adaptor.pc.var ≈ σ² rtol = 0.2 + # For this target, Nutpie (without regularization) will arrive at the true variances after two draws. + res_nutpie = runnuts_nutpie(ℓπ, DiagEuclideanMetric(D)) + @test res.adaptor.pc.var ≈ σ² rtol = 0.2 + @test preconditioned_cond(res_nutpie.adaptor.pc, Σ) < preconditioned_cond(res.adaptor.pc, Σ) + res = runnuts(ℓπ, DenseEuclideanMetric(D)) @test res.adaptor.pc.cov ≈ Diagonal(σ²) rtol = 0.25 end end - @testset "DenseEuclideanMetric" begin + @testset "'Dense' MvNormal target" begin + n_nutpie_superior = 0 for _ in 1:n_tests # Random covariance m = randn(D, D) @@ -143,9 +203,17 @@ end res = runnuts(ℓπ, DiagEuclideanMetric(D)) @test res.adaptor.pc.var ≈ diag(Σ) rtol = 0.2 + # For this target, Nutpie will NOT converge towards the true variances, even after infinite draws. + # HOWEVER, it will asymptotically (but also generally more quickly than Stan) + # find the best preconditioner for the target. + # As these are statistical algorithms, superiority is not always guaranteed, hence this way of testing. + res_nutpie = runnuts_nutpie(ℓπ, DiagEuclideanMetric(D)) + n_nutpie_superior += preconditioned_cond(res_nutpie.adaptor.pc, Σ) < preconditioned_cond(res.adaptor.pc, Σ) + res = runnuts(ℓπ, DenseEuclideanMetric(D)) @test res.adaptor.pc.cov ≈ Σ rtol = 0.25 end + @test n_nutpie_superior > 1 + n_tests / 2 end end @@ -156,6 +224,10 @@ end res = runnuts(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1) @test res.adaptor.pc.var == mass_init + mass_init = fill(0.5, D) + res = runnuts_nutpie(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1) + @test res.adaptor.pc.var == mass_init + mass_init = diagm(0 => fill(0.5, D)) res = runnuts(ℓπ, DenseEuclideanMetric(mass_init); n_samples=1) @test res.adaptor.pc.cov == mass_init From c7090ba6d80eb8c1e4cadc7b7a227c0750804a89 Mon Sep 17 00:00:00 2001 From: nsiccha Date: Thu, 27 Nov 2025 11:57:06 +0100 Subject: [PATCH 25/25] fix JET tests (#479) --- src/integrator.jl | 21 +++++++++------------ src/riemannian/integrator.jl | 4 ++-- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index bd79eec47..bacc1dcdc 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -89,14 +89,14 @@ Leapfrog integrator with randomly "jittered" step size `ϵ` for every trajectory $(TYPEDFIELDS) # Description -This is the same as `LeapFrog`(@ref) but with a "jittered" step size. This means -that at the beginning of each trajectory we sample a step size `ϵ` by adding or -subtracting from the nominal/base step size `ϵ0` some random proportion of `ϵ0`, +This is the same as `LeapFrog`(@ref) but with a "jittered" step size. This means +that at the beginning of each trajectory we sample a step size `ϵ` by adding or +subtracting from the nominal/base step size `ϵ0` some random proportion of `ϵ0`, with the proportion specified by `jitter`, i.e. `ϵ = ϵ0 - jitter * ϵ0 * rand()`. p Jittering might help alleviate issues related to poor interactions with a fixed step size: -- In regions with high "curvature" the current choice of step size might mean over-shoot - leading to almost all steps being rejected. Randomly sampling the step size at the +- In regions with high "curvature" the current choice of step size might mean over-shoot + leading to almost all steps being rejected. Randomly sampling the step size at the beginning of the trajectories can therefore increase the probability of escaping such high-curvature regions. - Exact periodicity of the simulated trajectories might occur, i.e. you might be so @@ -168,7 +168,7 @@ $(TYPEDFIELDS) # Description -Tempering can potentially allow greater exploration of the posterior, e.g. +Tempering can potentially allow greater exploration of the posterior, e.g. in a multi-modal posterior jumps between the modes can be more likely to occur. """ struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: AbstractLeapfrog{T} @@ -226,10 +226,7 @@ function step( ϵ = fwd ? step_size(lf) : -step_size(lf) ϵ = ϵ' - res = nothing - if FullTraj - res = Vector{P}(undef, n_steps) - end + res = FullTraj ? Vector{P}(undef, n_steps) : nothing (; θ, r) = z (; value, gradient) = z.ℓπ @@ -249,12 +246,12 @@ function step( # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) # Update result - if FullTraj + if !isnothing(res) res[i] = z end if !isfinite(z) # Remove undef - if FullTraj + if !isnothing(res) resize!(res, i) end break diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index 3a096fa9d..94269cfc1 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -11,7 +11,7 @@ Generalized leapfrog integrator with fixed step size `ϵ`. $(TYPEDFIELDS) -## References +## References 1. Girolami, Mark, and Ben Calderhead. "Riemann manifold Langevin and Hamiltonian Monte Carlo methods." Journal of the Royal Statistical Society Series B: Statistical Methodology 73, no. 2 (2011): 123-214. """ @@ -77,7 +77,7 @@ end # TODO(Kai) make sure vectorization works # TODO(Kai) check if tempering is valid -# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` +# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` function step( lf::GeneralizedLeapfrog{T}, h::Hamiltonian{<:DenseRiemannianMetric},