-
Notifications
You must be signed in to change notification settings - Fork 6
Add Mooncake and Enzyme tests for Diagonal #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f96b113
44cd02a
439a801
256642b
517706a
bad6f5b
cb3a53c
1b16db5
be263e5
00f0726
e451036
e99090f
b878a2f
785f9fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,37 +43,6 @@ using LinearAlgebra | |
| # other provided input variable. | ||
| #-------------------------------------------- | ||
|
|
||
| # this rule is necessary for now as without it, | ||
| # a segfault occurs both on 1.10 and 1.12 -- likely | ||
| # a deeper internal bug | ||
| function EnzymeRules.augmented_primal( | ||
| config::EnzymeRules.RevConfigWidth{1}, | ||
| func::Const{typeof(copy_input)}, | ||
| ::Type{RT}, | ||
| f::Annotation, | ||
| A::Annotation | ||
| ) where {RT} | ||
| ret = func.val(f.val, A.val) | ||
| primal = EnzymeRules.needs_primal(config) ? ret : nothing | ||
| shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing | ||
| return EnzymeRules.AugmentedReturn(primal, shadow, shadow) | ||
| end | ||
|
|
||
| function EnzymeRules.reverse( | ||
| config::EnzymeRules.RevConfigWidth{1}, | ||
| func::Const{typeof(copy_input)}, | ||
| ::Type{RT}, | ||
| cache, | ||
| f::Annotation, | ||
| A::Annotation | ||
| ) where {RT} | ||
| copy_shadow = cache | ||
| if !isa(A, Const) && !isnothing(copy_shadow) | ||
| A.dval .+= copy_shadow | ||
| end | ||
| return (nothing, nothing) | ||
| end | ||
|
|
||
| # two-argument factorizations like LQ, QR, EIG | ||
| for (f, pb) in ( | ||
| (qr_full!, qr_pullback!), | ||
|
|
@@ -101,9 +70,14 @@ for (f, pb) in ( | |
| # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed | ||
| # if arg isa Const, ret may still be modified further down the call graph so we should | ||
| # copy it to protect ourselves | ||
| cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing | ||
| dret = if EnzymeRules.needs_shadow(config) | ||
| (TA == Nothing && TB == Nothing) || isa(arg, Const) ? zero.(ret) : arg.dval | ||
| A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] | ||
| A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] | ||
| A_is_arg = A_is_arg1 || A_is_arg2 | ||
| cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing | ||
| dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) | ||
| make_zero.(ret) | ||
| elseif EnzymeRules.needs_shadow(config) | ||
| arg.dval | ||
| else | ||
| nothing | ||
| end | ||
|
|
@@ -125,11 +99,23 @@ for (f, pb) in ( | |
| # use A (so that whoever does this is forced to handle caching A | ||
| # appropriately here) | ||
| Aval = nothing | ||
| A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1] | ||
| A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2] | ||
| A_is_arg = A_is_arg1 || A_is_arg2 | ||
| argval = something(cache_arg, arg.val) | ||
| if !isa(A, Const) | ||
| $pb(A.dval, Aval, argval, darg) | ||
| if A_is_arg | ||
| ΔA = make_zero(A.dval) | ||
| $pb(ΔA, Aval, argval, darg) | ||
| A.dval .= ΔA | ||
| else | ||
| $pb(A.dval, Aval, argval, darg) | ||
| end | ||
| end | ||
| if !isa(arg, Const) | ||
| A.dval === arg.dval[1] || make_zero!(arg.dval[1]) | ||
| A.dval === arg.dval[2] || make_zero!(arg.dval[2]) | ||
| end | ||
| !isa(arg, Const) && make_zero!(arg.dval) | ||
| return (nothing, nothing, nothing) | ||
| end | ||
| end | ||
|
|
@@ -356,7 +342,13 @@ for (f, trunc_f, full_f, pb) in ( | |
| if !isa(A, Const) | ||
| $pb(A.dval, Aval, DVval, dDVtrunc, ind) | ||
| end | ||
| !isa(DV, Const) && make_zero!(DV.dval) | ||
| if !isa(DV, Const) | ||
| if !(A.dval === DV.dval[1]) | ||
| make_zero!(DV.dval) | ||
| else | ||
| make_zero!(DV.dval[2]) | ||
| end | ||
| end | ||
| return (nothing, nothing, nothing) | ||
| end | ||
| end | ||
|
|
@@ -392,21 +384,30 @@ for (f!, f_full!, pb!) in ( | |
| func::Const{typeof($f!)}, | ||
| ::Type{RT}, | ||
| cache, | ||
| A::Annotation, | ||
| A::Annotation{TA}, | ||
| D::Annotation, | ||
| alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, | ||
| ) where {RT} | ||
| ) where {RT, TA} | ||
| cache_D, dD, V = cache | ||
| Dval = something(cache_D, D.val) | ||
| # A is NOT used in the pullback, so we assign Aval = nothing | ||
| # to trigger an error in case the pullback is modified to directly | ||
| # use A (so that whoever does this is forced to handle caching A | ||
| # appropriately here) | ||
| Aval = nothing | ||
| A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code-organization-wise, I am not a huge fan of this hardcoded special case for different types. That being said, we should probably just add an implementation of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah you are right, my bad! I think then I would even more think an
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that sounds good to me. We could do it as part of this PR or separately?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a lot of goofy special cases running around here haha
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be reasonable to include it in this PR, as this does seem to be the primary change that is needed to solve the issue? Happy to defer it if you prefer though. |
||
| if !isa(A, Const) | ||
| $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) | ||
| if A_is_arg | ||
| ΔA = make_zero(A.dval) | ||
| $pb!(ΔA, Aval, (Diagonal(Dval), V), dD) | ||
| A.dval .= ΔA | ||
| else | ||
| $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) | ||
| end | ||
| end | ||
| if !isa(D, Const) && !A_is_arg | ||
| make_zero!(D.dval) | ||
| end | ||
| !isa(D, Const) && make_zero!(D.dval) | ||
| return (nothing, nothing, nothing) | ||
| end | ||
| end | ||
|
|
@@ -438,21 +439,30 @@ function EnzymeRules.reverse( | |
| func::Const{typeof(svd_vals!)}, | ||
| ::Type{RT}, | ||
| cache, | ||
| A::Annotation, | ||
| A::Annotation{TA}, | ||
| S::Annotation, | ||
| alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, | ||
| ) where {RT} | ||
| ) where {RT, TA} | ||
| cache_S, dS, U, Vᴴ = cache | ||
| # A is NOT used in the pullback, so we assign Aval = nothing | ||
| # to trigger an error in case the pullback is modified to directly | ||
| # use A (so that whoever does this is forced to handle caching A | ||
| # appropriately here) | ||
| Aval = nothing | ||
| Sval = something(cache_S, S.val) | ||
| A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval | ||
| if !isa(A, Const) | ||
| svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) | ||
| if A_is_arg | ||
| ΔA = make_zero(A.dval) | ||
| svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS) | ||
| A.dval .= ΔA | ||
| else | ||
| svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) | ||
| end | ||
| end | ||
| if !isa(S, Const) && !A_is_arg | ||
| make_zero!(S.dval) | ||
| end | ||
| !isa(S, Const) && make_zero!(S.dval) | ||
| return (nothing, nothing, nothing) | ||
| end | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,19 +15,6 @@ using LinearAlgebra | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} | ||||||||||||||||||||||||||||||||||||||||||||||||
| function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Ac_dAc = Mooncake.zero_fcodual(Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dAc = Mooncake.tangent(Ac_dAc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function copy_input_pb(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return Ac_dAc, copy_input_pb | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any} | ||||||||||||||||||||||||||||||||||||||||||||||||
| # two-argument in-place factorizations like LQ, QR, EIG | ||||||||||||||||||||||||||||||||||||||||||||||||
| for (f!, f, pb, adj) in ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -53,13 +40,33 @@ for (f!, f, pb, adj) in ( | |||||||||||||||||||||||||||||||||||||||||||||||
| arg2c = copy(arg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f!(A, args, Mooncake.primal(alg_dalg)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $adj(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg1, arg1c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg2, arg2c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| # DON'T copy Ac to A if A === one | ||||||||||||||||||||||||||||||||||||||||||||||||
| # of the output args -- this can | ||||||||||||||||||||||||||||||||||||||||||||||||
| # mess up the pullback because | ||||||||||||||||||||||||||||||||||||||||||||||||
| # generally the args are used there | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === arg1 || A === arg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(ΔA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+47
to
+54
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Again mostly a readability suggestion. (Note that I might again be missing something about the use of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment above about |
||||||||||||||||||||||||||||||||||||||||||||||||
| if A === arg1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg2, arg2c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| elseif A === arg2 | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg1, arg1c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg2, arg2c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg1, arg1c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return args_dargs, $adj | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -140,9 +147,19 @@ for (f!, f, f_full, pb, adj) in ( | |||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, diagview(DV[1])) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V = DV[2] | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $adj(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, Dc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A !== D | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(ΔA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= A | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A !== D | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, Dc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return D_dD, $adj | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -199,15 +216,27 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| # not for nested structs with various fields (like Diagonal{Complex}) | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_codual = Mooncake.zero_fcodual(output) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| _warn_pullback_truncerror(dy[3]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| D′, dD′ = arrayify(Dtrunc, dDtrunc_) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V′, dV′ = arrayify(Vtrunc, dVtrunc_) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[1], DVc[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| D, dD = arrayify(DV[1], dDV[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V, dV = arrayify(DV[2], dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === D || A === V) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(ΔA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === D | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[1], DVc[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD′) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dV′) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -239,12 +268,22 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| _warn_pullback_truncerror(dϵ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # compute pullbacks | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDVtrunc) # since this is allocated in this function this is probably not required | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === DV[1] || A === DV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| # restore state | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === DV[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -351,12 +390,23 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $f_adjoint!(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # compute pullbacks | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === DV[1] || A === DV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # restore state | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === DV[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this has to be accumulating, but I'm not entirely sure I follow the logic of this branch, why can't we immediately in-place accumulate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't do it inplace because you'll modify dA which is the same as darg, I think. I think we should not do a + because A is overwritten in the primal step (but this is confusing, I'm open to being convinced otherwise)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm following now.
Would it help if instead of allocating a new dA and then doing the copy afterwards, we instead copy the
dargif we need to?Also, does that mean that we could in principle be checking that
dAis zero, since nothing is allowed to depend on the value ofAafter the decomposition?Or would that be too strict of a check in the case where
Ais secretly not overwritten by the specific implementation of the decomposition.