Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
- package-ecosystem: "julia"
directories: # Location of Julia projects
- "/"
schedule:
interval: "weekly"
16 changes: 0 additions & 16 deletions .github/workflows/CompatHelper.yml

This file was deleted.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.4.4"
version = "0.5.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
61 changes: 35 additions & 26 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Tables
using CategoricalArrays

using Random
import Random.GLOBAL_RNG
import Random.default_rng

const MMI = MLJModelInterface
const DT = DecisionTree
Expand Down Expand Up @@ -36,7 +36,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
merge_purity_threshold::Float64 = 1.0::(_ ≤ 1)
display_depth::Int = 5::(_ ≥ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
rng::Union{AbstractRNG,Integer} = default_rng()
end

function MMI.fit(
Expand All @@ -48,6 +48,7 @@ function MMI.fit(
classes,
)

rng = copy(m.rng)
integers_seen = unique(yplain)
classes_seen = MMI.decoder(classes)(integers_seen)

Expand All @@ -56,8 +57,8 @@ function MMI.fit(
m.max_depth,
m.min_samples_leaf,
m.min_samples_split,
m.min_purity_increase,
rng=m.rng)
m.min_purity_increase;
rng)
if m.post_prune
tree = DT.prune_tree(tree, m.merge_purity_threshold)
end
Expand Down Expand Up @@ -117,7 +118,7 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
n_trees::Int = 100::(_ ≥ 0)
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
rng::Union{AbstractRNG,Integer} = default_rng()
end

function MMI.fit(
Expand All @@ -129,6 +130,7 @@ function MMI.fit(
classes,
)

rng = copy(m.rng)
integers_seen = unique(yplain)
classes_seen = MMI.decoder(classes)(integers_seen)

Expand All @@ -140,8 +142,8 @@ function MMI.fit(
m.min_samples_leaf,
m.min_samples_split,
m.min_purity_increase;
rng=m.rng)
cache = deepcopy(m)
rng)
cache = (deepcopy(m), rng)

report = (features=features,)

Expand All @@ -157,13 +159,15 @@ function MMI.update(
model::RandomForestClassifier,
verbosity::Int,
old_fitresult,
old_model,
cache,
Xmatrix,
yplain,
features,
classes,
)

old_model, rng = cache

only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)

if !only_iterations_have_changed
Expand Down Expand Up @@ -196,12 +200,12 @@ function MMI.update(
model.min_samples_leaf,
model.min_samples_split,
model.min_purity_increase;
rng=model.rng,
rng,
)
end

fitresult = (forest, old_fitresult[2:3]...)
cache = deepcopy(model)
cache = (deepcopy(model), rng)
report = (features=features,)
return fitresult, cache, report

Expand All @@ -223,7 +227,7 @@ MMI.iteration_parameter(::Type{<:RandomForestClassifier}) = :n_trees
MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
n_iter::Int = 10::(_ ≥ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
rng::Union{AbstractRNG,Integer} = default_rng()
end

function MMI.fit(
Expand All @@ -235,11 +239,12 @@ function MMI.fit(
classes,
)

rng = copy(m.rng)
integers_seen = unique(yplain)
classes_seen = MMI.decoder(classes)(integers_seen)

stumps, coefs =
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter; rng)
cache = nothing

report = (features=features,)
Expand Down Expand Up @@ -275,11 +280,12 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
post_prune::Bool = false
merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
rng::Union{AbstractRNG,Integer} = default_rng()
end

function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)

rng = copy(m.rng)
tree = DT.build_tree(
y,
Xmatrix,
Expand All @@ -288,7 +294,7 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
m.min_samples_leaf,
m.min_samples_split,
m.min_purity_increase;
rng=m.rng
rng
)

if m.post_prune
Expand Down Expand Up @@ -328,11 +334,12 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
n_trees::Int = 100::(_ ≥ 0)
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
rng::Union{AbstractRNG,Integer} = default_rng()
end

function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)

