Skip to content

Commit 1d4b1f6

Browse files
author
Frankie Robertson
committed
Add TrackedLikelihoodIntegrator which uses GriddedAbilityTracker
1 parent 2e28f93 commit 1d4b1f6

File tree

7 files changed

+113
-24
lines changed

7 files changed

+113
-24
lines changed

src/CatConfig.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,34 @@ function _find_ability_estimator_and_tracker(bits...)
6565
(ability_estimator, ability_tracker)
6666
end
6767

68+
function collect_trackers(_)
69+
return NullAbilityTracker()
70+
end
71+
72+
function collect_trackers(tracker::AbilityTracker)
73+
return tracker
74+
end
75+
76+
function collect_trackers(config::CatConfigBase)
77+
acc = NullAbilityTracker()
78+
for fieldname in fieldnames(typeof(config))
79+
tracker = collect_trackers(getfield(config, fieldname))
80+
if !(tracker isa NullAbilityTracker)
81+
acc = ConsAbilityTracker(tracker, acc)
82+
end
83+
end
84+
return acc
85+
end
86+
87+
function collect_trackers(next_item_rule::NextItemRule, ability_tracker::AbilityTracker)
88+
rest = collect_trackers(next_item_rule)
89+
if !(ability_tracker isa NullAbilityTracker)
90+
ConsAbilityTracker(ability_tracker, rest)
91+
else
92+
rest
93+
end
94+
end
95+
6896
function CatRules(bits...)
6997
ability_estimator, ability_tracker = _find_ability_estimator_and_tracker(bits...)
7098
if ability_estimator === nothing
@@ -73,7 +101,7 @@ function CatRules(bits...)
73101
if ability_tracker === nothing
74102
error("Could not find an ability tracker in $(bits)")
75103
end
76-
next_item = NextItemRule(bits..., ability_estimator=ability_estimator)
104+
next_item = NextItemRule(bits..., ability_estimator=ability_estimator, ability_tracker=ability_tracker)
77105
if next_item === nothing
78106
error("Could not find a next item rule in $(bits)")
79107
end
@@ -85,7 +113,7 @@ function CatRules(bits...)
85113
next_item=next_item,
86114
termination_condition=termination_condition,
87115
ability_estimator=ability_estimator,
88-
ability_tracker=ability_tracker
116+
ability_tracker=collect_trackers(next_item, ability_tracker)
89117
)
90118
end
91119

src/aggregators/Aggregators.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ export variance, variance_given_mean, mean_1d
4040
# XXX: Does having a common supertype of DistributionAbilityEstimator and PointAbilityEstimator make sense?
4141
abstract type AbilityEstimator <: CatConfigBase end
4242

43-
function AbilityEstimator(bits...; ability_estimator=nothing)
43+
function AbilityEstimator(bits...; ability_estimator=nothing, ability_tracker=nothing)
4444
@returnsome ability_estimator
4545
@returnsome find1_instance(AbilityEstimator, bits)
4646
item_bank = find1_type_sloppy(AbstractItemBank, bits)
@@ -73,24 +73,51 @@ end
7373

7474
abstract type AbilityTracker <: CatConfigBase end
7575

76-
function AbilityTracker(bits...; ability_estimator=nothing)
76+
function AbilityTracker(bits...; integrator=nothing, ability_estimator=nothing)
7777
@returnsome find1_instance(AbilityTracker, bits)
7878
ability_tracker = find1_type(AbilityTracker, bits)
7979
if (ability_tracker !== nothing)
8080
ability_tracker()
8181
end
82-
NullAbilityTracker()
83-
# TODO: find if ability_estimator is GriddedAbilityEstimator and then propagate stuff to GriddedAbilityTracker
82+
if ability_estimator !== nothing && integrator !== nothing
83+
GriddedAbilityTracker(ability_estimator, integrator)
84+
else
85+
NullAbilityTracker()
86+
end
87+
end
88+
89+
function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked)
90+
ability_tracker = AbilityTracker(bits...; ability_estimator=ability_estimator)
91+
if ability_tracker isa GriddedAbilityTracker && ability_tracker.integrator === integrator
92+
return ability_tracker
93+
end
94+
if prefer_tracked
95+
return AbilityTracker(bits...; integrator=integrator, ability_estimator=ability_estimator)
96+
end
8497
end
8598

