Skip to content
100 changes: 55 additions & 45 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,6 @@ using LinearAlgebra
# other provided input variable.
#--------------------------------------------

# this rule is necessary for now as without it,
# a segfault occurs both on 1.10 and 1.12 -- likely
# a deeper internal bug
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(copy_input)},
::Type{RT},
f::Annotation,
A::Annotation
) where {RT}
ret = func.val(f.val, A.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, shadow)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(copy_input)},
::Type{RT},
cache,
f::Annotation,
A::Annotation
) where {RT}
copy_shadow = cache
if !isa(A, Const) && !isnothing(copy_shadow)
A.dval .+= copy_shadow
end
return (nothing, nothing)
end

# two-argument factorizations like LQ, QR, EIG
for (f, pb) in (
(qr_full!, qr_pullback!),
Expand Down Expand Up @@ -101,9 +70,14 @@ for (f, pb) in (
# if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed
# if arg isa Const, ret may still be modified further down the call graph so we should
# copy it to protect ourselves
cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
dret = if EnzymeRules.needs_shadow(config)
(TA == Nothing && TB == Nothing) || isa(arg, Const) ? zero.(ret) : arg.dval
A_is_arg1 = !isa(A, Const) && A.val === arg.val[1]
A_is_arg2 = !isa(A, Const) && A.val === arg.val[2]
A_is_arg = A_is_arg1 || A_is_arg2
cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing
dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const))
make_zero.(ret)
elseif EnzymeRules.needs_shadow(config)
arg.dval
else
nothing
end
Expand All @@ -125,11 +99,23 @@ for (f, pb) in (
# use A (so that whoever does this is forced to handle caching A
# appropriately here)
Aval = nothing
A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1]
A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2]
A_is_arg = A_is_arg1 || A_is_arg2
argval = something(cache_arg, arg.val)
if !isa(A, Const)
$pb(A.dval, Aval, argval, darg)
if A_is_arg
ΔA = make_zero(A.dval)
$pb(ΔA, Aval, argval, darg)
A.dval .= ΔA
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A.dval .= ΔA
A.dval .+= ΔA

