Skip to content
Draft
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: false # TODO: toggle
fail-fast: true # TODO: toggle
matrix:
version:
- '1.10'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ChainRulesCore:
RuleConfig,
frule_via_ad,
rrule_via_ad,
unthunk
unthunk,
@not_implemented
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
function ChainRulesCore.rrule(
dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C}
) where {C}
(; f, backend, context_wrappers) = dw
y = f(x, contexts...)
wrapped_contexts = map(DI.call, context_wrappers, contexts)
prep_same = DI.prepare_pullback_same_point_nokwarg(
Val(false), f, backend, x, (y,), wrapped_contexts...
)
function diffwith_pullbackfunc(dy)
dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only
dc = map(contexts) do c
@not_implemented(
"""
Derivatives with respect to context arguments are not implemented.
"""
)
end
return (NoTangent(), dx, dc...)
end
return y, pullbackfunc
return y, diffwith_pullbackfunc
end
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end

function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Mooncake:
value_and_pullback!!,
zero_dual,
zero_tangent,
zero_rdata,
rdata_type,
fdata,
rdata,
Expand All @@ -26,11 +27,13 @@ using Mooncake:
@is_primitive,
zero_fcodual,
MinimalCtx,
NoFData,
NoRData,
primal,
_copy_output,
_copy_to_output!!,
tangent_to_primal!!
tangent_to_primal!!,
increment!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any}
const NumberOrArray = Union{Number, AbstractArray{<:Number}}

# Mark DifferentiateWith with a range of context arities as primitives.
# For C contexts, the corresponding call tuple type is
# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}:
# one slot for the primal input x and C slots for contexts.
for C in 0:16
@eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}}
end
struct MooncakeDifferentiateWithError <: Exception
F::Type
X::Type
Expand All @@ -12,72 +19,87 @@ end
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
return print(
io,
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
"MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.",
)
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:Number},
contexts::Vararg{CoDual, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
end

function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
)
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:AbstractArray{<:Number}},
contexts::Vararg{CoDual, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), dy, rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), NoRData(), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,20 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
return zero_tangent(x)
end
end

nanify(x::AbstractFloat) = convert(typeof(x), NaN)
nanify(x::AbstractArray) = map(nanify, x)
nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x)))
nanify(x::Tuple) = map(nanify, x)
nanify(::NoFData) = NoFData()
nanify(::NoRData) = NoRData()

function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
primal_contexts = map(primal, contexts)
fdata_contexts = map(tangent, contexts)
zero_rdata_contexts = map(zero_rdata, primal_contexts)
foreach(fdata_contexts) do fc
increment!!(fc, nanify(fc))
end
return map(nanify, zero_rdata_contexts)
end
44 changes: 38 additions & 6 deletions DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

!!! warning

`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
`DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).

Expand All @@ -25,16 +25,17 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

# Fields

- `f`: the function in question, with signature `f(x)`
- `f`: the function in question, with signature `f(x, contexts...)`
- `backend::AbstractADType`: the substitute backend to use for differentiation
- `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts.

!!! note

For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.

# Constructor

DifferentiateWith(f, backend)
DifferentiateWith(f, backend, context_wrappers)

# Example

Expand Down Expand Up @@ -69,22 +70,53 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1]
70.0
```
"""
struct DifferentiateWith{F, B <: AbstractADType}
struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
f::F
backend::B
context_wrappers::N

function DifferentiateWith(
f::F,
backend::B,
context_wrappers::NTuple{C, Any},
) where {F, B <: AbstractADType, C}
for (i, wrapper) in pairs(context_wrappers)
# Accept typical constructor-like values: functions or types.
if !(wrapper isa Function || wrapper isa Type)
throw(
ArgumentError(
"Each context wrapper must be a callable object or type " *
"(e.g., a wrapper constructor like `Constant` or `Cache`), " *
"but element $i has type $(typeof(wrapper)).",
),
)
end
end
return new{C, F, B, typeof(context_wrappers)}(
f,
backend,
context_wrappers,
)
end
end

(dw::DifferentiateWith)(x) = dw.f(x)
DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ())

function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}
return dw.f(x, args...)
end

function Base.show(io::IO, dw::DifferentiateWith)
(; f, backend) = dw
(; f, backend, context_wrappers) = dw
return print(
io,
DifferentiateWith,
"(",
repr(f; context = io),
", ",
repr(backend; context = io),
", ",
repr(context_wrappers; context = io),
")",
)
end
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,5 @@ Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are
"""
@inline fix_tail(f::F) where {F} = f
fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...)

@inline call(f::F, x) where {F} = f(x)
36 changes: 30 additions & 6 deletions DifferentiationInterface/test/Back/DifferentiateWith/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ function (adb::ADBreaker)(x::AbstractArray)
return adb.f(x)
end

function differentiatewith_scenarios()
outofplace_scens = filter(DIT.default_scenarios()) do scen
DIT.function_place(scen) == :out
# TODO: break Mooncake with overlay?

function differentiatewith_scenarios(; kwargs...)
outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen
DIT.function_place(scen) == :out &&
# save some time
!isa(scen.x, AbstractMatrix) &&
!isa(scen.y, AbstractMatrix)
end
# with bad_scens, everything would break
bad_scens = map(outofplace_scens) do scen
Expand All @@ -44,7 +49,26 @@ test_differentiation(
differentiatewith_scenarios();
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests",
testset_name = "DI tests - normal",
)

test_differentiation(
[AutoZygote(), AutoMooncake(; config = nothing)],
map(DIT.constantify, differentiatewith_scenarios());
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Constant",
)

test_differentiation(
[AutoMooncake(; config = nothing)],
map(differentiatewith_scenarios()) do s
s = DIT.cachify(s; use_tuples = true)
DIT.change_function(s, DifferentiateWith(s.f, AutoFiniteDiff(), (Cache,)))
end;
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Cache",
)

@testset "ChainRules tests" begin
Expand All @@ -69,9 +93,9 @@ end;
MooncakeDifferentiateWithError =
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError

e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0)
@test sprint(showerror, e) ==
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
"MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported."

f_num2tup(x::Number) = (x,)
f_vec2tup(x::Vector) = (first(x),)
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/run_backend.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test
group = ENV["JULIA_DI_TEST_GROUP"]
@testset "$group" begin
@testset verbose = true "$group" begin
include(joinpath(@__DIR__, group, "test.jl"))
end
Loading
Loading