-
-
Notifications
You must be signed in to change notification settings - Fork 167
Added benchmarking FAQ #689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,65 @@ | ||
| # FAQ | ||
|
|
||
| ### Compilation is taking a long time. | ||
|
|
||
| - Set `dt0=<not None>`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time. | ||
| - Prefer `SaveAt(t0=True, t1=True)` over `SaveAt(ts=[t0, t1])`, if possible. | ||
| - It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s), e.g. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. | ||
|
|
||
| ### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. | ||
|
|
||
| Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). | ||
|
|
||
| ### Diffrax seem to be slower than <some other library\>? | ||
|
|
||
| Questions of this form are a fairly common source of issues in the Diffrax issue tracker! In practice, Diffrax is consistently amongst the fastest ODE solvers, and these usually stem from incorrect usage (e.g. recompiling your JAX program on each invocation) or comparisons (e.g. using different solvers/tolerances in each implementation). | ||
|
|
||
| Here's a list of some of the things to keep in mind when performing such comparisons: | ||
|
|
||
| 1. First of all, the usual list of JAX profiling concerns: | ||
|
|
||
| a. Make sure that your JAX program is compiled only once, and not repeatedly on each invocation (for example by passing in different raw Python floats each time). Use [`equinox.debug.assert_max_traces(max_traces=1)`](https://docs.kidger.site/equinox/api/debug/#equinox.debug.assert_max_traces) to debug this. | ||
|
|
||
| b. Your entire computation should be wrapped in a single `jax.jit`'d function (or equivalently `equinox.filter_jit`). | ||
|
|
||
| c. Run this function in advance (to JIT-compile it), before running it again to measure its speed. | ||
|
|
||
| d. Make sure not to include any code that is ran outside of the JIT'd function in your timings. | ||
|
|
||
| e. Make sure to call `jax.block_until_ready` on the output of the the function. | ||
|
|
||
| Typically your code should follow this template: | ||
| ```python | ||
| import equinox as eqx | ||
| import jax | ||
| import timeit | ||
|
|
||
| @jax.jit | ||
| @eqx.debug.assert_max_traces(max_traces=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have any overhead? I.e. should it be used to debug, then excluded from the final timing analysis (which is this code block)?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, good catch. Only overhead at compile time, I think! So we'll be fine for the use-case here. |
||
| def run(x): | ||
| ... | ||
|
|
||
| x = ... | ||
| run(x) # compile | ||
| execution_time = min(timeit.repeat(lambda: jax.block_until_ready(run(x)), number=1, repeat=20)) | ||
| ``` | ||
|
|
||
| 2. Use the same ODE solver in both implementations to get an apples-to-apples comparison. It's not surprising that different solvers give different performance characteristics. (And if one implementation does not provide a solver that the other does, then no comparison can be made.) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "same ODE solver" -> "same solver"?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So here (and for your comment below) I'm intentionally only focusing on ODE solvers. The reason for this is that I already know that we could do more to improve performance of SDE solvers, and I've just never really found the time to figure that one out. (And on this note you've been very gracious around not merging the stateful-controls PR on this topic!)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess timing has been mostly ODE focused, but for SDEs is there anything of note (e.g. in how brownian noise is computed between implementations)? |
||
|
|
||
| 3. Use the same step size control in both implementations. | ||
|
|
||
| a. If using adaptive step sizes then note that tolerances (the `rtol`, `atol` in `diffeqsolve(..., stepsize_controller=PIDController(rtol=..., atol=...))`) have solver- and implementation-specific meanings, so having these be the same is not enough. Aim to have roughly the same number of steps instead. You can check the number of steps taken in Diffrax via `diffeqsolve(...).stats['num_steps']`. | ||
|
|
||
| b. If using an automatic initial step size (`diffeqsolve(..., dt0=None)`) then use this (or disable this) in both implementations. | ||
|
|
||
| 4. If comparing to other JAX implementations, then make sure to set `import os; os.environ["EQX_ON_ERROR"] = "nan"` at the top of your script (before you import Diffrax or Equinox). This will disable various runtime correctness checks performed by Diffrax that are are typically not performed by other JAX frameworks. These add a few milliseconds of overhead that typically does not matter in real-word usage but may be large enough to appear in microbenchmarks. | ||
|
|
||
| a. If comparing to a loop-over-steps using `jax.lax.scan`, then the equivalent step size control in Diffrax is `diffeqsolve(..., stepsize_controller=StepTo(...))`. | ||
|
|
||
| 5. If you'd like to be really precise, then the best way to benchmark competing implementations is with a work-precision diagram: solve your ODE once with very tight tolerances and a very accurate solver (in any implementation). Then for each implementation: vary the tolerances or step sizes, and plot the time for the solve against and the numerical difference between the solution and the very accurate solution. This isn't required but is the gold-standard for benchmark comparisons. | ||
|
|
||
| 6. Both implementations should use the same precision (`float32` vs `float64`). Note that JAX defaults to 32-bit precision and requires a flag to enable 64-bit precision. | ||
|
|
||
| 7. The problem being solved should be large enough (ideally at least 100 milliseconds to solve) that you are not simply measuring various small overheads in different frameworks. | ||
|
|
||
| Take a look at [Diffrax issue #82](https://github.com/patrick-kidger/diffrax/issues/82) for a good example of how seemingly-reasonable benchmarks can hide a few pitfalls! | ||
|
|
||
| If you think you have a performance issue – after checking all of the above! – then feel free to open an issue on the Diffrax issue page. You should include a code snippet that demonstrates the issue; typically this should not be more than around 50 lines long if we are going to be able to volunteer to help you debug it :-). | ||
|
|
||
| ### How does this compare to `jax.experimental.ode.odeint`? | ||
|
|
||
| The equivalent solver in Diffrax is: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the first two of these points worth keeping anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They feel niche enough (i.e. not meeting the 'F' in 'FAQ') that I was feeling inclined to cut them.
If we ever find ourselves with a longer list of tips-and-tricks to tackle compilation time then I'd be happy to give them a home there, however.