Skip to content

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Dec 17, 2025

This PR starts using VarNamedTuple for VAIMAcc and to_samples. It also adds new features and fixes to VNT that were needed or useful in the process.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 17, 2025

Benchmark Report

  • this PR's head: 4a585adb3cbb01193a54ed45fb7eab31a4289d39
  • base branch: 753ca81b85af88adb0970dff88670dda2445fa4d

Computer Information

Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │   375.88 │   364.35 │    1.03 │   9.59 │    9.99 │    0.96 │   3606.17 │   3640.01 │    0.99 │
│                   LDA │    12 │ reversediff │             typed │   true │  2747.63 │  2582.00 │    1.06 │   4.90 │    5.11 │    0.96 │  13458.31 │  13181.30 │    1.02 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 60231.14 │ 59744.40 │    1.01 │   6.12 │    6.20 │    0.99 │ 368533.56 │ 370712.46 │    0.99 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  6002.04 │  5804.86 │    1.03 │   5.67 │    5.76 │    0.98 │  34019.31 │  33427.19 │    1.02 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │ 31239.31 │ 32437.82 │    0.96 │  10.43 │   10.07 │    1.04 │ 325760.56 │ 326566.16 │    1.00 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │  3664.61 │  3689.31 │    0.99 │  12.48 │   12.55 │    0.99 │  45738.41 │  46316.61 │    0.99 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     2.60 │     2.58 │    1.01 │   3.85 │    3.92 │    0.98 │     10.02 │     10.10 │    0.99 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │  1092.90 │  1081.57 │    1.01 │ 140.61 │  136.34 │    1.03 │ 153677.72 │ 147459.09 │    1.04 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │      err │      err │     err │    err │     err │     err │       err │       err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │      err │      err │     err │    err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │  1496.06 │  1488.47 │    1.01 │   6.38 │    6.70 │    0.95 │   9543.46 │   9978.80 │    0.96 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │  1498.76 │  1479.06 │    1.01 │   5.86 │    5.90 │    0.99 │   8782.91 │   8729.14 │    1.01 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │  1494.59 │  1481.56 │    1.01 │ 102.85 │  103.90 │    0.99 │ 153712.14 │ 153938.43 │    1.00 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │  1497.69 │  1488.99 │    1.01 │  60.79 │   62.34 │    0.98 │  91040.47 │  92818.56 │    0.98 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │  1497.74 │  1469.18 │    1.02 │  63.35 │   62.74 │    1.01 │  94885.77 │  92171.38 │    1.03 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │  1496.22 │  1481.71 │    1.01 │  59.67 │   61.20 │    0.98 │  89281.65 │  90676.33 │    0.98 │
│              Submodel │     1 │    mooncake │             typed │   true │     3.34 │     3.27 │    1.02 │  11.12 │   11.09 │    1.00 │     37.15 │     36.30 │    1.02 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link

codecov bot commented Dec 17, 2025

Codecov Report

❌ Patch coverage is 25.17483% with 107 lines in your changes missing coverage. Please review.
✅ Project coverage is 36.76%. Comparing base (753ca81) to head (4a585ad).

Files with missing lines Patch % Lines
src/varnamedtuple.jl 22.65% 99 Missing ⚠️
src/test_utils/models.jl 0.00% 8 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (753ca81) and HEAD (4a585ad). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (753ca81) HEAD (4a585ad)
11 5
Additional details and impacted files
@@                        Coverage Diff                         @@
##           mhauru/vnt-concretized-slices    #1182       +/-   ##
==================================================================
- Coverage                          80.16%   36.76%   -43.41%     
==================================================================
  Files                                 42       41        -1     
  Lines                               4356     4464      +108     
==================================================================
- Hits                                3492     1641     -1851     
- Misses                               864     2823     +1959     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1182 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1182/

Copy link
Member Author

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the process of using VNT for VAIMAcc, I also had to implement values, length, empty, and isempty for VNT, so they are all bundled in this PR.

This would be ready for review if not for an annoying issue with LKJCholesky, which probably requires some special treatment that I haven't figured out yet. In particular, I think the problem stems from an interaction with MCMCChains, which splits a Cholesky variable into the component elements of the the .L field. I'll need to come back to this. Maybe FlexiChains will come in time to save me from having to solve this?

Benchmarks

Code:

Details
module VAIMBench

using DynamicPPL, Distributions, Chairmarks
using StableRNGs: StableRNG
include("benchmarks/src/Models.jl")
using .Models: Models

