Raymond | dee0849 | 2015-04-02 10:43:13 -0700 | [diff] [blame] | 1 | /* |
| 2 | * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | * contributor license agreements. See the NOTICE file distributed with |
| 4 | * this work for additional information regarding copyright ownership. |
| 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | * (the "License"); you may not use this file except in compliance with |
| 7 | * the License. You may obtain a copy of the License at |
| 8 | * |
| 9 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | * |
| 11 | * Unless required by applicable law or agreed to in writing, software |
| 12 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | * See the License for the specific language governing permissions and |
| 15 | * limitations under the License. |
| 16 | */ |
| 17 | |
| 18 | package org.apache.commons.math.stat.clustering; |
| 19 | |
| 20 | import java.util.ArrayList; |
| 21 | import java.util.Collection; |
| 22 | import java.util.List; |
| 23 | import java.util.Random; |
| 24 | |
| 25 | import org.apache.commons.math.exception.ConvergenceException; |
| 26 | import org.apache.commons.math.exception.util.LocalizedFormats; |
| 27 | import org.apache.commons.math.stat.descriptive.moment.Variance; |
| 28 | |
| 29 | /** |
| 30 | * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. |
| 31 | * @param <T> type of the points to cluster |
| 32 | * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> |
| 33 | * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $ |
| 34 | * @since 2.0 |
| 35 | */ |
| 36 | public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { |
| 37 | |
| 38 | /** Strategies to use for replacing an empty cluster. */ |
| 39 | public static enum EmptyClusterStrategy { |
| 40 | |
| 41 | /** Split the cluster with largest distance variance. */ |
| 42 | LARGEST_VARIANCE, |
| 43 | |
| 44 | /** Split the cluster with largest number of points. */ |
| 45 | LARGEST_POINTS_NUMBER, |
| 46 | |
| 47 | /** Create a cluster around the point farthest from its centroid. */ |
| 48 | FARTHEST_POINT, |
| 49 | |
| 50 | /** Generate an error. */ |
| 51 | ERROR |
| 52 | |
| 53 | } |
| 54 | |
| 55 | /** Random generator for choosing initial centers. */ |
| 56 | private final Random random; |
| 57 | |
| 58 | /** Selected strategy for empty clusters. */ |
| 59 | private final EmptyClusterStrategy emptyStrategy; |
| 60 | |
| 61 | /** Build a clusterer. |
| 62 | * <p> |
| 63 | * The default strategy for handling empty clusters that may appear during |
| 64 | * algorithm iterations is to split the cluster with largest distance variance. |
| 65 | * </p> |
| 66 | * @param random random generator to use for choosing initial centers |
| 67 | */ |
| 68 | public KMeansPlusPlusClusterer(final Random random) { |
| 69 | this(random, EmptyClusterStrategy.LARGEST_VARIANCE); |
| 70 | } |
| 71 | |
| 72 | /** Build a clusterer. |
| 73 | * @param random random generator to use for choosing initial centers |
| 74 | * @param emptyStrategy strategy to use for handling empty clusters that |
| 75 | * may appear during algorithm iterations |
| 76 | * @since 2.2 |
| 77 | */ |
| 78 | public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { |
| 79 | this.random = random; |
| 80 | this.emptyStrategy = emptyStrategy; |
| 81 | } |
| 82 | |
| 83 | /** |
| 84 | * Runs the K-means++ clustering algorithm. |
| 85 | * |
| 86 | * @param points the points to cluster |
| 87 | * @param k the number of clusters to split the data into |
| 88 | * @param maxIterations the maximum number of iterations to run the algorithm |
| 89 | * for. If negative, no maximum will be used |
| 90 | * @return a list of clusters containing the points |
| 91 | */ |
| 92 | public List<Cluster<T>> cluster(final Collection<T> points, |
| 93 | final int k, final int maxIterations) { |
| 94 | // create the initial clusters |
| 95 | List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); |
| 96 | assignPointsToClusters(clusters, points); |
| 97 | |
| 98 | // iterate through updating the centers until we're done |
| 99 | final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; |
| 100 | for (int count = 0; count < max; count++) { |
| 101 | boolean clusteringChanged = false; |
| 102 | List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); |
| 103 | for (final Cluster<T> cluster : clusters) { |
| 104 | final T newCenter; |
| 105 | if (cluster.getPoints().isEmpty()) { |
| 106 | switch (emptyStrategy) { |
| 107 | case LARGEST_VARIANCE : |
| 108 | newCenter = getPointFromLargestVarianceCluster(clusters); |
| 109 | break; |
| 110 | case LARGEST_POINTS_NUMBER : |
| 111 | newCenter = getPointFromLargestNumberCluster(clusters); |
| 112 | break; |
| 113 | case FARTHEST_POINT : |
| 114 | newCenter = getFarthestPoint(clusters); |
| 115 | break; |
| 116 | default : |
| 117 | throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| 118 | } |
| 119 | clusteringChanged = true; |
| 120 | } else { |
| 121 | newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); |
| 122 | if (!newCenter.equals(cluster.getCenter())) { |
| 123 | clusteringChanged = true; |
| 124 | } |
| 125 | } |
| 126 | newClusters.add(new Cluster<T>(newCenter)); |
| 127 | } |
| 128 | if (!clusteringChanged) { |
| 129 | return clusters; |
| 130 | } |
| 131 | assignPointsToClusters(newClusters, points); |
| 132 | clusters = newClusters; |
| 133 | } |
| 134 | return clusters; |
| 135 | } |
| 136 | |
| 137 | /** |
| 138 | * Adds the given points to the closest {@link Cluster}. |
| 139 | * |
| 140 | * @param <T> type of the points to cluster |
| 141 | * @param clusters the {@link Cluster}s to add the points to |
| 142 | * @param points the points to add to the given {@link Cluster}s |
| 143 | */ |
| 144 | private static <T extends Clusterable<T>> void |
| 145 | assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) { |
| 146 | for (final T p : points) { |
| 147 | Cluster<T> cluster = getNearestCluster(clusters, p); |
| 148 | cluster.addPoint(p); |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | /** |
| 153 | * Use K-means++ to choose the initial centers. |
| 154 | * |
| 155 | * @param <T> type of the points to cluster |
| 156 | * @param points the points to choose the initial centers from |
| 157 | * @param k the number of centers to choose |
| 158 | * @param random random generator to use |
| 159 | * @return the initial centers |
| 160 | */ |
| 161 | private static <T extends Clusterable<T>> List<Cluster<T>> |
| 162 | chooseInitialCenters(final Collection<T> points, final int k, final Random random) { |
| 163 | |
| 164 | final List<T> pointSet = new ArrayList<T>(points); |
| 165 | final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); |
| 166 | |
| 167 | // Choose one center uniformly at random from among the data points. |
| 168 | final T firstPoint = pointSet.remove(random.nextInt(pointSet.size())); |
| 169 | resultSet.add(new Cluster<T>(firstPoint)); |
| 170 | |
| 171 | final double[] dx2 = new double[pointSet.size()]; |
| 172 | while (resultSet.size() < k) { |
| 173 | // For each data point x, compute D(x), the distance between x and |
| 174 | // the nearest center that has already been chosen. |
| 175 | int sum = 0; |
| 176 | for (int i = 0; i < pointSet.size(); i++) { |
| 177 | final T p = pointSet.get(i); |
| 178 | final Cluster<T> nearest = getNearestCluster(resultSet, p); |
| 179 | final double d = p.distanceFrom(nearest.getCenter()); |
| 180 | sum += d * d; |
| 181 | dx2[i] = sum; |
| 182 | } |
| 183 | |
| 184 | // Add one new data point as a center. Each point x is chosen with |
| 185 | // probability proportional to D(x)2 |
| 186 | final double r = random.nextDouble() * sum; |
| 187 | for (int i = 0 ; i < dx2.length; i++) { |
| 188 | if (dx2[i] >= r) { |
| 189 | final T p = pointSet.remove(i); |
| 190 | resultSet.add(new Cluster<T>(p)); |
| 191 | break; |
| 192 | } |
| 193 | } |
| 194 | } |
| 195 | |
| 196 | return resultSet; |
| 197 | |
| 198 | } |
| 199 | |
| 200 | /** |
| 201 | * Get a random point from the {@link Cluster} with the largest distance variance. |
| 202 | * |
| 203 | * @param clusters the {@link Cluster}s to search |
| 204 | * @return a random point from the selected cluster |
| 205 | */ |
| 206 | private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) { |
| 207 | |
| 208 | double maxVariance = Double.NEGATIVE_INFINITY; |
| 209 | Cluster<T> selected = null; |
| 210 | for (final Cluster<T> cluster : clusters) { |
| 211 | if (!cluster.getPoints().isEmpty()) { |
| 212 | |
| 213 | // compute the distance variance of the current cluster |
| 214 | final T center = cluster.getCenter(); |
| 215 | final Variance stat = new Variance(); |
| 216 | for (final T point : cluster.getPoints()) { |
| 217 | stat.increment(point.distanceFrom(center)); |
| 218 | } |
| 219 | final double variance = stat.getResult(); |
| 220 | |
| 221 | // select the cluster with the largest variance |
| 222 | if (variance > maxVariance) { |
| 223 | maxVariance = variance; |
| 224 | selected = cluster; |
| 225 | } |
| 226 | |
| 227 | } |
| 228 | } |
| 229 | |
| 230 | // did we find at least one non-empty cluster ? |
| 231 | if (selected == null) { |
| 232 | throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| 233 | } |
| 234 | |
| 235 | // extract a random point from the cluster |
| 236 | final List<T> selectedPoints = selected.getPoints(); |
| 237 | return selectedPoints.remove(random.nextInt(selectedPoints.size())); |
| 238 | |
| 239 | } |
| 240 | |
| 241 | /** |
| 242 | * Get a random point from the {@link Cluster} with the largest number of points |
| 243 | * |
| 244 | * @param clusters the {@link Cluster}s to search |
| 245 | * @return a random point from the selected cluster |
| 246 | */ |
| 247 | private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) { |
| 248 | |
| 249 | int maxNumber = 0; |
| 250 | Cluster<T> selected = null; |
| 251 | for (final Cluster<T> cluster : clusters) { |
| 252 | |
| 253 | // get the number of points of the current cluster |
| 254 | final int number = cluster.getPoints().size(); |
| 255 | |
| 256 | // select the cluster with the largest number of points |
| 257 | if (number > maxNumber) { |
| 258 | maxNumber = number; |
| 259 | selected = cluster; |
| 260 | } |
| 261 | |
| 262 | } |
| 263 | |
| 264 | // did we find at least one non-empty cluster ? |
| 265 | if (selected == null) { |
| 266 | throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| 267 | } |
| 268 | |
| 269 | // extract a random point from the cluster |
| 270 | final List<T> selectedPoints = selected.getPoints(); |
| 271 | return selectedPoints.remove(random.nextInt(selectedPoints.size())); |
| 272 | |
| 273 | } |
| 274 | |
| 275 | /** |
| 276 | * Get the point farthest to its cluster center |
| 277 | * |
| 278 | * @param clusters the {@link Cluster}s to search |
| 279 | * @return point farthest to its cluster center |
| 280 | */ |
| 281 | private T getFarthestPoint(final Collection<Cluster<T>> clusters) { |
| 282 | |
| 283 | double maxDistance = Double.NEGATIVE_INFINITY; |
| 284 | Cluster<T> selectedCluster = null; |
| 285 | int selectedPoint = -1; |
| 286 | for (final Cluster<T> cluster : clusters) { |
| 287 | |
| 288 | // get the farthest point |
| 289 | final T center = cluster.getCenter(); |
| 290 | final List<T> points = cluster.getPoints(); |
| 291 | for (int i = 0; i < points.size(); ++i) { |
| 292 | final double distance = points.get(i).distanceFrom(center); |
| 293 | if (distance > maxDistance) { |
| 294 | maxDistance = distance; |
| 295 | selectedCluster = cluster; |
| 296 | selectedPoint = i; |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | } |
| 301 | |
| 302 | // did we find at least one non-empty cluster ? |
| 303 | if (selectedCluster == null) { |
| 304 | throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| 305 | } |
| 306 | |
| 307 | return selectedCluster.getPoints().remove(selectedPoint); |
| 308 | |
| 309 | } |
| 310 | |
| 311 | /** |
| 312 | * Returns the nearest {@link Cluster} to the given point |
| 313 | * |
| 314 | * @param <T> type of the points to cluster |
| 315 | * @param clusters the {@link Cluster}s to search |
| 316 | * @param point the point to find the nearest {@link Cluster} for |
| 317 | * @return the nearest {@link Cluster} to the given point |
| 318 | */ |
| 319 | private static <T extends Clusterable<T>> Cluster<T> |
| 320 | getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { |
| 321 | double minDistance = Double.MAX_VALUE; |
| 322 | Cluster<T> minCluster = null; |
| 323 | for (final Cluster<T> c : clusters) { |
| 324 | final double distance = point.distanceFrom(c.getCenter()); |
| 325 | if (distance < minDistance) { |
| 326 | minDistance = distance; |
| 327 | minCluster = c; |
| 328 | } |
| 329 | } |
| 330 | return minCluster; |
| 331 | } |
| 332 | |
| 333 | } |