blob: 7224dedc17fd27eb0cd93d747052d4c066031b36 [file] [log] [blame]
Yang Li35aa84b2009-05-18 18:29:05 -07001/*
2 * Copyright (C) 2008-2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Romain Guydb567c32009-05-21 16:23:21 -070017package android.gesture;
Yang Li35aa84b2009-05-18 18:29:05 -070018
Yang Li35aa84b2009-05-18 18:29:05 -070019import java.util.ArrayList;
20import java.util.Collections;
21import java.util.Comparator;
Yang Li35aa84b2009-05-18 18:29:05 -070022import java.util.TreeMap;
23
24/**
25 * An implementation of an instance-based learner
26 */
27
28class InstanceLearner extends Learner {
Romain Guy0e1ca572009-06-09 12:56:34 -070029 private static final Comparator<Prediction> sComparator = new Comparator<Prediction>() {
30 public int compare(Prediction object1, Prediction object2) {
31 double score1 = object1.score;
32 double score2 = object2.score;
33 if (score1 > score2) {
34 return -1;
35 } else if (score1 < score2) {
36 return 1;
37 } else {
38 return 0;
39 }
40 }
41 };
42
Yang Li35aa84b2009-05-18 18:29:05 -070043 @Override
Yang Li4758f122009-12-14 15:41:07 -080044 ArrayList<Prediction> classify(int sequenceType, int orientationType, float[] vector) {
Yang Li35aa84b2009-05-18 18:29:05 -070045 ArrayList<Prediction> predictions = new ArrayList<Prediction>();
46 ArrayList<Instance> instances = getInstances();
47 int count = instances.size();
48 TreeMap<String, Double> label2score = new TreeMap<String, Double>();
49 for (int i = 0; i < count; i++) {
50 Instance sample = instances.get(i);
Yang Lie6ea0032009-05-21 14:47:59 -070051 if (sample.vector.length != vector.length) {
Yang Li35aa84b2009-05-18 18:29:05 -070052 continue;
53 }
54 double distance;
Romain Guy0a637162009-05-29 14:43:54 -070055 if (sequenceType == GestureStore.SEQUENCE_SENSITIVE) {
Romain Guy46c53122010-02-04 14:19:50 -080056 distance = GestureUtils.minimumCosineDistance(sample.vector, vector, orientationType);
Yang Li35aa84b2009-05-18 18:29:05 -070057 } else {
Romain Guy46c53122010-02-04 14:19:50 -080058 distance = GestureUtils.squaredEuclideanDistance(sample.vector, vector);
Yang Li35aa84b2009-05-18 18:29:05 -070059 }
60 double weight;
61 if (distance == 0) {
62 weight = Double.MAX_VALUE;
63 } else {
64 weight = 1 / distance;
65 }
66 Double score = label2score.get(sample.label);
67 if (score == null || weight > score) {
68 label2score.put(sample.label, weight);
69 }
70 }
71
Romain Guy0e1ca572009-06-09 12:56:34 -070072// double sum = 0;
Romain Guydb567c32009-05-21 16:23:21 -070073 for (String name : label2score.keySet()) {
Yang Li35aa84b2009-05-18 18:29:05 -070074 double score = label2score.get(name);
Romain Guy0e1ca572009-06-09 12:56:34 -070075// sum += score;
Yang Li35aa84b2009-05-18 18:29:05 -070076 predictions.add(new Prediction(name, score));
77 }
78
79 // normalize
Romain Guy0e1ca572009-06-09 12:56:34 -070080// for (Prediction prediction : predictions) {
81// prediction.score /= sum;
82// }
Yang Li35aa84b2009-05-18 18:29:05 -070083
Romain Guy0e1ca572009-06-09 12:56:34 -070084 Collections.sort(predictions, sComparator);
Yang Li35aa84b2009-05-18 18:29:05 -070085
Yang Li35aa84b2009-05-18 18:29:05 -070086 return predictions;
87 }
88}