From 4840dadcdbfd0eb875733a6924ca1811532cf35c Mon Sep 17 00:00:00 2001 From: Adomas Baliuka Date: Fri, 9 Jun 2023 23:03:12 +0200 Subject: [PATCH 1/2] Adds operator Base.:*(::ToeplitzFactorization, ::StridedVector) Tests contain some superfluous stuff for illustrating what I'm doing! DO not merge this! Should be changed and rebased. --- src/linearalgebra.jl | 66 +++++++++++++++++++++++++++++++++++++-- test/runtests.jl | 74 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 8a945d6..8aa3509 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -39,8 +39,64 @@ function mul!( return y end + + +# option 1 +# Requires a change to `mul!`. +# could also duplicate version mul! (with changed name or signature) that doesn't modify factorization +# this shortens the code overall but +# may have a small perforance overhead due to the additional argument and the if statement in `mul!`. +# function Base.:*(A::ToeplitzFactorization, x::StridedVector) +# T = promote_type(eltype(A), eltype(x)) +# result = similar(x, T) +# return mul!(result, A, x, 1.0, 1.0; dont_modify_A=true) +# end + +# option 2 +# this makes the code longer (and have duplicates) but might be more efficient overall +# If going for this option, mul! is unchanged. +# However, it has a dirty hack to get output size, which might be a very bad idea... +# Maybe there is a better way? +function Base.:*(A::ToeplitzFactorization, x::StridedVector) + # adapted from + # mul!(::StridedVector, A::ToeplitzFactorization, x::StridedVector, α::Number, β::Number) + vcvr_dft = A.vcvr_dft + tmp = copy(A.tmp) # avoid changing (using `mul!`, `copyto!` and `*=` ) + dft = A.dft + + N = length(vcvr_dft) + n = length(x) + m = N - n + 1 # dirty hack to get output size. + if m > N || n > N + throw(DimensionMismatch( + "Toeplitz factorization size incompatible (max $(N-1)) with input vector (size $n)" + )) + end + + T = Base.promote_eltype(A, x) + y = Vector{T}(undef, m) + + @inbounds begin + copyto!(tmp, 1, x, 1, n) + for i in (n+1):N + tmp[i] = zero(eltype(tmp)) + end + mul!(tmp, dft, tmp) + for i in eachindex(tmp) + tmp[i] *= vcvr_dft[i] + end + dft \ tmp + + for i in eachindex(y) + y[i] = maybereal(T, tmp[i]) + end + end + + return y +end + function mul!( - y::StridedVector, A::ToeplitzFactorization, x::StridedVector, α::Number, β::Number + y::StridedVector, A::ToeplitzFactorization, x::StridedVector, α::Number, β::Number #; dont_modify_A=false, ) n = length(x) m = length(y) @@ -53,7 +109,13 @@ function mul!( end T = Base.promote_eltype(y, A, x, α, β) - tmp = A.tmp + + #if dont_modify_A + # tmp = copy(A.tmp) # to make function thread-safe + #else + tmp = A.tmp + #end + dft = A.dft @inbounds begin copyto!(tmp, 1, x, 1, n) diff --git a/test/runtests.jl b/test/runtests.jl index b87e9da..8972917 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -553,3 +553,77 @@ end @test cholesky(T).U ≈ cholesky(Matrix(T)).U @test cholesky(T).L ≈ cholesky(Matrix(T)).L end + + +@testset "NEW TODO MODIFY" begin + function getdata(m=1760, n=4097) + + T = Toeplitz{Float64}( + [1.0; rand([0.0, 1.0], m-1)], + [1.0; rand([0.0, 1.0], n-1)], + ) + + x = rand(Float64[0, 1], n) + + return T, x + end + + @testset "adsf" begin + T, x = getdata() + + correct_result = T * x + + Tfac = factorize(T) + + result = zeros(Float64, 1760) + @test Tfac * x == T * x == correct_result + @test Tfac * x == correct_result # can(!) reuse factorization + @test Tfac * x == correct_result == mul!(result, Tfac, x, 1.0, 0.0) # modifies Tfac.tmp + @test correct_result == mul!(result, Tfac, x, 1.0, 0.0) # but we still can reuse factorization! + + result = zeros(Float64, 1760) + + xcopy = copy(x) + + @test mul!(result, factorize(T), x, 1.0, 0.0) == correct_result + @test result == correct_result + @test x == xcopy + + @test mul!(result, factorize(T), x, 1.0, 0.0) == correct_result + # @test mul!(result, factorize(T), x, 1.0, 1.0) == correct_result + + end + + @testset "threaded mul!" begin + + T, x = getdata() + Tfac = factorize(T) + result = zeros(Float64, 1760) + + correct_result = T * x + + valid = Bool[] + Base.Threads.@threads for i = 1:100 + push!(valid, mul!(result, factorize(T), x, 1.0, 0.0) == correct_result) + end + @test all(valid) + end + + + @testset "threaded mul! (Broken!!!!)" begin + + T, x = getdata() + Tfac = factorize(T) + result = zeros(Float64, 1760) + + correct_result = T * x + + Tfac = factorize(T) + valid = Bool[] + Base.Threads.@threads for i = 1:100 + push!(valid, mul!(result, Tfac, x, 1.0, 0.0) == correct_result) + end + @test all(valid) broken=true + end + +end From e008b1d06f31ed1eaebab1cc64aba74806356e05 Mon Sep 17 00:00:00 2001 From: Adomas Baliuka Date: Fri, 9 Jun 2023 23:51:59 +0200 Subject: [PATCH 2/2] Comment out test that needs multithreading Because Github CI doesn't have multithreading --- test/runtests.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8972917..09aa066 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -610,20 +610,20 @@ end end - @testset "threaded mul! (Broken!!!!)" begin + # @testset "threaded mul! (Broken!!!!)" begin - T, x = getdata() - Tfac = factorize(T) - result = zeros(Float64, 1760) - - correct_result = T * x - - Tfac = factorize(T) - valid = Bool[] - Base.Threads.@threads for i = 1:100 - push!(valid, mul!(result, Tfac, x, 1.0, 0.0) == correct_result) - end - @test all(valid) broken=true - end + # T, x = getdata() + # Tfac = factorize(T) + # result = zeros(Float64, 1760) + + # correct_result = T * x + + # Tfac = factorize(T) + # valid = Bool[] + # Base.Threads.@threads for i = 1:100 + # push!(valid, mul!(result, Tfac, x, 1.0, 0.0) == correct_result) + # end + # @test all(valid) broken=true + # end end