|
| 1 | +abstract type MatrixScalarizer end |
| 2 | + |
| 3 | +struct DeterminantScalarizer <: MatrixScalarizer end |
| 4 | +(::DeterminantScalarizer)(mat) = det(mat) |
| 5 | + |
| 6 | +struct TraceScalarizer <: MatrixScalarizer end |
| 7 | +(::TraceScalarizer)(mat) = tr(mat) |
| 8 | + |
| 9 | +abstract type StateCriteria end |
| 10 | +abstract type ItemCriteria end |
| 11 | + |
| 12 | +struct AbilityCovarianceStateCriteria{ |
| 13 | + DistEstT <: DistributionAbilityEstimator, |
| 14 | + IntegratorT <: AbilityIntegrator |
| 15 | +} <: StateCriteria |
| 16 | + dist_est::DistEstT |
| 17 | + integrator::IntegratorT |
| 18 | + skip_zero::Bool |
| 19 | +end |
| 20 | + |
| 21 | +function AbilityCovarianceStateCriteria(bits...) |
| 22 | + skip_zero = false |
| 23 | + @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) |
| 24 | + return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero) |
| 25 | +end |
| 26 | + |
| 27 | +# XXX: Should be at type level |
| 28 | +should_minimize(::AbilityCovarianceStateCriteria) = true |
| 29 | + |
| 30 | +function (criteria::AbilityCovarianceStateCriteria)( |
| 31 | + tracked_responses::TrackedResponses, |
| 32 | + denom = normdenom(criteria.integrator, |
| 33 | + criteria.dist_est, |
| 34 | + tracked_responses) |
| 35 | +) |
| 36 | + if denom == 0.0 && criteria.skip_zero |
| 37 | + return Inf |
| 38 | + end |
| 39 | + covariance_matrix( |
| 40 | + criteria.integrator, |
| 41 | + criteria.dist_est, |
| 42 | + tracked_responses, |
| 43 | + denom |
| 44 | + ) |
| 45 | +end |
| 46 | + |
| 47 | +struct ScalarizedStateCriteron{ |
| 48 | + StateCriteriaT <: StateCriteria, |
| 49 | + MatrixScalarizerT <: MatrixScalarizer |
| 50 | +} <: StateCriterion |
| 51 | + criteria::StateCriteriaT |
| 52 | + scalarizer::MatrixScalarizerT |
| 53 | +end |
| 54 | + |
| 55 | +function (ssc::ScalarizedStateCriteron)(tracked_responses) |
| 56 | + res = ssc.criteria(tracked_responses) |> ssc.scalarizer |
| 57 | + if !should_minimize(ssc.criteria) |
| 58 | + res = -res |
| 59 | + end |
| 60 | + res |
| 61 | +end |
| 62 | + |
| 63 | +struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria |
| 64 | + ability_estimator::AbilityEstimatorT |
| 65 | + expected_item_information::F |
| 66 | +end |
| 67 | + |
| 68 | +function InformationMatrixCriteria(ability_estimator) |
| 69 | + InformationMatrixCriteria(ability_estimator, expected_item_information) |
| 70 | +end |
| 71 | + |
| 72 | +function init_thread(item_criterion::InformationMatrixCriteria, |
| 73 | + responses::TrackedResponses) |
| 74 | + # TODO: No need to do this one per thread. It just need to be done once per |
| 75 | + # θ update. |
| 76 | + # TODO: Update this to use track!(...) mechanism |
| 77 | + ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator) |
| 78 | + responses_information(responses.item_bank, responses.responses, ability) |
| 79 | +end |
| 80 | + |
| 81 | +function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64}, |
| 82 | + tracked_responses::TrackedResponses, |
| 83 | + item_idx) |
| 84 | + # TODO: Add in information from the prior |
| 85 | + ability = maybe_tracked_ability_estimate( |
| 86 | + tracked_responses, item_criterion.ability_estimator) |
| 87 | + return acc_info .+ |
| 88 | + item_criterion.expected_item_information( |
| 89 | + ItemResponse(tracked_responses.item_bank, item_idx), ability) |
| 90 | +end |
| 91 | + |
| 92 | +should_minimize(::InformationMatrixCriteria) = false |
| 93 | + |
| 94 | +struct ScalarizedItemCriteron{ |
| 95 | + ItemCriteriaT <: ItemCriteria, |
| 96 | + MatrixScalarizerT <: MatrixScalarizer |
| 97 | +} <: ItemCriterion |
| 98 | + criteria::ItemCriteriaT |
| 99 | + scalarizer::MatrixScalarizerT |
| 100 | +end |
| 101 | + |
| 102 | +function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx) |
| 103 | + res = ssc.criteria( |
| 104 | + init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |> |
| 105 | + ssc.scalarizer |
| 106 | + if !should_minimize(ssc.criteria) |
| 107 | + res = -res |
| 108 | + end |
| 109 | + res |
| 110 | +end |
| 111 | + |
| 112 | +struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria |
| 113 | + weights::Vector{Float64} |
| 114 | + criteria::InnerT |
| 115 | +end |
| 116 | + |
| 117 | +function (wsc::WeightedStateCriteria)(tracked_responses, item_idx) |
| 118 | + wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights |
| 119 | +end |
| 120 | + |
| 121 | +struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria |
| 122 | + weights::Vector{Float64} |
| 123 | + criteria::InnerT |
| 124 | +end |
| 125 | + |
| 126 | +function (wsc::WeightedItemCriteria)(tracked_responses, item_idx) |
| 127 | + wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights |
| 128 | +end |
0 commit comments