K-Means color clustering
Test: runtest -x tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
Bug: 37014702
Change-Id: Idabc163df5ded362acbe462ae6b351394a36db10
diff --git a/core/java/android/app/WallpaperColors.java b/core/java/android/app/WallpaperColors.java
index 23e9ca5..8f172ba 100644
--- a/core/java/android/app/WallpaperColors.java
+++ b/core/java/android/app/WallpaperColors.java
@@ -27,6 +27,7 @@
import android.util.Size;
import com.android.internal.graphics.palette.Palette;
+import com.android.internal.graphics.palette.VariationalKMeansQuantizer;
import java.util.ArrayList;
import java.util.Collections;
@@ -142,6 +143,8 @@
final Palette palette = Palette
.from(bitmap)
+ .setQuantizer(new VariationalKMeansQuantizer())
+ .maximumColorCount(5)
.clearFilters()
.resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA)
.generate();
diff --git a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
index 56d60a1..9ac753b 100644
--- a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
+++ b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
@@ -61,7 +61,7 @@
* This means that the color space is divided into distinct colors, rather than representative
* colors.
*/
-final class ColorCutQuantizer {
+final class ColorCutQuantizer implements Quantizer {
private static final String LOG_TAG = "ColorCutQuantizer";
private static final boolean LOG_TIMINGS = false;
@@ -73,22 +73,22 @@
private static final int QUANTIZE_WORD_WIDTH = 5;
private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1;
- final int[] mColors;
- final int[] mHistogram;
- final List<Swatch> mQuantizedColors;
- final TimingLogger mTimingLogger;
- final Palette.Filter[] mFilters;
+ int[] mColors;
+ int[] mHistogram;
+ List<Swatch> mQuantizedColors;
+ TimingLogger mTimingLogger;
+ Palette.Filter[] mFilters;
private final float[] mTempHsl = new float[3];
/**
- * Constructor.
+ * Execute color quantization.
*
* @param pixels histogram representing an image's pixel data
* @param maxColors The maximum number of colors that should be in the result palette.
* @param filters Set of filters to use in the quantization stage
*/
- ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
+ public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null;
mFilters = filters;
@@ -160,7 +160,7 @@
/**
* @return the list of quantized colors
*/
- List<Swatch> getQuantizedColors() {
+ public List<Swatch> getQuantizedColors() {
return mQuantizedColors;
}
diff --git a/core/java/com/android/internal/graphics/palette/Palette.java b/core/java/com/android/internal/graphics/palette/Palette.java
index 9f1504a..a4f9a59 100644
--- a/core/java/com/android/internal/graphics/palette/Palette.java
+++ b/core/java/com/android/internal/graphics/palette/Palette.java
@@ -613,6 +613,8 @@
private final List<Palette.Filter> mFilters = new ArrayList<>();
private Rect mRegion;
+ private Quantizer mQuantizer;
+
/**
* Construct a new {@link Palette.Builder} using a source {@link Bitmap}
*/
@@ -726,6 +728,18 @@
}
/**
+ * Set a specific quantization algorithm. {@link ColorCutQuantizer} will
+ * be used if unspecified.
+ *
+ * @param quantizer Quantizer implementation.
+ */
+ @NonNull
+ public Palette.Builder setQuantizer(Quantizer quantizer) {
+ mQuantizer = quantizer;
+ return this;
+ }
+
+ /**
* Set a region of the bitmap to be used exclusively when calculating the palette.
* <p>This only works when the original input is a {@link Bitmap}.</p>
*
@@ -818,17 +832,19 @@
}
// Now generate a quantizer from the Bitmap
- final ColorCutQuantizer quantizer = new ColorCutQuantizer(
- getPixelsFromBitmap(bitmap),
- mMaxColors,
- mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()]));
+ if (mQuantizer == null) {
+ mQuantizer = new ColorCutQuantizer();
+ }
+ mQuantizer.quantize(getPixelsFromBitmap(bitmap),
+ mMaxColors, mFilters.isEmpty() ? null :
+ mFilters.toArray(new Palette.Filter[mFilters.size()]));
// If created a new bitmap, recycle it
if (bitmap != mBitmap) {
bitmap.recycle();
}
- swatches = quantizer.getQuantizedColors();
+ swatches = mQuantizer.getQuantizedColors();
if (logger != null) {
logger.addSplit("Color quantization completed");
diff --git a/core/java/com/android/internal/graphics/palette/Quantizer.java b/core/java/com/android/internal/graphics/palette/Quantizer.java
new file mode 100644
index 0000000..db60f2e
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/Quantizer.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import java.util.List;
+
+/**
+ * Definition of an algorithm that receives pixels and outputs a list of colors.
+ */
+public interface Quantizer {
+ void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters);
+ List<Palette.Swatch> getQuantizedColors();
+}
diff --git a/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
new file mode 100644
index 0000000..b035535
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
@@ -0,0 +1,154 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import android.util.Log;
+
+import com.android.internal.graphics.ColorUtils;
+import com.android.internal.ml.clustering.KMeans;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * A quantizer that uses k-means
+ */
+public class VariationalKMeansQuantizer implements Quantizer {
+
+ private static final String TAG = "KMeansQuantizer";
+ private static final boolean DEBUG = false;
+
+ /**
+ * Clusters closer than this value will me merged.
+ */
+ private final float mMinClusterSqDistance;
+
+ /**
+ * K-means can get stuck in local optima, this can be avoided by
+ * repeating it and getting the "best" execution.
+ */
+ private final int mInitializations;
+
+ /**
+ * Initialize KMeans with a fixed random state to have
+ * consistent results across multiple runs.
+ */
+ private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
+
+ private List<Palette.Swatch> mQuantizedColors;
+
+ public VariationalKMeansQuantizer() {
+ this(0.25f /* cluster distance */);
+ }
+
+ public VariationalKMeansQuantizer(float minClusterDistance) {
+ this(minClusterDistance, 1 /* initializations */);
+ }
+
+ public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
+ mMinClusterSqDistance = minClusterDistance * minClusterDistance;
+ mInitializations = initializations;
+ }
+
+ /**
+ * K-Means quantizer.
+ *
+ * @param pixels Pixels to quantize.
+ * @param maxColors Maximum number of clusters to extract.
+ * @param filters Colors that should be ignored
+ */
+ @Override
+ public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
+ // Start by converting all colors to HSL.
+ // HLS is way more meaningful for clustering than RGB.
+ final float[] hsl = {0, 0, 0};
+ final float[][] hslPixels = new float[pixels.length][3];
+ for (int i = 0; i < pixels.length; i++) {
+ ColorUtils.colorToHSL(pixels[i], hsl);
+ // Normalize hue so all values go from 0 to 1.
+ hslPixels[i][0] = hsl[0] / 360f;
+ hslPixels[i][1] = hsl[1];
+ hslPixels[i][2] = hsl[2];
+ }
+
+ final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
+
+ // Ideally we should run k-means again to merge clusters but it would be too expensive,
+ // instead we just merge all clusters that are closer than a threshold.
+ for (int i = 0; i < optimalMeans.size(); i++) {
+ KMeans.Mean current = optimalMeans.get(i);
+ float[] currentCentroid = current.getCentroid();
+ for (int j = i + 1; j < optimalMeans.size(); j++) {
+ KMeans.Mean compareTo = optimalMeans.get(j);
+ float[] compareToCentroid = compareTo.getCentroid();
+ float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
+ // Merge them
+ if (sqDistance < mMinClusterSqDistance) {
+ optimalMeans.remove(compareTo);
+ current.getItems().addAll(compareTo.getItems());
+ for (int k = 0; k < currentCentroid.length; k++) {
+ currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
+ }
+ j--;
+ }
+ }
+ }
+
+ // Convert data to final format, de-normalizing the hue.
+ mQuantizedColors = new ArrayList<>();
+ for (KMeans.Mean mean : optimalMeans) {
+ if (mean.getItems().size() == 0) {
+ continue;
+ }
+ float[] centroid = mean.getCentroid();
+ mQuantizedColors.add(new Palette.Swatch(new float[]{
+ centroid[0] * 360f,
+ centroid[1],
+ centroid[2]
+ }, mean.getItems().size()));
+ }
+ }
+
+ private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
+ List<KMeans.Mean> optimal = null;
+ double optimalScore = -Double.MAX_VALUE;
+ int runs = mInitializations;
+ while (runs > 0) {
+ if (DEBUG) {
+ Log.d(TAG, "k-means run: " + runs);
+ }
+ List<KMeans.Mean> means = mKMeans.predict(k, inputData);
+ double score = KMeans.score(means);
+ if (optimal == null || score > optimalScore) {
+ if (DEBUG) {
+ Log.d(TAG, "\tnew optimal score: " + score);
+ }
+ optimalScore = score;
+ optimal = means;
+ }
+ runs--;
+ }
+
+ return optimal;
+ }
+
+ @Override
+ public List<Palette.Swatch> getQuantizedColors() {
+ return mQuantizedColors;
+ }
+}
diff --git a/core/java/com/android/internal/ml/clustering/KMeans.java b/core/java/com/android/internal/ml/clustering/KMeans.java
new file mode 100644
index 0000000..4d5b333
--- /dev/null
+++ b/core/java/com/android/internal/ml/clustering/KMeans.java
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.ml.clustering;
+
+import android.annotation.NonNull;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Simple K-Means implementation
+ */
+public class KMeans {
+
+ private static final boolean DEBUG = false;
+ private static final String TAG = "KMeans";
+ private final Random mRandomState;
+ private final int mMaxIterations;
+ private float mSqConvergenceEpsilon;
+
+ public KMeans() {
+ this(new Random());
+ }
+
+ public KMeans(Random random) {
+ this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ }
+ public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
+ mRandomState = random;
+ mMaxIterations = maxIterations;
+ mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ }
+
+ /**
+ * Runs k-means on the input data (X) trying to find k means.
+ *
+ * K-Means is known for getting stuck into local optima, so you might
+ * want to run it multiple time and argmax on {@link KMeans#score(List)}
+ *
+ * @param k The number of points to return.
+ * @param inputData Input data.
+ * @return An array of k Means, each representing a centroid and data points that belong to it.
+ */
+ public List<Mean> predict(final int k, final float[][] inputData) {
+ checkDataSetSanity(inputData);
+ int dimension = inputData[0].length;
+
+ final ArrayList<Mean> means = new ArrayList<>();
+ for (int i = 0; i < k; i++) {
+ Mean m = new Mean(dimension);
+ for (int j = 0; j < dimension; j++) {
+ m.mCentroid[j] = mRandomState.nextFloat();
+ }
+ means.add(m);
+ }
+
+ // Iterate until we converge or run out of iterations
+ boolean converged = false;
+ for (int i = 0; i < mMaxIterations; i++) {
+ converged = step(means, inputData);
+ if (converged) {
+ if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
+ break;
+ }
+ }
+ if (!converged && DEBUG) Log.d(TAG, "Did not converge");
+
+ return means;
+ }
+
+ /**
+ * Score calculates the inertia between means.
+ * This can be considered as an E step of an EM algorithm.
+ *
+ * @param means Means to use when calculating score.
+ * @return The score
+ */
+ public static double score(@NonNull List<Mean> means) {
+ double score = 0;
+ final int meansSize = means.size();
+ for (int i = 0; i < meansSize; i++) {
+ Mean mean = means.get(i);
+ for (int j = 0; j < meansSize; j++) {
+ Mean compareTo = means.get(j);
+ if (mean == compareTo) {
+ continue;
+ }
+ double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
+ score += distance;
+ }
+ }
+ return score;
+ }
+
+ @VisibleForTesting
+ public void checkDataSetSanity(float[][] inputData) {
+ if (inputData == null) {
+ throw new IllegalArgumentException("Data set is null.");
+ } else if (inputData.length == 0) {
+ throw new IllegalArgumentException("Data set is empty.");
+ } else if (inputData[0] == null) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+
+ final int dimension = inputData[0].length;
+ final int length = inputData.length;
+ for (int i = 1; i < length; i++) {
+ if (inputData[i] == null || inputData[i].length != dimension) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+ }
+ }
+
+ /**
+ * K-Means iteration.
+ *
+ * @param means Current means
+ * @param inputData Input data
+ * @return True if data set converged
+ */
+ private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
+
+ // Clean up the previous state because we need to compute
+ // which point belongs to each mean again.
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ mean.mClosestItems.clear();
+ }
+ for (int i = inputData.length - 1; i >= 0; i--) {
+ final float[] current = inputData[i];
+ final Mean nearest = nearestMean(current, means);
+ nearest.mClosestItems.add(current);
+ }
+
+ boolean converged = true;
+ // Move each mean towards the nearest data set points
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ if (mean.mClosestItems.size() == 0) {
+ continue;
+ }
+
+ // Compute the new mean centroid:
+ // 1. Sum all all points
+ // 2. Average them
+ final float[] oldCentroid = mean.mCentroid;
+ mean.mCentroid = new float[oldCentroid.length];
+ for (int j = 0; j < mean.mClosestItems.size(); j++) {
+ // Update each centroid component
+ for (int p = 0; p < mean.mCentroid.length; p++) {
+ mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
+ }
+ }
+ for (int j = 0; j < mean.mCentroid.length; j++) {
+ mean.mCentroid[j] /= mean.mClosestItems.size();
+ }
+
+ // We converged if the centroid didn't move for any of the means.
+ if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
+ converged = false;
+ }
+ }
+ return converged;
+ }
+
+ @VisibleForTesting
+ public static Mean nearestMean(float[] point, List<Mean> means) {
+ Mean nearest = null;
+ float nearestDistance = Float.MAX_VALUE;
+
+ final int meanCount = means.size();
+ for (int i = 0; i < meanCount; i++) {
+ Mean next = means.get(i);
+ // We don't need the sqrt when comparing distances in euclidean space
+ // because they exist on both sides of the equation and cancel each other out.
+ float nextDistance = sqDistance(point, next.mCentroid);
+ if (nextDistance < nearestDistance) {
+ nearest = next;
+ nearestDistance = nextDistance;
+ }
+ }
+ return nearest;
+ }
+
+ @VisibleForTesting
+ public static float sqDistance(float[] a, float[] b) {
+ float dist = 0;
+ final int length = a.length;
+ for (int i = 0; i < length; i++) {
+ dist += (a[i] - b[i]) * (a[i] - b[i]);
+ }
+ return dist;
+ }
+
+ /**
+ * Definition of a mean, contains a centroid and points on its cluster.
+ */
+ public static class Mean {
+ float[] mCentroid;
+ final ArrayList<float[]> mClosestItems = new ArrayList<>();
+
+ public Mean(int dimension) {
+ mCentroid = new float[dimension];
+ }
+
+ public Mean(float ...centroid) {
+ mCentroid = centroid;
+ }
+
+ public float[] getCentroid() {
+ return mCentroid;
+ }
+
+ public List<float[]> getItems() {
+ return mClosestItems;
+ }
+
+ @Override
+ public String toString() {
+ return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
+ + mClosestItems.size() + ")";
+ }
+ }
+}