blob: 12b7b29ad1b7396f6f6629d3be621d4bc4ef53a7 [file] [log] [blame]
Wei Hua6b4eebc2012-03-09 10:24:16 -08001/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18package android.bordeaux.learning;
Wei Hua1dd8ef52012-03-30 15:15:12 -070019
Wei Hua6b4eebc2012-03-09 10:24:16 -080020import android.util.Log;
Wei Hua1dd8ef52012-03-30 15:15:12 -070021
22import java.io.Serializable;
Wei Hua6b4eebc2012-03-09 10:24:16 -080023import java.util.List;
24import java.util.Arrays;
25import java.util.ArrayList;
26
27/**
28 * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
29 * can be used to compare samples.
30 * This java class wraps the native StochasticLinearRanker class.
31 * To update the ranker, call updateClassifier with two samples, with the first
32 * one having higher rank than the second one.
33 * To get the rank score of the sample call scoreSample.
34 * TODO: adding more interfaces for changing the learning parameters
35 */
36public class StochasticLinearRanker {
37 String TAG = "StochasticLinearRanker";
Wei Hua1dd8ef52012-03-30 15:15:12 -070038
39 static public class Model implements Serializable {
40 public ArrayList<String> keys = new ArrayList<String>();
41 public ArrayList<Float> values = new ArrayList<Float>();
42 public ArrayList<Float> parameters = new ArrayList<Float>();
43 }
44
Wei Hua6b4eebc2012-03-09 10:24:16 -080045 static int VAR_NUM = 15;
46 public StochasticLinearRanker() {
47 mNativeClassifier = initNativeClassifier();
48 }
49
50 /**
51 * Train the ranker with a pair of samples. A sample, a pair of arrays of
52 * keys and values. The first sample should have higher rank than the second
53 * one.
54 */
55 public boolean updateClassifier(String[] keys_positive,
56 float[] values_positive,
57 String[] keys_negative,
58 float[] values_negative) {
59 return nativeUpdateClassifier(keys_positive, values_positive,
60 keys_negative, values_negative,
61 mNativeClassifier);
62 }
63
64 /**
Wei Hua1dd8ef52012-03-30 15:15:12 -070065 * Get the rank score of the sample, a sample is a list of key, value pairs.
Wei Hua6b4eebc2012-03-09 10:24:16 -080066 */
67 public float scoreSample(String[] keys, float[] values) {
68 return nativeScoreSample(keys, values, mNativeClassifier);
69 }
70
71 /**
72 * Get the current model and parameters of ranker
73 */
Wei Hua1dd8ef52012-03-30 15:15:12 -070074 public Model getModel(){
75 Model model = new Model();
Wei Hua6b4eebc2012-03-09 10:24:16 -080076 int len = nativeGetLengthClassifier(mNativeClassifier);
77 String[] keys = new String[len];
78 float[] values = new float[len];
79 float[] param = new float[VAR_NUM];
80 nativeGetClassifier(keys, values, param, mNativeClassifier);
81 boolean add_flag;
82 for (int i=0; i< keys.length ; i++){
Wei Hua1dd8ef52012-03-30 15:15:12 -070083 add_flag = model.keys.add(keys[i]);
84 add_flag = model.values.add(values[i]);
Wei Hua6b4eebc2012-03-09 10:24:16 -080085 }
86 for (int i=0; i< param.length ; i++)
Wei Hua1dd8ef52012-03-30 15:15:12 -070087 add_flag = model.parameters.add(param[i]);
88 return model;
Wei Hua6b4eebc2012-03-09 10:24:16 -080089 }
90
91 /**
92 * use the given model and parameters for ranker
93 */
Wei Hua1dd8ef52012-03-30 15:15:12 -070094 public boolean loadModel(Model model) {
95 float[] values = new float[model.values.size()];
96 float[] param = new float[model.parameters.size()];
97 for (int i = 0; i < model.values.size(); ++i) {
98 values[i] = model.values.get(i);
99 }
100 for (int i = 0; i < model.parameters.size(); ++i) {
101 param[i] = model.parameters.get(i);
102 }
103 String[] keys = new String[model.keys.size()];
104 model.keys.toArray(keys);
Wei Hua6b4eebc2012-03-09 10:24:16 -0800105 return nativeLoadClassifier(keys, values, param, mNativeClassifier);
106 }
107
108 @Override
109 protected void finalize() throws Throwable {
110 deleteNativeClassifier(mNativeClassifier);
111 }
112
113 static {
114 System.loadLibrary("bordeaux");
115 }
116
117 private int mNativeClassifier;
118
119 /*
120 * The following methods are the java stubs for the jni implementations.
121 */
122 private native int initNativeClassifier();
123
124 private native void deleteNativeClassifier(int classifierPtr);
125
126 private native boolean nativeUpdateClassifier(
127 String[] keys_positive,
128 float[] values_positive,
129 String[] keys_negative,
130 float[] values_negative,
131 int classifierPtr);
132
133 private native float nativeScoreSample(String[] keys,
134 float[] values,
135 int classifierPtr);
136 private native void nativeGetClassifier(String [] keys, float[] values, float[] param,
137 int classifierPtr);
138 private native boolean nativeLoadClassifier(String [] keys, float[] values,
139 float[] param, int classifierPtr);
140 private native int nativeGetLengthClassifier(int classifierPtr);
141}