Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Dec 14, 2025

Previously, if you were performing MCMC sampling with a LogDensityFunction ldf and you wanted to figure out a set of reasonable initial parameters you could do:

_, vi = DynamicPPL.init!!(rng, ldf.model, ldf.varinfo, strategy)
logp = ldf.getlogjoint(vi)
params = vi[:]

You could then check, for example, whether logp was finite, etc. etc -- see e.g. SliceSampling.jl or Turing's HMC implementation.

Now that LogDensityFunction no longer stores a varinfo, this becomes impossible. Turing can currently handle it because it is responsible for generating the VarInfo and creating the LDF. However, external samplers cannot, because the external sampler is just given the LDF and that's it -- there's no way for it to regenerate the VarInfo.

This unfortunately blocks a lot of things. For example, in MH if you want to generate a new proposal from the prior, you can't do it with just a LogDensityFunction. That's partly why in Turing we are still stuck with carrying an old model + varinfo wrapper struct around. See https://github.com/TuringLang/Turing.jl/blob/4dc7ad096f92fd571de312e7751986484ac6cb50/src/mcmc/mh.jl#L181-L199


There is a cheap way to fix it, which is to just lump a varinfo into LogDensityFunction. I think that is a band-aid, and I don't like that. It also makes my life harder later on, as I REALLY want us to minimise usage of any varinfo that is not OnlyAccsVarInfo.

So, this PR introduces a new function that will do that correctly for LogDensityFunction using a custom accumulator. It's kind of like ValuesAsInModel, but it is a bit more 'batteries included' since it also calculates logjac.

While this does get things across the line in a way I'm generally happy with (see the docstring for a very nice invariant, which is also checked in the test suite), it is my belief that this rand_with_logpdf function will need to eventually live inside AbstractMCMC. I think it should be a method on AbstractMCMC.LogDensityModel, which would delegate to the inner model.logdensity object.

The problem with that is, unfortuntaely, the strategy argument is currently very much DynamicPPL exclusive. So the InitFrom... strategies also need to be upstreamed, which I feel less happy about doing. So I think this PR is a good middle ground for now.

@penelopeysm penelopeysm changed the title Implement DynamicPPL.rand_with_logpdf Implement DynamicPPL.rand_with_logpdf on LogDensityFunction Dec 14, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Dec 14, 2025

Benchmark Report

  • this PR's head: 7d4fe3c1856c92e9d99759c17bf26d59d2162d63
  • base branch: 6266f644ce8caaa3b98fc65c1eb960f4f77243b1

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │        t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ─────────┬───────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │     base │   this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │   338.55 │    378.61 │    0.89 │  10.55 │   10.83 │    0.97 │   3571.81 │   4101.93 │    0.87 │
│                   LDA │    12 │ reversediff │             typed │   true │  2380.29 │   2616.99 │    0.91 │   4.98 │    4.98 │    1.00 │  11844.95 │  13033.13 │    0.91 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 95313.83 │ 106198.40 │    0.90 │   4.14 │    3.98 │    1.04 │ 394695.57 │ 422712.57 │    0.93 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  7327.59 │   8109.51 │    0.90 │   4.78 │    4.77 │    1.00 │  34999.00 │  38679.58 │    0.90 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │ 30567.11 │  33424.21 │    0.91 │  10.06 │    9.90 │    1.02 │ 307380.59 │ 330864.29 │    0.93 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │  3332.93 │   3687.63 │    0.90 │  12.90 │   12.73 │    1.01 │  42988.32 │  46956.88 │    0.92 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     2.47 │      2.68 │    0.92 │   3.99 │    3.99 │    1.00 │      9.83 │     10.68 │    0.92 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │  1123.07 │   1216.84 │    0.92 │ 121.92 │   62.00 │    1.97 │ 136922.33 │  75441.51 │    1.81 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │      err │       err │     err │    err │     err │     err │       err │       err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │      err │       err │     err │    err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │  1518.77 │   1661.82 │    0.91 │   6.38 │    5.72 │    1.12 │   9688.22 │   9500.64 │    1.02 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │  1523.90 │   1668.54 │    0.91 │   5.34 │    5.42 │    0.99 │   8144.24 │   9045.37 │    0.90 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │  1516.60 │   1681.91 │    0.90 │  90.47 │   92.05 │    0.98 │ 137204.71 │ 154819.11 │    0.89 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │  1550.64 │   1698.42 │    0.91 │  62.38 │   57.64 │    1.08 │  96728.33 │  97900.66 │    0.99 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │  1528.24 │   1678.77 │    0.91 │  61.24 │  119.49 │    0.51 │  93583.56 │ 200602.04 │    0.47 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │  1531.13 │   1663.83 │    0.92 │  59.52 │   56.15 │    1.06 │  91130.22 │  93417.29 │    0.98 │
│              Submodel │     1 │    mooncake │             typed │   true │     6.55 │      7.18 │    0.91 │   5.41 │    5.13 │    1.06 │     35.46 │     36.79 │    0.96 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴──────────┴───────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@penelopeysm penelopeysm force-pushed the py/ldfinit branch 2 times, most recently from 2cf2f3f to a35d0e2 Compare December 14, 2025 17:35
@codecov
Copy link

codecov bot commented Dec 14, 2025

Codecov Report

❌ Patch coverage is 71.42857% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 78.97%. Comparing base (6266f64) to head (7d4fe3c).

Files with missing lines Patch % Lines
src/logdensityfunction.jl 71.42% 12 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1178      +/-   ##
==========================================
+ Coverage   78.95%   78.97%   +0.02%     
==========================================
  Files          41       41              
  Lines        3896     3938      +42     
==========================================
+ Hits         3076     3110      +34     
- Misses        820      828       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1178 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1178/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants