K-Means color clustering

Test: runtest -x tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
Bug: 37014702
Change-Id: Idabc163df5ded362acbe462ae6b351394a36db10
diff --git a/core/java/android/app/WallpaperColors.java b/core/java/android/app/WallpaperColors.java
index 23e9ca5..8f172ba 100644
--- a/core/java/android/app/WallpaperColors.java
+++ b/core/java/android/app/WallpaperColors.java
@@ -27,6 +27,7 @@
 import android.util.Size;
 
 import com.android.internal.graphics.palette.Palette;
+import com.android.internal.graphics.palette.VariationalKMeansQuantizer;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -142,6 +143,8 @@
 
         final Palette palette = Palette
                 .from(bitmap)
+                .setQuantizer(new VariationalKMeansQuantizer())
+                .maximumColorCount(5)
                 .clearFilters()
                 .resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA)
                 .generate();
diff --git a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
index 56d60a1..9ac753b 100644
--- a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
+++ b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java
@@ -61,7 +61,7 @@
  * This means that the color space is divided into distinct colors, rather than representative
  * colors.
  */
-final class ColorCutQuantizer {
+final class ColorCutQuantizer implements Quantizer {
 
     private static final String LOG_TAG = "ColorCutQuantizer";
     private static final boolean LOG_TIMINGS = false;
@@ -73,22 +73,22 @@
     private static final int QUANTIZE_WORD_WIDTH = 5;
     private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1;
 
-    final int[] mColors;
-    final int[] mHistogram;
-    final List<Swatch> mQuantizedColors;
-    final TimingLogger mTimingLogger;
-    final Palette.Filter[] mFilters;
+    int[] mColors;
+    int[] mHistogram;
+    List<Swatch> mQuantizedColors;
+    TimingLogger mTimingLogger;
+    Palette.Filter[] mFilters;
 
     private final float[] mTempHsl = new float[3];
 
     /**
-     * Constructor.
+     * Execute color quantization.
      *
      * @param pixels histogram representing an image's pixel data
      * @param maxColors The maximum number of colors that should be in the result palette.
      * @param filters Set of filters to use in the quantization stage
      */
-    ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
+    public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
         mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null;
         mFilters = filters;
 
@@ -160,7 +160,7 @@
     /**
      * @return the list of quantized colors
      */
-    List<Swatch> getQuantizedColors() {
+    public List<Swatch> getQuantizedColors() {
         return mQuantizedColors;
     }
 
diff --git a/core/java/com/android/internal/graphics/palette/Palette.java b/core/java/com/android/internal/graphics/palette/Palette.java
index 9f1504a..a4f9a59 100644
--- a/core/java/com/android/internal/graphics/palette/Palette.java
+++ b/core/java/com/android/internal/graphics/palette/Palette.java
@@ -613,6 +613,8 @@
         private final List<Palette.Filter> mFilters = new ArrayList<>();
         private Rect mRegion;
 
+        private Quantizer mQuantizer;
+
         /**
          * Construct a new {@link Palette.Builder} using a source {@link Bitmap}
          */
@@ -726,6 +728,18 @@
         }
 
         /**
+         * Set a specific quantization algorithm. {@link ColorCutQuantizer} will
+         * be used if unspecified.
+         *
+         * @param quantizer Quantizer implementation.
+         */
+        @NonNull
+        public Palette.Builder setQuantizer(Quantizer quantizer) {
+            mQuantizer = quantizer;
+            return this;
+        }
+
+        /**
          * Set a region of the bitmap to be used exclusively when calculating the palette.
          * <p>This only works when the original input is a {@link Bitmap}.</p>
          *
@@ -818,17 +832,19 @@
                 }
 
                 // Now generate a quantizer from the Bitmap
-                final ColorCutQuantizer quantizer = new ColorCutQuantizer(
-                        getPixelsFromBitmap(bitmap),
-                        mMaxColors,
-                        mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()]));
+                if (mQuantizer == null) {
+                    mQuantizer = new ColorCutQuantizer();
+                }
+                mQuantizer.quantize(getPixelsFromBitmap(bitmap),
+                            mMaxColors, mFilters.isEmpty() ? null :
+                            mFilters.toArray(new Palette.Filter[mFilters.size()]));
 
                 // If created a new bitmap, recycle it
                 if (bitmap != mBitmap) {
                     bitmap.recycle();
                 }
 
-                swatches = quantizer.getQuantizedColors();
+                swatches = mQuantizer.getQuantizedColors();
 
                 if (logger != null) {
                     logger.addSplit("Color quantization completed");
diff --git a/core/java/com/android/internal/graphics/palette/Quantizer.java b/core/java/com/android/internal/graphics/palette/Quantizer.java
new file mode 100644
index 0000000..db60f2e
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/Quantizer.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import java.util.List;
+
+/**
+ * Definition of an algorithm that receives pixels and outputs a list of colors.
+ */
+public interface Quantizer {
+    void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters);
+    List<Palette.Swatch> getQuantizedColors();
+}
diff --git a/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
new file mode 100644
index 0000000..b035535
--- /dev/null
+++ b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java
@@ -0,0 +1,154 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.graphics.palette;
+
+import android.util.Log;
+
+import com.android.internal.graphics.ColorUtils;
+import com.android.internal.ml.clustering.KMeans;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * A quantizer that uses k-means
+ */
+public class VariationalKMeansQuantizer implements Quantizer {
+
+    private static final String TAG = "KMeansQuantizer";
+    private static final boolean DEBUG = false;
+
+    /**
+     * Clusters closer than this value will me merged.
+     */
+    private final float mMinClusterSqDistance;
+
+    /**
+     * K-means can get stuck in local optima, this can be avoided by
+     * repeating it and getting the "best" execution.
+     */
+    private final int mInitializations;
+
+    /**
+     * Initialize KMeans with a fixed random state to have
+     * consistent results across multiple runs.
+     */
+    private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
+
+    private List<Palette.Swatch> mQuantizedColors;
+
+    public VariationalKMeansQuantizer() {
+        this(0.25f /* cluster distance */);
+    }
+
+    public VariationalKMeansQuantizer(float minClusterDistance) {
+        this(minClusterDistance, 1 /* initializations */);
+    }
+
+    public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
+        mMinClusterSqDistance = minClusterDistance * minClusterDistance;
+        mInitializations = initializations;
+    }
+
+    /**
+     * K-Means quantizer.
+     *
+     * @param pixels Pixels to quantize.
+     * @param maxColors Maximum number of clusters to extract.
+     * @param filters Colors that should be ignored
+     */
+    @Override
+    public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
+        // Start by converting all colors to HSL.
+        // HLS is way more meaningful for clustering than RGB.
+        final float[] hsl = {0, 0, 0};
+        final float[][] hslPixels = new float[pixels.length][3];
+        for (int i = 0; i < pixels.length; i++) {
+            ColorUtils.colorToHSL(pixels[i], hsl);
+            // Normalize hue so all values go from 0 to 1.
+            hslPixels[i][0] = hsl[0] / 360f;
+            hslPixels[i][1] = hsl[1];
+            hslPixels[i][2] = hsl[2];
+        }
+
+        final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
+
+        // Ideally we should run k-means again to merge clusters but it would be too expensive,
+        // instead we just merge all clusters that are closer than a threshold.
+        for (int i = 0; i < optimalMeans.size(); i++) {
+            KMeans.Mean current = optimalMeans.get(i);
+            float[] currentCentroid = current.getCentroid();
+            for (int j = i + 1; j < optimalMeans.size(); j++) {
+                KMeans.Mean compareTo = optimalMeans.get(j);
+                float[] compareToCentroid = compareTo.getCentroid();
+                float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
+                // Merge them
+                if (sqDistance < mMinClusterSqDistance) {
+                    optimalMeans.remove(compareTo);
+                    current.getItems().addAll(compareTo.getItems());
+                    for (int k = 0; k < currentCentroid.length; k++) {
+                        currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
+                    }
+                    j--;
+                }
+            }
+        }
+
+        // Convert data to final format, de-normalizing the hue.
+        mQuantizedColors = new ArrayList<>();
+        for (KMeans.Mean mean : optimalMeans) {
+            if (mean.getItems().size() == 0) {
+                continue;
+            }
+            float[] centroid = mean.getCentroid();
+            mQuantizedColors.add(new Palette.Swatch(new float[]{
+                    centroid[0] * 360f,
+                    centroid[1],
+                    centroid[2]
+            }, mean.getItems().size()));
+        }
+    }
+
+    private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
+        List<KMeans.Mean> optimal = null;
+        double optimalScore = -Double.MAX_VALUE;
+        int runs = mInitializations;
+        while (runs > 0) {
+            if (DEBUG) {
+                Log.d(TAG, "k-means run: " + runs);
+            }
+            List<KMeans.Mean> means = mKMeans.predict(k, inputData);
+            double score = KMeans.score(means);
+            if (optimal == null || score > optimalScore) {
+                if (DEBUG) {
+                    Log.d(TAG, "\tnew optimal score: " + score);
+                }
+                optimalScore = score;
+                optimal = means;
+            }
+            runs--;
+        }
+
+        return optimal;
+    }
+
+    @Override
+    public List<Palette.Swatch> getQuantizedColors() {
+        return mQuantizedColors;
+    }
+}
diff --git a/core/java/com/android/internal/ml/clustering/KMeans.java b/core/java/com/android/internal/ml/clustering/KMeans.java
new file mode 100644
index 0000000..4d5b333
--- /dev/null
+++ b/core/java/com/android/internal/ml/clustering/KMeans.java
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.internal.ml.clustering;
+
+import android.annotation.NonNull;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Simple K-Means implementation
+ */
+public class KMeans {
+
+    private static final boolean DEBUG = false;
+    private static final String TAG = "KMeans";
+    private final Random mRandomState;
+    private final int mMaxIterations;
+    private float mSqConvergenceEpsilon;
+
+    public KMeans() {
+        this(new Random());
+    }
+
+    public KMeans(Random random) {
+        this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+    }
+    public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
+        mRandomState = random;
+        mMaxIterations = maxIterations;
+        mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+    }
+
+    /**
+     * Runs k-means on the input data (X) trying to find k means.
+     *
+     * K-Means is known for getting stuck into local optima, so you might
+     * want to run it multiple time and argmax on {@link KMeans#score(List)}
+     *
+     * @param k The number of points to return.
+     * @param inputData Input data.
+     * @return An array of k Means, each representing a centroid and data points that belong to it.
+     */
+    public List<Mean> predict(final int k, final float[][] inputData) {
+        checkDataSetSanity(inputData);
+        int dimension = inputData[0].length;
+
+        final ArrayList<Mean> means = new ArrayList<>();
+        for (int i = 0; i < k; i++) {
+            Mean m = new Mean(dimension);
+            for (int j = 0; j < dimension; j++) {
+                m.mCentroid[j] = mRandomState.nextFloat();
+            }
+            means.add(m);
+        }
+
+        // Iterate until we converge or run out of iterations
+        boolean converged = false;
+        for (int i = 0; i < mMaxIterations; i++) {
+            converged = step(means, inputData);
+            if (converged) {
+                if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
+                break;
+            }
+        }
+        if (!converged && DEBUG) Log.d(TAG, "Did not converge");
+
+        return means;
+    }
+
+    /**
+     * Score calculates the inertia between means.
+     * This can be considered as an E step of an EM algorithm.
+     *
+     * @param means Means to use when calculating score.
+     * @return The score
+     */
+    public static double score(@NonNull List<Mean> means) {
+        double score = 0;
+        final int meansSize = means.size();
+        for (int i = 0; i < meansSize; i++) {
+            Mean mean = means.get(i);
+            for (int j = 0; j < meansSize; j++) {
+                Mean compareTo = means.get(j);
+                if (mean == compareTo) {
+                    continue;
+                }
+                double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
+                score += distance;
+            }
+        }
+        return score;
+    }
+
+    @VisibleForTesting
+    public void checkDataSetSanity(float[][] inputData) {
+        if (inputData == null) {
+            throw new IllegalArgumentException("Data set is null.");
+        } else if (inputData.length == 0) {
+            throw new IllegalArgumentException("Data set is empty.");
+        } else if (inputData[0] == null) {
+            throw new IllegalArgumentException("Bad data set format.");
+        }
+
+        final int dimension = inputData[0].length;
+        final int length = inputData.length;
+        for (int i = 1; i < length; i++) {
+            if (inputData[i] == null || inputData[i].length != dimension) {
+                throw new IllegalArgumentException("Bad data set format.");
+            }
+        }
+    }
+
+    /**
+     * K-Means iteration.
+     *
+     * @param means Current means
+     * @param inputData Input data
+     * @return True if data set converged
+     */
+    private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
+
+        // Clean up the previous state because we need to compute
+        // which point belongs to each mean again.
+        for (int i = means.size() - 1; i >= 0; i--) {
+            final Mean mean = means.get(i);
+            mean.mClosestItems.clear();
+        }
+        for (int i = inputData.length - 1; i >= 0; i--) {
+            final float[] current = inputData[i];
+            final Mean nearest = nearestMean(current, means);
+            nearest.mClosestItems.add(current);
+        }
+
+        boolean converged = true;
+        // Move each mean towards the nearest data set points
+        for (int i = means.size() - 1; i >= 0; i--) {
+            final Mean mean = means.get(i);
+            if (mean.mClosestItems.size() == 0) {
+                continue;
+            }
+
+            // Compute the new mean centroid:
+            //   1. Sum all all points
+            //   2. Average them
+            final float[] oldCentroid = mean.mCentroid;
+            mean.mCentroid = new float[oldCentroid.length];
+            for (int j = 0; j < mean.mClosestItems.size(); j++) {
+                // Update each centroid component
+                for (int p = 0; p < mean.mCentroid.length; p++) {
+                    mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
+                }
+            }
+            for (int j = 0; j < mean.mCentroid.length; j++) {
+                mean.mCentroid[j] /= mean.mClosestItems.size();
+            }
+
+            // We converged if the centroid didn't move for any of the means.
+            if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
+                converged = false;
+            }
+        }
+        return converged;
+    }
+
+    @VisibleForTesting
+    public static Mean nearestMean(float[] point, List<Mean> means) {
+        Mean nearest = null;
+        float nearestDistance = Float.MAX_VALUE;
+
+        final int meanCount = means.size();
+        for (int i = 0; i < meanCount; i++) {
+            Mean next = means.get(i);
+            // We don't need the sqrt when comparing distances in euclidean space
+            // because they exist on both sides of the equation and cancel each other out.
+            float nextDistance = sqDistance(point, next.mCentroid);
+            if (nextDistance < nearestDistance) {
+                nearest = next;
+                nearestDistance = nextDistance;
+            }
+        }
+        return nearest;
+    }
+
+    @VisibleForTesting
+    public static float sqDistance(float[] a, float[] b) {
+        float dist = 0;
+        final int length = a.length;
+        for (int i = 0; i < length; i++) {
+            dist += (a[i] - b[i]) * (a[i] - b[i]);
+        }
+        return dist;
+    }
+
+    /**
+     * Definition of a mean, contains a centroid and points on its cluster.
+     */
+    public static class Mean {
+        float[] mCentroid;
+        final ArrayList<float[]> mClosestItems = new ArrayList<>();
+
+        public Mean(int dimension) {
+            mCentroid = new float[dimension];
+        }
+
+        public Mean(float ...centroid) {
+            mCentroid = centroid;
+        }
+
+        public float[] getCentroid() {
+            return mCentroid;
+        }
+
+        public List<float[]> getItems() {
+            return mClosestItems;
+        }
+
+        @Override
+        public String toString() {
+            return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
+                    + mClosestItems.size() + ")";
+        }
+    }
+}