Updated LetterRecognizer & related gesture recognition code
- added personalization for letter recognizer
diff --git a/tests/sketch/src/com/android/gesture/GestureLibrary.java b/tests/sketch/src/com/android/gesture/GestureLibrary.java
index 3e753e7..915b840 100644
--- a/tests/sketch/src/com/android/gesture/GestureLibrary.java
+++ b/tests/sketch/src/com/android/gesture/GestureLibrary.java
@@ -49,11 +49,11 @@
private static final String NAMESPACE = "";
public static final int SEQUENCE_INVARIANT = 1;
- // when SEQUENCE_SENSITIVE is used, only single stroke gestures are allowed
+ // when SEQUENCE_SENSITIVE is used, only single stroke gestures are currently allowed
public static final int SEQUENCE_SENSITIVE = 2;
+ // ORIENTATION_SENSITIVE and ORIENTATION_INVARIANT are only for SEQUENCE_SENSITIVE gestures
public static final int ORIENTATION_INVARIANT = 1;
- // ORIENTATION_SENSITIVE is only available for single stroke gestures
public static final int ORIENTATION_SENSITIVE = 2;
private int mSequenceType = SEQUENCE_SENSITIVE;
@@ -77,8 +77,8 @@
}
/**
- * Specify whether the gesture library will handle orientation sensitive
- * gestures. Use ORIENTATION_INVARIANT or ORIENTATION_SENSITIVE
+ * Specify how the gesture library will handle orientation.
+ * Use ORIENTATION_INVARIANT or ORIENTATION_SENSITIVE
*
* @param style
*/
@@ -114,8 +114,8 @@
* @return a list of predictions of possible entries for a given gesture
*/
public ArrayList<Prediction> recognize(Gesture gesture) {
- Instance instance = Instance.createInstance(this, gesture, null);
- return mClassifier.classify(this, instance);
+ Instance instance = Instance.createInstance(mSequenceType, gesture, null);
+ return mClassifier.classify(mSequenceType, instance.vector);
}
/**
@@ -134,7 +134,7 @@
mEntryName2gestures.put(entryName, gestures);
}
gestures.add(gesture);
- mClassifier.addInstance(Instance.createInstance(this, gesture, entryName));
+ mClassifier.addInstance(Instance.createInstance(mSequenceType, gesture, entryName));
mChanged = true;
}
@@ -300,7 +300,7 @@
mGestures = null;
} else if (localName.equals(GestureConstants.XML_TAG_GESTURE)) {
mGestures.add(mCurrentGesture);
- mClassifier.addInstance(Instance.createInstance(GestureLibrary.this,
+ mClassifier.addInstance(Instance.createInstance(mSequenceType,
mCurrentGesture, mEntryName));
mCurrentGesture = null;
} else if (localName.equals(GestureConstants.XML_TAG_STROKE)) {
diff --git a/tests/sketch/src/com/android/gesture/GestureStroke.java b/tests/sketch/src/com/android/gesture/GestureStroke.java
index 3555010..c2ebc17 100644
--- a/tests/sketch/src/com/android/gesture/GestureStroke.java
+++ b/tests/sketch/src/com/android/gesture/GestureStroke.java
@@ -244,4 +244,12 @@
public void invalidate() {
mCachedPath = null;
}
+
+ /**
+ * Compute an oriented bounding box of the stroke
+ * @return OrientedBoundingBox
+ */
+ public OrientedBoundingBox computeOrientedBoundingBox() {
+ return GestureUtilities.computeOrientedBoundingBox(points);
+ }
}
diff --git a/tests/sketch/src/com/android/gesture/GestureUtilities.java b/tests/sketch/src/com/android/gesture/GestureUtilities.java
index 2798616..92de987 100755
--- a/tests/sketch/src/com/android/gesture/GestureUtilities.java
+++ b/tests/sketch/src/com/android/gesture/GestureUtilities.java
@@ -26,7 +26,7 @@
import static com.android.gesture.GestureConstants.*;
-public final class GestureUtilities {
+final class GestureUtilities {
private static final int TEMPORAL_SAMPLING_RATE = 16;
private GestureUtilities() {
@@ -348,33 +348,31 @@
/**
* Calculate the cosine distance between two instances
*
- * @param in1
- * @param in2
+ * @param vector1
+ * @param vector2
* @return the distance between 0 and Math.PI
*/
- static double cosineDistance(Instance in1, Instance in2) {
+ static double cosineDistance(float[] vector1, float[] vector2) {
float sum = 0;
- float[] vector1 = in1.vector;
- float[] vector2 = in2.vector;
int len = vector1.length;
for (int i = 0; i < len; i++) {
sum += vector1[i] * vector2[i];
}
- return Math.acos(sum / (in1.magnitude * in2.magnitude));
+ return Math.acos(sum);
}
- public static OrientedBoundingBox computeOrientedBoundingBox(ArrayList<GesturePoint> pts) {
+ static OrientedBoundingBox computeOrientedBoundingBox(ArrayList<GesturePoint> pts) {
GestureStroke stroke = new GestureStroke(pts);
float[] points = temporalSampling(stroke, TEMPORAL_SAMPLING_RATE);
return computeOrientedBoundingBox(points);
}
- public static OrientedBoundingBox computeOrientedBoundingBox(float[] points) {
+ static OrientedBoundingBox computeOrientedBoundingBox(float[] points) {
float[] meanVector = computeCentroid(points);
return computeOrientedBoundingBox(points, meanVector);
}
- public static OrientedBoundingBox computeOrientedBoundingBox(float[] points, float[] centroid) {
+ static OrientedBoundingBox computeOrientedBoundingBox(float[] points, float[] centroid) {
Matrix tr = new Matrix();
tr.setTranslate(-centroid[0], -centroid[1]);
tr.mapPoints(points);
diff --git a/tests/sketch/src/com/android/gesture/Instance.java b/tests/sketch/src/com/android/gesture/Instance.java
index 011d1fc..b2e030e 100755
--- a/tests/sketch/src/com/android/gesture/Instance.java
+++ b/tests/sketch/src/com/android/gesture/Instance.java
@@ -23,7 +23,7 @@
class Instance {
private static final int SEQUENCE_SAMPLE_SIZE = 16;
- private static final int PATCH_SAMPLE_SIZE = 8;
+ private static final int PATCH_SAMPLE_SIZE = 16;
private final static float[] ORIENTATIONS = {
0, 45, 90, 135, 180, -0, -45, -90, -135, -180
@@ -35,22 +35,26 @@
// the label can be null
final String label;
- // the length of the vector
- final float magnitude;
-
// the id of the instance
final long id;
-
+
private Instance(long id, float[] sample, String sampleName) {
this.id = id;
vector = sample;
label = sampleName;
+ }
+
+ private void normalize() {
+ float[] sample = vector;
float sum = 0;
int size = sample.length;
for (int i = 0; i < size; i++) {
sum += sample[i] * sample[i];
}
- magnitude = (float) Math.sqrt(sum);
+ float magnitude = (float) Math.sqrt(sum);
+ for (int i = 0; i < size; i++) {
+ sample[i] /= magnitude;
+ }
}
/**
@@ -60,21 +64,25 @@
* @param label
* @return the instance
*/
- static Instance createInstance(GestureLibrary gesturelib, Gesture gesture, String label) {
+ static Instance createInstance(int samplingType, Gesture gesture, String label) {
float[] pts;
- if (gesturelib.getGestureType() == GestureLibrary.SEQUENCE_SENSITIVE) {
- pts = temporalSampler(gesturelib, gesture);
+ Instance instance;
+ if (samplingType == GestureLibrary.SEQUENCE_SENSITIVE) {
+ pts = temporalSampler(samplingType, gesture);
+ instance = new Instance(gesture.getID(), pts, label);
+ instance.normalize();
} else {
pts = spatialSampler(gesture);
+ instance = new Instance(gesture.getID(), pts, label);
}
- return new Instance(gesture.getID(), pts, label);
+ return instance;
}
-
+
private static float[] spatialSampler(Gesture gesture) {
return GestureUtilities.spatialSampling(gesture, PATCH_SAMPLE_SIZE);
}
- private static float[] temporalSampler(GestureLibrary gesturelib, Gesture gesture) {
+ private static float[] temporalSampler(int samplingType, Gesture gesture) {
float[] pts = GestureUtilities.temporalSampling(gesture.getStrokes().get(0),
SEQUENCE_SAMPLE_SIZE);
float[] center = GestureUtilities.computeCentroid(pts);
@@ -82,7 +90,7 @@
orientation *= 180 / Math.PI;
float adjustment = -orientation;
- if (gesturelib.getOrientationStyle() == GestureLibrary.ORIENTATION_SENSITIVE) {
+ if (samplingType == GestureLibrary.ORIENTATION_SENSITIVE) {
int count = ORIENTATIONS.length;
for (int i = 0; i < count; i++) {
float delta = ORIENTATIONS[i] - orientation;
diff --git a/tests/sketch/src/com/android/gesture/InstanceLearner.java b/tests/sketch/src/com/android/gesture/InstanceLearner.java
index 335719a..4495256 100644
--- a/tests/sketch/src/com/android/gesture/InstanceLearner.java
+++ b/tests/sketch/src/com/android/gesture/InstanceLearner.java
@@ -34,21 +34,21 @@
private static final String LOGTAG = "InstanceLearner";
@Override
- ArrayList<Prediction> classify(GestureLibrary lib, Instance instance) {
+ ArrayList<Prediction> classify(int gestureType, float[] vector) {
ArrayList<Prediction> predictions = new ArrayList<Prediction>();
ArrayList<Instance> instances = getInstances();
int count = instances.size();
TreeMap<String, Double> label2score = new TreeMap<String, Double>();
for (int i = 0; i < count; i++) {
Instance sample = instances.get(i);
- if (sample.vector.length != instance.vector.length) {
+ if (sample.vector.length != vector.length) {
continue;
}
double distance;
- if (lib.getGestureType() == GestureLibrary.SEQUENCE_SENSITIVE) {
- distance = GestureUtilities.cosineDistance(sample, instance);
+ if (gestureType == GestureLibrary.SEQUENCE_SENSITIVE) {
+ distance = GestureUtilities.cosineDistance(sample.vector, vector);
} else {
- distance = GestureUtilities.squaredEuclideanDistance(sample.vector, instance.vector);
+ distance = GestureUtilities.squaredEuclideanDistance(sample.vector, vector);
}
double weight;
if (distance == 0) {
diff --git a/tests/sketch/src/com/android/gesture/Learner.java b/tests/sketch/src/com/android/gesture/Learner.java
index b4183d2..15b2053 100755
--- a/tests/sketch/src/com/android/gesture/Learner.java
+++ b/tests/sketch/src/com/android/gesture/Learner.java
@@ -79,5 +79,5 @@
instances.removeAll(toDelete);
}
- abstract ArrayList<Prediction> classify(GestureLibrary library, Instance instance);
+ abstract ArrayList<Prediction> classify(int gestureType, float[] vector);
}
diff --git a/tests/sketch/src/com/android/gesture/LetterRecognizer.java b/tests/sketch/src/com/android/gesture/LetterRecognizer.java
index 73151de5..086aedf 100644
--- a/tests/sketch/src/com/android/gesture/LetterRecognizer.java
+++ b/tests/sketch/src/com/android/gesture/LetterRecognizer.java
@@ -20,12 +20,14 @@
import android.content.res.Resources;
import android.util.Log;
-import java.io.IOException;
-import java.io.DataInputStream;
import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
+import java.util.HashMap;
public class LetterRecognizer {
private static final String LOG_TAG = "LetterRecognizer";
@@ -37,8 +39,13 @@
private final String[] mClasses;
- private final int mInputCount;
+ private final int mPatchSize;
+
+ static final String GESTURE_FILE_NAME = "letters.xml";
+ private GestureLibrary mGestureLibrary;
+ private final static int ADJUST_RANGE = 3;
+
private static class SigmoidUnit {
final float[] mWeights;
@@ -62,11 +69,15 @@
}
private LetterRecognizer(int numOfInput, int numOfHidden, String[] classes) {
- mInputCount = (int)Math.sqrt(numOfInput);
+ mPatchSize = (int)Math.sqrt(numOfInput);
mHiddenLayer = new SigmoidUnit[numOfHidden];
mClasses = classes;
mOutputLayer = new SigmoidUnit[classes.length];
}
+
+ public void save() {
+ mGestureLibrary.save();
+ }
public static LetterRecognizer getLetterRecognizer(Context context, int type) {
switch (type) {
@@ -78,7 +89,12 @@
}
public ArrayList<Prediction> recognize(Gesture gesture) {
- return classify(GestureUtilities.spatialSampling(gesture, mInputCount));
+ float[] query = GestureUtilities.spatialSampling(gesture, mPatchSize);
+ ArrayList<Prediction> predictions = classify(query);
+ if (mGestureLibrary != null) {
+ adjustPrediction(gesture, predictions);
+ }
+ return predictions;
}
private ArrayList<Prediction> classify(float[] vector) {
@@ -151,16 +167,16 @@
SigmoidUnit[] outputLayer = new SigmoidUnit[oCount];
for (int i = 0; i < hCount; i++) {
- float[] weights = new float[iCount];
- for (int j = 0; j < iCount; j++) {
+ float[] weights = new float[iCount + 1];
+ for (int j = 0; j <= iCount; j++) {
weights[j] = in.readFloat();
}
hiddenLayer[i] = new SigmoidUnit(weights);
}
for (int i = 0; i < oCount; i++) {
- float[] weights = new float[hCount];
- for (int j = 0; j < hCount; j++) {
+ float[] weights = new float[hCount + 1];
+ for (int j = 0; j <= hCount; j++) {
weights[j] = in.readFloat();
}
outputLayer[i] = new SigmoidUnit(weights);
@@ -170,11 +186,43 @@
classifier.mOutputLayer = outputLayer;
} catch (IOException e) {
- Log.d(LOG_TAG, "Failed to load gestures:", e);
+ Log.d(LOG_TAG, "Failed to load handwriting data:", e);
} finally {
GestureUtilities.closeStream(in);
}
return classifier;
}
+
+ public void enablePersonalization(boolean enable) {
+ if (enable) {
+ mGestureLibrary = new GestureLibrary(GESTURE_FILE_NAME);
+ mGestureLibrary.setGestureType(GestureLibrary.SEQUENCE_INVARIANT);
+ mGestureLibrary.load();
+ } else {
+ mGestureLibrary = null;
+ }
+ }
+
+ public void addExample(String letter, Gesture example) {
+ mGestureLibrary.addGesture(letter, example);
+ }
+
+ private void adjustPrediction(Gesture query, ArrayList<Prediction> predictions) {
+ ArrayList<Prediction> results = mGestureLibrary.recognize(query);
+ HashMap<String, Prediction> topNList = new HashMap<String, Prediction>();
+ for (int j = 0; j < ADJUST_RANGE; j++) {
+ Prediction prediction = predictions.remove(0);
+ topNList.put(prediction.name, prediction);
+ }
+ int count = results.size();
+ for (int j = count - 1; j >= 0 && !topNList.isEmpty(); j--) {
+ Prediction item = results.get(j);
+ Prediction original = topNList.get(item.name);
+ if (original != null) {
+ predictions.add(0, original);
+ topNList.remove(item.name);
+ }
+ }
+ }
}