diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 7042514d0..3015da956 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - '1.10' diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index bacb7baa4..46ea2487f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -9,7 +9,8 @@ using ChainRulesCore: RuleConfig, frule_via_ad, rrule_via_ad, - unthunk + unthunk, + @not_implemented import DifferentiationInterface as DI ruleconfig(backend::AutoChainRules) = backend.ruleconfig diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 292372b81..9a14a7625 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,10 +1,22 @@ -function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) - (; f, backend) = dw - y = f(x) - prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,)) - function pullbackfunc(dy) - tx = DI.pullback(f, prep_same, backend, x, (dy,)) - return (NoTangent(), only(tx)) +function ChainRulesCore.rrule( + dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C} + ) where {C} + (; f, backend, context_wrappers) = dw + y = f(x, contexts...) + wrapped_contexts = map(DI.call, context_wrappers, contexts) + prep_same = DI.prepare_pullback_same_point_nokwarg( + Val(false), f, backend, x, (y,), wrapped_contexts... + ) + function diffwith_pullbackfunc(dy) + dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only + dc = map(contexts) do c + @not_implemented( + """ + Derivatives with respect to context arguments are not implemented. + """ + ) + end + return (NoTangent(), dx, dc...) end - return y, pullbackfunc + return y, diffwith_pullbackfunc end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl index 96316f5b6..4d46bc613 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl @@ -1,4 +1,4 @@ -function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N} +function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) @@ -6,7 +6,7 @@ function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N} return make_dual(T, y, ty) end -function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N} +function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..626e3660c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -18,6 +18,7 @@ using Mooncake: value_and_pullback!!, zero_dual, zero_tangent, + zero_rdata, rdata_type, fdata, rdata, @@ -26,11 +27,13 @@ using Mooncake: @is_primitive, zero_fcodual, MinimalCtx, + NoFData, NoRData, primal, _copy_output, _copy_to_output!!, - tangent_to_primal!! + tangent_to_primal!!, + increment!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index ad2d9f7c7..9775a6952 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,5 +1,12 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any} +const NumberOrArray = Union{Number, AbstractArray{<:Number}} +# Mark DifferentiateWith with a range of context arities as primitives. +# For C contexts, the corresponding call tuple type is +# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}: +# one slot for the primal input x and C slots for contexts. +for C in 0:16 + @eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}} +end struct MooncakeDifferentiateWithError <: Exception F::Type X::Type @@ -12,28 +19,37 @@ end function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) return print( io, - "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", + "MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.", ) end -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) +function Mooncake.rrule!!( + dw::CoDual{<:DI.DifferentiateWith{C}}, + x::CoDual{<:Number}, + contexts::Vararg{CoDual, C} + ) where {C} + @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) primal_x = primal(x) - (; f, backend) = primal_func - y = zero_fcodual(f(primal_x)) + primal_contexts = map(primal, contexts) + (; f, backend, context_wrappers) = primal_func + y = zero_fcodual(f(primal_x, primal_contexts...)) + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (y.dx,)) - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) - return NoRData(), rdata(only(tx)) + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), rdata(dx), rc...) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) - return NoRData(), rdata(only(tx)) + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), rdata(dx), rc...) end pullback = if primal(y) isa Number @@ -41,35 +57,41 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number elseif primal(y) isa AbstractArray pullback_array!! else - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) end return y, pullback end function Mooncake.rrule!!( - dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} - ) + dw::CoDual{<:DI.DifferentiateWith{C}}, + x::CoDual{<:AbstractArray{<:Number}}, + contexts::Vararg{CoDual, C} + ) where {C} + @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) primal_x = primal(x) - fdata_arg = x.dx - (; f, backend) = primal_func - y = zero_fcodual(f(primal_x)) + primal_contexts = map(primal, contexts) + (; f, backend, context_wrappers) = primal_func + y = zero_fcodual(f(primal_x, primal_contexts...)) + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (y.dx,)) - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) - fdata_arg .+= only(tx) - return NoRData(), dy + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + x.dx .+= dx + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), dy, rc...) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) - fdata_arg .+= only(tx) - return NoRData(), NoRData() + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + x.dx .+= dx + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), NoRData(), rc...) end pullback = if primal(y) isa Number @@ -77,7 +99,7 @@ function Mooncake.rrule!!( elseif primal(y) isa AbstractArray pullback_array!! else - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) end return y, pullback diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..edda7bdb2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -17,3 +17,20 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) return zero_tangent(x) end end + +nanify(x::AbstractFloat) = convert(typeof(x), NaN) +nanify(x::AbstractArray) = map(nanify, x) +nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x))) +nanify(x::Tuple) = map(nanify, x) +nanify(::NoFData) = NoFData() +nanify(::NoRData) = NoRData() + +function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C} + primal_contexts = map(primal, contexts) + fdata_contexts = map(tangent, contexts) + zero_rdata_contexts = map(zero_rdata, primal_contexts) + foreach(fdata_contexts) do fc + increment!!(fc, nanify(fc)) + end + return map(nanify, zero_rdata_contexts) +end diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index f0c2ecf38..399be23b8 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -14,7 +14,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning - `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. + `DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code. It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). @@ -25,8 +25,9 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be # Fields - - `f`: the function in question, with signature `f(x)` + - `f`: the function in question, with signature `f(x, contexts...)` - `backend::AbstractADType`: the substitute backend to use for differentiation + - `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts. !!! note @@ -34,7 +35,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be # Constructor - DifferentiateWith(f, backend) + DifferentiateWith(f, backend, context_wrappers) # Example @@ -69,15 +70,44 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1] 70.0 ``` """ -struct DifferentiateWith{F, B <: AbstractADType} +struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}} f::F backend::B + context_wrappers::N + + function DifferentiateWith( + f::F, + backend::B, + context_wrappers::NTuple{C, Any}, + ) where {F, B <: AbstractADType, C} + for (i, wrapper) in pairs(context_wrappers) + # Accept typical constructor-like values: functions or types. + if !(wrapper isa Function || wrapper isa Type) + throw( + ArgumentError( + "Each context wrapper must be a callable object or type " * + "(e.g., a wrapper constructor like `Constant` or `Cache`), " * + "but element $i has type $(typeof(wrapper)).", + ), + ) + end + end + return new{C, F, B, typeof(context_wrappers)}( + f, + backend, + context_wrappers, + ) + end end -(dw::DifferentiateWith)(x) = dw.f(x) +DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ()) + +function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C} + return dw.f(x, args...) +end function Base.show(io::IO, dw::DifferentiateWith) - (; f, backend) = dw + (; f, backend, context_wrappers) = dw return print( io, DifferentiateWith, @@ -85,6 +115,8 @@ function Base.show(io::IO, dw::DifferentiateWith) repr(f; context = io), ", ", repr(backend; context = io), + ", ", + repr(context_wrappers; context = io), ")", ) end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 2d2575d01..3058a4c63 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -179,3 +179,5 @@ Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are """ @inline fix_tail(f::F) where {F} = f fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...) + +@inline call(f::F, x) where {F} = f(x) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 860d5e85f..b7676b4cd 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -24,9 +24,14 @@ function (adb::ADBreaker)(x::AbstractArray) return adb.f(x) end -function differentiatewith_scenarios() - outofplace_scens = filter(DIT.default_scenarios()) do scen - DIT.function_place(scen) == :out +# TODO: break Mooncake with overlay? + +function differentiatewith_scenarios(; kwargs...) + outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen + DIT.function_place(scen) == :out && + # save some time + !isa(scen.x, AbstractMatrix) && + !isa(scen.y, AbstractMatrix) end # with bad_scens, everything would break bad_scens = map(outofplace_scens) do scen @@ -44,7 +49,26 @@ test_differentiation( differentiatewith_scenarios(); excluded = SECOND_ORDER, logging = LOGGING, - testset_name = "DI tests", + testset_name = "DI tests - normal", +) + +test_differentiation( + [AutoZygote(), AutoMooncake(; config = nothing)], + map(DIT.constantify, differentiatewith_scenarios()); + excluded = SECOND_ORDER, + logging = LOGGING, + testset_name = "DI tests - Constant", +) + +test_differentiation( + [AutoMooncake(; config = nothing)], + map(differentiatewith_scenarios()) do s + s = DIT.cachify(s; use_tuples = true) + DIT.change_function(s, DifferentiateWith(s.f, AutoFiniteDiff(), (Cache,))) + end; + excluded = SECOND_ORDER, + logging = LOGGING, + testset_name = "DI tests - Cache", ) @testset "ChainRules tests" begin @@ -69,9 +93,9 @@ end; MooncakeDifferentiateWithError = Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError - e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) + e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0) @test sprint(showerror, e) == - "MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." + "MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported." f_num2tup(x::Number) = (x,) f_vec2tup(x::Vector) = (first(x),) diff --git a/DifferentiationInterface/test/Back/run_backend.jl b/DifferentiationInterface/test/Back/run_backend.jl index 637feb4c4..841019fdb 100644 --- a/DifferentiationInterface/test/Back/run_backend.jl +++ b/DifferentiationInterface/test/Back/run_backend.jl @@ -1,5 +1,5 @@ using Test group = ENV["JULIA_DI_TEST_GROUP"] -@testset "$group" begin +@testset verbose = true "$group" begin include(joinpath(@__DIR__, group, "test.jl")) end diff --git a/DifferentiationInterface/test/Core/Internals/display.jl b/DifferentiationInterface/test/Core/Internals/display.jl index 316fa6921..f1200fcfe 100644 --- a/DifferentiationInterface/test/Core/Internals/display.jl +++ b/DifferentiationInterface/test/Core/Internals/display.jl @@ -11,7 +11,7 @@ detector = DenseSparsityDetector(AutoForwardDiff(); atol = 1.0e-23) "DenseSparsityDetector(AutoForwardDiff(); atol=1.0e-23, method=:iterative)" diffwith = DifferentiateWith(exp, AutoForwardDiff()) -@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff())" +@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff(), ())" @test required_packages(AutoForwardDiff()) == ["ForwardDiff"] @test required_packages(AutoZygote()) == ["Zygote"]