@@ -41,6 +41,7 @@ namespace Accord.MachineLearning
4141 /// has approximately the same number of data points. The Balanced k-Means implementation
4242 /// used in the framework uses the <see cref="Munkres"/> algorithm to solve the assignment
4343 /// problem thus enforcing balance between the clusters.</para>
44+ ///
4445 /// <para>
4546 /// Note: the <see cref="Learn(double[][], double[])"/> method of this class will
4647 /// return the centroids of balanced clusters, but please note that these centroids
@@ -50,6 +51,18 @@ namespace Accord.MachineLearning
5051 /// contents of the <see cref="Labels"/> property.</para>
5152 /// </remarks>
5253 ///
54+ /// <para>
55+ /// References:
56+ /// <list type="bullet">
57+ /// <item><description>
58+ /// M. I. Malinen and P.Fränti, "Balanced K-means for Clustering", Joint Int.Workshop on Structural, Syntactic,
59+ /// and Statistical Pattern Recognition (S+SSPR 2014), LNCS 8621, 32-41, Joensuu, Finland, August 2014. </description></item>
60+ /// <item><description>
61+ /// M. I. Malinen, "New alternatives for k-Means clustering." PhD thesis. Available in:
62+ /// http://cs.uef.fi/sipu/pub/PhD_Thesis_Mikko_Malinen.pdf </description></item>
63+ /// </list></para>
64+ ///
65+ ///
5366 /// <example>
5467 /// How to perform clustering with Balanced K-Means.
5568 ///
@@ -63,6 +76,7 @@ namespace Accord.MachineLearning
6376 [ Serializable ]
6477 public class BalancedKMeans : KMeans
6578 {
79+ internal Munkres munkres ;
6680
6781 /// <summary>
6882 /// Gets the labels assigned for each data point in the last
@@ -93,7 +107,9 @@ public BalancedKMeans(int k, IDistance<double[]> distance)
93107 /// <param name="k">The number of clusters to divide the input data into.</param>
94108 ///
95109 public BalancedKMeans ( int k )
96- : base ( k ) { }
110+ : base ( k )
111+ {
112+ }
97113
98114
99115 /// <summary>
@@ -142,12 +158,12 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
142158 int rows = x . Length ;
143159 int cols = x [ 0 ] . Length ;
144160
145- // Perform a random initialization of the clusters
146- // if the algorithm has not been initialized before.
161+ // Perform some iterations of the original k-Means
162+ // algorithm if the model has not been initialized
147163 //
148164 if ( this . Clusters . Centroids [ 0 ] == null )
149165 {
150- Randomize ( x ) ;
166+ base . Learn ( x ) ;
151167 }
152168
153169 // Initial variables
@@ -160,31 +176,26 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
160176
161177 bool shouldStop = false ;
162178
163- var m = new Munkres ( x . Length , x . Length ) ;
164- double [ ] [ ] costMatrix = m . CostMatrix ;
179+ // We will solve the problem of assigning N data points
180+ // to K clusters where the cost will be their distance.
181+ this . munkres = new Munkres ( x . Length , x . Length )
182+ {
183+ Tolerance = Tolerance
184+ } ;
165185
166186 while ( ! shouldStop ) // Main loop
167187 {
168188 Array . Clear ( count , 0 , count . Length ) ;
169189 for ( int i = 0 ; i < newCentroids . Length ; i ++ )
170190 Array . Clear ( newCentroids [ i ] , 0 , newCentroids [ i ] . Length ) ;
171- for ( int i = 0 ; i < labels . Length ; i ++ )
172- labels [ i ] = - 1 ;
173191
174192 // Set the cost matrix for Munkres algorithm
175- for ( int i = 0 ; i < costMatrix . Length ; i ++ )
176- for ( int j = 0 ; j < costMatrix [ i ] . Length ; j ++ )
177- costMatrix [ i ] [ j ] = Distance . Distance ( x [ j ] , centroids [ i % k ] ) ;
178-
179- //string str = costMatrix.ToCSharp();
193+ GetDistances ( Distance , x , centroids , k , munkres . CostMatrix ) ;
180194
181- m . Minimize ( ) ; // solve the assignment problem
195+ munkres . Minimize ( ) ; // solve the assignment problem
182196
183- for ( int i = 0 ; i < x . Length ; i ++ )
184- {
185- if ( m . Solution [ i ] >= 0 )
186- labels [ ( int ) m . Solution [ i ] ] = i % k ;
187- }
197+ // Get the clustering from the assignment
198+ GetLabels ( x , k , munkres . Solution , labels ) ;
188199
189200 // For each point in the data set,
190201 for ( int i = 0 ; i < x . Length ; i ++ )
@@ -229,11 +240,9 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
229240 shouldStop = converged ( centroids , newCentroids ) ;
230241
231242 // go to next generation
232- Parallel . For ( 0 , centroids . Length , ParallelOptions , i =>
233- {
243+ for ( int i = 0 ; i < centroids . Length ; i ++ )
234244 for ( int j = 0 ; j < centroids [ i ] . Length ; j ++ )
235245 centroids [ i ] [ j ] = newCentroids [ i ] [ j ] ;
236- } ) ;
237246 }
238247
239248 for ( int i = 0 ; i < Clusters . Centroids . Length ; i ++ )
@@ -249,5 +258,32 @@ public override KMeansClusterCollection Learn(double[][] x, double[] weights = n
249258 return Clusters ;
250259 }
251260
261+ internal static void GetLabels ( double [ ] [ ] points , int clusters , double [ ] solution , int [ ] labels )
262+ {
263+ for ( int i = 0 ; i < points . Length ; i ++ )
264+ {
265+ int j = ( int ) solution [ i ] ;
266+ if ( j >= 0 )
267+ labels [ j ] = GetIndex ( clusters , i ) ;
268+ else
269+ labels [ j ] = - 1 ;
270+ }
271+ }
272+
273+ internal static double [ ] [ ] GetDistances ( IDistance < double [ ] , double [ ] > distance , double [ ] [ ] points , double [ ] [ ] centroids , int k , double [ ] [ ] result )
274+ {
275+ for ( int i = 0 ; i < result . Length ; i ++ )
276+ for ( int j = 0 ; j < result [ i ] . Length ; j ++ )
277+ result [ i ] [ j ] = distance . Distance ( points [ j ] , centroids [ GetIndex ( k , i ) ] ) ;
278+ return result ;
279+ }
280+
281+ private static int GetIndex ( int clusters , int index )
282+ {
283+ // Equation 6.6 uses ((a mod k) + 1) instead of (a mod k):
284+ // http://cs.uef.fi/sipu/pub/PhD_Thesis_Mikko_Malinen.pdf
285+
286+ return ( index + 1 ) % clusters ;
287+ }
252288 }
253289}
0 commit comments