From f96b113283badba6b5db94ae414b274a42e5bc44 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Feb 2026 11:24:00 +0100 Subject: [PATCH 01/14] Add Mooncake tests for Diagonal --- test/mooncake/eig.jl | 2 ++ test/mooncake/eigh.jl | 2 ++ test/mooncake/lq.jl | 4 ++++ test/mooncake/orthnull.jl | 4 ++++ test/mooncake/polar.jl | 5 +++++ test/mooncake/qr.jl | 4 ++++ test/mooncake/svd.jl | 4 ++++ 7 files changed, 25 insertions(+) diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index b313f9b2f..213ceca68 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eig(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index 800dbaa05..e2e52ffe0 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_eigh(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 0f05f85ab..9ffdc730d 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 09e3a28cc..370454b55 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 1faf3c104..d6c089098 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -17,5 +17,10 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 17415e8df..ed4db1925 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index d2d40df42..f096fdb8e 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -15,5 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end From 44cd02a49840dd11438e863105d8a6e9b4cd056d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 23 Mar 2026 12:44:37 +0100 Subject: [PATCH 02/14] Fix stupid typos --- test/mooncake/eig.jl | 2 +- test/mooncake/eigh.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index 213ceca68..a87319082 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_eig(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index e2e52ffe0..e39f68316 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_eigh(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end From 439a8019cd402b7e126f42124c9ac32878e3258f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 11:26:16 +0100 Subject: [PATCH 03/14] Diagonal + eig(h) working --- .../MatrixAlgebraKitEnzymeExt.jl | 89 ++++++++++--------- test/enzyme/eig.jl | 2 + test/enzyme/eigh.jl | 4 +- 3 files changed, 50 insertions(+), 45 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 85072c949..55dc59474 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -43,37 +43,6 @@ using LinearAlgebra # other provided input variable. #-------------------------------------------- -# this rule is necessary for now as without it, -# a segfault occurs both on 1.10 and 1.12 -- likely -# a deeper internal bug -function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(copy_input)}, - ::Type{RT}, - f::Annotation, - A::Annotation - ) where {RT} - ret = func.val(f.val, A.val) - primal = EnzymeRules.needs_primal(config) ? ret : nothing - shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, shadow) -end - -function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(copy_input)}, - ::Type{RT}, - cache, - f::Annotation, - A::Annotation - ) where {RT} - copy_shadow = cache - if !isa(A, Const) && !isnothing(copy_shadow) - A.dval .+= copy_shadow - end - return (nothing, nothing) -end - # two-argument factorizations like LQ, QR, EIG for (f, pb) in ( (qr_full!, qr_pullback!), @@ -101,9 +70,11 @@ for (f, pb) in ( # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed # if arg isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves - cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing - dret = if EnzymeRules.needs_shadow(config) - (TA == Nothing && TB == Nothing) || isa(arg, Const) ? zero.(ret) : arg.dval + cache_arg = (arg.val !== ret && arg.val[1] !== A.val) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) + make_zero.(ret) + elseif EnzymeRules.needs_shadow(config) + arg.dval else nothing end @@ -125,11 +96,24 @@ for (f, pb) in ( # use A (so that whoever does this is forced to handle caching A # appropriately here) Aval = nothing + A_is_arg = !isa(A, Const) && A.dval === arg.dval[1] argval = something(cache_arg, arg.val) if !isa(A, Const) - $pb(A.dval, Aval, argval, darg) + if A_is_arg + ΔA = make_zero(A.dval) + $pb(ΔA, Aval, argval, darg) + A.dval .= ΔA + else + $pb(A.dval, Aval, argval, darg) + end + end + if !isa(arg, Const) + if !A_is_arg + make_zero!(arg.dval) + else + make_zero!(arg.dval[2]) + end end - !isa(arg, Const) && make_zero!(arg.dval) return (nothing, nothing, nothing) end end @@ -356,7 +340,13 @@ for (f, trunc_f, full_f, pb) in ( if !isa(A, Const) $pb(A.dval, Aval, DVval, dDVtrunc, ind) end - !isa(DV, Const) && make_zero!(DV.dval) + if !isa(DV, Const) + if !(A.dval === DV.dval[1]) + make_zero!(DV.dval) + else + make_zero!(DV.dval[2]) + end + end return (nothing, nothing, nothing) end end @@ -392,10 +382,10 @@ for (f!, f_full!, pb!) in ( func::Const{typeof($f!)}, ::Type{RT}, cache, - A::Annotation, + A::Annotation{TA}, D::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TA} cache_D, dD, V = cache Dval = something(cache_D, D.val) # A is NOT used in the pullback, so we assign Aval = nothing @@ -403,10 +393,19 @@ for (f!, f_full!, pb!) in ( # use A (so that whoever does this is forced to handle caching A # appropriately here) Aval = nothing + A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval if !isa(A, Const) - $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) + if A_is_arg + ΔA = make_zero(A.dval) + $pb!(ΔA, Aval, (Diagonal(Dval), V), dD) + A.dval .= ΔA + else + $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) + end + end + if !isa(D, Const) && !A_is_arg + make_zero!(D.dval) end - !isa(D, Const) && make_zero!(D.dval) return (nothing, nothing, nothing) end end @@ -438,10 +437,10 @@ function EnzymeRules.reverse( func::Const{typeof(svd_vals!)}, ::Type{RT}, cache, - A::Annotation, + A::Annotation{TA}, S::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TA} cache_S, dS, U, Vᴴ = cache # A is NOT used in the pullback, so we assign Aval = nothing # to trigger an error in case the pullback is modified to directly @@ -452,7 +451,9 @@ function EnzymeRules.reverse( if !isa(A, Const) svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) end - !isa(S, Const) && make_zero!(S.dval) + if !isa(S, Const) && !(TA <: Diagonal && (diagview(A.dval) === S.dval)) + make_zero!(S.dval) + end return (nothing, nothing, nothing) end diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 898d773a8..949129eac 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index d32db3dd5..64c796fc6 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -14,6 +14,8 @@ m = 19 for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(1234) if !is_buildkite - TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + #TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end From 256642b00381065110a540a1b5d9c6a0827b5f94 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 12:55:49 +0100 Subject: [PATCH 04/14] SVD working for diag --- .../MatrixAlgebraKitEnzymeExt.jl | 11 +++++++++-- test/enzyme/svd.jl | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 55dc59474..736c0606b 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -448,10 +448,17 @@ function EnzymeRules.reverse( # appropriately here) Aval = nothing Sval = something(cache_S, S.val) + A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval if !isa(A, Const) - svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) + if A_is_arg + ΔA = make_zero(A.dval) + svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS) + A.dval .= ΔA + else + svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) + end end - if !isa(S, Const) && !(TA <: Diagonal && (diagview(A.dval) === S.dval)) + if !isa(S, Const) && !A_is_arg make_zero!(S.dval) end return (nothing, nothing, nothing) diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index 6143f61e4..e4aaa7aa1 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end From 517706a6d4964917b685e1e885d1251be77c3db7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 13:31:46 +0100 Subject: [PATCH 05/14] LQ/QR working --- .../MatrixAlgebraKitEnzymeExt.jl | 13 ++++++++++--- test/enzyme/lq.jl | 2 ++ test/enzyme/qr.jl | 2 ++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 736c0606b..820bba904 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -70,7 +70,10 @@ for (f, pb) in ( # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed # if arg isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves - cache_arg = (arg.val !== ret && arg.val[1] !== A.val) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] + A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] + A_is_arg = A_is_arg1 || A_is_arg2 + cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) make_zero.(ret) elseif EnzymeRules.needs_shadow(config) @@ -96,7 +99,9 @@ for (f, pb) in ( # use A (so that whoever does this is forced to handle caching A # appropriately here) Aval = nothing - A_is_arg = !isa(A, Const) && A.dval === arg.dval[1] + A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1] + A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2] + A_is_arg = A_is_arg1 || A_is_arg2 argval = something(cache_arg, arg.val) if !isa(A, Const) if A_is_arg @@ -110,8 +115,10 @@ for (f, pb) in ( if !isa(arg, Const) if !A_is_arg make_zero!(arg.dval) - else + elseif A_is_arg1 make_zero!(arg.dval[2]) + elseif A_is_arg2 + make_zero!(arg.dval[1]) end end return (nothing, nothing, nothing) diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl index f7ae2ebf7..7c747529d 100644 --- a/test/enzyme/lq.jl +++ b/test/enzyme/lq.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + m == n && TestSuite.test_enzyme_lq(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl index 728e267d3..2d8b9e7e1 100644 --- a/test/enzyme/qr.jl +++ b/test/enzyme/qr.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + m == n && TestSuite.test_enzyme_qr(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end From bad6f5ba1137e486d394bb137aed81b64023c02e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 13:43:44 +0100 Subject: [PATCH 06/14] Polar and orthnull --- test/enzyme/orthnull.jl | 2 ++ test/enzyme/polar.jl | 2 ++ test/testsuite/enzyme/orthnull.jl | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl index eaeae8400..086873d3f 100644 --- a/test/enzyme/orthnull.jl +++ b/test/enzyme/orthnull.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + m == n && TestSuite.test_enzyme_orthnull(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl index 6ab965ac1..183086adb 100644 --- a/test/enzyme/polar.jl +++ b/test/enzyme/polar.jl @@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + AT = Diagonal{T, Vector{T}} + #m == n && TestSuite.test_enzyme_polar(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/testsuite/enzyme/orthnull.jl b/test/testsuite/enzyme/orthnull.jl index 842fdc2f4..9fae30e5e 100644 --- a/test/testsuite/enzyme/orthnull.jl +++ b/test/testsuite/enzyme/orthnull.jl @@ -37,7 +37,7 @@ function test_enzyme_left_orth( test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) end - if m >= n + if m >= n && !(T <: Diagonal) @testset "polar" begin A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :polar) @@ -71,7 +71,7 @@ function test_enzyme_right_orth( test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) end - if m <= n + if m <= n && !(T <: Diagonal) @testset "polar" begin A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :polar) From cb3a53caf9e064477dcb6272da6c8396d9987a4f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 17:19:59 +0100 Subject: [PATCH 07/14] Some Mooncake progress --- .../MatrixAlgebraKitMooncakeExt.jl | 108 ++++++++++++------ test/mooncake/lq.jl | 4 +- test/mooncake/polar.jl | 4 +- test/mooncake/qr.jl | 4 +- test/testsuite/mooncake/orthnull.jl | 4 +- 5 files changed, 84 insertions(+), 40 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index ef32d6de4..21613c9fe 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -15,19 +15,6 @@ using LinearAlgebra Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} -function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) - Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) - Ac_dAc = Mooncake.zero_fcodual(Ac) - dAc = Mooncake.tangent(Ac_dAc) - function copy_input_pb(::NoRData) - Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) - return NoRData(), NoRData(), NoRData() - end - return Ac_dAc, copy_input_pb -end - -Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any} # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), @@ -53,12 +40,26 @@ for (f!, f, pb, adj) in ( arg2c = copy(arg2) $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) + if !(A === arg1 || A === arg2) + $pb(dA, A, (arg1, arg2), (darg1, darg2)) + else + ΔA = zero(A) + $pb(ΔA, A, (arg1, arg2), (darg1, darg2)) + dA .= ΔA + end + if A === arg1 + zero!(darg2) + copy!(arg2, arg2c) + elseif A === arg2 + zero!(darg1) + copy!(arg1, arg1c) + else + zero!(darg1) + zero!(darg2) + copy!(arg2, arg2c) + copy!(arg1, arg1c) + end copy!(A, Ac) - $pb(dA, A, (arg1, arg2), (darg1, darg2)) - copy!(arg1, arg1c) - copy!(arg2, arg2c) - zero!(darg1) - zero!(darg2) return NoRData(), NoRData(), NoRData(), NoRData() end return args_dargs, $adj @@ -140,9 +141,19 @@ for (f!, f, f_full, pb, adj) in ( copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) - $pb(dA, A, DV, dD) - copy!(D, Dc) - zero!(dD) + if A !== D + $pb(dA, A, DV, dD) + else + ΔA = zero(A) + $pb(ΔA, A, DV, dD) + dA .= A + end + if A !== D + zero!(dD) + copy!(D, Dc) + else + copy!(A, Ac) + end return NoRData(), NoRData(), NoRData(), NoRData() end return D_dD, $adj @@ -199,15 +210,27 @@ for f in (:eig, :eigh) # not for nested structs with various fields (like Diagonal{Complex}) output_codual = Mooncake.zero_fcodual(output) function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) - copy!(A, Ac) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) _warn_pullback_truncerror(dy[3]) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) - $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) - copy!(DV[1], DVc[1]) - copy!(DV[2], DVc[2]) + D, dD = arrayify(DV[1], dDV[1]) + V, dV = arrayify(DV[2], dDV[2]) + copy!(A, Ac) + if !(A === D || A === V) + $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) + else + ΔA = zero(A) + $f_trunc_pullback!(ΔA, A, (D′, V′), (dD′, dV′)) + dA .= ΔA + end + if A === D + copy!(DV[2], DVc[2]) + else + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + end zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() @@ -239,12 +262,22 @@ for f in (:eig, :eigh) _warn_pullback_truncerror(dϵ) # compute pullbacks - $f_pullback!(dA, Ac, DV, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required - + if !(A === DV[1] || A === DV[2]) + $f_pullback!(dA, Ac, DV, dDVtrunc, ind) + else + ΔA = zero(A) + $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) + dA .= ΔA + end # restore state copy!(A, Ac) - copy!.(DV, DVc) + if A === DV[1] + copy!(DV[2], DVc[2]) + zero!(dDV[2]) + else + copy!.(DV, DVc) + zero!.(dDV) + end return ntuple(Returns(NoRData()), 4) end @@ -351,12 +384,23 @@ for f in (:eig, :eigh) dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) function $f_adjoint!(::NoRData) # compute pullbacks - $f_pullback!(dA, Ac, DV, dDVtrunc, ind) - zero!.(dDV) + if !(A === DV[1] || A === DV[2]) + $f_pullback!(dA, Ac, DV, dDVtrunc, ind) + else + ΔA = zero(A) + $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) + dA .= ΔA + end # restore state copy!(A, Ac) - copy!.(DV, DVc) + if A === DV[1] + copy!(DV[2], DVc[2]) + zero!(dDV[2]) + else + copy!.(DV, DVc) + zero!.(dDV) + end return ntuple(Returns(NoRData()), 4) end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 9ffdc730d..fb9b70f78 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - if m == n + #=if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end + end=# # broken with singular exception end end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index d6c089098..6442b7b33 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -17,10 +17,10 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) - if m == n + #=if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end + end=# # broken due to pullback end end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index ed4db1925..3ae2ae81f 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - if m == n + #=if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end + end=# # broken with singular exception end end diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index 48e4ad921..e89700868 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -43,7 +43,7 @@ function test_mooncake_left_orth( ) end - if m >= n + if m >= n && !(T <: Diagonal) @testset "polar" begin alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :polar) VC = left_orth(A, alg) @@ -91,7 +91,7 @@ function test_mooncake_right_orth( ) end - if m <= n + if m <= n && !(T <: Diagonal) @testset "polar" begin alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :polar) CVᴴ = right_orth(A, alg) From 1b16db5344fc58948dd8c7e4b6ae6cffcbf52243 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Mar 2026 17:34:28 +0100 Subject: [PATCH 08/14] Formatter --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 820bba904..3d1d00142 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -101,7 +101,7 @@ for (f, pb) in ( Aval = nothing A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1] A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2] - A_is_arg = A_is_arg1 || A_is_arg2 + A_is_arg = A_is_arg1 || A_is_arg2 argval = something(cache_arg, arg.val) if !isa(A, Const) if A_is_arg From be263e555880a575fd8f81c08c51ed32b4788028 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Mar 2026 13:45:50 +0100 Subject: [PATCH 09/14] Working Mooncake eig --- .../MatrixAlgebraKitMooncakeExt.jl | 2 +- test/testsuite/mooncake/eig.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 21613c9fe..d9c52780a 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -40,6 +40,7 @@ for (f!, f, pb, adj) in ( arg2c = copy(arg2) $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) + copy!(A, Ac) if !(A === arg1 || A === arg2) $pb(dA, A, (arg1, arg2), (darg1, darg2)) else @@ -59,7 +60,6 @@ for (f!, f, pb, adj) in ( copy!(arg2, arg2c) copy!(arg1, arg1c) end - copy!(A, Ac) return NoRData(), NoRData(), NoRData(), NoRData() end return args_dargs, $adj diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 88042df28..aad4b4881 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -31,6 +31,12 @@ function test_mooncake_eig_full( rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) + if T <: Diagonal{<:Complex} + Mooncake.TestUtils.test_rule( + rng, eig_full!, A, (A, DV[2]), alg; + mode = Mooncake.ReverseMode, output_tangent, atol, rtol + ) + end Mooncake.TestUtils.test_rule( rng, call_and_zero!, eig_full!, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false @@ -57,6 +63,12 @@ function test_mooncake_eig_vals( rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) + if T <: Diagonal{<:Complex} + Mooncake.TestUtils.test_rule( + rng, eig_vals!, A, A.diag, alg; + mode = Mooncake.ReverseMode, output_tangent, atol, rtol + ) + end Mooncake.TestUtils.test_rule( rng, call_and_zero!, eig_vals!, A, alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false From 00f0726fc4a6e08a46a8733489d66bd71e5654d8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Mar 2026 13:47:12 +0100 Subject: [PATCH 10/14] Remove comment --- test/mooncake/qr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 3ae2ae81f..79c05e009 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -18,6 +18,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) #=if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# # broken with singular exception + end=# end end From e451036fda7bf9c4599a92c64de5fc542035d807 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Mar 2026 14:36:10 +0100 Subject: [PATCH 11/14] No diag for orthnull --- test/mooncake/orthnull.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 370454b55..7ce6fa562 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - if m == n + #=if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end + end=# end end From e99090f491b6ce851cace907921ebe3ecb586af0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Mar 2026 15:15:32 +0100 Subject: [PATCH 12/14] Diag tests working for Mooncake --- .../MatrixAlgebraKitMooncakeExt.jl | 10 ++++++++-- test/mooncake/eig.jl | 2 +- test/mooncake/lq.jl | 6 +++--- test/mooncake/orthnull.jl | 4 ++-- test/mooncake/qr.jl | 6 +++--- test/testsuite/TestSuite.jl | 6 ++++++ 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index d9c52780a..10c167420 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -40,8 +40,12 @@ for (f!, f, pb, adj) in ( arg2c = copy(arg2) $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) - copy!(A, Ac) + # DON'T copy Ac to A if A === one + # of the output args -- this can + # mess up the pullback because + # generally the args are used there if !(A === arg1 || A === arg2) + copy!(A, Ac) $pb(dA, A, (arg1, arg2), (darg1, darg2)) else ΔA = zero(A) @@ -49,9 +53,11 @@ for (f!, f, pb, adj) in ( dA .= ΔA end if A === arg1 + copy!(A, Ac) zero!(darg2) copy!(arg2, arg2c) elseif A === arg2 + copy!(A, Ac) zero!(darg1) copy!(arg1, arg1c) else @@ -60,7 +66,7 @@ for (f!, f, pb, adj) in ( copy!(arg2, arg2c) copy!(arg1, arg1c) end - return NoRData(), NoRData(), NoRData(), NoRData() + return ntuple(Returns(NoRData()), 4) end return args_dargs, $adj end diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index a87319082..a0e606941 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index fb9b70f78..42d0fdb6b 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #=if m == n + if m == n AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_lq(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# # broken with singular exception + TestSuite.test_mooncake_lq(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 7ce6fa562..370454b55 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #=if m == n + if m == n AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_orthnull(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# + end end end diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 79c05e009..bbb9a8d17 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -15,9 +15,9 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #=if m == n + if m == n AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_qr(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# + TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index b4fb93008..36ae68304 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -96,6 +96,12 @@ function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz.. return mul!(A, V, C) end +function instantiate_rank_deficient_matrix(::Type{T}, sz; trunc = truncrank(div(min(sz...), 2))) where {T <: Diagonal} + A = instantiate_matrix(eltype(T), sz) + V, C = left_orth!(A; trunc) + return Diagonal(diag(mul!(A, V, C))) +end + include("ad_utils.jl") include("projections.jl") From b878a2fd8716aae11360033ee63f5e65e4adb575 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Mar 2026 15:38:58 +0100 Subject: [PATCH 13/14] Formatter --- ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 10c167420..363356ff4 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -40,7 +40,7 @@ for (f!, f, pb, adj) in ( arg2c = copy(arg2) $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) - # DON'T copy Ac to A if A === one + # DON'T copy Ac to A if A === one # of the output args -- this can # mess up the pullback because # generally the args are used there From 785f9fac274580cd2ca7013986c198d717d5c146 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 30 Mar 2026 15:56:35 +0200 Subject: [PATCH 14/14] Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Co-authored-by: Lukas Devos --- .../MatrixAlgebraKitEnzymeExt.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 3d1d00142..6ef246cb2 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -113,13 +113,8 @@ for (f, pb) in ( end end if !isa(arg, Const) - if !A_is_arg - make_zero!(arg.dval) - elseif A_is_arg1 - make_zero!(arg.dval[2]) - elseif A_is_arg2 - make_zero!(arg.dval[1]) - end + A.dval === arg.dval[1] || make_zero!(arg.dval[1]) + A.dval === arg.dval[2] || make_zero!(arg.dval[2]) end return (nothing, nothing, nothing) end