@@ -15,7 +15,7 @@ function kmeans!(alg::MiniBatch, X, k;
1515 k_init = " k-means++" , init = nothing , max_iters = 300 ,
1616 tol = eltype (X)(1e-6 ), max_no_improvement = 10 , verbose = false , rng = Random. GLOBAL_RNG)
1717
18- # Get the type and dimensions of design matrix, X
18+ # Get the type and dimensions of design matrix, X - (Step 1)
1919 T = eltype (X)
2020 nrow, ncol = size (X)
2121
@@ -25,9 +25,8 @@ function kmeans!(alg::MiniBatch, X, k;
2525 # Initialize counter for the no. of data in each cluster - (Step 3) in paper
2626 N = zeros (T, k)
2727
28- # Initialize nearest centers
29- labels = Vector {Int} (undef, alg. b)
30- final_labels = Vector {Int} (undef, ncol)
28+ # Initialize nearest centers for both batch and whole dataset labels
29+ final_labels = Vector {Int} (undef, ncol) # dataset labels
3130
3231 converged = false
3332 niters = 0
@@ -36,7 +35,7 @@ function kmeans!(alg::MiniBatch, X, k;
3635 J = zero (T)
3736
3837 # TODO : Main Steps. Batch update centroids until convergence
39- while niters <= max_iters
38+ while niters <= max_iters # Step 4 in paper
4039
4140 # b examples picked randomly from X (Step 5 in paper)
4241 batch_rand_idx = isnothing (weights) ? rand (rng, 1 : ncol, alg. b) : wsample (rng, 1 : ncol, weights, alg. b)
@@ -53,14 +52,14 @@ function kmeans!(alg::MiniBatch, X, k;
5352 min_dist = dist < min_dist ? dist : min_dist
5453 end
5554
56- labels[i ] = label
55+ final_labels[batch_rand_idx[i] ] = label
5756 end
5857
5958 # TODO : Batch gradient step
60- for j in axes (batch_sample, 2 ) # iterate over examples (Step 9)
59+ @inbounds for j in axes (batch_sample, 2 ) # iterate over examples (Step 9)
6160
62- # Get cached center/label for this x => labels[j ] (Step 10)
63- label = labels[j ]
61+ # Get cached center/label for this x => labels[batch_rand_idx[j] ] (Step 10)
62+ label = final_labels[batch_rand_idx[j] ]
6463 # Update per-center counts
6564 N[label] += isnothing (weights) ? 1 : weights[j] # verify (Step 11)
6665
@@ -71,8 +70,11 @@ function kmeans!(alg::MiniBatch, X, k;
7170 centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample[:, j])
7271 end
7372
74- # TODO : Calculate cost and check for convergence
75- J = sum_of_squares (batch_sample, labels, centroids) # just a placeholder for now
73+ # TODO : Reassign all labels based on new centres generated from the latest sample
74+ final_labels = reassign_labels (X, metric, final_labels, centroids)
75+
76+ # TODO : Calculate cost on whole dataset after reassignment and check for convergence
77+ J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
7678
7779 if verbose
7880 # Show progress and terminate if J stopped decreasing.
@@ -87,18 +89,8 @@ function kmeans!(alg::MiniBatch, X, k;
8789 if counter >= max_no_improvement
8890 converged = true
8991 # TODO : Compute label assignment for the complete dataset
90- @inbounds for i in axes (X, 2 )
91- min_dist = distance (metric, X, centroids, i, 1 )
92- label = 1
93-
94- for j in 2 : size (centroids, 2 )
95- dist = distance (metric, X, centroids, i, j)
96- label = dist < min_dist ? j : label
97- min_dist = dist < min_dist ? dist : min_dist
98- end
99-
100- final_labels[i] = label
101- end
92+ final_labels = reassign_labels (X, metric, final_labels, centroids)
93+
10294 # TODO : Compute totalcost for the complete dataset
10395 J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
10496 break
@@ -127,3 +119,19 @@ function sum_of_squares(x, labels, centre)
127119 end
128120 return s
129121end
122+
123+ function reassign_labels (DMatrix, metric, labels, centres)
124+ @inbounds for i in axes (DMatrix, 2 )
125+ min_dist = distance (metric, DMatrix, centres, i, 1 )
126+ label = 1
127+
128+ for j in 2 : size (centres, 2 )
129+ dist = distance (metric, DMatrix, centres, i, j)
130+ label = dist < min_dist ? j : label
131+ min_dist = dist < min_dist ? dist : min_dist
132+ end
133+
134+ labels[i] = label
135+ end
136+ return labels
137+ end
0 commit comments