Skip to content

Commit 461c3d6

Browse files
committed
Added verbosity to MLJ interface fit
1 parent d485c12 commit 461c3d6

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

src/mlj_interface.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,15 @@ mutable struct KMeans <: MMI.Unsupervised
1919
max_iters::Int
2020
copy::Bool
2121
threads::Int
22-
verbosity::Int
2322
init
2423
end
2524

2625

2726
function KMeans(; algo=:Hamerly, k_init="k-means++",
2827
k=3, tol=1e-6, max_iters=300, copy=true,
29-
threads=Threads.nthreads(), verbosity=0, init=nothing)
28+
threads=Threads.nthreads(), init=nothing)
3029

31-
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, verbosity, init)
30+
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, init)
3231
message = MMI.clean!(model)
3332
isempty(message) || @warn message
3433
return model
@@ -68,11 +67,6 @@ function MMI.clean!(m::KMeans)
6867
m.threads = Threads.nthreads()
6968
end
7069

71-
if !(m.verbosity (0, 1))
72-
push!(warning, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 1.")
73-
m.verbosity = 1
74-
end
75-
7670
return join(warning, "\n")
7771
end
7872

@@ -85,7 +79,7 @@ end
8579
8680
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8781
"""
88-
function MMI.fit(m::KMeans, X)
82+
function MMI.fit(m::KMeans, verbosity::Int, X)
8983
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
9084
if !m.copy
9185
# permutes dimensions of input table without copying and pass to model
@@ -99,16 +93,22 @@ function MMI.fit(m::KMeans, X)
9993
algo = MLJDICT[m.algo] # select algo
10094

10195
# fit model and get results
102-
verbose = m.verbosity != 0
96+
verbose = verbosity > 0 # Display fitting operations if verbosity > 0
10397
fitresult = ParallelKMeans.kmeans(algo, DMatrix, m.k;
10498
n_threads = m.threads, k_init=m.k_init,
10599
max_iters=m.max_iters, tol=m.tol, init=m.init,
106100
verbose=verbose)
101+
107102
cache = nothing
108103
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
109104
converged=fitresult.converged, totalcost=fitresult.totalcost,
110105
labels=fitresult.assignments)
111-
106+
"""
107+
# TODO: warn users about non convergence
108+
if verbose & (!fitresult.converged)
109+
@warn "Specified model failed to converge."
110+
end
111+
"""
112112
return (fitresult, cache, report)
113113
end
114114

@@ -144,7 +144,7 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
144144

145145
# Warn users if fitresult is from a `non-converged` fit
146146
if !fitresult[end].converged
147-
@warn "Failed to converged. Using last assignments to make transformations."
147+
@warn "Failed to converge. Using last assignments to make transformations."
148148
end
149149

150150
# results from fitted model
@@ -175,7 +175,7 @@ MMI.metadata_pkg.(KMeans,
175175
# Metadata for ParaKMeans model interface
176176
MMI.metadata_model(KMeans,
177177
input = MMI.Table(MMI.Continuous),
178-
output = MMI.Table(MMI.Count),
178+
output = MMI.Table(MMI.Continuous),
179179
weights = false,
180180
descr = ParallelKMeans_Desc,
181181
path = "ParallelKMeans.KMeans")

test/test07_mlj_interface.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ using MLJBase
1919
@test model.copy == true
2020
@test model.threads == Threads.nthreads()
2121
@test model.tol == 1.0e-6
22-
@test model.verbosity == 0
2322
end
2423

2524

@@ -30,15 +29,14 @@ end
3029
@test_logs (:warn, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.") ParallelKMeans.KMeans(tol=2)
3130
@test_logs (:warn, "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations.") ParallelKMeans.KMeans(max_iters=0)
3231
@test_logs (:warn, "Number of threads must be at least 1. Defaulting to all threads available.") ParallelKMeans.KMeans(threads=0)
33-
@test_logs (:warn, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 1.") ParallelKMeans.KMeans(verbosity=100)
3432
end
3533

3634

3735
@testset "Test model fitting verbosity" begin
3836
Random.seed!(2020)
3937
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
40-
model = KMeans(k=2, max_iters=1, verbosity=1)
41-
results = @capture_out fit(model, X)
38+
model = KMeans(k=2, max_iters=1)
39+
results = @capture_out fit(model, 1, X)
4240

4341
@test results == "Iteration 1: Jclust = 28.0\n"
4442
end
@@ -50,7 +48,7 @@ end
5048
X_test = table([10 1])
5149

5250
model = KMeans(algo = :Lloyd, k=2)
53-
results = fit(model, X)
51+
results = fit(model, 0, X)
5452

5553
@test results[2] == nothing
5654
@test results[end].converged == true
@@ -72,7 +70,7 @@ end
7270
X_test = table([10 1])
7371

7472
model = KMeans(algo=:Hamerly, k=2)
75-
results = fit(model, X)
73+
results = fit(model, 0, X)
7674

7775
@test results[2] == nothing
7876
@test results[end].converged == true
@@ -87,13 +85,14 @@ end
8785
@test preds[:x1][1] == 2
8886
end
8987

88+
9089
@testset "Test Elkan model fitting" begin
9190
Random.seed!(2020)
9291
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
9392
X_test = table([10 1])
9493

9594
model = KMeans(algo=:Elkan, k=2)
96-
results = fit(model, X)
95+
results = fit(model, 0, X)
9796

9897
@test results[2] == nothing
9998
@test results[end].converged == true
@@ -108,15 +107,26 @@ end
108107
@test preds[:x1][1] == 2
109108
end
110109

110+
111111
@testset "Testing non convergence warning" begin
112112
Random.seed!(2020)
113113
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
114114
X_test = table([10 1])
115115

116116
model = KMeans(k=2, max_iters=1)
117-
results = fit(model, X)
117+
results = fit(model, 0, X)
118118

119-
@test_logs (:warn, "Failed to converged. Using last assignments to make transformations.") transform(model, results, X_test)
119+
@test_logs (:warn, "Failed to converge. Using last assignments to make transformations.") transform(model, results, X_test)
120120
end
121121

122+
"""
123+
@testset "Testing non convergence warning during model fitting" begin
124+
Random.seed!(2020)
125+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
126+
X_test = table([10 1])
127+
128+
model = KMeans(k=2, max_iters=1)
129+
@test_logs (:warn, "Specified model failed to converge.") fit(model, 1, X);
130+
end
131+
"""
122132
end # module

0 commit comments

Comments
 (0)