Add tagdepth fast path to tag comparison#807
Conversation
Fixes SciML/OrdinaryDiffEq.jl#3381 and superscedes SciML/OrdinaryDiffEq.jl#3587 . Also fixes NonlinearSolve.jl master and superscedes SciML/NonlinearSolve.jl#932 Superscedes #724 and is a better solution to #714. The crux of the issue is that ForwardDiff.jl's tagging system is somewhat designed around the tag only being used once, i.e. the function is created, the derivative function is called, the tag is set for that derivative as a type of the function being differentiated, and therefore it's unique. Then this ends up working with nested differentiation because you call the inner function first, usually, before the outer function, or only do the combination, and so the tag ordering is set correctly. Mixing tagging with precompilation then leads to this issue where it's possible for the outer tag to be precompiled before the inner tag. This makes the tag ordering the opposite, and what happens is then that the type promotion mechanism gets confused because it is tied to the tag ordering. This seems pretty fundamental because it's a useful property, it's the core property used to prevent perturbation confusion, but it means that this interaction between nested differentiation and precompilation ends up having odd bugs. I tried working around this downstream (SciML/OrdinaryDiffEq.jl#3587) but it was very nasty. Basically, you had to make sure you didn't have dual numbers automatically converting Float64s, as then sometimes it could convert to the inner type instead of the outer type, and it wouldn't do the normal conversion of first to the inner to then the wrapped outer type because doing so required the outer type to postdate the inner type. But, this really then showcases that the bug truly only manifests with nested types. And if you have nested types, you know you don't have perturbation confusion if one tag is nested deeper than the other tag, because there are not the same number of partials. So in the case where the tag depths are not the same, you can do an alternative tag ordering since you will have already proven perturbations aren't confusing. And in that case, you can choose the deeper nested tags to just always be `<` the less deeper tags. So added that and poof, tag nesting worked out in these cases with precompilation. So I think this captures the true crux of the problem and solves it at its core.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #807 +/- ##
==========================================
+ Coverage 90.75% 90.80% +0.05%
==========================================
Files 11 11
Lines 1071 1077 +6
==========================================
+ Hits 972 978 +6
Misses 99 99 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| @inline tagdepth(::Type) = 0 | ||
| @inline tagdepth(::Type{<:Dual{T,V,N}}) where {T,V,N} = 1 + tagdepth(V) | ||
| @inline tagdepth(::Type{<:Tag{F,V}}) where {F,V} = 1 + tagdepth(V) |
There was a problem hiding this comment.
Why not just defining this function on ::Type and Type{<:Dual}? Why ::Type{<:Tag} as well?
There was a problem hiding this comment.
Because when you nest it's a tag of duals
| d1 = tagdepth(Tag{F1,V1}) | ||
| d2 = tagdepth(Tag{F2,V2}) | ||
| d1 != d2 && return d1 < d2 |
There was a problem hiding this comment.
You can define Tags basically arbitrarily, so why would this be safe?
There was a problem hiding this comment.
If the tags are defined according to the interfaces in the package, T is the type being differentiated. If T is the type being differentiated, then this gaurentees tag ordering by differentiation hierarchy. We previously had PRs closed about documenting this saying that it is non public API, so since this package always obeys that invarient internally and it's purposefully non public, it would be non breaking to enforce it.
| d1 = tagdepth(Tag{F1,V1}) | ||
| d2 = tagdepth(Tag{F2,V2}) | ||
| d1 != d2 && return d1 < d2 | ||
| return tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) |
There was a problem hiding this comment.
It seems one could still run into the same precompilation-caused problems here, e.g., if V1 === V2 (e.g. both Float64)?
There was a problem hiding this comment.
Build an example? All of the examples from before that cannot happen because it's not Float64 in both cases but Dual of Float64, and that tag nesting is exactly the part you missed from the earlier part.
|
@devmotion From your comments I think you missed the core part of this. The tag by design does f and eltype (V). That eltype when nested is itself a Tag Dual. In a nested differentiation context that establishes its own natural ordering for the duals, as in the nested context the promotion of duals to the outer dual via appending zero partials to the inner dual is the natural action. As a safety check, what I can do is add a proof that V1 nests V2 exactly as well, i.e. the stripped V1 matches V2. In that case the definition is very clear, and this is the actual case that is triggered by precompilation issue. |
Fixes SciML/OrdinaryDiffEq.jl#3381 and superscedes SciML/OrdinaryDiffEq.jl#3587 . Also fixes NonlinearSolve.jl master and superscedes SciML/NonlinearSolve.jl#932
Superscedes #724 and is a better solution to #714.
The crux of the issue is that ForwardDiff.jl's tagging system is somewhat designed around the tag only being used once, i.e. the function is created, the derivative function is called, the tag is set for that derivative as a type of the function being differentiated, and therefore it's unique. Then this ends up working with nested differentiation because you call the inner function first, usually, before the outer function, or only do the combination, and so the tag ordering is set correctly.
Mixing tagging with precompilation then leads to this issue where it's possible for the outer tag to be precompiled before the inner tag. This makes the tag ordering the opposite, and what happens is then that the type promotion mechanism gets confused because it is tied to the tag ordering. This seems pretty fundamental because it's a useful property, it's the core property used to prevent perturbation confusion, but it means that this interaction between nested differentiation and precompilation ends up having odd bugs.
I tried working around this downstream (SciML/OrdinaryDiffEq.jl#3587) but it was very nasty. Basically, you had to make sure you didn't have dual numbers automatically converting Float64s, as then sometimes it could convert to the inner type instead of the outer type, and it wouldn't do the normal conversion of first to the inner to then the wrapped outer type because doing so required the outer type to postdate the inner type.
But, this really then showcases that the bug truly only manifests with nested types. And if you have nested types, you know you don't have perturbation confusion if one tag is nested deeper than the other tag, because there are not the same number of partials. So in the case where the tag depths are not the same, you can do an alternative tag ordering since you will have already proven perturbations aren't confusing. And in that case, you can choose the deeper nested tags to just always be
<the less deeper tags.So added that and poof, tag nesting worked out in these cases with precompilation. So I think this captures the true crux of the problem and solves it at its core.