function run()
    rng = StableRNG(23)

    smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))

    loop_univariate1k, multivariate1k = begin
        data_1k = randn(rng, 1_000)
        loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k)
        multi = Models.multivariate(length(data_1k)) | (; o=data_1k)
        loop, multi
    end

    loop_univariate10k, multivariate10k = begin
        data_10k = randn(rng, 10_000)
        loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k)
        multi = Models.multivariate(length(data_10k)) | (; o=data_10k)
        loop, multi
    end

    lda_instance = begin
        w = [1, 2, 3, 2, 1, 1]
        d = [1, 1, 1, 2, 2, 2]
        Models.lda(2, d, w)
    end

    models = [
        ("simple_assume_observe", Models.simple_assume_observe(randn(rng))),
        ("smorgasbord", smorgasbord_instance),
        ("loop_univariate1k", loop_univariate1k),
        ("multivariate1k", multivariate1k),
        ("loop_univariate10k", loop_univariate10k),
        ("multivariate10k", multivariate10k),
        ("dynamic", Models.dynamic()),
        ("parent", Models.parent(randn(rng))),
        # ("lda", lda_instance),
    ]

    function print_diff(r, ref)
        diff = r.time - ref.time
        units = if diff < 1e-6
            "ns"
        elseif diff < 1e-3
            "µs"
        else
            "ms"
        end
        diff = if units == "ns"
            round(diff / 1e-9; digits=1)
        elseif units == "µs"
            round(diff / 1e-6; digits=1)
        else
            round(diff / 1e-3; digits=1)
        end
        sign = diff < 0 ? "" : "+"
        return println(" ($(sign)$(diff) $units)")
    end

    for (name, m) in models
        println()
        println(name)
        vi = VarInfo(m)
        ranges = DynamicPPL.get_ranges_and_linked(vi)
        if !(ranges isa Tuple)
            ranges = (ranges,)
        end
        x = vi[:]
        strategy = InitFromParams(DynamicPPL.VectorWithRanges{false}(ranges..., x), nothing)

        print("Without VAIMAcc: ")
        oavi = OnlyAccsVarInfo(
            (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
        )
        wo = @b DynamicPPL.init!!($m, $oavi, $strategy)
        display(wo)

        print("With VAIMAcc:    ")
        oavi = OnlyAccsVarInfo(
            (
                DynamicPPL.LogPriorAccumulator(),
                DynamicPPL.LogLikelihoodAccumulator(),
                DynamicPPL.ValuesAsInModelAccumulator(false),
            ),
        )
        w = @b DynamicPPL.init!!($m, $oavi, $strategy)
        show(stdout, MIME"text/plain"(), w)
        print_diff(w, wo)

        print("Only VAIMAcc:    ")
        oavi = OnlyAccsVarInfo((DynamicPPL.ValuesAsInModelAccumulator(false),))
        o = @b DynamicPPL.init!!($m, $oavi, $strategy)
        diff = o.time - wo.time
        show(stdout, MIME"text/plain"(), o)
        print_diff(o, wo)
    end
end

run()

end

This evaluates each model from our benchmark suite 1) with logprob accumulators only, 2) with logprob and VAIM accumulators, 3) with VAIMAcc only. For 2) and 3) I also print the time difference compared to 1). The evaluations are done using the fancy new machinery that FastLDF brought, i.e. VectorWithRanges.

Results on the current release:

simple_assume_observe
Without VAIMAcc: 12.153 ns
With VAIMAcc:    188.554 ns (9 allocs: 384 bytes) (+176.4 ns)
Only VAIMAcc:    186.769 ns (9 allocs: 384 bytes) (+174.6 ns)

smorgasbord
Without VAIMAcc: 5.688 μs (12 allocs: 6.156 KiB)
With VAIMAcc:    77.333 μs (563 allocs: 22.328 KiB) (+71.6 µs)
Only VAIMAcc:    64.000 μs (560 allocs: 21.391 KiB) (+58.3 µs)

loop_univariate1k
Without VAIMAcc: 21.000 μs (8 allocs: 16.172 KiB)
With VAIMAcc:    634.916 μs (7269 allocs: 193.406 KiB) (+613.9 µs)
Only VAIMAcc:    626.167 μs (7269 allocs: 193.406 KiB) (+605.2 µs)

multivariate1k
Without VAIMAcc: 11.250 μs (24 allocs: 80.500 KiB)
With VAIMAcc:    12.541 μs (49 allocs: 89.766 KiB) (+1.3 µs)
Only VAIMAcc:    3.042 μs (36 allocs: 73.359 KiB) (-8208.0 ns)

loop_univariate10k
Without VAIMAcc: 280.500 μs (102 allocs: 194.375 KiB)
With VAIMAcc:    10.346 ms (72680 allocs: 1.999 MiB) (+10.1 ms)
Only VAIMAcc:    10.119 ms (72680 allocs: 1.999 MiB) (+9.8 ms)

