blob: a01af54e3bbf8a8260a898233f3c5ad39b60f17c [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.textclassifier.ulp.kmeans;
import android.util.Log;
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/** Simple K-Means implementation which found in android internal ml library */
public class KMeans {
private static final boolean DEBUG = false;
private static final String TAG = "KMeans";
private final Random randomState;
private final int maxIterations;
private final float sqConvergenceEpsilon;
public KMeans() {
this(new Random());
}
public KMeans(Random random) {
this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
}
public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
randomState = random;
this.maxIterations = maxIterations;
sqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
}
/**
* Runs k-means on the input data (X) trying to find k means.
*
* <p>K-Means is known for getting stuck into local optima, so you might want to run it multiple
* time and argmax on {@link KMeans#score(List)}
*
* @param k The number of points to return.
* @param inputData Input data.
* @return An array of k Means, each representing a centroid and data points that belong to it.
*/
public List<Mean> predict(final int k, final float[][] inputData) {
checkDataSetSanity(inputData);
int dimension = inputData[0].length;
final ArrayList<Mean> means = new ArrayList<>();
for (int i = 0; i < k; i++) {
Mean m = new Mean(dimension);
for (int j = 0; j < dimension; j++) {
m.centroid[j] = randomState.nextFloat();
}
means.add(m);
}
// Iterate until we converge or run out of iterations
boolean converged = false;
for (int i = 0; i < maxIterations; i++) {
converged = step(means, inputData);
if (converged) {
if (DEBUG) {
Log.d(TAG, "Converged at iteration: " + i);
}
break;
}
}
if (!converged && DEBUG) {
Log.d(TAG, "Did not converge");
}
return means;
}
/**
* Score calculates the inertia between means. This can be considered as an E step of an EM
* algorithm.
*
* @param means Means to use when calculating score.
* @return The score
*/
public static double score(List<Mean> means) {
double score = 0;
final int meansSize = means.size();
for (int i = 0; i < meansSize; i++) {
Mean mean = means.get(i);
for (int j = 0; j < meansSize; j++) {
Mean compareTo = means.get(j);
if (mean == compareTo) {
continue;
}
double distance = Math.sqrt(sqDistance(mean.centroid, compareTo.centroid));
score += distance;
}
}
return score;
}
/** */
@VisibleForTesting
public void checkDataSetSanity(float[][] inputData) {
if (inputData == null) {
throw new IllegalArgumentException("Data set is null.");
} else if (inputData.length == 0) {
throw new IllegalArgumentException("Data set is empty.");
} else if (inputData[0] == null) {
throw new IllegalArgumentException("Bad data set format.");
}
final int dimension = inputData[0].length;
final int length = inputData.length;
for (int i = 1; i < length; i++) {
if (inputData[i] == null || inputData[i].length != dimension) {
throw new IllegalArgumentException("Bad data set format.");
}
}
}
/**
* K-Means iteration.
*
* @param means Current means
* @param inputData Input data
* @return True if data set converged
*/
private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
// Clean up the previous state because we need to compute
// which point belongs to each mean again.
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
mean.closestItems.clear();
}
for (int i = inputData.length - 1; i >= 0; i--) {
final float[] current = inputData[i];
final Mean nearest = nearestMean(current, means);
nearest.closestItems.add(current);
}
boolean converged = true;
// Move each mean towards the nearest data set points
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
if (mean.closestItems.isEmpty()) {
continue;
}
// Compute the new mean centroid:
// 1. Sum all all points
// 2. Average them
final float[] oldCentroid = mean.centroid;
mean.centroid = new float[oldCentroid.length];
for (int j = 0; j < mean.closestItems.size(); j++) {
// Update each centroid component
for (int p = 0; p < mean.centroid.length; p++) {
mean.centroid[p] += mean.closestItems.get(j)[p];
}
}
for (int j = 0; j < mean.centroid.length; j++) {
mean.centroid[j] /= mean.closestItems.size();
}
// We converged if the centroid didn't move for any of the means.
if (sqDistance(oldCentroid, mean.centroid) > sqConvergenceEpsilon) {
converged = false;
}
}
return converged;
}
/** */
@VisibleForTesting
public static Mean nearestMean(float[] point, List<Mean> means) {
Mean nearest = null;
float nearestDistance = Float.MAX_VALUE;
final int meanCount = means.size();
for (int i = 0; i < meanCount; i++) {
Mean next = means.get(i);
// We don't need the sqrt when comparing distances in euclidean space
// because they exist on both sides of the equation and cancel each other out.
float nextDistance = sqDistance(point, next.centroid);
if (nextDistance < nearestDistance) {
nearest = next;
nearestDistance = nextDistance;
}
}
return nearest;
}
/** */
@VisibleForTesting
public static float sqDistance(float[] a, float[] b) {
float dist = 0;
final int length = a.length;
for (int i = 0; i < length; i++) {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
return dist;
}
/** Definition of a mean, contains a centroid and points on its cluster. */
public static class Mean {
float[] centroid;
final ArrayList<float[]> closestItems = new ArrayList<>();
public Mean(int dimension) {
centroid = new float[dimension];
}
public Mean(float... centroid) {
this.centroid = centroid;
}
public float[] getCentroid() {
return centroid;
}
public List<float[]> getItems() {
return closestItems;
}
@Override
public String toString() {
return "Mean(centroid: " + Arrays.toString(centroid) + ", size: " + closestItems.size() + ")";
}
}
}