diff --git a/Project.toml b/Project.toml index 34dc8c76a..0dc8dd4d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.6" +version = "1.72.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index cdadca6c3..08fa3143a 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -380,12 +380,18 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) fλ = first.(fλ_df_dλ) df_dλ = last.(unthunk.(fλ_df_dλ)) fA = (U * Diagonal(fλ)) * U' - Y = if eltype(A) <: Real + Y = if eltype(A) <: Real && eltype(fλ) <: Complex + # Real input with complex output: always Symmetric (matches Julia's behavior) Symmetric(fA) elseif eltype(fλ) <: Complex + # Complex input with complex output: plain Matrix fA - else + elseif A isa Hermitian && (eltype(A) <: Complex || VERSION >= v"1.12.0-DEV.0") + # Complex Hermitian input with real output: always Hermitian (conjugate symmetry) + # Real Hermitian input with real output: Hermitian on Julia 1.12+, Symmetric before Hermitian(fA) + else + Symmetric(fA) end intermediates = (λ, U, fλ, df_dλ) return Y, intermediates diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2099b9b3b..d52dfcb5f 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -65,16 +65,18 @@ @testset "Diagonal" begin # fwd - @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) - @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) + # Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which + # uses scalar indexing incompatible with GPU arrays + @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0])) + @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4)) # rev - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0])) + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4)) # Needs to not try and inplace, as `mul!` will do wrong. # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3,3)) + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4,4)) end @testset "$adj * Vector" for adj in (adjoint, transpose) @@ -83,11 +85,13 @@ end end + # Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which + # uses scalar indexing incompatible with GPU arrays (JLArrays) @testset "muladd: $T" for T in (Float64, ComplexF64) - @testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false] + @testset "add $(typeof(z))" for z in [rand(), rand(T, 4), rand(T, 4, 4), false] @testset "matrix * matrix" begin - A = rand(T, 3, 3) - B = rand(T, 3, 3) + A = rand(T, 4, 4) + B = rand(T, 4, 4) @gpu test_rrule(muladd, A, B, z) @gpu test_rrule(muladd, A', B, z) @gpu test_rrule(muladd, A , B', z) @@ -95,38 +99,38 @@ @gpu test_frule(muladd, A', B, z) @gpu test_frule(muladd, A , B', z) - C = rand(T, 3, 5) - D = rand(T, 5, 3) + C = rand(T, 4, 5) + D = rand(T, 5, 4) @gpu test_rrule(muladd, C, D, z) @gpu test_frule(muladd, C, D, z) end if ndims(z) <= 1 @testset "matrix * vector" begin - A, B = rand(T, 3, 3), rand(T, 3) + A, B = rand(T, 4, 4), rand(T, 4) test_rrule(muladd, A, B, z) - test_rrule(muladd, A, B ⊢ rand(T, 3,1), z) + test_rrule(muladd, A, B ⊢ rand(T, 4,1), z) test_frule(muladd, A, B, z) end @testset "adjoint * matrix" begin - At, B = rand(T, 3)', rand(T, 3, 3) + At, B = rand(T, 4)', rand(T, 4, 4) test_rrule(muladd, At, B, z') - test_rrule(muladd, At ⊢ rand(T,1,3), B, z') + test_rrule(muladd, At ⊢ rand(T,1,4), B, z') test_frule(muladd, At, B, z') end end if ndims(z) == 0 @testset "adjoint * vector" begin # like dot - A, B = rand(T, 3)', rand(T, 3) + A, B = rand(T, 4)', rand(T, 4) test_rrule(muladd, A, B, z) - test_rrule(muladd, A ⊢ rand(T,1,3), B, z') + test_rrule(muladd, A ⊢ rand(T,1,4), B, z') test_frule(muladd, A, B, z) end end if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1) @testset "vector * adjoint" begin # outer product - A, B = rand(T, 3), rand(T, 3)' + A, B = rand(T, 4), rand(T, 4)' test_rrule(muladd, A, B, z) - test_rrule(muladd, A, B ⊢ rand(T,1,3), z) + test_rrule(muladd, A, B ⊢ rand(T,1,4), z) test_frule(muladd, A, B, z) end end diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 052045d1e..0cd1f1ce9 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -24,8 +24,8 @@ @testset "sortslices" begin test_frule(sortslices, rand(3,4); fkwargs=(; dims=2)) - test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) - test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) + test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2), check_inferred=false) + test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last), check_inferred=false) test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) @test_throws Exception sortslices(Diagonal(1:3), dims=1) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 5f5efa8d2..c02b5ac67 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -138,8 +138,14 @@ test_rrule(logabsdet, -B) end @testset "tr" begin - @gpu test_frule(tr, randn(4, 4)) - @gpu test_rrule(tr, randn(4, 4)) + if VERSION >= v"1.12.0-DEV.0" + # tr uses scalar indexing in LinearAlgebra on Julia 1.12+, broken on GPU arrays + @gpu_broken test_frule(tr, randn(4, 4)) + @gpu_broken test_rrule(tr, randn(4, 4)) + else + @gpu test_frule(tr, randn(4, 4)) + @gpu test_rrule(tr, randn(4, 4)) + end end @testset "sylvester" begin @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 593b82148..9067cbc36 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -329,13 +329,13 @@ Y_ad, ∂Y_ad = @maybe_inferred frule((ZeroTangent(), ΔA), f, A) else TY = T∂Y = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},Symmetric{T},Hermitian{Complex{T}},Hermitian{T}} else Union{Matrix{T},Hermitian{T}} end Y_ad, ∂Y_ad = @maybe_inferred Tuple{TY,T∂Y} frule((ZeroTangent(), ΔA), f, A) end - @test Y_ad == Y + @test Y_ad ≈ Y @test typeof(Y_ad) === typeof(Y) hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo @test ∂Y_ad isa typeof(Y) @@ -382,13 +382,13 @@ Y_ad, back = @maybe_inferred rrule(f, A) else TY = if T <: Real - Union{Symmetric{Complex{T}},Symmetric{T}} + Union{Symmetric{Complex{T}},Symmetric{T},Hermitian{Complex{T}},Hermitian{T}} else Union{Matrix{T},Hermitian{T}} end Y_ad, back = @maybe_inferred Tuple{TY,Any} rrule(f, A) end - @test Y_ad == Y + @test Y_ad ≈ Y @test typeof(Y_ad) === typeof(Y) hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo ∂self, ∂A = @maybe_inferred back(ΔY)