@@ -94,15 +94,19 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
9494
9595 # fit model and get results
9696 verbose = verbosity > 0 # Display fitting operations if verbosity > 0
97- fitresult = ParallelKMeans. kmeans (algo, DMatrix, m. k;
97+ result = ParallelKMeans. kmeans (algo, DMatrix, m. k;
9898 n_threads = m. threads, k_init= m. k_init,
9999 max_iters= m. max_iters, tol= m. tol, init= m. init,
100100 verbose= verbose)
101101
102+ cluster_labels = MMI. categorical (1 : m. k)
103+ fitresult = (centers = result. centers, labels = cluster_labels, converged = result. converged)
102104 cache = nothing
103- report = (cluster_centers= fitresult. centers, iterations= fitresult. iterations,
104- converged= fitresult. converged, totalcost= fitresult. totalcost,
105- labels= fitresult. assignments)
105+
106+ report = (cluster_centers= result. centers, iterations= result. iterations,
107+ totalcost= result. totalcost, assignments= result. assignments, labels= cluster_labels)
108+
109+
106110 """
107111 # TODO: warn users about non convergence
108112 if verbose & (!fitresult.converged)
114118
115119
116120function MMI. fitted_params (model:: KMeans , fitresult)
117- # extract what's relevant from `fitresult`
118- results, _, _ = fitresult # unpack fitresult
119- centers = results. centers
120- converged = results. converged
121- iters = results. iterations
122- totalcost = results. totalcost
123-
124- # then return as a NamedTuple
125- return (cluster_centers = centers, totalcost = totalcost,
126- iterations = iters, converged = converged)
121+ # Centroids
122+ return (cluster_centers = fitresult. centers, )
127123end
128124
129125
132128# ###
133129
134130function MMI. transform (m:: KMeans , fitresult, Xnew)
135- # make predictions/assignments using the learned centroids
131+ # transform new data using the fitted centroids.
136132
137133 if ! m. copy
138134 # permutes dimensions of input table without copying and pass to model
@@ -143,21 +139,36 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
143139 end
144140
145141 # Warn users if fitresult is from a `non-converged` fit
146- if ! fitresult[ end ] . converged
142+ if ! fitresult. converged
147143 @warn " Failed to converge. Using last assignments to make transformations."
148144 end
149145
150- # results from fitted model
151- results = fitresult[1 ]
152-
153146 # use centroid matrix to assign clusters for new data
154- centroids = results. centers
155- distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
156- preds = argmin .(eachrow (distances))
157- return MMI. table (reshape (preds, :, 1 ), prototype= Xnew)
147+ distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, fitresult. centers; dims= 2 )
148+ # preds = argmin.(eachrow(distances))
149+ return MMI. table (distances, prototype= Xnew)
158150end
159151
160152
153+ function MMI. predict (m:: KMeans , fitresult, Xnew)
154+ locations, cluster_labels, _ = fitresult
155+
156+ Xarray = MMI. matrix (Xnew)
157+ (n, p), k = size (Xarray), m. k
158+
159+ pred = zeros (Int, n)
160+ @inbounds for i ∈ 1 : n
161+ minv = Inf
162+ for j ∈ 1 : k
163+ curv = Distances. evaluate (Distances. Euclidean (), view (Xarray, i, :), view (locations, :, j))
164+ P = curv < minv
165+ pred[i] = j * P + pred[i] * ! P # if P is true --> j
166+ minv = curv * P + minv * ! P # if P is true --> curvalue
167+ end
168+ end
169+ return cluster_labels[pred]
170+ end
171+
161172# ###
162173# ### METADATA
163174# ###
@@ -176,6 +187,7 @@ MMI.metadata_pkg.(KMeans,
176187MMI. metadata_model (KMeans,
177188 input = MMI. Table (MMI. Continuous),
178189 output = MMI. Table (MMI. Continuous),
190+ target = AbstractArray{<: MMI.Multiclass },
179191 weights = false ,
180192 descr = ParallelKMeans_Desc,
181193 path = " ParallelKMeans.KMeans" )
0 commit comments