Skip to content

Commit 6550be9

Browse files
authored
Merge pull request #141 from numlinalg/v0.2-kaczmarz_issue
Fix: Kaczmarz + Sampling on Sparse Matrix
2 parents 8cb508d + 9510516 commit 6550be9

File tree

6 files changed

+451
-33
lines changed

6 files changed

+451
-33
lines changed

src/Compressors/sampling.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,44 @@ function mul!(
215215
mul!(C_sub, A, I, alpha, 1)
216216

217217
return nothing
218-
end
218+
end
219+
220+
###############################################################################
221+
# Binary Operator Compressor-Array Multiplications for sparse matrices/vectors
222+
###############################################################################
223+
# S * A
224+
function (*)(S::SamplingRecipe, A::Union{SparseMatrixCSC, SparseVector})
225+
s_rows = size(S, 1)
226+
a_cols = size(A, 2)
227+
C = a_cols == 1 ? spzeros(eltype(A), s_rows) : spzeros(eltype(A), s_rows, a_cols)
228+
mul!(C, S, A)
229+
return C
230+
end
231+
232+
# A * S
233+
function (*)(A::Union{SparseMatrixCSC, SparseVector}, S::SamplingRecipe)
234+
s_cols = size(S, 2)
235+
a_rows = size(A, 1)
236+
C = a_rows == 1 ? spzeros(eltype(A), s_cols)' : spzeros(eltype(A), a_rows, s_cols)
237+
mul!(C, A, S)
238+
return C
239+
end
240+
241+
# S' * A
242+
function (*)(S::CompressorAdjoint{<:SamplingRecipe}, A::Union{SparseMatrixCSC, SparseVector})
243+
s_rows = size(S, 1)
244+
a_cols = size(A, 2)
245+
C = a_cols == 1 ? spzeros(eltype(A), s_rows) : spzeros(eltype(A), s_rows, a_cols)
246+
mul!(C, S, A)
247+
return C
248+
end
249+
250+
# A * S'
251+
function (*)(A::Union{SparseMatrixCSC, SparseVector}, S::CompressorAdjoint{<:SamplingRecipe})
252+
s_cols = size(S, 2)
253+
a_rows = size(A, 1)
254+
C = a_rows == 1 ? spzeros(eltype(A), s_cols)' : spzeros(eltype(A), a_rows, s_cols)
255+
mul!(C, A, S)
256+
return C
257+
end
258+

src/RLinearAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import LinearAlgebra: Adjoint, axpby!, axpy!, dot, I, ldiv!, lmul!, lq!
55
import LinearAlgebra: lq, LQ, lu!, mul!, norm, qr!, UpperTriangular, svd
66
import StatsBase: ProbabilityWeights, sample, sample!, wsample!
77
import Random: bitrand, rand!, randn!
8-
import SparseArrays: SparseMatrixCSC, sprandn, sparse
8+
import SparseArrays: SparseMatrixCSC, SparseVector, spzeros, sprandn, sparse
99

1010
# Include the files correspoding to the top-level techniques
1111
include("Compressors.jl")

src/Solvers/SubSolvers/lq.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ function ldiv!(
3939
)
4040
fill!(x, zero(eltype(b)))
4141
# this will modify B in place so you cannot use it again
42-
ldiv!(x, lq!(solver.A), b)
42+
# using qr here on the transpose of the matrix will work for sparse and dense matrices
43+
# while the lq would have only worked for dense matrices
44+
ldiv!(x, qr!(solver.A')', b)
4345
return nothing
4446
end

src/Solvers/kaczmarz.jl

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ Let ``A`` be an ``m \\times n`` matrix and consider the consistent linear system
1515
this interesection is to iteratively project some abritrary point, ``x`` from one
1616
hyperplane to the next, through
1717
``
18-
x_{+} = x + \\alpha \\frac{b_i - \\lange A_{i\\cdot}, x\\rangle}{\\| A_{i\\cdot}.
18+
x_{+} = x + \\alpha \\frac{b_i - A_{i\\cdot}^\\top x}{\\| A_{i\\cdot} \\|_2^2}
1919
``
20-
Doing this with random permutation of ``i`` can lead to a geometric convergence
20+
Doing this with random permutation of ``i`` can lead to a geometric convergence
2121
[strohmer2009randomized](@cite).
2222
Here ``\\alpha`` is viewed as an over-relaxation parameter and can improve convergence.
2323
One can also generalize this procedure to blocks by considering the ``S`` being a
@@ -158,8 +158,8 @@ mutable struct KaczmarzRecipe{
158158
alpha::Float64
159159
compressed_mat::M
160160
compressed_vec::V
161-
solution_vec::V
162-
update_vec::V
161+
solution_vec::Vector{T}
162+
update_vec::Vector{T}
163163
mat_view::MV
164164
vec_view::VV
165165
end
@@ -198,9 +198,22 @@ function complete_solver(
198198
# We assume the user is using compressors to only decrease dimension
199199
sample_size::Int64 = compressor.n_rows
200200
cols_a = size(A, 2)
201-
# Allocate the information in the buffer using the types of A and b
202-
compressed_mat = zeros(eltype(A), sample_size, cols_a)
203-
compressed_vec = zeros(eltype(b), sample_size)
201+
# because sampling recipe use subsets sparse matrices will be problematic
202+
if typeof(compressor) <: SamplingRecipe && typeof(A) <: SparseMatrixCSC
203+
compressed_mat = spzeros(eltype(A), sample_size, cols_a)
204+
else
205+
# Allocate the information in the buffer using the types of A
206+
compressed_mat = zeros(eltype(A), sample_size, cols_a)
207+
end
208+
209+
# because sampling recipe use subsets sparse vectors will be problematic
210+
if typeof(compressor) <: SamplingRecipe && typeof(b) <: SparseVector
211+
compressed_vec = spzeros(eltype(b), sample_size)
212+
else
213+
# Allocate the information in the buffer using the type of b
214+
compressed_vec = zeros(eltype(b), sample_size)
215+
end
216+
204217
# Since sub_solver is applied to compressed matrices use here
205218
sub_solver = complete_sub_solver(ingredients.sub_solver, compressed_mat, compressed_vec)
206219
mat_view = view(compressed_mat, 1:sample_size, :)
@@ -250,12 +263,15 @@ function kaczmarz_update!(solver::KaczmarzRecipe)
250263
# when the constant vector is a zero dimensional subArray we know that we should perform
251264
# the one dimension kaczmarz update
252265

253-
# Compute the projection scaling (bi - dot(ai,x)) / ||ai||^2
254-
scaling = solver.alpha * (dotu(solver.mat_view, solver.solution_vec)
255-
- solver.vec_view[1])
256-
scaling /= dot(solver.mat_view, solver.mat_view)
257-
# udpate the solution computes solution_vec = solution_vec - scaling * mat_view'
258-
axpby!(-scaling, solver.mat_view', 1.0, solver.solution_vec)
266+
# check that the row is non-zero
267+
if !iszero(solver.mat_view)
268+
# Compute the projection scaling (bi - dot(ai,x)) / ||ai||^2
269+
scaling = solver.alpha * (dotu(solver.mat_view, solver.solution_vec)
270+
- solver.vec_view[1])
271+
scaling /= dot(solver.mat_view, solver.mat_view)
272+
# udpate the solution computes solution_vec = solution_vec - scaling * mat_view'
273+
axpby!(-scaling, solver.mat_view', 1.0, solver.solution_vec)
274+
end
259275
return nothing
260276
end
261277

@@ -278,16 +294,19 @@ A function that performs the kaczmarz update when the compression dim is greater
278294
function kaczmarz_update_block!(solver::KaczmarzRecipe)
279295
# when the constant vector is a one dimensional subArray we know that we should perform
280296
# the one dimension kaczmarz update
281-
# sub-solver needs to designed for new compressed matrix
282-
update_sub_solver!(solver.sub_solver, solver.mat_view)
283-
# Compute the block residual
284-
# (computes solver.vec_view - solver.mat_view * solver.solution_vec)
285-
mul!(solver.vec_view, solver.mat_view, solver.solution_vec, -1.0, 1.0)
286-
# use sub-solver to find update the solution (solves min ||tilde A - tilde b|| and
287-
# stores in update_vec)
288-
ldiv!(solver.update_vec, solver.sub_solver, solver.vec_view)
289-
# computes solver.solution_vec = solver.solution_vec + alpha * solver.update_vec
290-
axpby!(solver.alpha, solver.update_vec, 1.0, solver.solution_vec)
297+
if !iszero(solver.mat_view)
298+
# sub-solver needs to designed for new compressed matrix
299+
update_sub_solver!(solver.sub_solver, solver.mat_view)
300+
# Compute the block residual
301+
# (computes solver.vec_view - solver.mat_view * solver.solution_vec)
302+
mul!(solver.vec_view, solver.mat_view, solver.solution_vec, -1.0, 1.0)
303+
# use sub-solver to find update the solution (solves min ||tilde A - tilde b|| and
304+
# stores in update_vec)
305+
ldiv!(solver.update_vec, solver.sub_solver, solver.vec_view)
306+
# computes solver.solution_vec = solver.solution_vec + alpha * solver.update_vec
307+
axpby!(solver.alpha, solver.update_vec, 1.0, solver.solution_vec)
308+
end
309+
291310
return nothing
292311
end
293312

test/Compressors/sampling.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Sampling_compressor
22
using Test, RLinearAlgebra, Random
33
using StatsBase: ProbabilityWeights, sample
44
import LinearAlgebra: mul!, Adjoint
5+
import SparseArrays: sprandn
56
using ..FieldTest
67
using ..ApproxTol
78

@@ -499,8 +500,139 @@ Random.seed!(2131)
499500
mul!(x, S', yc, alpha, beta)
500501
@test x alpha * Sty_exact + beta * xc
501502
end
503+
504+
end
505+
506+
@testset "Left Cardinality Sparse" begin
507+
let a_matrix_rows = 20,
508+
a_matrix_cols = 12,
509+
comp_dim = 7,
510+
alpha = 2.5,
511+
beta = 1.5
512+
513+
# Setup matrices and vectors
514+
A = sprandn(a_matrix_rows, a_matrix_cols, .8)
515+
B = sprandn(comp_dim, a_matrix_cols, .8)
516+
C1 = sprandn(comp_dim, a_matrix_cols, .8)
517+
C2 = sprandn(a_matrix_rows, a_matrix_cols, .8)
518+
x = sprandn(a_matrix_rows, .8)
519+
y = sprandn(comp_dim, .8)
520+
521+
# Keep copies for 5-argument mul! verification
522+
C1c = deepcopy(C1)
523+
C2c = deepcopy(C2)
524+
yc = deepcopy(y)
525+
xc = deepcopy(x)
526+
527+
# Setup the Sampling compressor recipe
528+
S_info = Sampling(
529+
cardinality=Left(),
530+
compression_dim=comp_dim,
531+
distribution=Uniform(cardinality=Left(), replace=false)
532+
)
533+
S = complete_compressor(S_info, A)
534+
535+
# Calculate all ground truth results
536+
SA_exact = A[S.idx, :]
537+
StB_exact = zeros(a_matrix_rows, a_matrix_cols); for i in 1:comp_dim; StB_exact[S.idx[i], :] = B[i, :]; end
538+
Sx_exact = x[S.idx]
539+
Sty_exact = zeros(a_matrix_rows); for i in 1:comp_dim; Sty_exact[S.idx[i]] = y[i]; end
540+
541+
# Test '*' operations by comparing to ground truths
542+
@test S * A SA_exact
543+
@test S' * B StB_exact
544+
@test A' * S' SA_exact'
545+
@test B' * S StB_exact'
546+
@test S * x Sx_exact
547+
@test x' * S' Sx_exact'
548+
@test S' * y Sty_exact
549+
@test y' * S Sty_exact'
550+
551+
# Test the 5-argument mul!
552+
mul!(C1, S, A, alpha, beta)
553+
@test C1 alpha * SA_exact + beta * C1c
554+
555+
mul!(C2, S', B, alpha, beta)
556+
@test C2 alpha * StB_exact + beta * C2c
557+
558+
mul!(y, S, xc, alpha, beta)
559+
@test y alpha * Sx_exact + beta * yc
560+
561+
mul!(x, S', yc, alpha, beta)
562+
@test x alpha * Sty_exact + beta * xc
563+
end
564+
end
565+
566+
# Test multiplications with right compressors
567+
@testset "Right Cardinality" begin
568+
let n = 20,
569+
comp_dim = 10,
570+
alpha = 2.0,
571+
beta = 2.0
572+
573+
# Setup matrices and vectors with dimensions
574+
A = sprandn(n, comp_dim, .8)
575+
B = sprandn(n, n, .8)
576+
# C1 is for S'*A, C2 is for B*S
577+
C1 = sprandn(n, n, .8)
578+
C2 = sprandn(n, comp_dim, .8)
579+
x = sprandn(comp_dim, .8)
580+
y = sprandn(n, .8)
581+
582+
# Keep copies for 5-argument mul! verification
583+
C1c = deepcopy(C1)
584+
C2c = deepcopy(C2)
585+
yc = deepcopy(y)
586+
xc = deepcopy(x)
587+
588+
# Setup the Sampling compressor recipe. It's created from B, an n x n matrix.
589+
# The operator S will have conceptual dimensions (n x comp_dim).
590+
S_info = Sampling(
591+
cardinality=Right(),
592+
compression_dim=comp_dim,
593+
distribution=Uniform(cardinality=Right(), replace=false)
594+
)
595+
S = complete_compressor(S_info, B)
596+
597+
# Calculate all ground truth results based on direct indexing/operations
598+
StA_exact = A[S.idx, :]
599+
BS_exact = B[:, S.idx]
600+
BtS_exact = B'[:, S.idx]
601+
ASt_exact = zeros(n, n); for i in 1:comp_dim; ASt_exact[:, S.idx[i]] = A[:, i]; end
602+
Sx_exact = zeros(n); for i in 1:comp_dim; Sx_exact[S.idx[i]] = x[i]; end
603+
Sty_exact = y[S.idx]
604+
605+
# Test '*' operations by comparing to ground truths
606+
@test S' * A StA_exact
607+
@test A' * S StA_exact'
608+
@test B * S BS_exact
609+
@test S' * B' BS_exact'
610+
@test B' * S BtS_exact
611+
@test S' * B BtS_exact'
612+
@test A * S' ASt_exact
613+
@test S * A' ASt_exact'
614+
@test S * x Sx_exact
615+
@test x' * S' Sx_exact'
616+
@test y' * S Sty_exact'
617+
@test S' * y Sty_exact
618+
619+
# Test the 5-argument mul!
620+
mul!(C1, A, S', alpha, beta)
621+
@test C1 alpha * ASt_exact + beta * C1c
622+
623+
mul!(C2, B, S, alpha, beta)
624+
@test C2 alpha * BS_exact + beta * C2c
625+
626+
mul!(y, S, xc, alpha, beta)
627+
@test y alpha * Sx_exact + beta * yc
628+
629+
mul!(x, S', yc, alpha, beta)
630+
@test x alpha * Sty_exact + beta * xc
631+
end
632+
502633
end
503634
end
635+
504636
end
505637

506638
end

0 commit comments

Comments
 (0)