8699
abstract type AbilityIntegrator <: CatConfigBase end
87-
function AbilityIntegrator(bits...; ability_estimator=nothing)
100+
function AbilityIntegrator(bits...; ability_estimator=nothing, prefer_tracked=false)
88101
@returnsome find1_instance(AbilityIntegrator, bits)
89102
zero_arg_intergrators = find1_type(RiemannEnumerationIntegrator, bits)
90103
if (zero_arg_intergrators !== nothing)
91104
return RiemannEnumerationIntegrator()
92105
end
93-
@returnsome Integrator(bits...) integrator -> FunctionIntegrator(integrator)
106+
integrator = Integrator(bits...)
107+
if integrator === nothing
108+
return nothing
109+
end
110+
tracker = compatible_tracker(
111+
bits...;
112+
integrator=integrator,
113+
ability_estimator=ability_estimator,
114+
prefer_tracked=prefer_tracked
115+
)
116+
if tracker !== nothing
117+
TrackedLikelihoodIntegrator(integrator, tracker)
118+
else
119+
FunctionIntegrator(integrator)
120+
end
94121
end
95122

96123
abstract type AbilityOptimizer end

src/aggregators/ability_tracker.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ struct NullAbilityTracker <: AbilityTracker end
5959

6060
function track!(_, ::NullAbilityTracker) end
6161

62+
struct ConsAbilityTracker{H <: AbilityTracker, T <: AbilityTracker} <: AbilityTracker
63+
head::H
64+
tail::T
65+
end
66+
67+
function track!(responses, cons::ConsAbilityTracker)
68+
track!(responses, cons.head)
69+
track!(responses, cons.tail)
70+
end
71+
6272
struct VarNormal{T <: Real}
6373
mean::T
6474
var::T
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
struct GriddedAbilityTracker{AbilityEstimatorT <: DistributionAbilityEstimator, GridT <: AbstractVector{Float64}} <: AbilityTracker
1+
struct GriddedAbilityTracker{AbilityEstimatorT <: DistributionAbilityEstimator, IntegratorT} <: AbilityTracker
22
ability_estimator::AbilityEstimatorT
3-
grid::GridT
3+
integrator::IntegratorT
44
cur_ability::Vector{Float64}
55
end
66

7-
GriddedAbilityTracker(ability_estimator, grid) = GriddedAbilityTracker(ability_estimator, grid, fill(NaN, length(grid)))
7+
function GriddedAbilityTracker(ability_estimator::DistributionAbilityEstimator, integrator::FixedGridIntegrator)
8+
GriddedAbilityTracker(ability_estimator, integrator, fill(NaN, length(integrator.grid)))
9+
end
810

911
function track!(responses, ability_tracker::GriddedAbilityTracker)
1012
ability_pdf = pdf(ability_tracker.ability_estimator, responses)
11-
ability_tracker.cur_ability .= ability_pdf.(ability_tracker.grid)
13+
ability_tracker.cur_ability .= ability_pdf.(ability_tracker.integrator.grid)
1214
end

src/aggregators/integrators.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ function (product::FunctionProduct)(x::T) where {T}
77
product.f(x) * product.lh_function(x)
88
end
99