rng = copy(m.rng)
forest = DT.build_forest(
y,
Xmatrix,
Expand All @@ -342,11 +349,11 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
m.max_depth,
m.min_samples_leaf,
m.min_samples_split,
m.min_purity_increase,
rng=m.rng
m.min_purity_increase;
rng
)

cache = deepcopy(m)
cache = (deepcopy(m), rng)
report = (features=features,)

return forest, cache, report
Expand All @@ -356,12 +363,14 @@ function MMI.update(
model::RandomForestRegressor,
verbosity::Int,
old_forest,
old_model,
cache,
Xmatrix,
y,
features,
)

old_model, rng = cache

only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)

if !only_iterations_have_changed
Expand Down Expand Up @@ -394,11 +403,11 @@ function MMI.update(
model.min_samples_leaf,
model.min_samples_split,
model.min_purity_increase;
rng=model.rng
rng,
)
end

cache = deepcopy(model)
cache = (deepcopy(model), rng)
report = (features=features,)

return forest, cache, report
Expand Down Expand Up @@ -607,7 +616,7 @@ Train the machine using `fit!(mach, rows=...)`.
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
:split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed
- `rng=Random.default_rng()`: random number generator or seed


# Operations
Expand Down Expand Up @@ -743,7 +752,7 @@ Train the machine with `fit!(mach, rows=...)`.
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
:split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed
- `rng=Random.default_rng()`: random number generator or seed


# Operations
Expand Down Expand Up @@ -840,7 +849,7 @@ Train the machine with `fit!(mach, rows=...)`.
- `feature_importance`: method to use for computing feature importances. One of `(:impurity,
:split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed
- `rng=Random.default_rng()`: random number generator or seed

# Operations

Expand Down Expand Up @@ -951,7 +960,7 @@ Train the machine with `fit!(mach, rows=...)`.
- `feature_importance`: method to use for computing feature importances. One of
`(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed
- `rng=Random.default_rng()`: random number generator or seed


# Operations
Expand Down Expand Up @@ -1067,7 +1076,7 @@ Train the machine with `fit!(mach, rows=...)`.
- `feature_importance`: method to use for computing feature importances. One of
`(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed
- `rng=Random.default_rng()`: random number generator or seed


# Operations
Expand Down
32 changes: 31 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ function reproducibility(model, X, y, loss)
end
mach = machine(model, X, y)
train, test = partition(eachindex(y), 0.7)
model.rng = stable_rng()
errs = map(1:N) do i
model.rng = stable_rng()
fit!(mach, rows=train, force=true, verbosity=0)
yhat = predict(mach, rows=test)
loss(yhat, y[test]) |> mean
Expand Down Expand Up @@ -201,6 +201,36 @@ end
end
end

# The following test is broken and I do not believe a fix is possible without significant
# changes at DecisionTree.jl
stat(::RandomForestRegressor, mach) = predict(mach, rows=:) |> mean
stat(::RandomForestClassifier, mach) = pdf.(predict(mach, rows=:), 1) |> mean
stat(mach::MLJBase.Machine) = stat(mach.model, mach)
@testset "two-stage fit with warm-restart same as fit-in-one" begin
rng = stable_rng()
for (modeltype, data) in [
RandomForestClassifier => make_blobs(; rng),
RandomForestRegressor => make_regression(; rng),
]
X, y = data

# fit in two steps:
model = modeltype(; rng=stable_rng())
mach = machine(model, X, y)
fit!(mach; verbosity=0) # step 1
model.n_trees += 5
@test_logs (:info, r"^Updating") (:info, r"Adding 5") fit!(mach) # step 2
statistic = stat(mach)

# fit in one step:
model = modeltype(; rng=stable_rng())
model.n_trees += 5
mach = machine(model, X, y)
fit!(mach; verbosity=0)
@test_broken statistic ≈ stat(mach)
end
end

@testset "feature importance defined" begin
for model ∈ [
DecisionTreeClassifier(),
Expand Down
Loading