Skip to content
This repository was archived by the owner on Nov 19, 2020. It is now read-only.

Commit 8532311

Browse files
committed
GH-451: BalancedKMeans does not find a solution for this case.
1 parent 97124d6 commit 8532311

File tree

7 files changed

+857
-300
lines changed

7 files changed

+857
-300
lines changed

Sources/Accord.MachineLearning/Clustering/KMeans/BalancedKMeans.cs

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ namespace Accord.MachineLearning
4141
/// has approximately the same number of data points. The Balanced k-Means implementation
4242
/// used in the framework uses the <see cref="Munkres"/> algorithm to solve the assignment
4343
/// problem thus enforcing balance between the clusters.</para>
44+
///
4445
/// <para>
4546
/// Note: the <see cref="Learn(double[][], double[])"/> method of this class will
4647
/// return the centroids of balanced clusters, but please note that these centroids
@@ -50,6 +51,18 @@ namespace Accord.MachineLearning
5051
/// contents of the <see cref="Labels"/> property.</para>
5152
/// </remarks>
5253
///
54+
/// <para>
55+
/// References:
56+
/// <list type="bullet">
57+
/// <item><description>
58+
/// M. I. Malinen and P.Fränti, "Balanced K-means for Clustering", Joint Int.Workshop on Structural, Syntactic,
59+
/// and Statistical Pattern Recognition (S+SSPR 2014), LNCS 8621, 32-41, Joensuu, Finland, August 2014. </description></item>
60+
/// <item><description>
61+
/// M. I. Malinen, "New alternatives for k-Means clustering." PhD thesis. Available in:
62+
/// http://cs.uef.fi/sipu/pub/PhD_Thesis_Mikko_Malinen.pdf </description></item>
63+
/// </list></para>
64+
///
65+
///
5366
/// <example>
5467
/// How to perform clustering with Balanced K-Means.
5568
///
@@ -63,6 +76,7 @@ namespace Accord.MachineLearning
6376
[Serializable]
6477
public class BalancedKMeans : KMeans
6578
{
79+
internal Munkres munkres;
6680

6781
/// <summary>
6882
/// Gets the labels assigned for each data point in the last
@@ -93,7 +107,9 @@ public BalancedKMeans(int k, IDistance<double[]> distance)
93107
/// <param name="k">The number of clusters to divide the input data into.</param>
94108
///
95109
public BalancedKMeans(int k)
96-
: base(k) { }
110+
: base(k)
111+
{
112+
}
97113

98114

99115
/// <summary>
@@ -142,12 +158,12 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
142158
int rows = x.Length;
143159
int cols = x[0].Length;
144160

145-
// Perform a random initialization of the clusters
146-
// if the algorithm has not been initialized before.
161+
// Perform some iterations of the original k-Means
162+
// algorithm if the model has not been initialized
147163
//
148164
if (this.Clusters.Centroids[0] == null)
149165
{
150-
Randomize(x);
166+
base.Learn(x);
151167
}
152168

153169
// Initial variables
@@ -160,31 +176,26 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
160176

161177
bool shouldStop = false;
162178

163-
var m = new Munkres(x.Length, x.Length);
164-
double[][] costMatrix = m.CostMatrix;
179+
// We will solve the problem of assigning N data points
180+
// to K clusters where the cost will be their distance.
181+
this.munkres = new Munkres(x.Length, x.Length)
182+
{
183+
Tolerance = Tolerance
184+
};
165185

166186
while (!shouldStop) // Main loop
167187
{
168188
Array.Clear(count, 0, count.Length);
169189
for (int i = 0; i < newCentroids.Length; i++)
170190
Array.Clear(newCentroids[i], 0, newCentroids[i].Length);
171-
for (int i = 0; i < labels.Length; i++)
172-
labels[i] = -1;
173191

174192
// Set the cost matrix for Munkres algorithm
175-
for (int i = 0; i < costMatrix.Length; i++)
176-
for (int j = 0; j < costMatrix[i].Length; j++)
177-
costMatrix[i][j] = Distance.Distance(x[j], centroids[i % k]);
178-
179-
//string str = costMatrix.ToCSharp();
193+
GetDistances(Distance, x, centroids, k, munkres.CostMatrix);
180194

181-
m.Minimize(); // solve the assignment problem
195+
munkres.Minimize(); // solve the assignment problem
182196

183-
for (int i = 0; i < x.Length; i++)
184-
{
185-
if (m.Solution[i] >= 0)
186-
labels[(int)m.Solution[i]] = i % k;
187-
}
197+
// Get the clustering from the assignment
198+
GetLabels(x, k, munkres.Solution, labels);
188199

189200
// For each point in the data set,
190201
for (int i = 0; i < x.Length; i++)
@@ -229,11 +240,9 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
229240
shouldStop = converged(centroids, newCentroids);
230241

231242
// go to next generation
232-
Parallel.For(0, centroids.Length, ParallelOptions, i =>
233-
{
243+
for (int i = 0; i < centroids.Length; i++)
234244
for (int j = 0; j < centroids[i].Length; j++)
235245
centroids[i][j] = newCentroids[i][j];
236-
});
237246
}
238247

239248
for (int i = 0; i < Clusters.Centroids.Length; i++)
@@ -249,5 +258,32 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
249258
return Clusters;
250259
}
251260

261+
internal static void GetLabels(double[][] points, int clusters, double[] solution, int[] labels)
262+
{
263+
for (int i = 0; i < points.Length; i++)
264+
{
265+
int j = (int)solution[i];
266+
if (j >= 0)
267+
labels[j] = GetIndex(clusters, i);
268+
else
269+
labels[j] = -1;
270+
}
271+
}
272+
273+
internal static double[][] GetDistances(IDistance<double[], double[]> distance, double[][] points, double[][] centroids, int k, double[][] result)
274+
{
275+
for (int i = 0; i < result.Length; i++)
276+
for (int j = 0; j < result[i].Length; j++)
277+
result[i][j] = distance.Distance(points[j], centroids[GetIndex(k, i)]);
278+
return result;
279+
}
280+
281+
private static int GetIndex(int clusters, int index)
282+
{
283+
// Equation 6.6 uses ((a mod k) + 1) instead of (a mod k):
284+
// http://cs.uef.fi/sipu/pub/PhD_Thesis_Mikko_Malinen.pdf
285+
286+
return (index + 1) % clusters;
287+
}
252288
}
253289
}

Sources/Accord.MachineLearning/Clustering/KMeans/KMeans.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ public KMeansClusterCollection Clusters
194194
get { return clusters; }
195195
}
196196

197+
/// <summary>
198+
/// Gets or sets the cluster centroids. Setting this property is equivalent
199+
/// to setting <see cref="ClusterCollection{TData, TCentroids, TCluster}.Centroids">
200+
/// KMeans.Clusters.Centroids</see>.
201+
/// </summary>
202+
///
203+
public double[][] Centroids
204+
{
205+
get { return clusters.Centroids; }
206+
set { clusters.Centroids = value; }
207+
}
208+
197209
/// <summary>
198210
/// Gets the number of clusters.
199211
/// </summary>
@@ -569,7 +581,7 @@ protected bool converged(double[][] centroids, double[][] newCentroids)
569581
{
570582
Iterations++;
571583

572-
if (MaxIterations > 0 && Iterations > MaxIterations)
584+
if (MaxIterations > 0 && Iterations >= MaxIterations)
573585
return true;
574586

575587
if (Token.IsCancellationRequested)

0 commit comments

Comments
 (0)