blob: 540a1ec2bd5a573c9d185cbde49c0e077d3ede6d [file] [log] [blame]
Lucas Dupin1d3c00d52017-06-05 08:40:39 -07001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.internal.ml.clustering;
18
19import static org.junit.Assert.assertEquals;
20import static org.junit.Assert.assertTrue;
21
22import android.annotation.SuppressLint;
Brett Chabot502ec7a2019-03-01 14:43:20 -080023
24import androidx.test.filters.SmallTest;
25import androidx.test.runner.AndroidJUnit4;
Lucas Dupin1d3c00d52017-06-05 08:40:39 -070026
27import org.junit.Assert;
28import org.junit.Before;
29import org.junit.Test;
30import org.junit.runner.RunWith;
31
32import java.util.Arrays;
33import java.util.List;
34import java.util.Random;
35
36@SmallTest
37@RunWith(AndroidJUnit4.class)
38public class KMeansTest {
39
40 // Error tolerance (epsilon)
41 private static final double EPS = 0.01;
42
43 private KMeans mKMeans;
44
45 @Before
46 public void setUp() {
47 // Setup with a random seed to have predictable results
48 mKMeans = new KMeans(new Random(0), 30, 0);
49 }
50
51 @Test
52 public void getCheckDataSanityTest() {
53 try {
54 mKMeans.checkDataSetSanity(new float[][] {
55 {0, 1, 2},
56 {1, 2, 3}
57 });
58 } catch (IllegalArgumentException e) {
59 Assert.fail("Valid data didn't pass sanity check");
60 }
61
62 try {
63 mKMeans.checkDataSetSanity(new float[][] {
64 null,
65 {1, 2, 3}
66 });
67 Assert.fail("Data has null items and passed");
68 } catch (IllegalArgumentException e) {}
69
70 try {
71 mKMeans.checkDataSetSanity(new float[][] {
72 {0, 1, 2, 4},
73 {1, 2, 3}
74 });
75 Assert.fail("Data has invalid shape and passed");
76 } catch (IllegalArgumentException e) {}
77
78 try {
79 mKMeans.checkDataSetSanity(null);
80 Assert.fail("Null data should throw exception");
81 } catch (IllegalArgumentException e) {}
82 }
83
84 @Test
85 public void sqDistanceTest() {
86 float a[] = {4, 10};
87 float b[] = {5, 2};
88 float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
89
90 assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
91 }
92
93 @Test
94 public void nearestMeanTest() {
95 KMeans.Mean meanA = new KMeans.Mean(0, 1);
96 KMeans.Mean meanB = new KMeans.Mean(1, 1);
97 List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
98
99 KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
100
101 assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
102 }
103
104 @SuppressLint("DefaultLocale")
105 @Test
106 public void scoreTest() {
107 List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
108 new KMeans.Mean(0, 0.1f, 0.15f),
109 new KMeans.Mean(0.1f, 0.2f, 0.1f));
110 List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
111 new KMeans.Mean(0, 0.5f, 0.5f),
112 new KMeans.Mean(1, 0.9f, 0.9f));
113
114 double closeScore = KMeans.score(closeMeans);
115 double farScore = KMeans.score(farMeans);
116 assertTrue(String.format("Score of well distributed means should be greater than "
117 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
118 }
119
120 @Test
121 public void predictTest() {
122 float[] expectedCentroid1 = {1, 1, 1};
123 float[] expectedCentroid2 = {0, 0, 0};
124 float[][] X = new float[][] {
125 {1, 1, 1},
126 {1, 1, 1},
127 {1, 1, 1},
128 {0, 0, 0},
129 {0, 0, 0},
130 {0, 0, 0},
131 };
132
133 final int numClusters = 2;
134
135 // Here we assume that we won't get stuck into a local optima.
136 // It's fine because we're seeding a random, we won't ever have
137 // unstable results but in real life we need multiple initialization
138 // and score comparison
139 List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
140
141 assertEquals("Expected number of clusters is invalid", numClusters, means.size());
142
143 boolean exists1 = false, exists2 = false;
144 for (KMeans.Mean mean : means) {
145 if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
146 exists1 = true;
147 } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
148 exists2 = true;
149 } else {
150 throw new AssertionError("Unexpected mean: " + mean);
151 }
152 }
153 assertTrue("Expected means were not predicted, got: " + means,
154 exists1 && exists2);
155 }
156}