Skip to content

Commit 9bdfae2

Browse files
author
Andrey Oskin
committed
fast YingYang implementation
1 parent c4fb1f5 commit 9bdfae2

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

src/yingyang.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ struct YingYang <: AbstractKMeansAlg
1818
divider::Int
1919
end
2020

21-
YingYang() = YingYang(true, 10)
22-
YingYang(auto::Bool) = YingYang(auto, 10)
21+
YingYang() = YingYang(true, 7)
22+
YingYang(auto::Bool) = YingYang(auto, 7)
2323
YingYang(divider::Int) = YingYang(true, divider)
2424

2525
function kmeans!(alg::YingYang, containers, X, k;
@@ -58,6 +58,7 @@ function kmeans!(alg::YingYang, containers, X, k;
5858
end
5959
J_previous = J
6060

61+
# push!(containers.debug, [0, 0, 0])
6162
# Core calculation of the YingYang, 3.2-3.3 steps of the original paper
6263
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
6364
collect_containers(alg, containers, n_threads)
@@ -77,6 +78,7 @@ function kmeans!(alg::YingYang, containers, X, k;
7778
# TODO empty placeholder vectors should be calculated
7879
# TODO Float64 type definitions is too restrictive, should be relaxed
7980
# especially during GPU related development
81+
# return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged), containers
8082
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
8183
end
8284

@@ -123,6 +125,8 @@ function create_containers(alg::YingYang, k, nrow, ncol, n_threads)
123125
# total_sum_calculation
124126
sum_of_squares = Vector{Float64}(undef, n_threads)
125127

128+
# debug = []
129+
126130
return (
127131
centroids_new = centroids_new,
128132
centroids_cnt = centroids_cnt,
@@ -134,7 +138,8 @@ function create_containers(alg::YingYang, k, nrow, ncol, n_threads)
134138
groups = groups,
135139
indices = indices,
136140
gd = gd,
137-
mask = mask
141+
mask = mask,
142+
# debug = debug
138143
)
139144
end
140145

@@ -215,6 +220,7 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
215220

216221
# Global filtering
217222
ubx <= lbx && continue
223+
# containers.debug[end][1] += 1 # number of misses
218224

219225
# tighten upper bound
220226
label = labels[i]
@@ -232,21 +238,26 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
232238
mask[old_label] = true
233239
ri = groups[orig_group_id]
234240
old_lb = new_lb + gd[orig_group_id] # recovering initial value of lower bound
241+
new_lb2 = Inf
235242
for c in ri
236243
((c == old_label) | (ubx < old_lb - p[c])) && continue
237244
mask[c] = true
245+
# containers.debug[end][2] += 1 # local filter update
238246
dist = distance(X, centroids, i, c)
239247
if dist < ubx2
240-
new_lb = ubx
248+
new_lb2 = ubx2
241249
ubx2 = dist
242250
ubx = sqrt(dist)
243251
label = c
252+
elseif dist < new_lb2
253+
new_lb2 = dist
244254
end
245255
end
246-
new_lb2 = new_lb^2
256+
new_lb = sqrt(new_lb2)
247257
for c in ri
248258
mask[c] && continue
249259
new_lb < old_lb - p[c] && continue
260+
# containers.debug[end][3] += 1 # lower bound update
250261
dist = distance(X, centroids, i, c)
251262
if dist < new_lb2
252263
new_lb2 = dist
@@ -264,31 +275,37 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
264275
ubx < lb[gi, i] && continue
265276
new_lb = lb[gi, i]
266277
old_lb = new_lb + gd[gi]
278+
new_lb2 = Inf
267279
ri = groups[gi]
268280
for c in ri
269281
# local filtering
270282
ubx < old_lb - p[c] && continue
283+
# containers.debug[end][2] += 1 # local filter update
271284
mask[c] = true
272285
dist = distance(X, centroids, i, c)
273286
if dist < ubx2
274-
# closest canter was in previous cluster
287+
# closest center was in previous cluster
275288
if indices[label] != gi
276289
lb[indices[label], i] = ubx
277290
else
278291
new_lb = ubx
279292
end
293+
new_lb2 = ubx2
280294
ubx2 = dist
281295
ubx = sqrt(dist)
282296
label = c
297+
elseif dist < new_lb2
298+
new_lb2 = dist
283299
end
284300
end
285301

286-
new_lb2 = new_lb^2
302+
new_lb = sqrt(new_lb2)
287303
for c in ri
288304
mask[c] && continue
289305
new_lb < old_lb - p[c] && continue
306+
# containers.debug[end][3] += 1 # lower bound update
290307
dist = distance(X, centroids, i, c)
291-
if dist < newlb2
308+
if dist < new_lb2
292309
new_lb2 = dist
293310
new_lb = sqrt(new_lb2)
294311
end

0 commit comments

Comments
 (0)