Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/StructuralEquationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ const SEM = StructuralEquationModels
# type hierarchy
include("types.jl")
include("objective_gradient_hessian.jl")

# helper objects and functions
include("additional_functions/commutation_matrix.jl")

# fitted objects
include("frontend/fit/SemFit.jl")
# specification of models
Expand Down
75 changes: 75 additions & 0 deletions src/additional_functions/commutation_matrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""

transpose_linear_indices(n, [m])

Put each linear index of the *n×m* matrix to the position of the
corresponding element in the transposed matrix.

## Example
`
1 4
2 5 => 1 2 3
3 6 4 5 6
`
"""
transpose_linear_indices(n::Integer, m::Integer = n) =
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)

"""
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}

A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
If *vec(A)* is a vectorized form of a n×n matrix *A*,
then ``C * vec(A) = vec(Aᵀ)``.
"""
struct CommutationMatrix <: AbstractMatrix{Int}
n::Int
n²::Int
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*

CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
end

Base.size(A::CommutationMatrix) = (A.n², A.n²)
Base.size(A::CommutationMatrix, dim::Integer) =
1 <= dim <= 2 ? A.n² : throw(ArgumentError("invalid matrix dimension $dim"))
Base.length(A::CommutationMatrix) = A.n²^2
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0

function Base.:(*)(A::CommutationMatrix, B::AbstractVector)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) elements"),
)
return B[A.transpose_inds]
end

function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)
return B[A.transpose_inds, :]
end

function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)
return SparseMatrixCSC(
size(B, 1),
size(B, 2),
copy(B.colptr),
A.transpose_inds[B.rowval],
copy(B.nzval),
)
end

function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)

@inbounds for (i, rowind) in enumerate(B.rowval)
B.rowval[i] = A.transpose_inds[rowind]
end
return B
end
155 changes: 23 additions & 132 deletions src/additional_functions/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function get_observed(rowind, data, semobserved; args = (), kwargs = NamedTuple(
return observed_vec
end

skipmissing_mean(mat::AbstractMatrix) =
skipmissing_mean(mat::AbstractMatrix) =
[mean(skipmissing(coldata)) for coldata in eachcol(mat)]

function F_one_person(imp_mean, meandiff, inverse, data, logdet)
Expand Down Expand Up @@ -111,143 +111,34 @@ function cov_and_mean(rows; corrected = false)
return obs_cov, vec(obs_mean)
end

function duplication_matrix(nobs)
nobs = Int(nobs)
n1 = Int(nobs * (nobs + 1) * 0.5)
n2 = Int(nobs^2)
Dt = zeros(n1, n2)

for j in 1:nobs
for i in j:nobs
u = zeros(n1)
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
T = zeros(nobs, nobs)
T[j, i] = 1
T[i, j] = 1
Dt += u * transpose(vec(T))
# n²×(n(n+1)/2) matrix to transform a vector of lower
# triangular entries into a vectorized form of a n×n symmetric matrix,
# opposite of elimination_matrix()
function duplication_matrix(n::Integer)
ntri = div(n * (n + 1), 2)
D = zeros(n^2, ntri)
for j in 1:n
for i in j:n
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
D[j+n*(i-1), tri_ix] = 1
D[i+n*(j-1), tri_ix] = 1
end
end
D = transpose(Dt)
return D
end

function elimination_matrix(nobs)
nobs = Int(nobs)
n1 = Int(nobs * (nobs + 1) * 0.5)
n2 = Int(nobs^2)
L = zeros(n1, n2)

for j in 1:nobs
for i in j:nobs
u = zeros(n1)
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
T = zeros(nobs, nobs)
T[i, j] = 1
L += u * transpose(vec(T))
# (n(n+1)/2)×n² matrix to transform a
# vectorized form of a n×n symmetric matrix
# into vector of its lower triangular entries,
# opposite of duplication_matrix()
function elimination_matrix(n::Integer)
ntri = div(n * (n + 1), 2)
L = zeros(ntri, n^2)
for j in 1:n
for i in j:n
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
L[tri_ix, i+n*(j-1)] = 1
end
end
return L
end

function commutation_matrix(n; tosparse = false)
M = zeros(n^2, n^2)

for i in 1:n
for j in 1:n
M[i+n*(j-1), j+n*(i-1)] = 1.0
end
end

if tosparse
M = sparse(M)
end

return M
end

function commutation_matrix_pre_square(A)
n2 = size(A, 1)
n = Int(sqrt(n2))

ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

A_post = A[ind, :]

return A_post
end

function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
n2 = size(A, 1)
n = Int(sqrt(n2))

ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

@views @inbounds B .+= A[ind, :]

return B
end

function get_commutation_lookup(n2::Int64)
n = Int(sqrt(n2))
ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

lookup = Dict{Int64, Int64}()

for i in 1:n2
j = findall(x -> (x == i), ind)[1]
push!(lookup, i => j)
end

return lookup
end

function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
for (i, rowind) in enumerate(A.rowval)
A.rowval[i] = lookup[rowind]
end
end

function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
lookup = get_commutation_lookup(size(A, 2))
commutation_matrix_pre_square!(A, lookup)
end

function commutation_matrix_pre_square(A::SparseMatrixCSC)
B = copy(A)
commutation_matrix_pre_square!(B)
return B
end

function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
B = copy(A)
commutation_matrix_pre_square!(B, lookup)
return B
end

function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
n2 = size(A, 1)
n = Int(sqrt(n2))

indadd = (0:(n-1)) * n

Threads.@threads for i in 1:n
for j in 1:n
row = i + indadd[j]
@views @inbounds B[row, :] .+= A[row, :]
end
end

return B
end
13 changes: 5 additions & 8 deletions src/loss/ML/FIML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Analytic gradients are available.
## Implementation
Subtype of `SemLossFunction`.
"""
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, W} <: SemLossFunction
inverses::INV #preallocated inverses of imp_cov
choleskys::C #preallocated choleskys
logdets::L #logdets of implied covmats
Expand All @@ -37,7 +37,7 @@ mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction

