Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.72.6"
version = "1.72.7"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
10 changes: 8 additions & 2 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,18 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
fλ = first.(fλ_df_dλ)
df_dλ = last.(unthunk.(fλ_df_dλ))
fA = (U * Diagonal(fλ)) * U'
Y = if eltype(A) <: Real
Y = if eltype(A) <: Real && eltype(fλ) <: Complex
# Real input with complex output: always Symmetric (matches Julia's behavior)
Symmetric(fA)
elseif eltype(fλ) <: Complex
# Complex input with complex output: plain Matrix
fA
else
elseif A isa Hermitian && (eltype(A) <: Complex || VERSION >= v"1.12.0-DEV.0")
# Complex Hermitian input with real output: always Hermitian (conjugate symmetry)
# Real Hermitian input with real output: Hermitian on Julia 1.12+, Symmetric before
Hermitian(fA)
else
Symmetric(fA)
end
intermediates = (λ, U, fλ, df_dλ)
return Y, intermediates
Expand Down
40 changes: 22 additions & 18 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,18 @@

@testset "Diagonal" begin
# fwd
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
# Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which
# uses scalar indexing incompatible with GPU arrays
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0]))
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4))

# rev
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0]))
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4))

# Needs to not try and inplace, as `mul!` will do wrong.
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3,3))
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4,4))
end

@testset "$adj * Vector" for adj in (adjoint, transpose)
Expand All @@ -83,50 +85,52 @@
end
end

# Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which
# uses scalar indexing incompatible with GPU arrays (JLArrays)
@testset "muladd: $T" for T in (Float64, ComplexF64)
@testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false]
@testset "add $(typeof(z))" for z in [rand(), rand(T, 4), rand(T, 4, 4), false]
@testset "matrix * matrix" begin
A = rand(T, 3, 3)
B = rand(T, 3, 3)
A = rand(T, 4, 4)
B = rand(T, 4, 4)
@gpu test_rrule(muladd, A, B, z)
@gpu test_rrule(muladd, A', B, z)
@gpu test_rrule(muladd, A , B', z)
@gpu test_frule(muladd, A, B, z)
@gpu test_frule(muladd, A', B, z)
@gpu test_frule(muladd, A , B', z)

C = rand(T, 3, 5)
D = rand(T, 5, 3)
C = rand(T, 4, 5)
D = rand(T, 5, 4)
@gpu test_rrule(muladd, C, D, z)
@gpu test_frule(muladd, C, D, z)
end
if ndims(z) <= 1
@testset "matrix * vector" begin
A, B = rand(T, 3, 3), rand(T, 3)
A, B = rand(T, 4, 4), rand(T, 4)
test_rrule(muladd, A, B, z)
test_rrule(muladd, A, B ⊢ rand(T, 3,1), z)
test_rrule(muladd, A, B ⊢ rand(T, 4,1), z)
test_frule(muladd, A, B, z)
end
@testset "adjoint * matrix" begin
At, B = rand(T, 3)', rand(T, 3, 3)
At, B = rand(T, 4)', rand(T, 4, 4)
test_rrule(muladd, At, B, z')
test_rrule(muladd, At ⊢ rand(T,1,3), B, z')
test_rrule(muladd, At ⊢ rand(T,1,4), B, z')
test_frule(muladd, At, B, z')
end
end
if ndims(z) == 0
@testset "adjoint * vector" begin # like dot
A, B = rand(T, 3)', rand(T, 3)
A, B = rand(T, 4)', rand(T, 4)
test_rrule(muladd, A, B, z)
test_rrule(muladd, A ⊢ rand(T,1,3), B, z')
test_rrule(muladd, A ⊢ rand(T,1,4), B, z')
test_frule(muladd, A, B, z)
end
end
if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1)
@testset "vector * adjoint" begin # outer product
A, B = rand(T, 3), rand(T, 3)'
A, B = rand(T, 4), rand(T, 4)'
test_rrule(muladd, A, B, z)
test_rrule(muladd, A, B ⊢ rand(T,1,3), z)
test_rrule(muladd, A, B ⊢ rand(T,1,4), z)
test_frule(muladd, A, B, z)
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
@testset "sortslices" begin
test_frule(sortslices, rand(3,4); fkwargs=(; dims=2))

test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2), check_inferred=false)
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last), check_inferred=false)
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)

@test_throws Exception sortslices(Diagonal(1:3), dims=1)
Expand Down
10 changes: 8 additions & 2 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,14 @@
test_rrule(logabsdet, -B)
end
@testset "tr" begin
@gpu test_frule(tr, randn(4, 4))
@gpu test_rrule(tr, randn(4, 4))
if VERSION >= v"1.12.0-DEV.0"
# tr uses scalar indexing in LinearAlgebra on Julia 1.12+, broken on GPU arrays
@gpu_broken test_frule(tr, randn(4, 4))
@gpu_broken test_rrule(tr, randn(4, 4))
else
@gpu test_frule(tr, randn(4, 4))
@gpu test_rrule(tr, randn(4, 4))
end
end
@testset "sylvester" begin
@testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3)
Expand Down
8 changes: 4 additions & 4 deletions test/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,13 @@
Y_ad, ∂Y_ad = @maybe_inferred frule((ZeroTangent(), ΔA), f, A)
else
TY = T∂Y = if T <: Real
Union{Symmetric{Complex{T}},Symmetric{T}}
Union{Symmetric{Complex{T}},Symmetric{T},Hermitian{Complex{T}},Hermitian{T}}
else
Union{Matrix{T},Hermitian{T}}
end
Y_ad, ∂Y_ad = @maybe_inferred Tuple{TY,T∂Y} frule((ZeroTangent(), ΔA), f, A)
end
@test Y_ad == Y
@test Y_ad Y
@test typeof(Y_ad) === typeof(Y)
hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo
@test ∂Y_ad isa typeof(Y)
Expand Down Expand Up @@ -382,13 +382,13 @@
Y_ad, back = @maybe_inferred rrule(f, A)
else
TY = if T <: Real
Union{Symmetric{Complex{T}},Symmetric{T}}
Union{Symmetric{Complex{T}},Symmetric{T},Hermitian{Complex{T}},Hermitian{T}}
else
Union{Matrix{T},Hermitian{T}}
end
Y_ad, back = @maybe_inferred Tuple{TY,Any} rrule(f, A)
end
@test Y_ad == Y
@test Y_ad Y
@test typeof(Y_ad) === typeof(Y)
hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo
∂self, ∂A = @maybe_inferred back(ΔY)
Expand Down
Loading