blob: eb61866b09ac8b4061c2818f823cd18e4987ec01 [file] [log] [blame]
Raymonddee08492015-04-02 10:43:13 -07001/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.commons.math.stat.clustering;
19
20import java.util.ArrayList;
21import java.util.Collection;
22import java.util.List;
23import java.util.Random;
24
25import org.apache.commons.math.exception.ConvergenceException;
26import org.apache.commons.math.exception.util.LocalizedFormats;
27import org.apache.commons.math.stat.descriptive.moment.Variance;
28
29/**
30 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
31 * @param <T> type of the points to cluster
32 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
33 * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $
34 * @since 2.0
35 */
36public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
37
38 /** Strategies to use for replacing an empty cluster. */
39 public static enum EmptyClusterStrategy {
40
41 /** Split the cluster with largest distance variance. */
42 LARGEST_VARIANCE,
43
44 /** Split the cluster with largest number of points. */
45 LARGEST_POINTS_NUMBER,
46
47 /** Create a cluster around the point farthest from its centroid. */
48 FARTHEST_POINT,
49
50 /** Generate an error. */
51 ERROR
52
53 }
54
55 /** Random generator for choosing initial centers. */
56 private final Random random;
57
58 /** Selected strategy for empty clusters. */
59 private final EmptyClusterStrategy emptyStrategy;
60
61 /** Build a clusterer.
62 * <p>
63 * The default strategy for handling empty clusters that may appear during
64 * algorithm iterations is to split the cluster with largest distance variance.
65 * </p>
66 * @param random random generator to use for choosing initial centers
67 */
68 public KMeansPlusPlusClusterer(final Random random) {
69 this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
70 }
71
72 /** Build a clusterer.
73 * @param random random generator to use for choosing initial centers
74 * @param emptyStrategy strategy to use for handling empty clusters that
75 * may appear during algorithm iterations
76 * @since 2.2
77 */
78 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
79 this.random = random;
80 this.emptyStrategy = emptyStrategy;
81 }
82
83 /**
84 * Runs the K-means++ clustering algorithm.
85 *
86 * @param points the points to cluster
87 * @param k the number of clusters to split the data into
88 * @param maxIterations the maximum number of iterations to run the algorithm
89 * for. If negative, no maximum will be used
90 * @return a list of clusters containing the points
91 */
92 public List<Cluster<T>> cluster(final Collection<T> points,
93 final int k, final int maxIterations) {
94 // create the initial clusters
95 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
96 assignPointsToClusters(clusters, points);
97
98 // iterate through updating the centers until we're done
99 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
100 for (int count = 0; count < max; count++) {
101 boolean clusteringChanged = false;
102 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
103 for (final Cluster<T> cluster : clusters) {
104 final T newCenter;
105 if (cluster.getPoints().isEmpty()) {
106 switch (emptyStrategy) {
107 case LARGEST_VARIANCE :
108 newCenter = getPointFromLargestVarianceCluster(clusters);
109 break;
110 case LARGEST_POINTS_NUMBER :
111 newCenter = getPointFromLargestNumberCluster(clusters);
112 break;
113 case FARTHEST_POINT :
114 newCenter = getFarthestPoint(clusters);
115 break;
116 default :
117 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
118 }
119 clusteringChanged = true;
120 } else {
121 newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
122 if (!newCenter.equals(cluster.getCenter())) {
123 clusteringChanged = true;
124 }
125 }
126 newClusters.add(new Cluster<T>(newCenter));
127 }
128 if (!clusteringChanged) {
129 return clusters;
130 }
131 assignPointsToClusters(newClusters, points);
132 clusters = newClusters;
133 }
134 return clusters;
135 }
136
137 /**
138 * Adds the given points to the closest {@link Cluster}.
139 *
140 * @param <T> type of the points to cluster
141 * @param clusters the {@link Cluster}s to add the points to
142 * @param points the points to add to the given {@link Cluster}s
143 */
144 private static <T extends Clusterable<T>> void
145 assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
146 for (final T p : points) {
147 Cluster<T> cluster = getNearestCluster(clusters, p);
148 cluster.addPoint(p);
149 }
150 }
151
152 /**
153 * Use K-means++ to choose the initial centers.
154 *
155 * @param <T> type of the points to cluster
156 * @param points the points to choose the initial centers from
157 * @param k the number of centers to choose
158 * @param random random generator to use
159 * @return the initial centers
160 */
161 private static <T extends Clusterable<T>> List<Cluster<T>>
162 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
163
164 final List<T> pointSet = new ArrayList<T>(points);
165 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
166
167 // Choose one center uniformly at random from among the data points.
168 final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
169 resultSet.add(new Cluster<T>(firstPoint));
170
171 final double[] dx2 = new double[pointSet.size()];
172 while (resultSet.size() < k) {
173 // For each data point x, compute D(x), the distance between x and
174 // the nearest center that has already been chosen.
175 int sum = 0;
176 for (int i = 0; i < pointSet.size(); i++) {
177 final T p = pointSet.get(i);
178 final Cluster<T> nearest = getNearestCluster(resultSet, p);
179 final double d = p.distanceFrom(nearest.getCenter());
180 sum += d * d;
181 dx2[i] = sum;
182 }
183
184 // Add one new data point as a center. Each point x is chosen with
185 // probability proportional to D(x)2
186 final double r = random.nextDouble() * sum;
187 for (int i = 0 ; i < dx2.length; i++) {
188 if (dx2[i] >= r) {
189 final T p = pointSet.remove(i);
190 resultSet.add(new Cluster<T>(p));
191 break;
192 }
193 }
194 }
195
196 return resultSet;
197
198 }
199
200 /**
201 * Get a random point from the {@link Cluster} with the largest distance variance.
202 *
203 * @param clusters the {@link Cluster}s to search
204 * @return a random point from the selected cluster
205 */
206 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
207
208 double maxVariance = Double.NEGATIVE_INFINITY;
209 Cluster<T> selected = null;
210 for (final Cluster<T> cluster : clusters) {
211 if (!cluster.getPoints().isEmpty()) {
212
213 // compute the distance variance of the current cluster
214 final T center = cluster.getCenter();
215 final Variance stat = new Variance();
216 for (final T point : cluster.getPoints()) {
217 stat.increment(point.distanceFrom(center));
218 }
219 final double variance = stat.getResult();
220
221 // select the cluster with the largest variance
222 if (variance > maxVariance) {
223 maxVariance = variance;
224 selected = cluster;
225 }
226
227 }
228 }
229
230 // did we find at least one non-empty cluster ?
231 if (selected == null) {
232 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
233 }
234
235 // extract a random point from the cluster
236 final List<T> selectedPoints = selected.getPoints();
237 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
238
239 }
240
241 /**
242 * Get a random point from the {@link Cluster} with the largest number of points
243 *
244 * @param clusters the {@link Cluster}s to search
245 * @return a random point from the selected cluster
246 */
247 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
248
249 int maxNumber = 0;
250 Cluster<T> selected = null;
251 for (final Cluster<T> cluster : clusters) {
252
253 // get the number of points of the current cluster
254 final int number = cluster.getPoints().size();
255
256 // select the cluster with the largest number of points
257 if (number > maxNumber) {
258 maxNumber = number;
259 selected = cluster;
260 }
261
262 }
263
264 // did we find at least one non-empty cluster ?
265 if (selected == null) {
266 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
267 }
268
269 // extract a random point from the cluster
270 final List<T> selectedPoints = selected.getPoints();
271 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
272
273 }
274
275 /**
276 * Get the point farthest to its cluster center
277 *
278 * @param clusters the {@link Cluster}s to search
279 * @return point farthest to its cluster center
280 */
281 private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
282
283 double maxDistance = Double.NEGATIVE_INFINITY;
284 Cluster<T> selectedCluster = null;
285 int selectedPoint = -1;
286 for (final Cluster<T> cluster : clusters) {
287
288 // get the farthest point
289 final T center = cluster.getCenter();
290 final List<T> points = cluster.getPoints();
291 for (int i = 0; i < points.size(); ++i) {
292 final double distance = points.get(i).distanceFrom(center);
293 if (distance > maxDistance) {
294 maxDistance = distance;
295 selectedCluster = cluster;
296 selectedPoint = i;
297 }
298 }
299
300 }
301
302 // did we find at least one non-empty cluster ?
303 if (selectedCluster == null) {
304 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
305 }
306
307 return selectedCluster.getPoints().remove(selectedPoint);
308
309 }
310
311 /**
312 * Returns the nearest {@link Cluster} to the given point
313 *
314 * @param <T> type of the points to cluster
315 * @param clusters the {@link Cluster}s to search
316 * @param point the point to find the nearest {@link Cluster} for
317 * @return the nearest {@link Cluster} to the given point
318 */
319 private static <T extends Clusterable<T>> Cluster<T>
320 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
321 double minDistance = Double.MAX_VALUE;
322 Cluster<T> minCluster = null;
323 for (final Cluster<T> c : clusters) {
324 final double distance = point.distanceFrom(c.getCenter());
325 if (distance < minDistance) {
326 minDistance = distance;
327 minCluster = c;
328 }
329 }
330 return minCluster;
331 }
332
333}