mult::T

commutation_indices::U
commutator::CommutationMatrix

interaction::W
end
Expand All @@ -64,8 +64,6 @@ function SemFIML(; observed, specification, kwargs...)
∇ind =
[findall(x -> !(x[1] ∈ ind || x[2] ∈ ind), ∇ind) for ind in patterns_not(observed)]

commutation_indices = get_commutation_lookup(get_n_nodes(specification)^2)

return SemFIML(
inverses,
choleskys,
Expand All @@ -75,7 +73,7 @@ function SemFIML(; observed, specification, kwargs...)
meandiff,
imp_inv,
mult,
commutation_indices,
CommutationMatrix(get_n_nodes(specification)),
nothing,
)
end
Expand Down Expand Up @@ -163,10 +161,9 @@ function ∇F_fiml_outer(JΣ, Jμ, imply, model, semfiml)
Iₙ = sparse(1.0I, size(A(imply))...)
P = kron(F⨉I_A⁻¹(imply), F⨉I_A⁻¹(imply))
Q = kron(S(imply) * I_A⁻¹(imply)', Iₙ)
#commutation_matrix_pre_square_add!(Q, Q)
Q2 = commutation_matrix_pre_square(Q, semfiml.commutation_indices)
Q .+= semfiml.commutator * Q

∇Σ = P * (∇S(imply) + (Q + Q2) * ∇A(imply))
∇Σ = P * (∇S(imply) + Q * ∇A(imply))

∇μ =
F⨉I_A⁻¹(imply) * ∇M(imply) +
Expand Down
49 changes: 49 additions & 0 deletions test/unit_tests/matrix_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using StructuralEquationModels, Test, Random, SparseArrays, LinearAlgebra
using StructuralEquationModels:
CommutationMatrix, transpose_linear_indices, duplication_matrix, elimination_matrix

Random.seed!(73721)

n = 4
m = 5

@testset "Commutation matrix" begin
# transpose linear indices
A = rand(n, m)
@test reshape(A[transpose_linear_indices(n, m)], m, n) == A'
# commutation matrix multiplication
K = CommutationMatrix(n)
# test K array interface methods
@test size(K) == (n^2, n^2)
@test size(K, 1) == n^2
@test length(K) == n^4
nn_linind = LinearIndices((n, n))
@test K[nn_linind[3, 2], nn_linind[2, 3]] == 1
@test K[nn_linind[3, 2], nn_linind[3, 2]] == 0

B = rand(n, n)
@test_throws DimensionMismatch K * rand(n, m)
@test K * vec(B) == vec(B')
C = sprand(n, n, 0.5)
@test K * vec(C) == vec(C')
# lmul!
D = sprand(n^2, n^2, 0.1)
E = copy(D)
F = Matrix(E)
lmul!(K, D)
@test D == K * E
@test Matrix(D) == K * F
end

@testset "Duplication / elimination matrix" begin
A = rand(m, m)
A = A * A'

# dupication
D = duplication_matrix(m)
@test D * A[tril(trues(size(A)))] == vec(A)

# elimination
E = elimination_matrix(m)
@test E * vec(A) == A[tril(trues(size(A)))]
end
4 changes: 4 additions & 0 deletions test/unit_tests/unit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ end
@safetestset "SemObs" begin
include("data_input_formats.jl")
end

@safetestset "Matrix algebra helper functions" begin
include("matrix_helpers.jl")
end