Skip to content

Commit 743df4b

Browse files
committed
Coreset & Yinyang exclusive euclidean support
1 parent ceccba0 commit 743df4b

File tree

4 files changed

+11
-42
lines changed

4 files changed

+11
-42
lines changed

src/coreset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Coreset(; m = 100, alg = Lloyd()) = Coreset(m, alg)
3535
Coreset(m::Int) = Coreset(m, Lloyd())
3636
Coreset(alg::AbstractKMeansAlg) = Coreset(100, alg)
3737

38-
function kmeans!(alg::Coreset, containers, X, k, weights, metric=Euclidean();
38+
function kmeans!(alg::Coreset, containers, X, k, weights, metric::Euclidean = Euclidean();
3939
n_threads = Threads.nthreads(),
4040
k_init = "k-means++", max_iters = 300,
4141
tol = eltype(design_matrix)(1e-6), verbose = false,

src/yinyang.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,17 @@ Yinyang(; group_size = 7, auto = true) = Yinyang(auto, group_size)
4747
阴阳(group_size::Int) = Yinyang(true, group_size)
4848
阴阳(; group_size = 7, auto = true) = Yinyang(auto, group_size)
4949

50-
function kmeans!(alg::Yinyang, containers, X, k, weights, metric=Euclidean();
50+
metric_checker(metric::Euclidean) = Euclidean()
51+
metric_checker(metric::Metric) = throw(error("Euclidean() is the only supported distance metric."))
52+
53+
54+
function kmeans!(alg::Yinyang, containers, X, k, weights, metric::Euclidean = Euclidean();
5155
n_threads = Threads.nthreads(),
5256
k_init = "k-means++", max_iters = 300,
5357
tol = 1e-6, verbose = false,
5458
init = nothing, rng = Random.GLOBAL_RNG)
5559

60+
#metric = metric_checker(metric)
5661
nrow, ncol = size(X)
5762

5863
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)

test/test06_yinyang.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -198,29 +198,11 @@ end
198198
@test !alg.auto
199199
end
200200

201-
@testset "Yinyang metric support" begin
201+
@testset "Yinyang non-euclidean metric error" begin
202202
rng = StableRNG(2020)
203203
X = [1. 2. 4.;]
204204

205-
res = kmeans(Yinyang(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
206-
207-
@test res.assignments == [2, 2, 1]
208-
@test res.centers == [4.0 1.5]
209-
@test res.totalcost == 1.0
210-
@test res.converged
211-
212-
rng = StableRNG(2020)
213-
X = rand(3, 100)
214-
rng_orig = deepcopy(rng)
215-
216-
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
217-
218-
rng = deepcopy(rng_orig)
219-
res = kmeans(Yinyang(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
220-
221-
@test res.totalcost baseline.totalcost
222-
@test res.converged == baseline.converged
223-
@test res.iterations == baseline.iterations
205+
@test_throws MethodError res = kmeans(Yinyang(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
224206
end
225207

226208
end # module

test/test07_coreset.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,11 @@ end
4747
@test alg.alg == Hamerly()
4848
end
4949

50-
@testset "Coreset metric support" begin
50+
@testset "Coreset non-euclidean metric error" begin
5151
rng = StableRNG(2020)
5252
X = [1. 2. 4.;]
5353

54-
res = kmeans(Coreset(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
55-
56-
@test res.assignments == [2, 2, 1]
57-
@test res.centers [4.0 1.4865168535972686]
58-
@test res.totalcost == 1.0
59-
@test res.converged
60-
61-
62-
rng = StableRNG(2020)
63-
X = rand(rng, 3, 100)
64-
rng_orig = deepcopy(rng)
65-
66-
baseline = kmeans(Lloyd(), X, 10, tol = 1e-10, metric=Cityblock(), rng = rng, n_threads = 1)
67-
rng = deepcopy(rng_orig)
68-
res = kmeans(Coreset(), X, 10; tol = 1e-10, metric = Cityblock(), rng = rng, n_threads = 1)
69-
70-
@test res.converged == baseline.converged
71-
@test res.iterations == baseline.iterations
72-
@test floor(res.totalcost - baseline.totalcost) 1
54+
@test_throws MethodError res = kmeans(Coreset(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
7355

7456
end
7557

0 commit comments

Comments
 (0)