Skip to content

Commit 2e28f93

Browse files
author
Frankie Robertson
committed
Save dts
1 parent 288436a commit 2e28f93

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

experiments/dt.jl

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ using ComputerAdaptiveTesting.Aggregators
22
using PsychometricsBazaarBase.Integrators
33
using FittedItemBanks.DummyData: dummy_full, std_normal, std_mv_normal
44
using ComputerAdaptiveTesting.Sim: run_random_comparison
5-
using ComputerAdaptiveTesting.NextItemRules
65
using FittedItemBanks
76
using Base.Filesystem
87
using ComputerAdaptiveTesting
9-
using ComputerAdaptiveTesting.DecisionTree: DecisionTreeGenerationConfig, generate_dt_cat
8+
using ComputerAdaptiveTesting.DecisionTree: DecisionTreeGenerationConfig, generate_dt_cat, save_mmap
109
using ComputerAdaptiveTesting.Sim
1110
using ComputerAdaptiveTesting.NextItemRules
1211
using ComputerAdaptiveTesting.TerminationConditions
@@ -19,29 +18,15 @@ using ItemResponseDatasets.VocabIQ
1918
using RIrtWrappers.Mirt
2019
using Random
2120
using Distributions
21+
using Profile.Allocs: @profile
22+
using PProf
2223

23-
const next_item_aliases = [keys(catr_next_item_aliases)..., "drule"]
2424

25-
pclamp(x) = clamp.(x, 0.0, 1.0)
26-
abs_rand(rng, dist, dims...) = abs.(rand(rng, dist, dims...))
27-
clamp_rand(rng, dist, dims...) = pclamp.(rand(rng, dist, dims...))
25+
include("./utils/RandomItemBanks.jl")
2826

29-
function clumpy_4pl_item_bank(rng, num_clumps, num_questions)
30-
clump_dist_mat = hcat(
31-
Normal.(rand(rng, Normal(), num_clumps), 0.1), # Difficulty
32-
Normal.(abs_rand(rng, Normal(1.0, 0.2), num_clumps), 0.1), # Discrimination
33-
Normal.(clamp_rand(rng, Normal(0.1, 0.2), num_clumps), 0.02), # Guess
34-
Normal.(clamp_rand(rng, Normal(0.1, 0.2), num_clumps), 0.02) # Slip
35-
)
36-
params_clumps = mapslices(Product, clump_dist_mat; dims=[2])[:, 1]
37-
# TODO: Resample the clumps to create a correlated distribution
38-
params = Array{Float64, 2}(undef, num_questions, 4)
39-
for (question_idx, clump) in enumerate(sample(rng, params_clumps, num_questions; replace=true))
40-
(difficulty, discrimination, guess, slip) = rand(rng, clump)
41-
params[question_idx, :] = [difficulty, abs(discrimination), pclamp(guess), pclamp(slip)]
42-
end
43-
ItemBank4PL(params[:, 1], params[:, 2], params[:, 3], params[:, 4])
44-
end
27+
using .RandomItemBanks: clumpy_4pl_item_bank
28+
29+
const next_item_aliases = [keys(catr_next_item_aliases)..., "drule"]
4530

4631
# copy-pasted
4732
function get_next_item_rule(rule_name)::Tuple{AbilityEstimator, NextItemRule}
@@ -62,19 +47,27 @@ function get_next_item_rule(rule_name)::Tuple{AbilityEstimator, NextItemRule}
6247
(ability_estimator, next_item_rule)
6348
end
6449

65-
function main(rule_name)
50+
51+
function main(rule_name, out_dir)
6652
rng = Xoshiro(42)
6753
params = clumpy_4pl_item_bank(rng, 3, 1000)
6854
ability_estimator, next_item_rule = get_next_item_rule(rule_name)
69-
dt = generate_dt_cat(
55+
dt = @time generate_dt_cat(
7056
DecisionTreeGenerationConfig(;
71-
max_depth=UInt(1),
57+
max_depth=UInt(2),
7258
next_item=next_item_rule,
7359
ability_estimator=ability_estimator,
7460
), params
7561
)
62+
config = @time DecisionTreeGenerationConfig(;
63+
max_depth=UInt(4),
64+
next_item=next_item_rule,
65+
ability_estimator=ability_estimator,
66+
)
67+
dt = @time generate_dt_cat(config, params)
68+
save_mmap(out_dir, dt)
7669
end
7770

7871
if abspath(PROGRAM_FILE) == @__FILE__
79-
main(ARGS[1])
80-
end
72+
main(ARGS[1], ARGS[2])
73+
end

0 commit comments

Comments
 (0)