K-Means color clustering
Test: runtest -x tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
Bug: 37014702
Change-Id: Idabc163df5ded362acbe462ae6b351394a36db10
diff --git a/core/java/android/app/WallpaperColors.java b/core/java/android/app/WallpaperColors.java
index 23e9ca5..8f172ba 100644
--- a/core/java/android/app/WallpaperColors.java
+++ b/core/java/android/app/WallpaperColors.java
@@ -27,6 +27,7 @@
import android.util.Size;
import com.android.internal.graphics.palette.Palette;
+import com.android.internal.graphics.palette.VariationalKMeansQuantizer;
import java.util.ArrayList;
import java.util.Collections;
@@ -142,6 +143,8 @@
final Palette palette = Palette
.from(bitmap)
+ .setQuantizer(new VariationalKMeansQuantizer())
+ .maximumColorCount(5)
.clearFilters()
.resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA)
.generate();
diff --git a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
index 56d60a1..9ac753b 100644
--- a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
+++ b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
@@ -61,7 +61,7 @@
* This means that the color space is divided into distinct colors, rather than representative
* colors.
*/
-final class ColorCutQuantizer {
+final class ColorCutQuantizer implements Quantizer {
private static final String LOG_TAG = "ColorCutQuantizer";
private static final boolean LOG_TIMINGS = false;
@@ -73,22 +73,22 @@
private static final int QUANTIZE_WORD_WIDTH = 5;
private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1;
- final int[] mColors;
- final int[] mHistogram;
- final List<Swatch> mQuantizedColors;
- final TimingLogger mTimingLogger;
- final Palette.Filter[] mFilters;
+ int[] mColors;
+ int[] mHistogram;
+ List<Swatch> mQuantizedColors;
+ TimingLogger mTimingLogger;
+ Palette.Filter[] mFilters;
private final float[] mTempHsl = new float[3];
/**
- * Constructor.
+ * Execute color quantization.
*
* @param pixels histogram representing an image's pixel data
* @param maxColors The maximum number of colors that should be in the result palette.
* @param filters Set of filters to use in the quantization stage
*/
- ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
+ public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null;
mFilters = filters;
@@ -160,7 +160,7 @@
/**
* @return the list of quantized colors
*/
- List<Swatch> getQuantizedColors() {
+ public List<Swatch> getQuantizedColors() {
return mQuantizedColors;
}
diff --git a/core/java/com/android/internal/graphics/palette/Palette.java b/core/java/com/android/internal/graphics/palette/Palette.java
index 9f1504a..a4f9a59 100644
--- a/core/java/com/android/internal/graphics/palette/Palette.java
+++ b/core/java/com/android/internal/graphics/palette/Palette.java
@@ -613,6 +613,8 @@
private final List<Palette.Filter> mFilters = new ArrayList<>();
private Rect mRegion;
+ private Quantizer mQuantizer;
+
/**
* Construct a new {@link Palette.Builder} using a source {@link Bitmap}
*/
@@ -726,6 +728,18 @@
}
/**
+ * Set a specific quantization algorithm. {@link ColorCutQuantizer} will
+ * be used if unspecified.
+ *
+ * @param quantizer Quantizer implementation.
+ */
+ @NonNull
+ public Palette.Builder setQuantizer(Quantizer quantizer) {
+ mQuantizer = quantizer;
+ return this;
+ }
+
+ /**
* Set a region of the bitmap to be used exclusively when calculating the palette.
* <p>This only works when the original input is a {@link Bitmap}.</p>
*
@@ -818,17 +832,19 @@
}
// Now generate a quantizer from the Bitmap
- final ColorCutQuantizer quantizer = new ColorCutQuantizer(
- getPixelsFromBitmap(bitmap),
- mMaxColors,
- mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()]));
+ if (mQuantizer == null) {
+ mQuantizer = new ColorCutQuantizer();
+ }
+ mQuantizer.quantize(getPixelsFromBitmap(bitmap),
+ mMaxColors, mFilters.isEmpty() ? null :
+ mFilters.toArray(new Palette.Filter[mFilters.size()]));
// If created a new bitmap, recycle it
if (bitmap != mBitmap) {
bitmap.recycle();
}
- swatches = quantizer.getQuantizedColors();
+ swatches = mQuantizer.getQuantizedColors();
if (logger != null) {
logger.addSplit("Color quantization completed");
diff --git a/core/java/com/android/internal/graphics/palette/Quantizer.java b/core/java/com/android/internal/graphics/palette/Quantizer.java
new file mode 100644
index 0000000..db60f2e
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/Quantizer.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import java.util.List;
+
+/**
+ * Definition of an algorithm that receives pixels and outputs a list of colors.
+ */
+public interface Quantizer {
+ void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters);
+ List<Palette.Swatch> getQuantizedColors();
+}
diff --git a/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
new file mode 100644
index 0000000..b035535
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
@@ -0,0 +1,154 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import android.util.Log;
+
+import com.android.internal.graphics.ColorUtils;
+import com.android.internal.ml.clustering.KMeans;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * A quantizer that uses k-means
+ */
+public class VariationalKMeansQuantizer implements Quantizer {
+
+ private static final String TAG = "KMeansQuantizer";
+ private static final boolean DEBUG = false;
+
+ /**
+ * Clusters closer than this value will me merged.
+ */
+ private final float mMinClusterSqDistance;
+
+ /**
+ * K-means can get stuck in local optima, this can be avoided by
+ * repeating it and getting the "best" execution.
+ */
+ private final int mInitializations;
+
+ /**
+ * Initialize KMeans with a fixed random state to have
+ * consistent results across multiple runs.
+ */
+ private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
+
+ private List<Palette.Swatch> mQuantizedColors;
+
+ public VariationalKMeansQuantizer() {
+ this(0.25f /* cluster distance */);
+ }
+
+ public VariationalKMeansQuantizer(float minClusterDistance) {
+ this(minClusterDistance, 1 /* initializations */);
+ }
+
+ public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
+ mMinClusterSqDistance = minClusterDistance * minClusterDistance;
+ mInitializations = initializations;
+ }
+
+ /**
+ * K-Means quantizer.
+ *
+ * @param pixels Pixels to quantize.
+ * @param maxColors Maximum number of clusters to extract.
+ * @param filters Colors that should be ignored
+ */
+ @Override
+ public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
+ // Start by converting all colors to HSL.
+ // HLS is way more meaningful for clustering than RGB.
+ final float[] hsl = {0, 0, 0};
+ final float[][] hslPixels = new float[pixels.length][3];
+ for (int i = 0; i < pixels.length; i++) {
+ ColorUtils.colorToHSL(pixels[i], hsl);
+ // Normalize hue so all values go from 0 to 1.
+ hslPixels[i][0] = hsl[0] / 360f;
+ hslPixels[i][1] = hsl[1];
+ hslPixels[i][2] = hsl[2];
+ }
+
+ final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
+
+ // Ideally we should run k-means again to merge clusters but it would be too expensive,
+ // instead we just merge all clusters that are closer than a threshold.
+ for (int i = 0; i < optimalMeans.size(); i++) {
+ KMeans.Mean current = optimalMeans.get(i);
+ float[] currentCentroid = current.getCentroid();
+ for (int j = i + 1; j < optimalMeans.size(); j++) {
+ KMeans.Mean compareTo = optimalMeans.get(j);
+ float[] compareToCentroid = compareTo.getCentroid();
+ float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
+ // Merge them
+ if (sqDistance < mMinClusterSqDistance) {
+ optimalMeans.remove(compareTo);
+ current.getItems().addAll(compareTo.getItems());
+ for (int k = 0; k < currentCentroid.length; k++) {
+ currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
+ }
+ j--;
+ }
+ }
+ }
+
+ // Convert data to final format, de-normalizing the hue.
+ mQuantizedColors = new ArrayList<>();
+ for (KMeans.Mean mean : optimalMeans) {
+ if (mean.getItems().size() == 0) {
+ continue;
+ }
+ float[] centroid = mean.getCentroid();
+ mQuantizedColors.add(new Palette.Swatch(new float[]{
+ centroid[0] * 360f,
+ centroid[1],
+ centroid[2]
+ }, mean.getItems().size()));
+ }
+ }
+
+ private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
+ List<KMeans.Mean> optimal = null;
+ double optimalScore = -Double.MAX_VALUE;
+ int runs = mInitializations;
+ while (runs > 0) {
+ if (DEBUG) {
+ Log.d(TAG, "k-means run: " + runs);
+ }
+ List<KMeans.Mean> means = mKMeans.predict(k, inputData);
+ double score = KMeans.score(means);
+ if (optimal == null || score > optimalScore) {
+ if (DEBUG) {
+ Log.d(TAG, "\tnew optimal score: " + score);
+ }
+ optimalScore = score;
+ optimal = means;
+ }
+ runs--;
+ }
+
+ return optimal;
+ }
+
+ @Override
+ public List<Palette.Swatch> getQuantizedColors() {
+ return mQuantizedColors;
+ }
+}
diff --git a/core/java/com/android/internal/ml/clustering/KMeans.java b/core/java/com/android/internal/ml/clustering/KMeans.java
new file mode 100644
index 0000000..4d5b333
--- /dev/null
+++ b/core/java/com/android/internal/ml/clustering/KMeans.java
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.ml.clustering;
+
+import android.annotation.NonNull;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Simple K-Means implementation
+ */
+public class KMeans {
+
+ private static final boolean DEBUG = false;
+ private static final String TAG = "KMeans";
+ private final Random mRandomState;
+ private final int mMaxIterations;
+ private float mSqConvergenceEpsilon;
+
+ public KMeans() {
+ this(new Random());
+ }
+
+ public KMeans(Random random) {
+ this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ }
+ public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
+ mRandomState = random;
+ mMaxIterations = maxIterations;
+ mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ }
+
+ /**
+ * Runs k-means on the input data (X) trying to find k means.
+ *
+ * K-Means is known for getting stuck into local optima, so you might
+ * want to run it multiple time and argmax on {@link KMeans#score(List)}
+ *
+ * @param k The number of points to return.
+ * @param inputData Input data.
+ * @return An array of k Means, each representing a centroid and data points that belong to it.
+ */
+ public List<Mean> predict(final int k, final float[][] inputData) {
+ checkDataSetSanity(inputData);
+ int dimension = inputData[0].length;
+
+ final ArrayList<Mean> means = new ArrayList<>();
+ for (int i = 0; i < k; i++) {
+ Mean m = new Mean(dimension);
+ for (int j = 0; j < dimension; j++) {
+ m.mCentroid[j] = mRandomState.nextFloat();
+ }
+ means.add(m);
+ }
+
+ // Iterate until we converge or run out of iterations
+ boolean converged = false;
+ for (int i = 0; i < mMaxIterations; i++) {
+ converged = step(means, inputData);
+ if (converged) {
+ if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
+ break;
+ }
+ }
+ if (!converged && DEBUG) Log.d(TAG, "Did not converge");
+
+ return means;
+ }
+
+ /**
+ * Score calculates the inertia between means.
+ * This can be considered as an E step of an EM algorithm.
+ *
+ * @param means Means to use when calculating score.
+ * @return The score
+ */
+ public static double score(@NonNull List<Mean> means) {
+ double score = 0;
+ final int meansSize = means.size();
+ for (int i = 0; i < meansSize; i++) {
+ Mean mean = means.get(i);
+ for (int j = 0; j < meansSize; j++) {
+ Mean compareTo = means.get(j);
+ if (mean == compareTo) {
+ continue;
+ }
+ double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
+ score += distance;
+ }
+ }
+ return score;
+ }
+
+ @VisibleForTesting
+ public void checkDataSetSanity(float[][] inputData) {
+ if (inputData == null) {
+ throw new IllegalArgumentException("Data set is null.");
+ } else if (inputData.length == 0) {
+ throw new IllegalArgumentException("Data set is empty.");
+ } else if (inputData[0] == null) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+
+ final int dimension = inputData[0].length;
+ final int length = inputData.length;
+ for (int i = 1; i < length; i++) {
+ if (inputData[i] == null || inputData[i].length != dimension) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+ }
+ }
+
+ /**
+ * K-Means iteration.
+ *
+ * @param means Current means
+ * @param inputData Input data
+ * @return True if data set converged
+ */
+ private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
+
+ // Clean up the previous state because we need to compute
+ // which point belongs to each mean again.
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ mean.mClosestItems.clear();
+ }
+ for (int i = inputData.length - 1; i >= 0; i--) {
+ final float[] current = inputData[i];
+ final Mean nearest = nearestMean(current, means);
+ nearest.mClosestItems.add(current);
+ }
+
+ boolean converged = true;
+ // Move each mean towards the nearest data set points
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ if (mean.mClosestItems.size() == 0) {
+ continue;
+ }
+
+ // Compute the new mean centroid:
+ // 1. Sum all all points
+ // 2. Average them
+ final float[] oldCentroid = mean.mCentroid;
+ mean.mCentroid = new float[oldCentroid.length];
+ for (int j = 0; j < mean.mClosestItems.size(); j++) {
+ // Update each centroid component
+ for (int p = 0; p < mean.mCentroid.length; p++) {
+ mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
+ }
+ }
+ for (int j = 0; j < mean.mCentroid.length; j++) {
+ mean.mCentroid[j] /= mean.mClosestItems.size();
+ }
+
+ // We converged if the centroid didn't move for any of the means.
+ if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
+ converged = false;
+ }
+ }
+ return converged;
+ }
+
+ @VisibleForTesting
+ public static Mean nearestMean(float[] point, List<Mean> means) {
+ Mean nearest = null;
+ float nearestDistance = Float.MAX_VALUE;
+
+ final int meanCount = means.size();
+ for (int i = 0; i < meanCount; i++) {
+ Mean next = means.get(i);
+ // We don't need the sqrt when comparing distances in euclidean space
+ // because they exist on both sides of the equation and cancel each other out.
+ float nextDistance = sqDistance(point, next.mCentroid);
+ if (nextDistance < nearestDistance) {
+ nearest = next;
+ nearestDistance = nextDistance;
+ }
+ }
+ return nearest;
+ }
+
+ @VisibleForTesting
+ public static float sqDistance(float[] a, float[] b) {
+ float dist = 0;
+ final int length = a.length;
+ for (int i = 0; i < length; i++) {
+ dist += (a[i] - b[i]) * (a[i] - b[i]);
+ }
+ return dist;
+ }
+
+ /**
+ * Definition of a mean, contains a centroid and points on its cluster.
+ */
+ public static class Mean {
+ float[] mCentroid;
+ final ArrayList<float[]> mClosestItems = new ArrayList<>();
+
+ public Mean(int dimension) {
+ mCentroid = new float[dimension];
+ }
+
+ public Mean(float ...centroid) {
+ mCentroid = centroid;
+ }
+
+ public float[] getCentroid() {
+ return mCentroid;
+ }
+
+ public List<float[]> getItems() {
+ return mClosestItems;
+ }
+
+ @Override
+ public String toString() {
+ return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
+ + mClosestItems.size() + ")";
+ }
+ }
+}
diff --git a/tests/Internal/Android.mk b/tests/Internal/Android.mk
new file mode 100644
index 0000000..f59a624
--- /dev/null
+++ b/tests/Internal/Android.mk
@@ -0,0 +1,20 @@
+LOCAL_PATH:= $(call my-dir)
+include $(CLEAR_VARS)
+
+LOCAL_USE_AAPT2 := true
+LOCAL_MODULE_TAGS := tests
+
+LOCAL_PROTOC_OPTIMIZE_TYPE := nano
+
+# Include some source files directly to be able to access package members
+LOCAL_SRC_FILES := $(call all-java-files-under, src)
+
+LOCAL_JAVA_LIBRARIES := android.test.runner
+LOCAL_STATIC_JAVA_LIBRARIES := junit legacy-android-test android-support-test
+
+LOCAL_CERTIFICATE := platform
+
+LOCAL_PACKAGE_NAME := InternalTests
+LOCAL_COMPATIBILITY_SUITE := device-tests
+
+include $(BUILD_PACKAGE)
diff --git a/tests/Internal/AndroidManifest.xml b/tests/Internal/AndroidManifest.xml
new file mode 100644
index 0000000..a2c95fb
--- /dev/null
+++ b/tests/Internal/AndroidManifest.xml
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Copyright (C) 2017 The Android Open Source Project
+ ~
+ ~ Licensed under the Apache License, Version 2.0 (the "License");
+ ~ you may not use this file except in compliance with the License.
+ ~ You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License
+ -->
+
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.internal.tests">
+ <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
+ <application>
+ <uses-library android:name="android.test.runner" />
+ </application>
+
+ <instrumentation android:name="android.support.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.internal.tests"
+ android:label="Internal Tests" />
+</manifest>
diff --git a/tests/Internal/AndroidTest.xml b/tests/Internal/AndroidTest.xml
new file mode 100644
index 0000000..6531c93
--- /dev/null
+++ b/tests/Internal/AndroidTest.xml
@@ -0,0 +1,29 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Copyright (C) 2017 The Android Open Source Project
+ ~
+ ~ Licensed under the Apache License, Version 2.0 (the "License");
+ ~ you may not use this file except in compliance with the License.
+ ~ You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License
+ -->
+<configuration description="Runs tests for internal classes/utilities.">
+ <target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
+ <option name="test-file-name" value="InternalTests.apk" />
+ </target_preparer>
+
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="framework-base-presubmit" />
+ <option name="test-tag" value="InternalTests" />
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.internal.tests" />
+ <option name="runner" value="android.support.test.runner.AndroidJUnitRunner" />
+ </test>
+</configuration>
\ No newline at end of file
diff --git a/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
new file mode 100644
index 0000000..a64f8a6
--- /dev/null
+++ b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.internal.ml.clustering;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import android.annotation.SuppressLint;
+import android.support.test.filters.SmallTest;
+import android.support.test.runner.AndroidJUnit4;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class KMeansTest {
+
+ // Error tolerance (epsilon)
+ private static final double EPS = 0.01;
+
+ private KMeans mKMeans;
+
+ @Before
+ public void setUp() {
+ // Setup with a random seed to have predictable results
+ mKMeans = new KMeans(new Random(0), 30, 0);
+ }
+
+ @Test
+ public void getCheckDataSanityTest() {
+ try {
+ mKMeans.checkDataSetSanity(new float[][] {
+ {0, 1, 2},
+ {1, 2, 3}
+ });
+ } catch (IllegalArgumentException e) {
+ Assert.fail("Valid data didn't pass sanity check");
+ }
+
+ try {
+ mKMeans.checkDataSetSanity(new float[][] {
+ null,
+ {1, 2, 3}
+ });
+ Assert.fail("Data has null items and passed");
+ } catch (IllegalArgumentException e) {}
+
+ try {
+ mKMeans.checkDataSetSanity(new float[][] {
+ {0, 1, 2, 4},
+ {1, 2, 3}
+ });
+ Assert.fail("Data has invalid shape and passed");
+ } catch (IllegalArgumentException e) {}
+
+ try {
+ mKMeans.checkDataSetSanity(null);
+ Assert.fail("Null data should throw exception");
+ } catch (IllegalArgumentException e) {}
+ }
+
+ @Test
+ public void sqDistanceTest() {
+ float a[] = {4, 10};
+ float b[] = {5, 2};
+ float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
+
+ assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
+ }
+
+ @Test
+ public void nearestMeanTest() {
+ KMeans.Mean meanA = new KMeans.Mean(0, 1);
+ KMeans.Mean meanB = new KMeans.Mean(1, 1);
+ List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
+
+ KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
+
+ assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
+ }
+
+ @SuppressLint("DefaultLocale")
+ @Test
+ public void scoreTest() {
+ List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
+ new KMeans.Mean(0, 0.1f, 0.15f),
+ new KMeans.Mean(0.1f, 0.2f, 0.1f));
+ List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
+ new KMeans.Mean(0, 0.5f, 0.5f),
+ new KMeans.Mean(1, 0.9f, 0.9f));
+
+ double closeScore = KMeans.score(closeMeans);
+ double farScore = KMeans.score(farMeans);
+ assertTrue(String.format("Score of well distributed means should be greater than "
+ + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
+ }
+
+ @Test
+ public void predictTest() {
+ float[] expectedCentroid1 = {1, 1, 1};
+ float[] expectedCentroid2 = {0, 0, 0};
+ float[][] X = new float[][] {
+ {1, 1, 1},
+ {1, 1, 1},
+ {1, 1, 1},
+ {0, 0, 0},
+ {0, 0, 0},
+ {0, 0, 0},
+ };
+
+ final int numClusters = 2;
+
+ // Here we assume that we won't get stuck into a local optima.
+ // It's fine because we're seeding a random, we won't ever have
+ // unstable results but in real life we need multiple initialization
+ // and score comparison
+ List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
+
+ assertEquals("Expected number of clusters is invalid", numClusters, means.size());
+
+ boolean exists1 = false, exists2 = false;
+ for (KMeans.Mean mean : means) {
+ if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
+ exists1 = true;
+ } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
+ exists2 = true;
+ } else {
+ throw new AssertionError("Unexpected mean: " + mean);
+ }
+ }
+ assertTrue("Expected means were not predicted, got: " + means,
+ exists1 && exists2);
+ }
+}