Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions src/additional_functions/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@

(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)

(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)

Return a new model with swaped observed part.

# Arguments
- `model::AbstractSemSingle`: model to swap the observed part of.
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`

# For SemEnsemble models:
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
- `weights`: how to weight the different sub-models,
defaults to number of samples per group in the new data
- `kwargs`: has to be a dict with keys equal to the group names.
For `data` can also be a DataFrame with `column` containing the group information,
and for `specification` can also be an `EnsembleParameterTable`.

# Examples
See the online documentation on [Replace observed data](@ref).
"""
Expand Down Expand Up @@ -37,51 +47,28 @@ function update_observed end
replace_observed(model::AbstractSemSingle; kwargs...) =
replace_observed(model, typeof(observed(model)).name.wrapper; kwargs...)

# construct a new observed type
replace_observed(model::AbstractSemSingle, observed_type; kwargs...) =
replace_observed(model, observed_type(; kwargs...); kwargs...)

replace_observed(model::AbstractSemSingle, new_observed::SemObserved; kwargs...) =
replace_observed(
model,
observed(model),
implied(model),
loss(model),
new_observed;
kwargs...,
)

function replace_observed(
model::AbstractSemSingle,
old_observed,
implied,
loss,
new_observed::SemObserved;
kwargs...,
)
function replace_observed(model::AbstractSemSingle, observed_type; kwargs...)
new_observed = observed_type(;kwargs...)
kwargs = Dict{Symbol, Any}(kwargs...)

# get field types
kwargs[:observed_type] = typeof(new_observed)
kwargs[:old_observed_type] = typeof(old_observed)
kwargs[:implied_type] = typeof(implied)
kwargs[:loss_types] = [typeof(lossfun) for lossfun in loss.functions]
kwargs[:old_observed_type] = typeof(model.observed)
kwargs[:implied_type] = typeof(model.implied)
kwargs[:loss_types] = [typeof(lossfun) for lossfun in model.loss.functions]

# update implied
implied = update_observed(implied, new_observed; kwargs...)
kwargs[:implied] = implied
kwargs[:nparams] = nparams(implied)
new_implied = update_observed(model.implied, new_observed; kwargs...)
kwargs[:implied] = new_implied
kwargs[:nparams] = nparams(new_implied)

# update loss
loss = update_observed(loss, new_observed; kwargs...)
kwargs[:loss] = loss

#new_implied = update_observed(model.implied, new_observed; kwargs...)
new_loss = update_observed(model.loss, new_observed; kwargs...)

return Sem(
new_observed,
update_observed(model.implied, new_observed; kwargs...),
update_observed(model.loss, new_observed; kwargs...),
new_implied,
new_loss
)
end

Expand All @@ -92,6 +79,39 @@ function update_observed(loss::SemLoss, new_observed; kwargs...)
return SemLoss(new_functions, loss.weights)
end


function replace_observed(
emodel::SemEnsemble;
column = :group,
weights = nothing,
kwargs...,
)
kwargs = Dict{Symbol, Any}(kwargs...)
# allow for EnsembleParameterTable to be passed as specification
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
end
# allow for DataFrame with group variable "column" to be passed as new data
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
kwargs[:data] = Dict(
group => select(
filter(
r -> r[column] == group,
kwargs[:data]),
Not(column)) for group in emodel.groups)
end
# update each model for new data
models = emodel.sems
new_models = Tuple(
replace_observed(m; group_kwargs(g, kwargs)...) for (m, g) in zip(models, emodel.groups)
)
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
end

function group_kwargs(g, kwargs)
return Dict(k => kwargs[k][g] for k in keys(kwargs))
end

############################################################################################
# simulate data
############################################################################################
Expand Down
68 changes: 58 additions & 10 deletions src/frontend/fit/standard_errors/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,19 @@
se_bootstrap(sem_fit::SemFit; n_boot = 3000, data = nothing, kwargs...)

Return boorstrap standard errors.
Only works for single models.

# Arguments
- `n_boot`: number of boostrap samples
- `data`: data to sample from. Only needed if different than the data from `sem_fit`
- `kwargs...`: passed down to `replace_observed`
"""
function se_bootstrap(
semfit::SemFit;
semfit::SemFit{Mi, So, St, Mo, O};
n_boot = 3000,
data = nothing,
specification = nothing,
kwargs...,
)
if model(semfit) isa AbstractSemCollection
throw(
ArgumentError(
"bootstrap standard errors for ensemble models are not available yet",
),
)
end
) where {Mi, So, St, Mo <: AbstractSemSingle, O}

if isnothing(data)
data = samples(observed(model(semfit)))
Expand Down Expand Up @@ -69,6 +61,62 @@ function se_bootstrap(
return sd
end

function se_bootstrap(
semfit::SemFit{Mi, So, St, Mo, O};
n_boot = 3000,
data = nothing,
specification = nothing,
kwargs...,
) where {Mi, So, St, Mo <: SemEnsemble, O}

models = semfit.model.sems
groups = semfit.model.groups

if isnothing(data)
data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, models))
end

data = Dict(k => prepare_data_bootstrap(data[k]) for k in keys(data))

start = solution(semfit)

new_solution = zero(start)
sum = zero(start)
squared_sum = zero(start)

n_failed = 0.0

converged = true

for _ in 1:n_boot
sample_data = Dict(k => bootstrap_sample(data[k]) for k in keys(data))
new_model = replace_observed(
semfit.model;
data = sample_data,
specification = specification,
kwargs...,
)

new_solution .= 0.0

try
new_solution = solution(fit(new_model; start_val = start))
catch
n_failed += 1
end

@. sum += new_solution
@. squared_sum += new_solution^2

converged = true
end

n_conv = n_boot - n_failed
sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2)
print("Number of nonconverged models: ", n_failed, "\n")
return sd
end

function prepare_data_bootstrap(data)
return Matrix(data)
end
Expand Down
16 changes: 8 additions & 8 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end
# ensemble models
############################################################################################
"""
(1) SemEnsemble(models...; weights = nothing, kwargs...)
(1) SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...)

(2) SemEnsemble(;specification, data, groups, column = :group, kwargs...)

Expand All @@ -192,24 +192,24 @@ Returns a SemEnsemble with fields

For instructions on multigroup models, see the online documentation.
"""
struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I} <: AbstractSemCollection
struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I, G <: Vector{Symbol}} <: AbstractSemCollection
n::N
sems::T
weights::V
param_labels::I
groups::G
end

# constructor from multiple models
function SemEnsemble(models...; weights = nothing, kwargs...)
function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...)
n = length(models)

# default weights

if isnothing(weights)
nsamples_total = sum(nsamples, models)
weights = [nsamples(model) / nsamples_total for model in models]
end

# default group labels
groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups
# check parameters equality
param_labels = SEM.param_labels(models[1])
for model in models
Expand All @@ -220,7 +220,7 @@ function SemEnsemble(models...; weights = nothing, kwargs...)
end
end

return SemEnsemble(n, models, weights, param_labels)
return SemEnsemble(n, models, weights, param_labels, groups)
end

# constructor from EnsembleParameterTable and data set
Expand All @@ -238,7 +238,7 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...)
model = Sem(; specification = ram_matrices, data = data_group, kwargs...)
push!(models, model)
end
return SemEnsemble(models...; weights = nothing, kwargs...)
return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...)
end

param_labels(ensemble::SemEnsemble) = ensemble.param_labels
Expand Down
13 changes: 13 additions & 0 deletions test/examples/multigroup/build_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ model_ml_multigroup2 = SemEnsemble(
loss = SemML,
)

model_ml_multigroup3 = replace_observed(
model_ml_multigroup2,
column = :school,
specification = partable,
data = dat,
)

# gradients
@testset "ml_gradients_multigroup" begin
test_gradient(model_ml_multigroup, start_test; atol = 1e-9)
Expand All @@ -46,6 +53,12 @@ end
)
end

@testset "replace_observed_multigroup" begin
sem_fit_1 = fit(semoptimizer, model_ml_multigroup)
sem_fit_2 = fit(semoptimizer, model_ml_multigroup3)
@test sem_fit_1.solution ≈ sem_fit_2.solution
end

@testset "fitmeasures/se_ml" begin
solution_ml = fit(model_ml_multigroup)
test_fitmeasures(
Expand Down
Loading