blob: a64f8a60d485cb9aee9981c4a8ac26b7332d958c [file] [log] [blame]
Lucas Dupin1d3c00d52017-06-05 08:40:39 -07001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.internal.ml.clustering;
18
19import static org.junit.Assert.assertEquals;
20import static org.junit.Assert.assertTrue;
21
22import android.annotation.SuppressLint;
23import android.support.test.filters.SmallTest;
24import android.support.test.runner.AndroidJUnit4;
25
26import org.junit.Assert;
27import org.junit.Before;
28import org.junit.Test;
29import org.junit.runner.RunWith;
30
31import java.util.Arrays;
32import java.util.List;
33import java.util.Random;
34
35@SmallTest
36@RunWith(AndroidJUnit4.class)
37public class KMeansTest {
38
39 // Error tolerance (epsilon)
40 private static final double EPS = 0.01;
41
42 private KMeans mKMeans;
43
44 @Before
45 public void setUp() {
46 // Setup with a random seed to have predictable results
47 mKMeans = new KMeans(new Random(0), 30, 0);
48 }
49
50 @Test
51 public void getCheckDataSanityTest() {
52 try {
53 mKMeans.checkDataSetSanity(new float[][] {
54 {0, 1, 2},
55 {1, 2, 3}
56 });
57 } catch (IllegalArgumentException e) {
58 Assert.fail("Valid data didn't pass sanity check");
59 }
60
61 try {
62 mKMeans.checkDataSetSanity(new float[][] {
63 null,
64 {1, 2, 3}
65 });
66 Assert.fail("Data has null items and passed");
67 } catch (IllegalArgumentException e) {}
68
69 try {
70 mKMeans.checkDataSetSanity(new float[][] {
71 {0, 1, 2, 4},
72 {1, 2, 3}
73 });
74 Assert.fail("Data has invalid shape and passed");
75 } catch (IllegalArgumentException e) {}
76
77 try {
78 mKMeans.checkDataSetSanity(null);
79 Assert.fail("Null data should throw exception");
80 } catch (IllegalArgumentException e) {}
81 }
82
83 @Test
84 public void sqDistanceTest() {
85 float a[] = {4, 10};
86 float b[] = {5, 2};
87 float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
88
89 assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
90 }
91
92 @Test
93 public void nearestMeanTest() {
94 KMeans.Mean meanA = new KMeans.Mean(0, 1);
95 KMeans.Mean meanB = new KMeans.Mean(1, 1);
96 List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
97
98 KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
99
100 assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
101 }
102
103 @SuppressLint("DefaultLocale")
104 @Test
105 public void scoreTest() {
106 List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
107 new KMeans.Mean(0, 0.1f, 0.15f),
108 new KMeans.Mean(0.1f, 0.2f, 0.1f));
109 List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
110 new KMeans.Mean(0, 0.5f, 0.5f),
111 new KMeans.Mean(1, 0.9f, 0.9f));
112
113 double closeScore = KMeans.score(closeMeans);
114 double farScore = KMeans.score(farMeans);
115 assertTrue(String.format("Score of well distributed means should be greater than "
116 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
117 }
118
119 @Test
120 public void predictTest() {
121 float[] expectedCentroid1 = {1, 1, 1};
122 float[] expectedCentroid2 = {0, 0, 0};
123 float[][] X = new float[][] {
124 {1, 1, 1},
125 {1, 1, 1},
126 {1, 1, 1},
127 {0, 0, 0},
128 {0, 0, 0},
129 {0, 0, 0},
130 };
131
132 final int numClusters = 2;
133
134 // Here we assume that we won't get stuck into a local optima.
135 // It's fine because we're seeding a random, we won't ever have
136 // unstable results but in real life we need multiple initialization
137 // and score comparison
138 List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
139
140 assertEquals("Expected number of clusters is invalid", numClusters, means.size());
141
142 boolean exists1 = false, exists2 = false;
143 for (KMeans.Mean mean : means) {
144 if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
145 exists1 = true;
146 } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
147 exists2 = true;
148 } else {
149 throw new AssertionError("Unexpected mean: " + mean);
150 }
151 }
152 assertTrue("Expected means were not predicted, got: " + means,
153 exists1 && exists2);
154 }
155}