10+
struct TrackedLikelihoodIntegrator{IntegratorT <: Integrator} <: AbilityIntegrator
11+
integrator::IntegratorT
12+
tracker::GriddedAbilityTracker
13+
end
14+
15+
function(integrator::TrackedLikelihoodIntegrator{IntegratorT})(
16+
f::F,
17+
ncomp
18+
) where {F, IntegratorT}
19+
integrator.integrator((x, y) -> f(x) * y, integrator.tracker.cur_ability, ncomp)
20+
end
21+
1022
struct FunctionIntegrator{IntegratorT <: Integrator} <: AbilityIntegrator
1123
integrator::IntegratorT
1224
end
@@ -45,7 +57,7 @@ function (integrator::RiemannEnumerationIntegrator)(
4557
result[]
4658
end
4759

48-
function (integrator::AbilityIntegrator)(
60+
function (integrator::Union{RiemannEnumerationIntegrator, FunctionIntegrator})(
4961
f::F,
5062
ncomp,
5163
est,
@@ -54,3 +66,13 @@ function (integrator::AbilityIntegrator)(
5466
) where {F}
5567
integrator(maybe_apply_prior(f, est), ncomp, AbilityLikelihood(tracked_responses); kwargs...)
5668
end
69+
70+
function (integrator::TrackedLikelihoodIntegrator)(
71+
f::F,
72+
ncomp,
73+
est,
74+
tracked_responses::TrackedResponses;
75+
kwargs...
76+
) where {F}
77+
integrator(maybe_apply_prior(f, est), ncomp; kwargs...)
78+
end

src/next_item_rules/NextItemRules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ delegates to `ItemStrategyNextItemRule`.
5353
"""
5454
abstract type NextItemRule <: CatConfigBase end
5555

56-
function NextItemRule(bits...; ability_estimator=nothing, parallel=true)
56+
function NextItemRule(bits...; ability_estimator=nothing, ability_tracker=nothing, parallel=true)
5757
@returnsome find1_instance(NextItemRule, bits)
58-
@returnsome ItemStrategyNextItemRule(bits..., ability_estimator=ability_estimator, parallel=parallel)
58+
@returnsome ItemStrategyNextItemRule(bits..., ability_estimator=ability_estimator, ability_tracker=ability_tracker, parallel=parallel)
5959
end
6060

6161
include("./random.jl")
@@ -135,9 +135,9 @@ struct ItemStrategyNextItemRule{NextItemStrategyT <: NextItemStrategy, ItemCrite
135135
criterion::ItemCriterionT
136136
end
137137

138-
function ItemStrategyNextItemRule(bits...; parallel=true, ability_estimator=nothing)
138+
function ItemStrategyNextItemRule(bits...; parallel=true, ability_estimator=nothing, ability_tracker=nothing)
139139
strategy = NextItemStrategy(bits...; parallel=parallel)
140-
criterion = ItemCriterion(bits...; ability_estimator=ability_estimator)
140+
criterion = ItemCriterion(bits...; ability_estimator=ability_estimator, ability_tracker=ability_tracker)
141141
if strategy !== nothing && criterion !== nothing
142142
return ItemStrategyNextItemRule(strategy, criterion)
143143
end

src/next_item_rules/objective_function.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ $(TYPEDEF)
33
"""
44
abstract type ItemCriterion <: CatConfigBase end
55

6-
function ItemCriterion(bits...; ability_estimator=nothing)
6+
function ItemCriterion(bits...; ability_estimator=nothing, ability_tracker=nothing)
77
@returnsome find1_instance(ItemCriterion, bits)
8-
@returnsome find1_type(ItemCriterion, bits) typ -> typ(ability_estimator=ability_estimator)
9-
@returnsome ExpectationBasedItemCriterion(bits...; ability_estimator=ability_estimator)
8+
@returnsome find1_type(ItemCriterion, bits) typ -> typ(ability_estimator=ability_estimator, ability_tracker=ability_tracker)
9+
@returnsome ExpectationBasedItemCriterion(bits...; ability_estimator=ability_estimator, ability_tracker=ability_tracker)
1010
end
1111

1212
"""
1313
$(TYPEDEF)
1414
"""
1515
abstract type StateCriterion <: CatConfigBase end
1616

17-
function StateCriterion(bits...; ability_estimator=nothing)
17+
function StateCriterion(bits...; ability_estimator=nothing, ability_tracker=nothing)
1818
@returnsome find1_instance(StateCriterion, bits)
1919
@returnsome find1_type(StateCriterion, bits) typ -> typ()
2020
end
@@ -100,12 +100,12 @@ end
100100

101101
abstract type ExpectationBasedItemCriterion <: ItemCriterion end
102102

103-
function ExpectationBasedItemCriterion(bits...; ability_estimator=nothing)
104-
criterion = StateCriterion(bits...; ability_estimator=ability_estimator)
103+
function ExpectationBasedItemCriterion(bits...; ability_estimator=nothing, ability_tracker=nothing)
104+
criterion = StateCriterion(bits...; ability_estimator=ability_estimator, ability_tracker=ability_tracker)
105105
if criterion === nothing
106106
return nothing
107107
end
108-
ability_estimator = AbilityEstimator(bits..., ability_estimator=ability_estimator)
108+
ability_estimator = AbilityEstimator(bits..., ability_estimator=ability_estimator, ability_tracker=ability_tracker)
109109
if ability_estimator === nothing
110110
return nothing
111111
end

0 commit comments

Comments
 (0)