@@ -18,8 +18,8 @@ struct YingYang <: AbstractKMeansAlg
1818 divider:: Int
1919end
2020
21- YingYang () = YingYang (true , 10 )
22- YingYang (auto:: Bool ) = YingYang (auto, 10 )
21+ YingYang () = YingYang (true , 7 )
22+ YingYang (auto:: Bool ) = YingYang (auto, 7 )
2323YingYang (divider:: Int ) = YingYang (true , divider)
2424
2525function kmeans! (alg:: YingYang , containers, X, k;
@@ -58,6 +58,7 @@ function kmeans!(alg::YingYang, containers, X, k;
5858 end
5959 J_previous = J
6060
61+ # push!(containers.debug, [0, 0, 0])
6162 # Core calculation of the YingYang, 3.2-3.3 steps of the original paper
6263 @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X)
6364 collect_containers (alg, containers, n_threads)
@@ -77,6 +78,7 @@ function kmeans!(alg::YingYang, containers, X, k;
7778 # TODO empty placeholder vectors should be calculated
7879 # TODO Float64 type definitions is too restrictive, should be relaxed
7980 # especially during GPU related development
81+ # return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged), containers
8082 return KmeansResult (centroids, containers. labels, Float64[], Int[], Float64[], totalcost, niters, converged)
8183end
8284
@@ -123,6 +125,8 @@ function create_containers(alg::YingYang, k, nrow, ncol, n_threads)
123125 # total_sum_calculation
124126 sum_of_squares = Vector {Float64} (undef, n_threads)
125127
128+ # debug = []
129+
126130 return (
127131 centroids_new = centroids_new,
128132 centroids_cnt = centroids_cnt,
@@ -134,7 +138,8 @@ function create_containers(alg::YingYang, k, nrow, ncol, n_threads)
134138 groups = groups,
135139 indices = indices,
136140 gd = gd,
137- mask = mask
141+ mask = mask,
142+ # debug = debug
138143 )
139144end
140145
@@ -215,6 +220,7 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
215220
216221 # Global filtering
217222 ubx <= lbx && continue
223+ # containers.debug[end][1] += 1 # number of misses
218224
219225 # tighten upper bound
220226 label = labels[i]
@@ -232,21 +238,26 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
232238 mask[old_label] = true
233239 ri = groups[orig_group_id]
234240 old_lb = new_lb + gd[orig_group_id] # recovering initial value of lower bound
241+ new_lb2 = Inf
235242 for c in ri
236243 ((c == old_label) | (ubx < old_lb - p[c])) && continue
237244 mask[c] = true
245+ # containers.debug[end][2] += 1 # local filter update
238246 dist = distance (X, centroids, i, c)
239247 if dist < ubx2
240- new_lb = ubx
248+ new_lb2 = ubx2
241249 ubx2 = dist
242250 ubx = sqrt (dist)
243251 label = c
252+ elseif dist < new_lb2
253+ new_lb2 = dist
244254 end
245255 end
246- new_lb2 = new_lb ^ 2
256+ new_lb = sqrt (new_lb2)
247257 for c in ri
248258 mask[c] && continue
249259 new_lb < old_lb - p[c] && continue
260+ # containers.debug[end][3] += 1 # lower bound update
250261 dist = distance (X, centroids, i, c)
251262 if dist < new_lb2
252263 new_lb2 = dist
@@ -264,31 +275,37 @@ function chunk_update_centroids(alg, containers, centroids, X, r, idx)
264275 ubx < lb[gi, i] && continue
265276 new_lb = lb[gi, i]
266277 old_lb = new_lb + gd[gi]
278+ new_lb2 = Inf
267279 ri = groups[gi]
268280 for c in ri
269281 # local filtering
270282 ubx < old_lb - p[c] && continue
283+ # containers.debug[end][2] += 1 # local filter update
271284 mask[c] = true
272285 dist = distance (X, centroids, i, c)
273286 if dist < ubx2
274- # closest canter was in previous cluster
287+ # closest center was in previous cluster
275288 if indices[label] != gi
276289 lb[indices[label], i] = ubx
277290 else
278291 new_lb = ubx
279292 end
293+ new_lb2 = ubx2
280294 ubx2 = dist
281295 ubx = sqrt (dist)
282296 label = c
297+ elseif dist < new_lb2
298+ new_lb2 = dist
283299 end
284300 end
285301
286- new_lb2 = new_lb ^ 2
302+ new_lb = sqrt (new_lb2)
287303 for c in ri
288304 mask[c] && continue
289305 new_lb < old_lb - p[c] && continue
306+ # containers.debug[end][3] += 1 # lower bound update
290307 dist = distance (X, centroids, i, c)
291- if dist < newlb2
308+ if dist < new_lb2
292309 new_lb2 = dist
293310 new_lb = sqrt (new_lb2)
294311 end
0 commit comments