Skip to content

Commit 8cb508d

Browse files
authored
Merge pull request #119 from numlinalg/v0.2-ihs
Solver: Iterative Hessian Sketch
2 parents 4063e02 + 96b0da0 commit 8cb508d

File tree

6 files changed

+798
-7
lines changed

6 files changed

+798
-7
lines changed

docs/src/api/solvers.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ SolverRecipe
1212

1313
## Solver Structures
1414
```@docs
15+
IHS
16+
17+
IHSRecipe
18+
1519
Kaczmarz
1620
1721
KaczmarzRecipe

docs/src/refs.bib

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,29 @@ @misc{martinsson2020randomized
3434
year = {2020},
3535
publisher = {arXiv},
3636
doi = {10.48550/ARXIV.2002.01387},
37-
copyright = {arXiv.org perpetual, non-exclusive license},
38-
keywords = {FOS: Mathematics,Numerical Analysis (math.NA)}
37+
}
38+
39+
@article{needell2014paved,
40+
title = {Paved with Good Intentions: {{Analysis}} of a Randomized Block {{Kaczmarz}} Method},
41+
shorttitle = {Paved with Good Intentions},
42+
author = {Needell, Deanna and Tropp, Joel A.},
43+
year = {2014},
44+
month = jan,
45+
journal = {Linear Algebra and its Applications},
46+
volume = {441},
47+
pages = {199--221},
48+
issn = {00243795},
49+
doi = {10.1016/j.laa.2012.12.022},
50+
langid = {english},
51+
}
52+
53+
@article{pilanci2014iterative,
54+
title = {Iterative {{Hessian}} Sketch: {{Fast}} and Accurate Solution Approximation for Constrained Least-Squares},
55+
shorttitle = {Iterative {{Hessian}} Sketch},
56+
author = {Pilanci, Mert and Wainwright, Martin J.},
57+
year = {2016},
58+
Journal = {Journal of Machine Learning Research},
59+
volume = {17}
3960
}
4061

4162
@article{motzkin1954relaxation,
@@ -128,10 +149,9 @@ @article{strohmer2009randomized
128149
pages = {262--278},
129150
issn = {1069-5869, 1531-5851},
130151
doi = {10.1007/s00041-008-9030-4},
131-
copyright = {http://www.springer.com/tdm},
132-
langid = {english}
133152
}
134153

154+
135155
@article{tropp2011improved,
136156
title = {{{Improved Analysis of the Subsampled Randomized Hadamard Transform}}},
137157
author = {Tropp, Joel A.},

src/RLinearAlgebra.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module RLinearAlgebra
22
import Base.:*
33
import Base: transpose, adjoint, setproperty!
4-
import LinearAlgebra: Adjoint, axpby!, dot, I, ldiv!, lmul!, lq!, lq, LQ, lu!
5-
import LinearAlgebra: mul!, norm, qr!, svd
6-
import StatsBase: sample, sample!, ProbabilityWeights, wsample!
4+
import LinearAlgebra: Adjoint, axpby!, axpy!, dot, I, ldiv!, lmul!, lq!
5+
import LinearAlgebra: lq, LQ, lu!, mul!, norm, qr!, UpperTriangular, svd
6+
import StatsBase: ProbabilityWeights, sample, sample!, wsample!
77
import Random: bitrand, rand!, randn!
88
import SparseArrays: SparseMatrixCSC, sprandn, sparse
99

@@ -41,6 +41,7 @@ export Uniform, UniformRecipe
4141
export Solver, SolverRecipe
4242
export Kaczmarz, KaczmarzRecipe
4343
export complete_solver, update_solver!, rsolve!
44+
export IHS, IHSRecipe
4445

4546
# Export Logger types and functions
4647
export Logger, LoggerRecipe

src/Solvers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ include("Solvers/ErrorMethods.jl")
150150
#############################
151151
# The Solver Routine Files
152152
############################
153+
include("Solvers/ihs.jl")
153154
include("Solvers/kaczmarz.jl")
154155
############################
155156
# Helper functions

src/Solvers/ihs.jl

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""
2+
IHS <: Solver
3+
4+
An implementation of the Iterative Hessian Sketch solver for solving over determined
5+
least squares problems [pilanci2014iterative](@cite).
6+
7+
# Mathematical Description
8+
Let ``A \\in \\mathbb{R}^{m \\times n}, m \\gg n,`` and consider the least square problem
9+
``\\min_x \\|Ax - b \\|_2^2``. If we let ``S \\in \\mathbb{R}^{s \\times m}`` be a
10+
compression matrix, then Iterative Hessian Sketch iteratively finds a solution to this
11+
problem by repeatedly updating ``x_{k+1} = x_k + \\alpha u_k``where ``u_k`` is the solution
12+
to the convex optimization problem,
13+
``u_k \\in \\argmin_u \\{\\|S_k Au\\|_2^2 - \\langle A, b - Ax_k \\rangle \\}.`` This method
14+
has been to shown to converge geometrically at a rate ``\\rho \\in (0, 1/2]``. Typically the
15+
required compression dimension needs to be 4-8 times the size of n for the algorithm to
16+
perform successfully.
17+
18+
# Fields
19+
- `alpha::Float64`, a step size parameter.
20+
- `compressor::Compressor`, a technique for forming the compressed linear system.
21+
- `log::Logger`, a technique for logging the progress of the solver.
22+
- `error::SolverError`, a method for estimating the progress of the solver.
23+
24+
# Constructor
25+
function IHS(;
26+
compressor::Compressor = SparseSign(cardinality = Left()),
27+
log::Logger = BasicLogger(),
28+
error::SolverError = FullResidual(),
29+
alpha::Float64 = 1.0
30+
)
31+
## Keywords
32+
- `compressor::Compressor`, a technique for forming the compressed linear system.
33+
- `log::Logger`, a technique for logging the progress of the solver.
34+
- `error::SolverError`, a method for estimating the progress of the solver.
35+
- `alpha::Float64`, a step size parameter.
36+
37+
# Returns
38+
- A `IHS` object.
39+
"""
40+
mutable struct IHS <: Solver
41+
alpha::Float64
42+
log::Logger
43+
compressor::Compressor
44+
error::SolverError
45+
function IHS(alpha, log, compressor, error)
46+
if typeof(compressor.cardinality) != Left
47+
@warn "Compressor has cardinality `Right` but IHS compresses from the `Left`."
48+
end
49+
50+
if alpha < 0
51+
@warn "Negative step size could lead to divergent iterates."
52+
end
53+
54+
new(alpha, log, compressor, error)
55+
end
56+
57+
end
58+
59+
function IHS(;
60+
compressor::Compressor = SparseSign(cardinality = Left()),
61+
log::Logger = BasicLogger(),
62+
error::SolverError = FullResidual(),
63+
alpha::Float64 = 1.0
64+
)
65+
return IHS(
66+
alpha,
67+
log,
68+
compressor,
69+
error
70+
)
71+
end
72+
73+
"""
74+
IHSRecipe{
75+
Type<:Number,
76+
LR<:LoggerRecipe,
77+
CR<:CompressorRecipe,
78+
ER<:ErrorRecipe,
79+
M<:AbstractArray,
80+
MV<:SubArray,
81+
V<:AbstractVector
82+
} <: SolverRecipe
83+
84+
A mutable structure containing all information relevant to the Iterative Hessian Sketch
85+
solver. It is formed by calling the function `complete_solver` on a `IHS` solver, which
86+
includes all the user controlled parameters, the linear system `A`, and the constant
87+
vector `b`.
88+
89+
# Fields
90+
- `log::LoggerRecipe`, a technique for logging the progress of the solver.
91+
- `compressor::CompressorRecipe`, a technique for compressing the matrix ``A``.
92+
- `error::SolverErrorRecipe`, a technique for estimating the progress of the solver.
93+
- `alpha::Float64`, a step size parameter, by default is set to 1.
94+
- `compressed_mat::AbstractMatrix`, a buffer for storing the compressed matrix.
95+
- `mat_view::SubArray`, a container for storing a view of the compressed matrix buffer.
96+
- `residual_vec::AbstractVector`, a vector that contains the residual of the linear system
97+
``Ax-b``.
98+
- `gradient_vec::AbstractVector`, a vector that contains the gradient of the least squares
99+
problem, ``A^\\top(b-Ax)``.
100+
- `buffer_vec::AbstractVector`, a buffer vector for storing intermediate linear system solves.
101+
- `solution_vec::AbstractVector`, a vector storing the current IHS solution.
102+
- `R::UpperTriangular`, a container for storing the upper triangular portion of the R
103+
factor from a QR factorization of `mat_view`. This is used to solve the IHS sub-problem.
104+
"""
105+
mutable struct IHSRecipe{
106+
Type<:Number,
107+
LR<:LoggerRecipe,
108+
CR<:CompressorRecipe,
109+
ER<:SolverErrorRecipe,
110+
M<:AbstractArray,
111+
MV<:SubArray,
112+
V<:AbstractVector
113+
} <: SolverRecipe
114+
log::LR
115+
compressor::CR
116+
error::ER
117+
alpha::Float64
118+
compressed_mat::M
119+
mat_view::MV
120+
residual_vec::V
121+
gradient_vec::V
122+
buffer_vec::V
123+
solution_vec::V
124+
R::UpperTriangular{Type, M}
125+
end
126+
127+
function complete_solver(
128+
ingredients::IHS,
129+
x::AbstractVector,
130+
A::AbstractMatrix,
131+
b::AbstractVector
132+
)
133+
compressor = complete_compressor(ingredients.compressor, x, A, b)
134+
logger = complete_logger(ingredients.log)
135+
error = complete_error(ingredients.error, ingredients, A, b)
136+
sample_size::Int64 = compressor.n_rows
137+
rows_a, cols_a = size(A)
138+
# Check that required fields are in the types
139+
if !isdefined(error, :residual)
140+
throw(
141+
ArgumentError(
142+
"ErrorRecipe $(typeof(error)) does not contain the \
143+
field 'residual' and is not valid for an IHS solver."
144+
)
145+
)
146+
end
147+
148+
if !isdefined(logger, :converged)
149+
throw(
150+
ArgumentError(
151+
"LoggerRecipe $(typeof(logger)) does not contain \
152+
the field 'converged' and is not valid for an IHS solver."
153+
)
154+
)
155+
end
156+
157+
# Check that the sketch size is larger than the column dimension and return a warning
158+
# otherwise
159+
if cols_a > sample_size
160+
throw(
161+
ArgumentError(
162+
"Compression dimension not larger than column dimension this will lead to \
163+
singular QR decompositions, which cannot be inverted."
164+
)
165+
)
166+
end
167+
168+
if rows_a < sample_size
169+
throw(
170+
ArgumentError(
171+
"Compression dimension larger row dimension."
172+
)
173+
)
174+
end
175+
176+
if cols_a >= rows_a
177+
throw(
178+
ArgumentError(
179+
"Matrix must have more rows than columns."
180+
)
181+
)
182+
end
183+
compressed_mat = zeros(eltype(A), sample_size, cols_a)
184+
res = zeros(eltype(A), rows_a)
185+
grad = zeros(eltype(A), cols_a)
186+
buffer_vec = zeros(eltype(A), cols_a)
187+
solution_vec = x
188+
mat_view = view(compressed_mat, 1:sample_size, :)
189+
R = UpperTriangular(mat_view[1:cols_a, :])
190+
191+
return IHSRecipe{
192+
eltype(compressed_mat),
193+
typeof(logger),
194+
typeof(compressor),
195+
typeof(error),
196+
typeof(compressed_mat),
197+
typeof(mat_view),
198+
typeof(buffer_vec)
199+
}(
200+
logger,
201+
compressor,
202+
error,
203+
ingredients.alpha,
204+
compressed_mat,
205+
mat_view,
206+
res,
207+
grad,
208+
buffer_vec,
209+
solution_vec,
210+
R
211+
)
212+
end
213+
214+
function rsolve!(solver::IHSRecipe, x::AbstractVector, A::AbstractMatrix, b::AbstractVector)
215+
reset_logger!(solver.log)
216+
solver.solution_vec = x
217+
err = 0.0
218+
copyto!(solver.residual_vec, b)
219+
# compute the initial residual r = b - Ax
220+
mul!(solver.residual_vec, A, solver.solution_vec, -1.0, 1.0)
221+
for i in 1:solver.log.max_it
222+
# compute the gradient A'r
223+
mul!(solver.gradient_vec, A', solver.residual_vec)
224+
err = compute_error(solver.error, solver, A, b)
225+
update_logger!(solver.log, err, i)
226+
if solver.log.converged
227+
return nothing
228+
end
229+
230+
# generate a new compressor
231+
update_compressor!(solver.compressor, x, A, b)
232+
# Based on the size of the compressor update views of the matrix
233+
rows_s, cols_s = size(solver.compressor)
234+
solver.mat_view = view(solver.compressed_mat, 1:rows_s, :)
235+
# Compress the matrix
236+
mul!(solver.mat_view, solver.compressor, A)
237+
# Update the subsolver
238+
# This is the only piece of allocating code
239+
solver.R = UpperTriangular(qr!(solver.mat_view).R)
240+
# Compute first R' solver R'R x = g
241+
ldiv!(solver.buffer_vec, solver.R', solver.gradient_vec)
242+
# Compute second R Solve Rx = (R')^(-1)g will be stored in gradient_vec
243+
ldiv!(solver.gradient_vec, solver.R, solver.buffer_vec)
244+
# update the solution
245+
# solver.solution_vec = solver.solution_vec + alpha * solver.gradient_vec
246+
axpy!(solver.alpha, solver.gradient_vec, solver.solution_vec)
247+
# compute the fast update of r = r - A * gradient_vec
248+
# note: in this case gradient vec stores the update
249+
mul!(solver.residual_vec, A, solver.gradient_vec, -1.0, 1.0)
250+
end
251+
252+
return nothing
253+
254+
end

0 commit comments

Comments
 (0)