Skip to content

Commit 5f7f48a

Browse files
committed
Converging appears to work now. More eyes needed
1 parent d09e5cc commit 5f7f48a

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

src/mini_batch.jl

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function kmeans!(alg::MiniBatch, X, k;
1515
k_init = "k-means++", init = nothing, max_iters = 300,
1616
tol = eltype(X)(1e-6), max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
1717

18-
# Get the type and dimensions of design matrix, X
18+
# Get the type and dimensions of design matrix, X - (Step 1)
1919
T = eltype(X)
2020
nrow, ncol = size(X)
2121

@@ -25,9 +25,8 @@ function kmeans!(alg::MiniBatch, X, k;
2525
# Initialize counter for the no. of data in each cluster - (Step 3) in paper
2626
N = zeros(T, k)
2727

28-
# Initialize nearest centers
29-
labels = Vector{Int}(undef, alg.b)
30-
final_labels = Vector{Int}(undef, ncol)
28+
# Initialize nearest centers for both batch and whole dataset labels
29+
final_labels = Vector{Int}(undef, ncol) # dataset labels
3130

3231
converged = false
3332
niters = 0
@@ -36,7 +35,7 @@ function kmeans!(alg::MiniBatch, X, k;
3635
J = zero(T)
3736

3837
# TODO: Main Steps. Batch update centroids until convergence
39-
while niters <= max_iters
38+
while niters <= max_iters # Step 4 in paper
4039

4140
# b examples picked randomly from X (Step 5 in paper)
4241
batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b)
@@ -53,14 +52,14 @@ function kmeans!(alg::MiniBatch, X, k;
5352
min_dist = dist < min_dist ? dist : min_dist
5453
end
5554

56-
labels[i] = label
55+
final_labels[batch_rand_idx[i]] = label
5756
end
5857

5958
# TODO: Batch gradient step
60-
for j in axes(batch_sample, 2) # iterate over examples (Step 9)
59+
@inbounds for j in axes(batch_sample, 2) # iterate over examples (Step 9)
6160

62-
# Get cached center/label for this x => labels[j] (Step 10)
63-
label = labels[j]
61+
# Get cached center/label for this x => labels[batch_rand_idx[j]] (Step 10)
62+
label = final_labels[batch_rand_idx[j]]
6463
# Update per-center counts
6564
N[label] += isnothing(weights) ? 1 : weights[j] # verify (Step 11)
6665

@@ -71,8 +70,11 @@ function kmeans!(alg::MiniBatch, X, k;
7170
centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample[:, j])
7271
end
7372

74-
# TODO: Calculate cost and check for convergence
75-
J = sum_of_squares(batch_sample, labels, centroids) # just a placeholder for now
73+
# TODO: Reassign all labels based on new centres generated from the latest sample
74+
final_labels = reassign_labels(X, metric, final_labels, centroids)
75+
76+
# TODO: Calculate cost on whole dataset after reassignment and check for convergence
77+
J = sum_of_squares(X, final_labels, centroids) # just a placeholder for now
7678

7779
if verbose
7880
# Show progress and terminate if J stopped decreasing.
@@ -87,18 +89,8 @@ function kmeans!(alg::MiniBatch, X, k;
8789
if counter >= max_no_improvement
8890
converged = true
8991
# TODO: Compute label assignment for the complete dataset
90-
@inbounds for i in axes(X, 2)
91-
min_dist = distance(metric, X, centroids, i, 1)
92-
label = 1
93-
94-
for j in 2:size(centroids, 2)
95-
dist = distance(metric, X, centroids, i, j)
96-
label = dist < min_dist ? j : label
97-
min_dist = dist < min_dist ? dist : min_dist
98-
end
99-
100-
final_labels[i] = label
101-
end
92+
final_labels = reassign_labels(X, metric, final_labels, centroids)
93+
10294
# TODO: Compute totalcost for the complete dataset
10395
J = sum_of_squares(X, final_labels, centroids) # just a placeholder for now
10496
break
@@ -127,3 +119,19 @@ function sum_of_squares(x, labels, centre)
127119
end
128120
return s
129121
end
122+
123+
function reassign_labels(DMatrix, metric, labels, centres)
124+
@inbounds for i in axes(DMatrix, 2)
125+
min_dist = distance(metric, DMatrix, centres, i, 1)
126+
label = 1
127+
128+
for j in 2:size(centres, 2)
129+
dist = distance(metric, DMatrix, centres, i, j)
130+
label = dist < min_dist ? j : label
131+
min_dist = dist < min_dist ? dist : min_dist
132+
end
133+
134+
labels[i] = label
135+
end
136+
return labels
137+
end

0 commit comments

Comments
 (0)