multivariate10k
Without VAIMAcc: 110.167 μs (24 allocs: 896.500 KiB)
With VAIMAcc:    111.125 μs (49 allocs: 993.766 KiB) (+958.0 ns)
Only VAIMAcc:    23.167 μs (36 allocs: 801.359 KiB) (-87000.0 ns)

dynamic
Without VAIMAcc: 1.195 μs (14 allocs: 880 bytes)
With VAIMAcc:    2.713 μs (46 allocs: 2.938 KiB) (+1.5 µs)
Only VAIMAcc:    1.901 μs (40 allocs: 2.609 KiB) (+705.5 ns)

parent
Without VAIMAcc: 15.874 ns
With VAIMAcc:    219.672 ns (9 allocs: 384 bytes) (+203.8 ns)
Only VAIMAcc:    205.300 ns (9 allocs: 384 bytes) (+189.4 ns)
Main.VAIMBench

Results on this PR:

simple_assume_observe
Without VAIMAcc: 10.905 ns
With VAIMAcc:    12.288 ns (+1.4 ns)
Only VAIMAcc:    3.226 ns (-7.7 ns)

smorgasbord
Without VAIMAcc: 5.525 μs (12 allocs: 6.156 KiB)
With VAIMAcc:    8.070 μs (236 allocs: 31.297 KiB) (+2.5 µs)
Only VAIMAcc:    3.381 μs (233 allocs: 27.188 KiB) (-2144.0 ns)

loop_univariate1k
Without VAIMAcc: 9.625 μs (6 allocs: 16.125 KiB)
With VAIMAcc:    248.041 μs (9015 allocs: 372.641 KiB) (+238.4 µs)
Only VAIMAcc:    11.791 μs (1023 allocs: 60.406 KiB) (+2.2 µs)

multivariate1k
Without VAIMAcc: 11.500 μs (24 allocs: 80.500 KiB)
With VAIMAcc:    12.667 μs (41 allocs: 89.453 KiB) (+1.2 µs)
Only VAIMAcc:    3.292 μs (28 allocs: 72.969 KiB) (-8208.0 ns)

loop_univariate10k
Without VAIMAcc: 95.958 μs (6 allocs: 192.125 KiB)
With VAIMAcc:    2.514 ms (90023 allocs: 3.733 MiB) (+2.4 ms)
Only VAIMAcc:    110.959 μs (10031 allocs: 697.781 KiB) (+15.0 µs)

multivariate10k
Without VAIMAcc: 109.292 μs (24 allocs: 896.500 KiB)
With VAIMAcc:    110.584 μs (41 allocs: 993.453 KiB) (+1.3 µs)
Only VAIMAcc:    23.083 μs (28 allocs: 800.969 KiB) (-86209.0 ns)

dynamic
Without VAIMAcc: 1.159 μs (14 allocs: 880 bytes)
With VAIMAcc:    2.109 μs (35 allocs: 2.469 KiB) (+949.6 ns)
Only VAIMAcc:    1.393 μs (29 allocs: 2.141 KiB) (+233.1 ns)

parent
Without VAIMAcc: 10.901 ns
With VAIMAcc:    12.397 ns (+1.5 ns)
Only VAIMAcc:    3.203 ns (-7.7 ns)
Main.VAIMBench

The TL;DR is that this improves performance a lot, like 10-100x, for

  • Small models.
  • Models with IndexLenses.

For small models this makes using a VAIMAcc go from being the dominant cost of evaluation to being negligible. For big models with heavy likelihood computations this does nothing, since it only affects overheads.

Base.empty(::VarNamedTuple) = VarNamedTuple()

"""
empty!!(vnt::VarNamedTuple)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented empty!! for VNT with the idea that in VAIMAcc we could use it to save allocations, but I actually haven't started using it. Could come back to this at some point as an optimisation.

Comment on lines +1195 to +1199
function to_dict(::Type{T}, vnt::VarNamedTuple) where {T<:AbstractDict{<:VarName}}
pairs = splat(Pair).(zip(keys(vnt), values(vnt)))
return T(pairs...)
end
to_dict(vnt::VarNamedTuple) = to_dict(Dict{VarName,Any}, vnt)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At one point I thought I would need this function, so I made it, but in the end didn't use it. Kinda inclined to keep it though, I think it'll have a use at some point.

Comment on lines +85 to +86
expected_length = sum(prod DynamicPPL.varnamesize, keys(vi))
@test length(ps.params) == expected_length
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our notion of length has changed: Both [@varname(x[1]), @varname(x[2])] and [@varname(x[1:2])] have length 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants