Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 150 additions & 24 deletions src/singularity_removal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ end
level === nothing ? v : (v => level)
end

function structural_singularity_removal!(state::TransformationState;
variable_underconstrained! = force_var_to_zero!, kwargs...)
function structural_singularity_removal!(
state::TransformationState, ::Val{ReturnPivots} = Val{false}();
variable_underconstrained! = force_var_to_zero!, kwargs...
) where {ReturnPivots}
mm = linear_subsys_adjmat!(state; kwargs...)
if size(mm, 1) == 0
return mm # No linear subsystems
if ReturnPivots
return mm, PivotInfo(0, 0, Int[])
else
return mm # No linear subsystems
end
end

(; graph, var_to_diff, solvable_graph) = state.structure
mm = structural_singularity_removal!(state, mm; variable_underconstrained!)
mm, pivotinfo = structural_singularity_removal!(state, mm, Val{true}(); variable_underconstrained!)
s = state.structure
for (ei, e) in enumerate(mm.nzrows)
set_neighbors!(s.graph, e, mm.row_cols[ei])
Expand All @@ -34,7 +40,11 @@ function structural_singularity_removal!(state::TransformationState;
end
end

return mm
if ReturnPivots
return mm, pivotinfo
else
return mm
end
end

# For debug purposes
Expand All @@ -55,7 +65,7 @@ the `constraint`.
@inline function find_first_linear_variable(M::SparseMatrixCLIL,
range,
mask,
constraint)
constraint, ::Nothing = nothing)
eadj = M.row_cols
@inbounds for i in range
vertices = eadj[i]
Expand All @@ -70,10 +80,33 @@ the `constraint`.
return nothing
end

@inline function find_first_linear_variable(
M::SparseMatrixCLIL,
range,
mask,
constraint, var_priorities::AbstractVector{Int}
)
eadj = M.row_cols
@inbounds for i in range
vertices = eadj[i]
constraint(length(vertices)) || continue
candidate_v = 0
candidate_val = 0
for (j, v) in enumerate(vertices)
mask === nothing || mask[v] || continue
iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v] || continue
candidate_v = v
candidate_val = M.row_vals[i][j]
end
iszero(candidate_v) || return CartesianIndex(i, candidate_v), candidate_val
end
return nothing
end

@inline function find_first_linear_variable(M::AbstractMatrix,
range,
mask,
constraint)
constraint, ::Nothing = nothing)
@inbounds for i in range
row = @view M[i, :]
if constraint(count(!iszero, row))
Expand All @@ -87,12 +120,36 @@ end
return nothing
end

function find_masked_pivot(variables, M, k)
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1))
@inline function find_first_linear_variable(
M::AbstractMatrix,
range,
mask,
constraint, var_priorities::AbstractVector{Int}
)
@inbounds for i in range
row = @view M[i, :]
constraint(count(!iszero, row)) || continue
candidate_v = 0
candidate_val = 0
for (v, val) in enumerate(row)
mask === nothing || mask[v] || continue
if iszero(candidate_v) || var_priorities[v] < var_priorities[candidate_v]
candidate_v = v
candidate_val = val
end
end
iszero(candidate_v) && return nothing
return CartesianIndex(i, candidate_v), candidate_val
end
return nothing
end

function find_masked_pivot(variables, M, k, var_priorities)
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(1), var_priorities)
r !== nothing && return r
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2))
r = find_first_linear_variable(M, k:size(M, 1), variables, isequal(2), var_priorities)
r !== nothing && return r
r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true)
r = find_first_linear_variable(M, k:size(M, 1), variables, _ -> true, var_priorities)
return r
end

Expand Down Expand Up @@ -207,14 +264,15 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
end
end
solvable_variables = findall(is_linear_variables)
var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing

local bar
try
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)
catch e
e isa OverflowError || rethrow(e)
mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig)
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)
end

# This phrasing infers the return type as `Union{Tuple{...}}` instead of
Expand Down Expand Up @@ -243,6 +301,18 @@ end
(s::SyncedSwapRows{Nothing})(M, i::Int, j::Int) = Base.swaprows!(M, i, j)
(s::SyncedSwapRows)(M, i::Int, j::Int) = (Base.swaprows!(s.Mold, i, j); Base.swaprows!(M, i, j))

"""
$TYPEDEF

Lazy `&&` of two boolean masks. Only implements whatever is required for `find_masked_pivot`.
"""
struct LazyMaskAnd{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}}
mask1::V1
mask2::V2
end

Base.getindex(lma::LazyMaskAnd, i::Integer) = lma.mask1[i] && lma.mask2[i]

"""
$(TYPEDEF)

Expand All @@ -253,12 +323,21 @@ Mutable state threaded through the Bareiss factorization callbacks.
- `pivots`: accumulates the column index of every pivot chosen during elimination.
- `is_linear_variables`/`is_highest_diff`: masks used for the tiered pivot search.
"""
mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}}
mutable struct BareissContext{V1 <: AbstractVector{Bool}, V2 <: AbstractVector{Bool}, P <: Union{Nothing, AbstractVector{Int}}}
rank1::Union{Nothing, Int}
rank2::Union{Nothing, Int}
pivots::Vector{Int}
is_linear_variables::V1
is_highest_diff::V2
valid_pivot_mask::BitVector
var_priorities::P
end

function BareissContext(is_linear_variables, is_highest_diff, var_priorities = nothing)
return BareissContext(
nothing, nothing, Int[], is_linear_variables, is_highest_diff,
trues(length(is_linear_variables)), var_priorities
)
end

"""
Expand All @@ -273,15 +352,17 @@ The column index of every selected pivot is appended to `ctx.pivots`.
"""
function (ctx::BareissContext)(M, k::Int)
if ctx.rank1 === nothing
r = find_masked_pivot(ctx.is_linear_variables, M, k)
mask = LazyMaskAnd(ctx.is_linear_variables, ctx.valid_pivot_mask)
r = find_masked_pivot(ctx.is_linear_variables, M, k, ctx.var_priorities)
if r !== nothing
push!(ctx.pivots, r[1][2])
return r
end
ctx.rank1 = k - 1
end
if ctx.rank2 === nothing
r = find_masked_pivot(ctx.is_highest_diff, M, k)
mask = LazyMaskAnd(ctx.is_highest_diff, ctx.valid_pivot_mask)
r = find_masked_pivot(ctx.is_highest_diff, M, k, ctx.var_priorities)
if r !== nothing
push!(ctx.pivots, r[1][2])
return r
Expand All @@ -291,16 +372,28 @@ function (ctx::BareissContext)(M, k::Int)
# TODO: It would be better to sort the variables by
# derivative order here to enable more elimination
# opportunities.
r = find_masked_pivot(nothing, M, k)
r = find_masked_pivot(nothing, M, k, ctx.var_priorities)
r !== nothing && push!(ctx.pivots, r[1][2])
return r
end

struct BareissContextUpdate{C <: BareissContext, F}
context::C
inner_update::F
end

function (bcu::BareissContextUpdate)(zero!, M, k, swapto, pivot, last_pivot; kw...)
ctx = bcu.context
col = swapto[2]
ctx.valid_pivot_mask[col] = false
return bcu.inner_update(zero!, M, k, swapto, pivot, last_pivot; kw...)
end

function do_bareiss!(M, Mold, is_linear_variables::AbstractVector{Bool},
is_highest_diff::AbstractVector{Bool})
ctx = BareissContext(nothing, nothing, Int[], is_linear_variables, is_highest_diff)
bareiss_ops = (noop_colswap, SyncedSwapRows(Mold),
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
is_highest_diff::AbstractVector{Bool}, var_priorities = nothing)
ctx = BareissContext(is_linear_variables, is_highest_diff, var_priorities)
update! = BareissContextUpdate(ctx, bareiss_update_virtual_colswap_mtk!)
bareiss_ops = (noop_colswap, SyncedSwapRows(Mold), update!, bareiss_zero!)
rank3, = bareiss!(M, bareiss_ops; find_pivot = ctx)
rank2 = something(ctx.rank2, rank3)
rank1 = something(ctx.rank1, rank2)
Expand All @@ -321,8 +414,37 @@ function force_var_to_zero!(structure::SystemStructure, ils::SparseMatrixCLIL, v
return ils
end

function structural_singularity_removal!(state::TransformationState, ils::SparseMatrixCLIL;
variable_underconstrained! = force_var_to_zero!)
"""
$TYPEDSIGNATURES

Information about the pivots chosen by Bareiss during `structural_singularity_removal!`.
This can be returned from `structural_singularity_removal!` by passing `Val(true)` as the last
positional argument.

$TYPEDFIELDS
"""
struct PivotInfo
"""
The length of the prefix of `pivots` that is variables which _only_ occur in linear
equations of the sort considered by this pass. These variables must be solved for
using the integer coefficient equations considered by this pass.
"""
n_linear_vars::Int
"""
Number of elements in `pivots` after `n_linear_vars` corresponding to highest order
derivative variables.
"""
n_highest_diff_vars::Int
"""
The list of pivots chosen by the Bareiss algorithm.
"""
pivots::Vector{Int}
end

function structural_singularity_removal!(
state::TransformationState, ils::SparseMatrixCLIL, ::Val{ReturnPivots} = Val{false}();
variable_underconstrained! = force_var_to_zero!
) where {ReturnPivots}
(; structure) = state
(; graph, solvable_graph, var_to_diff, eq_to_diff) = state.structure
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
Expand All @@ -337,5 +459,9 @@ function structural_singularity_removal!(state::TransformationState, ils::Sparse
ils = variable_underconstrained!(structure, ils, v)
end

return ils
if ReturnPivots
return ils, PivotInfo(rank1, rank2, pivots)
else
return ils
end
end
Loading