diff --git a/src/additional_functions/simulation.jl b/src/additional_functions/simulation.jl index da3e6a90..a787516b 100644 --- a/src/additional_functions/simulation.jl +++ b/src/additional_functions/simulation.jl @@ -3,6 +3,8 @@ (2) replace_observed(model::AbstractSemSingle, observed; kwargs...) + (3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...) + Return a new model with swaped observed part. # Arguments @@ -10,6 +12,14 @@ Return a new model with swaped observed part. - `kwargs`: additional keyword arguments; typically includes `data` and `specification` - `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved` +# For SemEnsemble models: +- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group? +- `weights`: how to weight the different sub-models, + defaults to number of samples per group in the new data +- `kwargs`: has to be a dict with keys equal to the group names. + For `data` can also be a DataFrame with `column` containing the group information, + and for `specification` can also be an `EnsembleParameterTable`. + # Examples See the online documentation on [Replace observed data](@ref). """ @@ -37,51 +47,28 @@ function update_observed end replace_observed(model::AbstractSemSingle; kwargs...) = replace_observed(model, typeof(observed(model)).name.wrapper; kwargs...) -# construct a new observed type -replace_observed(model::AbstractSemSingle, observed_type; kwargs...) = - replace_observed(model, observed_type(; kwargs...); kwargs...) - -replace_observed(model::AbstractSemSingle, new_observed::SemObserved; kwargs...) = - replace_observed( - model, - observed(model), - implied(model), - loss(model), - new_observed; - kwargs..., - ) - -function replace_observed( - model::AbstractSemSingle, - old_observed, - implied, - loss, - new_observed::SemObserved; - kwargs..., -) +function replace_observed(model::AbstractSemSingle, observed_type; kwargs...) + new_observed = observed_type(;kwargs...) kwargs = Dict{Symbol, Any}(kwargs...) # get field types kwargs[:observed_type] = typeof(new_observed) - kwargs[:old_observed_type] = typeof(old_observed) - kwargs[:implied_type] = typeof(implied) - kwargs[:loss_types] = [typeof(lossfun) for lossfun in loss.functions] + kwargs[:old_observed_type] = typeof(model.observed) + kwargs[:implied_type] = typeof(model.implied) + kwargs[:loss_types] = [typeof(lossfun) for lossfun in model.loss.functions] # update implied - implied = update_observed(implied, new_observed; kwargs...) - kwargs[:implied] = implied - kwargs[:nparams] = nparams(implied) + new_implied = update_observed(model.implied, new_observed; kwargs...) + kwargs[:implied] = new_implied + kwargs[:nparams] = nparams(new_implied) # update loss - loss = update_observed(loss, new_observed; kwargs...) - kwargs[:loss] = loss - - #new_implied = update_observed(model.implied, new_observed; kwargs...) + new_loss = update_observed(model.loss, new_observed; kwargs...) return Sem( new_observed, - update_observed(model.implied, new_observed; kwargs...), - update_observed(model.loss, new_observed; kwargs...), + new_implied, + new_loss ) end @@ -92,6 +79,39 @@ function update_observed(loss::SemLoss, new_observed; kwargs...) return SemLoss(new_functions, loss.weights) end + +function replace_observed( + emodel::SemEnsemble; + column = :group, + weights = nothing, + kwargs..., +) + kwargs = Dict{Symbol, Any}(kwargs...) + # allow for EnsembleParameterTable to be passed as specification + if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable) + kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification]) + end + # allow for DataFrame with group variable "column" to be passed as new data + if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame) + kwargs[:data] = Dict( + group => select( + filter( + r -> r[column] == group, + kwargs[:data]), + Not(column)) for group in emodel.groups) + end + # update each model for new data + models = emodel.sems + new_models = Tuple( + replace_observed(m; group_kwargs(g, kwargs)...) for (m, g) in zip(models, emodel.groups) + ) + return SemEnsemble(new_models...; weights = weights, groups = emodel.groups) +end + +function group_kwargs(g, kwargs) + return Dict(k => kwargs[k][g] for k in keys(kwargs)) +end + ############################################################################################ # simulate data ############################################################################################ diff --git a/src/frontend/fit/standard_errors/bootstrap.jl b/src/frontend/fit/standard_errors/bootstrap.jl index 4589dc02..4b3e302b 100644 --- a/src/frontend/fit/standard_errors/bootstrap.jl +++ b/src/frontend/fit/standard_errors/bootstrap.jl @@ -2,7 +2,6 @@ se_bootstrap(sem_fit::SemFit; n_boot = 3000, data = nothing, kwargs...) Return boorstrap standard errors. -Only works for single models. # Arguments - `n_boot`: number of boostrap samples @@ -10,19 +9,12 @@ Only works for single models. - `kwargs...`: passed down to `replace_observed` """ function se_bootstrap( - semfit::SemFit; + semfit::SemFit{Mi, So, St, Mo, O}; n_boot = 3000, data = nothing, specification = nothing, kwargs..., -) - if model(semfit) isa AbstractSemCollection - throw( - ArgumentError( - "bootstrap standard errors for ensemble models are not available yet", - ), - ) - end + ) where {Mi, So, St, Mo <: AbstractSemSingle, O} if isnothing(data) data = samples(observed(model(semfit))) @@ -69,6 +61,62 @@ function se_bootstrap( return sd end +function se_bootstrap( + semfit::SemFit{Mi, So, St, Mo, O}; + n_boot = 3000, + data = nothing, + specification = nothing, + kwargs..., + ) where {Mi, So, St, Mo <: SemEnsemble, O} + + models = semfit.model.sems + groups = semfit.model.groups + + if isnothing(data) + data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, models)) + end + + data = Dict(k => prepare_data_bootstrap(data[k]) for k in keys(data)) + + start = solution(semfit) + + new_solution = zero(start) + sum = zero(start) + squared_sum = zero(start) + + n_failed = 0.0 + + converged = true + + for _ in 1:n_boot + sample_data = Dict(k => bootstrap_sample(data[k]) for k in keys(data)) + new_model = replace_observed( + semfit.model; + data = sample_data, + specification = specification, + kwargs..., + ) + + new_solution .= 0.0 + + try + new_solution = solution(fit(new_model; start_val = start)) + catch + n_failed += 1 + end + + @. sum += new_solution + @. squared_sum += new_solution^2 + + converged = true + end + + n_conv = n_boot - n_failed + sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2) + print("Number of nonconverged models: ", n_failed, "\n") + return sd +end + function prepare_data_bootstrap(data) return Matrix(data) end diff --git a/src/types.jl b/src/types.jl index 44d472eb..660c1c43 100644 --- a/src/types.jl +++ b/src/types.jl @@ -168,7 +168,7 @@ end # ensemble models ############################################################################################ """ - (1) SemEnsemble(models...; weights = nothing, kwargs...) + (1) SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) (2) SemEnsemble(;specification, data, groups, column = :group, kwargs...) @@ -192,24 +192,24 @@ Returns a SemEnsemble with fields For instructions on multigroup models, see the online documentation. """ -struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I} <: AbstractSemCollection +struct SemEnsemble{N, T <: Tuple, V <: AbstractVector, I, G <: Vector{Symbol}} <: AbstractSemCollection n::N sems::T weights::V param_labels::I + groups::G end # constructor from multiple models -function SemEnsemble(models...; weights = nothing, kwargs...) +function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) n = length(models) - # default weights - if isnothing(weights) nsamples_total = sum(nsamples, models) weights = [nsamples(model) / nsamples_total for model in models] end - + # default group labels + groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups # check parameters equality param_labels = SEM.param_labels(models[1]) for model in models @@ -220,7 +220,7 @@ function SemEnsemble(models...; weights = nothing, kwargs...) end end - return SemEnsemble(n, models, weights, param_labels) + return SemEnsemble(n, models, weights, param_labels, groups) end # constructor from EnsembleParameterTable and data set @@ -238,7 +238,7 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...) model = Sem(; specification = ram_matrices, data = data_group, kwargs...) push!(models, model) end - return SemEnsemble(models...; weights = nothing, kwargs...) + return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...) end param_labels(ensemble::SemEnsemble) = ensemble.param_labels diff --git a/test/examples/multigroup/build_models.jl b/test/examples/multigroup/build_models.jl index f6a7a230..f5ea0b5d 100644 --- a/test/examples/multigroup/build_models.jl +++ b/test/examples/multigroup/build_models.jl @@ -20,6 +20,13 @@ model_ml_multigroup2 = SemEnsemble( loss = SemML, ) +model_ml_multigroup3 = replace_observed( + model_ml_multigroup2, + column = :school, + specification = partable, + data = dat, +) + # gradients @testset "ml_gradients_multigroup" begin test_gradient(model_ml_multigroup, start_test; atol = 1e-9) @@ -46,6 +53,12 @@ end ) end +@testset "replace_observed_multigroup" begin + sem_fit_1 = fit(semoptimizer, model_ml_multigroup) + sem_fit_2 = fit(semoptimizer, model_ml_multigroup3) + @test sem_fit_1.solution ≈ sem_fit_2.solution +end + @testset "fitmeasures/se_ml" begin solution_ml = fit(model_ml_multigroup) test_fitmeasures(