I'm guessing this has to be accumulating, but I'm not entirely sure I follow the logic of this branch, why can't we immediately in-place accumulate?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do it inplace because you'll modify dA which is the same as darg, I think. I think we should not do a + because A is overwritten in the primal step (but this is confusing, I'm open to being convinced otherwise)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm following now.
Would it help if instead of allocating a new dA and then doing the copy afterwards, we instead copy the darg if we need to?
Also, does that mean that we could in principle be checking that dA is zero, since nothing is allowed to depend on the value of A after the decomposition?
Or would that be too strict of a check in the case where A is secretly not overwritten by the specific implementation of the decomposition.

else
$pb(A.dval, Aval, argval, darg)
end
end
if !isa(arg, Const)
A.dval === arg.dval[1] || make_zero!(arg.dval[1])
A.dval === arg.dval[2] || make_zero!(arg.dval[2])
end
!isa(arg, Const) && make_zero!(arg.dval)
return (nothing, nothing, nothing)
end
end
Expand Down Expand Up @@ -356,7 +342,13 @@ for (f, trunc_f, full_f, pb) in (
if !isa(A, Const)
$pb(A.dval, Aval, DVval, dDVtrunc, ind)
end
!isa(DV, Const) && make_zero!(DV.dval)
if !isa(DV, Const)
if !(A.dval === DV.dval[1])
make_zero!(DV.dval)
else
make_zero!(DV.dval[2])
end
end
return (nothing, nothing, nothing)
end
end
Expand Down Expand Up @@ -392,21 +384,30 @@ for (f!, f_full!, pb!) in (
func::Const{typeof($f!)},
::Type{RT},
cache,
A::Annotation,
A::Annotation{TA},
D::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT}
) where {RT, TA}
cache_D, dD, V = cache
Dval = something(cache_D, D.val)
# A is NOT used in the pullback, so we assign Aval = nothing
# to trigger an error in case the pullback is modified to directly
# use A (so that whoever does this is forced to handle caching A
# appropriately here)
Aval = nothing
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code-organization-wise, I am not a huge fan of this hardcoded special case for different types.
Why do you need the specific specialization here? Could it not just be A.dval === D.dval?

That being said, we should probably just add an implementation of isalias(A, arg).
Maybe something like Base.mightalias could be a more generic solution, or even a fallback definition, but that function is technically internal, and similar problems hold for checking Base.dataids.
I might actually be okay with depending on that though, that has been quite stable and does seem to me like a good way of going at this. ( and we already secretly use this in TensorOperations! )

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the _vals methods we can't do A.dval === D.dval because D is a Vector, and A is a Diagonal, whose diag field points to D

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you are right, my bad! I think then I would even more think an isalias(A, arg) = Base.mightalias(A, arg) is the right abstraction (I would leave in the hook since Base.mightalias is internal, and this allows us to overload without piracy)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds good to me. We could do it as part of this PR or separately?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of goofy special cases running around here haha

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to include it in this PR, as this does seem to be the primary change that is needed to solve the issue? Happy to defer it if you prefer though.

if !isa(A, Const)
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
if A_is_arg
ΔA = make_zero(A.dval)
$pb!(ΔA, Aval, (Diagonal(Dval), V), dD)
A.dval .= ΔA
else
$pb!(A.dval, Aval, (Diagonal(Dval), V), dD)
end
end
if !isa(D, Const) && !A_is_arg
make_zero!(D.dval)
end
!isa(D, Const) && make_zero!(D.dval)
return (nothing, nothing, nothing)
end
end
Expand Down Expand Up @@ -438,21 +439,30 @@ function EnzymeRules.reverse(
func::Const{typeof(svd_vals!)},
::Type{RT},
cache,
A::Annotation,
A::Annotation{TA},
S::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT}
) where {RT, TA}
cache_S, dS, U, Vᴴ = cache
# A is NOT used in the pullback, so we assign Aval = nothing
# to trigger an error in case the pullback is modified to directly
# use A (so that whoever does this is forced to handle caching A
# appropriately here)
Aval = nothing
Sval = something(cache_S, S.val)
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
if !isa(A, Const)
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
if A_is_arg
ΔA = make_zero(A.dval)
svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS)
A.dval .= ΔA
else
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
end
end
if !isa(S, Const) && !A_is_arg
make_zero!(S.dval)
end
!isa(S, Const) && make_zero!(S.dval)
return (nothing, nothing, nothing)
end

