Skip to content

Commit e255039

Browse files
author
Frankie Robertson
committed
Add extra constructors to CatRecorder for CATServe
1 parent fcc22b3 commit e255039

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

CATPlots/src/CATPlots.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,23 @@ function CatRecorder(xs::AbstractVector{Float64}, responses, integrator, raw_est
173173
end
174174

175175
function CatRecorder(xs::AbstractMatrix{Float64}, responses, integrator, raw_estimator, ability_estimator, actual_abilities=nothing)
176+
dims = size(xs, 1)
176177
points = size(xs, 2)
177178
num_questions = size(responses, 1)
178179
num_respondents = size(responses, 2)
179180
num_values = num_questions * num_respondents
180-
CatRecorder(xs, points, zeros(size(xs, 1), num_values), num_questions, num_respondents, integrator, raw_estimator, ability_estimator, actual_abilities)
181+
CatRecorder(xs, points, zeros(dims, num_values), num_questions, num_respondents, integrator, raw_estimator, ability_estimator, actual_abilities)
182+
end
183+
184+
function CatRecorder(xs::AbstractVector{Float64}, max_responses::Int, integrator, raw_estimator, ability_estimator, actual_abilities=nothing)
185+
points = size(xs, 1)
186+
CatRecorder(xs, points, zeros(max_responses), max_responses, 1, integrator, raw_estimator, ability_estimator, actual_abilities)
187+
end
188+
189+
function CatRecorder(xs::AbstractMatrix{Float64}, max_responses::Int, integrator, raw_estimator, ability_estimator, actual_abilities=nothing)
190+
dims = size(xs, 1)
191+
points = size(xs, 2)
192+
CatRecorder(xs, points, zeros(dims, max_responses), max_responses, 1, integrator, raw_estimator, ability_estimator, actual_abilities)
181193
end
182194

183195
function push_ability_est!(ability_ests::AbstractMatrix{Float64}, col_idx, ability_est)
@@ -223,7 +235,10 @@ function (recorder::CatRecorder)(tracked_responses, resp_idx, terminating)
223235
recorder.item_responses[:, recorder.col_idx] = resp.(Ref(ir), item_correct, eachmatcol(recorder.xs))
224236

225237
# Save item parameters
226-
recorder.item_difficulties[recorder.step, resp_idx] = item_params(tracked_responses.item_bank, item_index).difficulty
238+
params = item_params(tracked_responses.item_bank, item_index)
239+
if hasproperty(params, :difficulty)
240+
recorder.item_difficulties[recorder.step, resp_idx] = params.difficulty
241+
end
227242
recorder.item_correctness[recorder.step, resp_idx] = item_correct
228243

229244
recorder.col_idx += 1

0 commit comments

Comments
 (0)