Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2011 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | |
| 18 | package android.bordeaux.learning; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 19 | |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 20 | import android.util.Log; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 21 | |
| 22 | import java.io.Serializable; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 23 | import java.util.List; |
| 24 | import java.util.Arrays; |
| 25 | import java.util.ArrayList; |
| 26 | |
| 27 | /** |
| 28 | * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score |
| 29 | * can be used to compare samples. |
| 30 | * This java class wraps the native StochasticLinearRanker class. |
| 31 | * To update the ranker, call updateClassifier with two samples, with the first |
| 32 | * one having higher rank than the second one. |
| 33 | * To get the rank score of the sample call scoreSample. |
| 34 | * TODO: adding more interfaces for changing the learning parameters |
| 35 | */ |
| 36 | public class StochasticLinearRanker { |
| 37 | String TAG = "StochasticLinearRanker"; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 38 | |
| 39 | static public class Model implements Serializable { |
| 40 | public ArrayList<String> keys = new ArrayList<String>(); |
| 41 | public ArrayList<Float> values = new ArrayList<Float>(); |
| 42 | public ArrayList<Float> parameters = new ArrayList<Float>(); |
| 43 | } |
| 44 | |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 45 | static int VAR_NUM = 15; |
| 46 | public StochasticLinearRanker() { |
| 47 | mNativeClassifier = initNativeClassifier(); |
| 48 | } |
| 49 | |
| 50 | /** |
| 51 | * Train the ranker with a pair of samples. A sample, a pair of arrays of |
| 52 | * keys and values. The first sample should have higher rank than the second |
| 53 | * one. |
| 54 | */ |
| 55 | public boolean updateClassifier(String[] keys_positive, |
| 56 | float[] values_positive, |
| 57 | String[] keys_negative, |
| 58 | float[] values_negative) { |
| 59 | return nativeUpdateClassifier(keys_positive, values_positive, |
| 60 | keys_negative, values_negative, |
| 61 | mNativeClassifier); |
| 62 | } |
| 63 | |
| 64 | /** |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 65 | * Get the rank score of the sample, a sample is a list of key, value pairs. |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 66 | */ |
| 67 | public float scoreSample(String[] keys, float[] values) { |
| 68 | return nativeScoreSample(keys, values, mNativeClassifier); |
| 69 | } |
| 70 | |
| 71 | /** |
| 72 | * Get the current model and parameters of ranker |
| 73 | */ |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 74 | public Model getModel(){ |
| 75 | Model model = new Model(); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 76 | int len = nativeGetLengthClassifier(mNativeClassifier); |
| 77 | String[] keys = new String[len]; |
| 78 | float[] values = new float[len]; |
| 79 | float[] param = new float[VAR_NUM]; |
| 80 | nativeGetClassifier(keys, values, param, mNativeClassifier); |
| 81 | boolean add_flag; |
| 82 | for (int i=0; i< keys.length ; i++){ |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 83 | add_flag = model.keys.add(keys[i]); |
| 84 | add_flag = model.values.add(values[i]); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 85 | } |
| 86 | for (int i=0; i< param.length ; i++) |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 87 | add_flag = model.parameters.add(param[i]); |
| 88 | return model; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 89 | } |
| 90 | |
| 91 | /** |
| 92 | * use the given model and parameters for ranker |
| 93 | */ |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame^] | 94 | public boolean loadModel(Model model) { |
| 95 | float[] values = new float[model.values.size()]; |
| 96 | float[] param = new float[model.parameters.size()]; |
| 97 | for (int i = 0; i < model.values.size(); ++i) { |
| 98 | values[i] = model.values.get(i); |
| 99 | } |
| 100 | for (int i = 0; i < model.parameters.size(); ++i) { |
| 101 | param[i] = model.parameters.get(i); |
| 102 | } |
| 103 | String[] keys = new String[model.keys.size()]; |
| 104 | model.keys.toArray(keys); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 105 | return nativeLoadClassifier(keys, values, param, mNativeClassifier); |
| 106 | } |
| 107 | |
| 108 | @Override |
| 109 | protected void finalize() throws Throwable { |
| 110 | deleteNativeClassifier(mNativeClassifier); |
| 111 | } |
| 112 | |
| 113 | static { |
| 114 | System.loadLibrary("bordeaux"); |
| 115 | } |
| 116 | |
| 117 | private int mNativeClassifier; |
| 118 | |
| 119 | /* |
| 120 | * The following methods are the java stubs for the jni implementations. |
| 121 | */ |
| 122 | private native int initNativeClassifier(); |
| 123 | |
| 124 | private native void deleteNativeClassifier(int classifierPtr); |
| 125 | |
| 126 | private native boolean nativeUpdateClassifier( |
| 127 | String[] keys_positive, |
| 128 | float[] values_positive, |
| 129 | String[] keys_negative, |
| 130 | float[] values_negative, |
| 131 | int classifierPtr); |
| 132 | |
| 133 | private native float nativeScoreSample(String[] keys, |
| 134 | float[] values, |
| 135 | int classifierPtr); |
| 136 | private native void nativeGetClassifier(String [] keys, float[] values, float[] param, |
| 137 | int classifierPtr); |
| 138 | private native boolean nativeLoadClassifier(String [] keys, float[] values, |
| 139 | float[] param, int classifierPtr); |
| 140 | private native int nativeGetLengthClassifier(int classifierPtr); |
| 141 | } |