|
| 1 | +""" |
| 2 | + IHS <: Solver |
| 3 | +
|
| 4 | +An implementation of the Iterative Hessian Sketch solver for solving over determined |
| 5 | +least squares problems [pilanci2014iterative](@cite). |
| 6 | + |
| 7 | +# Mathematical Description |
| 8 | +Let ``A \\in \\mathbb{R}^{m \\times n}, m \\gg n,`` and consider the least square problem |
| 9 | +``\\min_x \\|Ax - b \\|_2^2``. If we let ``S \\in \\mathbb{R}^{s \\times m}`` be a |
| 10 | +compression matrix, then Iterative Hessian Sketch iteratively finds a solution to this |
| 11 | +problem by repeatedly updating ``x_{k+1} = x_k + \\alpha u_k``where ``u_k`` is the solution |
| 12 | +to the convex optimization problem, |
| 13 | +``u_k \\in \\argmin_u \\{\\|S_k Au\\|_2^2 - \\langle A, b - Ax_k \\rangle \\}.`` This method |
| 14 | +has been to shown to converge geometrically at a rate ``\\rho \\in (0, 1/2]``. Typically the |
| 15 | +required compression dimension needs to be 4-8 times the size of n for the algorithm to |
| 16 | +perform successfully. |
| 17 | +
|
| 18 | +# Fields |
| 19 | +- `alpha::Float64`, a step size parameter. |
| 20 | +- `compressor::Compressor`, a technique for forming the compressed linear system. |
| 21 | +- `log::Logger`, a technique for logging the progress of the solver. |
| 22 | +- `error::SolverError`, a method for estimating the progress of the solver. |
| 23 | +
|
| 24 | +# Constructor |
| 25 | + function IHS(; |
| 26 | + compressor::Compressor = SparseSign(cardinality = Left()), |
| 27 | + log::Logger = BasicLogger(), |
| 28 | + error::SolverError = FullResidual(), |
| 29 | + alpha::Float64 = 1.0 |
| 30 | + ) |
| 31 | +## Keywords |
| 32 | +- `compressor::Compressor`, a technique for forming the compressed linear system. |
| 33 | +- `log::Logger`, a technique for logging the progress of the solver. |
| 34 | +- `error::SolverError`, a method for estimating the progress of the solver. |
| 35 | +- `alpha::Float64`, a step size parameter. |
| 36 | +
|
| 37 | +# Returns |
| 38 | +- A `IHS` object. |
| 39 | +""" |
| 40 | +mutable struct IHS <: Solver |
| 41 | + alpha::Float64 |
| 42 | + log::Logger |
| 43 | + compressor::Compressor |
| 44 | + error::SolverError |
| 45 | + function IHS(alpha, log, compressor, error) |
| 46 | + if typeof(compressor.cardinality) != Left |
| 47 | + @warn "Compressor has cardinality `Right` but IHS compresses from the `Left`." |
| 48 | + end |
| 49 | + |
| 50 | + if alpha < 0 |
| 51 | + @warn "Negative step size could lead to divergent iterates." |
| 52 | + end |
| 53 | + |
| 54 | + new(alpha, log, compressor, error) |
| 55 | + end |
| 56 | + |
| 57 | +end |
| 58 | + |
| 59 | +function IHS(; |
| 60 | + compressor::Compressor = SparseSign(cardinality = Left()), |
| 61 | + log::Logger = BasicLogger(), |
| 62 | + error::SolverError = FullResidual(), |
| 63 | + alpha::Float64 = 1.0 |
| 64 | +) |
| 65 | + return IHS( |
| 66 | + alpha, |
| 67 | + log, |
| 68 | + compressor, |
| 69 | + error |
| 70 | + ) |
| 71 | +end |
| 72 | + |
| 73 | +""" |
| 74 | + IHSRecipe{ |
| 75 | + Type<:Number, |
| 76 | + LR<:LoggerRecipe, |
| 77 | + CR<:CompressorRecipe, |
| 78 | + ER<:ErrorRecipe, |
| 79 | + M<:AbstractArray, |
| 80 | + MV<:SubArray, |
| 81 | + V<:AbstractVector |
| 82 | + } <: SolverRecipe |
| 83 | +
|
| 84 | +A mutable structure containing all information relevant to the Iterative Hessian Sketch |
| 85 | +solver. It is formed by calling the function `complete_solver` on a `IHS` solver, which |
| 86 | +includes all the user controlled parameters, the linear system `A`, and the constant |
| 87 | +vector `b`. |
| 88 | +
|
| 89 | +# Fields |
| 90 | +- `log::LoggerRecipe`, a technique for logging the progress of the solver. |
| 91 | +- `compressor::CompressorRecipe`, a technique for compressing the matrix ``A``. |
| 92 | +- `error::SolverErrorRecipe`, a technique for estimating the progress of the solver. |
| 93 | +- `alpha::Float64`, a step size parameter, by default is set to 1. |
| 94 | +- `compressed_mat::AbstractMatrix`, a buffer for storing the compressed matrix. |
| 95 | +- `mat_view::SubArray`, a container for storing a view of the compressed matrix buffer. |
| 96 | +- `residual_vec::AbstractVector`, a vector that contains the residual of the linear system |
| 97 | + ``Ax-b``. |
| 98 | +- `gradient_vec::AbstractVector`, a vector that contains the gradient of the least squares |
| 99 | + problem, ``A^\\top(b-Ax)``. |
| 100 | +- `buffer_vec::AbstractVector`, a buffer vector for storing intermediate linear system solves. |
| 101 | +- `solution_vec::AbstractVector`, a vector storing the current IHS solution. |
| 102 | +- `R::UpperTriangular`, a container for storing the upper triangular portion of the R |
| 103 | + factor from a QR factorization of `mat_view`. This is used to solve the IHS sub-problem. |
| 104 | +""" |
| 105 | +mutable struct IHSRecipe{ |
| 106 | + Type<:Number, |
| 107 | + LR<:LoggerRecipe, |
| 108 | + CR<:CompressorRecipe, |
| 109 | + ER<:SolverErrorRecipe, |
| 110 | + M<:AbstractArray, |
| 111 | + MV<:SubArray, |
| 112 | + V<:AbstractVector |
| 113 | +} <: SolverRecipe |
| 114 | + log::LR |
| 115 | + compressor::CR |
| 116 | + error::ER |
| 117 | + alpha::Float64 |
| 118 | + compressed_mat::M |
| 119 | + mat_view::MV |
| 120 | + residual_vec::V |
| 121 | + gradient_vec::V |
| 122 | + buffer_vec::V |
| 123 | + solution_vec::V |
| 124 | + R::UpperTriangular{Type, M} |
| 125 | +end |
| 126 | + |
| 127 | +function complete_solver( |
| 128 | + ingredients::IHS, |
| 129 | + x::AbstractVector, |
| 130 | + A::AbstractMatrix, |
| 131 | + b::AbstractVector |
| 132 | +) |
| 133 | + compressor = complete_compressor(ingredients.compressor, x, A, b) |
| 134 | + logger = complete_logger(ingredients.log) |
| 135 | + error = complete_error(ingredients.error, ingredients, A, b) |
| 136 | + sample_size::Int64 = compressor.n_rows |
| 137 | + rows_a, cols_a = size(A) |
| 138 | + # Check that required fields are in the types |
| 139 | + if !isdefined(error, :residual) |
| 140 | + throw( |
| 141 | + ArgumentError( |
| 142 | + "ErrorRecipe $(typeof(error)) does not contain the \ |
| 143 | + field 'residual' and is not valid for an IHS solver." |
| 144 | + ) |
| 145 | + ) |
| 146 | + end |
| 147 | + |
| 148 | + if !isdefined(logger, :converged) |
| 149 | + throw( |
| 150 | + ArgumentError( |
| 151 | + "LoggerRecipe $(typeof(logger)) does not contain \ |
| 152 | + the field 'converged' and is not valid for an IHS solver." |
| 153 | + ) |
| 154 | + ) |
| 155 | + end |
| 156 | + |
| 157 | + # Check that the sketch size is larger than the column dimension and return a warning |
| 158 | + # otherwise |
| 159 | + if cols_a > sample_size |
| 160 | + throw( |
| 161 | + ArgumentError( |
| 162 | + "Compression dimension not larger than column dimension this will lead to \ |
| 163 | + singular QR decompositions, which cannot be inverted." |
| 164 | + ) |
| 165 | + ) |
| 166 | + end |
| 167 | + |
| 168 | + if rows_a < sample_size |
| 169 | + throw( |
| 170 | + ArgumentError( |
| 171 | + "Compression dimension larger row dimension." |
| 172 | + ) |
| 173 | + ) |
| 174 | + end |
| 175 | + |
| 176 | + if cols_a >= rows_a |
| 177 | + throw( |
| 178 | + ArgumentError( |
| 179 | + "Matrix must have more rows than columns." |
| 180 | + ) |
| 181 | + ) |
| 182 | + end |
| 183 | + compressed_mat = zeros(eltype(A), sample_size, cols_a) |
| 184 | + res = zeros(eltype(A), rows_a) |
| 185 | + grad = zeros(eltype(A), cols_a) |
| 186 | + buffer_vec = zeros(eltype(A), cols_a) |
| 187 | + solution_vec = x |
| 188 | + mat_view = view(compressed_mat, 1:sample_size, :) |
| 189 | + R = UpperTriangular(mat_view[1:cols_a, :]) |
| 190 | + |
| 191 | + return IHSRecipe{ |
| 192 | + eltype(compressed_mat), |
| 193 | + typeof(logger), |
| 194 | + typeof(compressor), |
| 195 | + typeof(error), |
| 196 | + typeof(compressed_mat), |
| 197 | + typeof(mat_view), |
| 198 | + typeof(buffer_vec) |
| 199 | + }( |
| 200 | + logger, |
| 201 | + compressor, |
| 202 | + error, |
| 203 | + ingredients.alpha, |
| 204 | + compressed_mat, |
| 205 | + mat_view, |
| 206 | + res, |
| 207 | + grad, |
| 208 | + buffer_vec, |
| 209 | + solution_vec, |
| 210 | + R |
| 211 | + ) |
| 212 | +end |
| 213 | + |
| 214 | +function rsolve!(solver::IHSRecipe, x::AbstractVector, A::AbstractMatrix, b::AbstractVector) |
| 215 | + reset_logger!(solver.log) |
| 216 | + solver.solution_vec = x |
| 217 | + err = 0.0 |
| 218 | + copyto!(solver.residual_vec, b) |
| 219 | + # compute the initial residual r = b - Ax |
| 220 | + mul!(solver.residual_vec, A, solver.solution_vec, -1.0, 1.0) |
| 221 | + for i in 1:solver.log.max_it |
| 222 | + # compute the gradient A'r |
| 223 | + mul!(solver.gradient_vec, A', solver.residual_vec) |
| 224 | + err = compute_error(solver.error, solver, A, b) |
| 225 | + update_logger!(solver.log, err, i) |
| 226 | + if solver.log.converged |
| 227 | + return nothing |
| 228 | + end |
| 229 | + |
| 230 | + # generate a new compressor |
| 231 | + update_compressor!(solver.compressor, x, A, b) |
| 232 | + # Based on the size of the compressor update views of the matrix |
| 233 | + rows_s, cols_s = size(solver.compressor) |
| 234 | + solver.mat_view = view(solver.compressed_mat, 1:rows_s, :) |
| 235 | + # Compress the matrix |
| 236 | + mul!(solver.mat_view, solver.compressor, A) |
| 237 | + # Update the subsolver |
| 238 | + # This is the only piece of allocating code |
| 239 | + solver.R = UpperTriangular(qr!(solver.mat_view).R) |
| 240 | + # Compute first R' solver R'R x = g |
| 241 | + ldiv!(solver.buffer_vec, solver.R', solver.gradient_vec) |
| 242 | + # Compute second R Solve Rx = (R')^(-1)g will be stored in gradient_vec |
| 243 | + ldiv!(solver.gradient_vec, solver.R, solver.buffer_vec) |
| 244 | + # update the solution |
| 245 | + # solver.solution_vec = solver.solution_vec + alpha * solver.gradient_vec |
| 246 | + axpy!(solver.alpha, solver.gradient_vec, solver.solution_vec) |
| 247 | + # compute the fast update of r = r - A * gradient_vec |
| 248 | + # note: in this case gradient vec stores the update |
| 249 | + mul!(solver.residual_vec, A, solver.gradient_vec, -1.0, 1.0) |
| 250 | + end |
| 251 | + |
| 252 | + return nothing |
| 253 | + |
| 254 | +end |
0 commit comments