From 8cc8c7247555663293a2b3ae9729f30b0c287c6f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 13 May 2026 20:29:36 +0000 Subject: [PATCH] Add tagdepth fast path to tag comparison Fixes https://github.com/SciML/OrdinaryDiffEq.jl/issues/3381 and superscedes https://github.com/SciML/OrdinaryDiffEq.jl/pull/3587 . Also fixes NonlinearSolve.jl master and superscedes https://github.com/SciML/NonlinearSolve.jl/pull/932 Superscedes https://github.com/JuliaDiff/ForwardDiff.jl/pull/724 and is a better solution to https://github.com/JuliaDiff/ForwardDiff.jl/issues/714. The crux of the issue is that ForwardDiff.jl's tagging system is somewhat designed around the tag only being used once, i.e. the function is created, the derivative function is called, the tag is set for that derivative as a type of the function being differentiated, and therefore it's unique. Then this ends up working with nested differentiation because you call the inner function first, usually, before the outer function, or only do the combination, and so the tag ordering is set correctly. Mixing tagging with precompilation then leads to this issue where it's possible for the outer tag to be precompiled before the inner tag. This makes the tag ordering the opposite, and what happens is then that the type promotion mechanism gets confused because it is tied to the tag ordering. This seems pretty fundamental because it's a useful property, it's the core property used to prevent perturbation confusion, but it means that this interaction between nested differentiation and precompilation ends up having odd bugs. I tried working around this downstream (https://github.com/SciML/OrdinaryDiffEq.jl/pull/3587) but it was very nasty. Basically, you had to make sure you didn't have dual numbers automatically converting Float64s, as then sometimes it could convert to the inner type instead of the outer type, and it wouldn't do the normal conversion of first to the inner to then the wrapped outer type because doing so required the outer type to postdate the inner type. But, this really then showcases that the bug truly only manifests with nested types. And if you have nested types, you know you don't have perturbation confusion if one tag is nested deeper than the other tag, because there are not the same number of partials. So in the case where the tag depths are not the same, you can do an alternative tag ordering since you will have already proven perturbations aren't confusing. And in that case, you can choose the deeper nested tags to just always be `<` the less deeper tags. So added that and poof, tag nesting worked out in these cases with precompilation. So I think this captures the true crux of the problem and solves it at its core. --- src/config.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/config.jl b/src/config.jl index 3c6c97e3..154806d4 100644 --- a/src/config.jl +++ b/src/config.jl @@ -20,10 +20,16 @@ end Tag(::Nothing, ::Type{V}) where {V} = nothing +@inline tagdepth(::Type) = 0 +@inline tagdepth(::Type{<:Dual{T,V,N}}) where {T,V,N} = 1 + tagdepth(V) +@inline tagdepth(::Type{<:Tag{F,V}}) where {F,V} = 1 + tagdepth(V) @inline function ≺(::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2} - tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) -end + d1 = tagdepth(Tag{F1,V1}) + d2 = tagdepth(Tag{F2,V2}) + d1 != d2 && return d1 < d2 + return tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) + end struct InvalidTagException{E,O} <: Exception end