Expand Down
118 changes: 84 additions & 34 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@ using LinearAlgebra

Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
Ac_dAc = Mooncake.zero_fcodual(Ac)
dAc = Mooncake.tangent(Ac_dAc)
function copy_input_pb(::NoRData)
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
return NoRData(), NoRData(), NoRData()
end
return Ac_dAc, copy_input_pb
end

Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand All @@ -53,13 +40,33 @@ for (f!, f, pb, adj) in (
arg2c = copy(arg2)
$f!(A, args, Mooncake.primal(alg_dalg))
function $adj(::NoRData)
copy!(A, Ac)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
copy!(arg1, arg1c)
copy!(arg2, arg2c)
zero!(darg1)
zero!(darg2)
return NoRData(), NoRData(), NoRData(), NoRData()
# DON'T copy Ac to A if A === one
# of the output args -- this can
# mess up the pullback because
# generally the args are used there
if !(A === arg1 || A === arg2)
copy!(A, Ac)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
else
ΔA = zero(A)
$pb(ΔA, A, (arg1, arg2), (darg1, darg2))
dA .= ΔA
end
Comment on lines +47 to +54
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if !(A === arg1 || A === arg2)
copy!(A, Ac)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
else
ΔA = zero(A)
$pb(ΔA, A, (arg1, arg2), (darg1, darg2))
dA .= ΔA
end
# compute pullback of A and restore input
# order is important - A might alias `args`!
$pb(dA, Ac, (arg1, arg2), (darg1, darg2))
copy!(A, Ac)
# compute pullbacks and restore other inputs if they don't alias
if A !== arg1
copy!(arg1, arg1c)
zero!(darg1)
end
if A !== arg2
copy!(arg2, arg2c)
zero!(darg2)
end

Again mostly a readability suggestion. (Note that I might again be missing something about the use of dA vs Delta A here)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above about Delta A

if A === arg1
copy!(A, Ac)
zero!(darg2)
copy!(arg2, arg2c)
elseif A === arg2
copy!(A, Ac)
zero!(darg1)
copy!(arg1, arg1c)
else
zero!(darg1)
zero!(darg2)
copy!(arg2, arg2c)
copy!(arg1, arg1c)
end
return ntuple(Returns(NoRData()), 4)
end
return args_dargs, $adj
end
Expand Down Expand Up @@ -140,9 +147,19 @@ for (f!, f, f_full, pb, adj) in (
copy!(D, diagview(DV[1]))
V = DV[2]
function $adj(::NoRData)
$pb(dA, A, DV, dD)
copy!(D, Dc)
zero!(dD)
if A !== D
$pb(dA, A, DV, dD)
else
ΔA = zero(A)
$pb(ΔA, A, DV, dD)
dA .= A
end
if A !== D
zero!(dD)
copy!(D, Dc)
else
copy!(A, Ac)
end
return NoRData(), NoRData(), NoRData(), NoRData()
end
return D_dD, $adj
Expand Down Expand Up @@ -199,15 +216,27 @@ for f in (:eig, :eigh)
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = Mooncake.zero_fcodual(output)
function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real})
copy!(A, Ac)
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
_warn_pullback_truncerror(dy[3])
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
$f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′))
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
D, dD = arrayify(DV[1], dDV[1])
V, dV = arrayify(DV[2], dDV[2])
copy!(A, Ac)
if !(A === D || A === V)
$f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′))
else
ΔA = zero(A)
$f_trunc_pullback!(ΔA, A, (D′, V′), (dD′, dV′))
dA .= ΔA
end
if A === D
copy!(DV[2], DVc[2])
else
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
end
zero!(dD′)
zero!(dV′)
return NoRData(), NoRData(), NoRData(), NoRData()
Expand Down Expand Up @@ -239,12 +268,22 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

if !(A === DV[1] || A === DV[2])
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
else
ΔA = zero(A)
$f_pullback!(ΔA, Ac, DV, dDVtrunc, ind)
dA .= ΔA
end
# restore state
copy!(A, Ac)
copy!.(DV, DVc)
if A === DV[1]
copy!(DV[2], DVc[2])
zero!(dDV[2])
else
copy!.(DV, DVc)
zero!.(dDV)
end

return ntuple(Returns(NoRData()), 4)
end
Expand Down Expand Up @@ -351,12 +390,23 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)
if !(A === DV[1] || A === DV[2])
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
else
ΔA = zero(A)
$f_pullback!(ΔA, Ac, DV, dDVtrunc, ind)
dA .= ΔA
end

# restore state
copy!(A, Ac)
copy!.(DV, DVc)
if A === DV[1]
copy!(DV[2], DVc[2])
zero!(dDV[2])
else
copy!.(DV, DVc)
zero!.(dDV)
end

return ntuple(Returns(NoRData()), 4)
end
Expand Down
2 changes: 2 additions & 0 deletions test/enzyme/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
4 changes: 3 additions & 1 deletion test/enzyme/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
#TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
2 changes: 2 additions & 0 deletions test/enzyme/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_lq(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
2 changes: 2 additions & 0 deletions test/enzyme/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_orthnull(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
2 changes: 2 additions & 0 deletions test/enzyme/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
#m == n && TestSuite.test_enzyme_polar(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
2 changes: 2 additions & 0 deletions test/enzyme/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_qr(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
2 changes: 2 additions & 0 deletions test/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
Loading
Loading