@@ -4,7 +4,7 @@ module Comparison
44# Should be kept in mind and kept distinct or code reuse
55
66using StatsBase
7- using FittedItemBanks: AbstractItemBank, ResponseType
7+ using FittedItemBanks: AbstractItemBank, ResponseType, subset
88using .. Responses
99using .. CatConfig: CatLoopConfig, CatRules
1010using .. Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!,
@@ -14,11 +14,11 @@ using Base: Iterators
1414
1515using HypothesisTests
1616using EffectSizes
17- using DataFrames
17+ using DataFrames: DataFrame
1818using ComputerAdaptiveTesting: Stateful
1919
2020export run_random_comparison, run_comparison
21- export CatComparisonExecutionStrategy# , IncreaseItemBankSizeExecutionStrategy
21+ export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
2222# export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy
2323# export DecisionTreeExecutionStrategy
2424export ReplayResponsesExecutionStrategy
8383
8484abstract type CatComparisonExecutionStrategy end
8585
86- Base. @kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy }
86+ struct CatComparisonConfig{
87+ StrategyT <: CatComparisonExecutionStrategy , PhasesT <: NamedTuple }
8788 """
8889 A named tuple with the (named) CatRules (or compatable) to be compared
8990 """
@@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate
99100 measurements::Vector{}
100101 =#
101102 """
102- Which phases to run and/or call the callback on
103+ The phases to run, optionally paired with a callback
103104 """
104- phases:: Set{Symbol} = Set ((:before_next_item , :after_next_item ))
105- """
106- The callback which should take a named tuple with information at different phases
107- """
108- callback:: Any
105+ phases:: PhasesT
106+ end
107+
108+ """
109+ CatComparisonConfig(;
110+ rules::NamedTuple{Symbol, StatefulCat},
111+ strategy::CatComparisonExecutionStrategy,
112+ phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
113+ callback::Callable
114+ ) -> CatComparisonConfig
115+
116+ CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems.
117+
118+ Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives
119+ identifiers to implementations of the `StatefulCat` interface; the `strategy` to use,
120+ an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as
121+ either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a
122+ `Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where
123+ no callback is provided.
124+
125+ The exact phases depend on the strategy used. See their individual documentation for more.
126+ """
127+ function CatComparisonConfig (; rules, strategy, phases = nothing , callback = nothing )
128+ if callback === nothing
129+ callback = (info; kwargs... ) -> nothing
130+ end
131+ if phases === nothing
132+ phases = (:before_next_item , :after_next_item )
133+ end
134+ # TODO : normalize phases into named tuple
135+ if ! (phases isa NamedTuple)
136+ phases = NamedTuple ((phase => callback for phase in phases))
137+ end
138+ CatComparisonConfig (rules, strategy, phases)
109139end
110140
111141# Comparison scenarios:
129159
130160# phase_func=nothing;
131161function measure_all (comparison, system, cat, phase; kwargs... )
132- if ! (phase in comparison. phases)
162+ @info " measure_all" phase comparison. phases
163+ if ! (phase in keys (comparison. phases))
133164 return
134165 end
166+ callback = comparison. phases[phase]
135167 strategy = comparison. strategy
136168 #= measurement_results = []
137169 for measurement in comparison.measurements
@@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...)
145177 #end
146178 push!(measurement_results, result)
147179 end=#
148- comparison . callback ((;
180+ callback ((;
149181 phase,
150182 system,
151183 cat,
@@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
158190 item_bank:: AbstractItemBank
159191 sizes:: AbstractVector{Int}
160192 starting_responses:: Int
193+ shuffle:: Bool
194+ time_limit:: Float64
195+
196+ function IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, args... )
197+ if any ((size > length (item_bank) for size in sizes))
198+ error (" IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank" )
199+ end
200+ new (item_bank, sizes, args... )
201+ end
161202end
162203
163204function IncreaseItemBankSizeExecutionStrategy (item_bank, sizes)
164- return IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, 0 )
205+ return IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, 0 , false , Inf )
165206end
166207
167- function run_comparison (strategy:: IncreaseItemBankSizeExecutionStrategy , config)
208+ function run_comparison (comparison:: CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy} )
209+ strategy = comparison. strategy
210+ current_cats = collect (pairs (comparison. rules))
211+ next_current_cats = copy (current_cats)
212+ @info " sizes" strategy. sizes
168213 for size in strategy. sizes
169- subsetted_item_bank = subset (strategy. item_bank, size)
170- responses = TrackedResponses (
171- BareResponses (ResponseType (strategy. item_bank)),
172- subsetted_item_bank,
173- config. ability_tracker
174- )
175- for _ in 1 : (strategy. starting_responses)
176- next_item = config. next_item (responses, subsetted_item_bank)
177- add_response! (responses,
178- Response (ResponseType (subsetted_item_bank), next_item, rand (Bool)))
214+ subsetted_item_bank = subset (strategy. item_bank, 1 : size)
215+ empty! (next_current_cats)
216+ for (name, cat) in current_cats
217+ Stateful. set_item_bank! (cat, subsetted_item_bank)
218+ for _ in 1 : (strategy. starting_responses)
219+ Stateful. next_item (cat)
220+ end
221+ measure_all (
222+ comparison,
223+ name,
224+ cat,
225+ :before_next_item
226+ )
227+ timed_next_item = @timed Stateful. next_item (cat)
228+ next_item = timed_next_item. value
229+ measure_all (
230+ comparison,
231+ name,
232+ cat,
233+ :after_next_item ,
234+ next_item = next_item,
235+ timing = timed_next_item
236+ )
237+ @info " next_item" timed_next_item. time strategy. time_limit
238+ if timed_next_item. time < strategy. time_limit
239+ push! (next_current_cats, name => cat)
240+ end
179241 end
180- measure_all (config, :before_next_item , before_next_item; responses = responses)
181- timed_next_item = @timed config. next_item (responses, item_bank)
182- next_item = timed_next_item. value
183- measure_all (config, :after_next_item , after_next_item;
184- responses = responses, next_item = next_item)
242+ current_cats, next_current_cats = next_current_cats, current_cats
185243 end
186244end
187245
0 commit comments