Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2012 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | package android.bordeaux.services; |
| 18 | |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 19 | import android.bordeaux.learning.StochasticLinearRanker; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 20 | import android.bordeaux.services.IBordeauxLearner.ModelChangeCallback; |
| 21 | import android.os.IBinder; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 22 | import android.util.Log; |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 23 | import java.util.List; |
| 24 | import java.util.ArrayList; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 25 | import java.io.*; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 26 | import java.lang.ClassNotFoundException; |
| 27 | import java.util.Arrays; |
| 28 | import java.util.ArrayList; |
| 29 | import java.util.List; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 30 | import java.util.Scanner; |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 31 | import java.io.ByteArrayOutputStream; |
| 32 | import java.util.HashMap; |
| 33 | import java.util.Map; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 34 | |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 35 | public class Learning_StochasticLinearRanker extends ILearning_StochasticLinearRanker.Stub |
| 36 | implements IBordeauxLearner { |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 37 | |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 38 | private final String TAG = "ILearning_StochasticLinearRanker"; |
| 39 | private StochasticLinearRankerWithPrior mLearningSlRanker = null; |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 40 | private ModelChangeCallback modelChangeCallback = null; |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 41 | |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 42 | public Learning_StochasticLinearRanker(){ |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 43 | } |
| 44 | |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 45 | public void ResetRanker(){ |
| 46 | if (mLearningSlRanker == null) |
| 47 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
| 48 | mLearningSlRanker.resetRanker(); |
| 49 | } |
| 50 | |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 51 | public boolean UpdateClassifier(List<StringFloat> sample_1, List<StringFloat> sample_2){ |
| 52 | ArrayList<StringFloat> temp_1 = (ArrayList<StringFloat>)sample_1; |
| 53 | String[] keys_1 = new String[temp_1.size()]; |
| 54 | float[] values_1 = new float[temp_1.size()]; |
| 55 | for (int i = 0; i < temp_1.size(); i++){ |
| 56 | keys_1[i] = temp_1.get(i).key; |
| 57 | values_1[i] = temp_1.get(i).value; |
| 58 | } |
| 59 | ArrayList<StringFloat> temp_2 = (ArrayList<StringFloat>)sample_2; |
| 60 | String[] keys_2 = new String[temp_2.size()]; |
| 61 | float[] values_2 = new float[temp_2.size()]; |
| 62 | for (int i = 0; i < temp_2.size(); i++){ |
| 63 | keys_2[i] = temp_2.get(i).key; |
| 64 | values_2[i] = temp_2.get(i).value; |
| 65 | } |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 66 | if (mLearningSlRanker == null) |
| 67 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 68 | boolean res = mLearningSlRanker.updateClassifier(keys_1,values_1,keys_2,values_2); |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 69 | if (res && modelChangeCallback != null) { |
| 70 | modelChangeCallback.modelChanged(this); |
| 71 | } |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 72 | return res; |
| 73 | } |
| 74 | |
| 75 | public float ScoreSample(List<StringFloat> sample) { |
| 76 | ArrayList<StringFloat> temp = (ArrayList<StringFloat>)sample; |
| 77 | String[] keys = new String[temp.size()]; |
| 78 | float[] values = new float[temp.size()]; |
| 79 | for (int i = 0; i < temp.size(); i++){ |
| 80 | keys[i] = temp.get(i).key; |
| 81 | values[i] = temp.get(i).value; |
| 82 | } |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 83 | if (mLearningSlRanker == null) |
| 84 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
| 85 | return mLearningSlRanker.scoreSample(keys,values); |
| 86 | } |
| 87 | |
| 88 | public boolean SetModelPriorWeight(List<StringFloat> sample) { |
| 89 | ArrayList<StringFloat> temp = (ArrayList<StringFloat>)sample; |
| 90 | HashMap<String, Float> weights = new HashMap<String, Float>(); |
| 91 | for (int i = 0; i < temp.size(); i++) |
| 92 | weights.put(temp.get(i).key, temp.get(i).value); |
| 93 | if (mLearningSlRanker == null) |
| 94 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
| 95 | return mLearningSlRanker.setModelPriorWeights(weights); |
| 96 | } |
| 97 | |
| 98 | public boolean SetModelParameter(String key, String value) { |
| 99 | if (mLearningSlRanker == null) |
| 100 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
| 101 | return mLearningSlRanker.setModelParameter(key,value); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 102 | } |
| 103 | |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 104 | // Beginning of the IBordeauxLearner Interface implementation |
| 105 | public byte [] getModel() { |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 106 | if (mLearningSlRanker == null) |
| 107 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
| 108 | StochasticLinearRankerWithPrior.Model model = mLearningSlRanker.getModel(); |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 109 | try { |
| 110 | ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); |
| 111 | ObjectOutputStream objStream = new ObjectOutputStream(byteStream); |
| 112 | objStream.writeObject(model); |
| 113 | //return byteStream.toByteArray(); |
| 114 | byte[] bytes = byteStream.toByteArray(); |
Mohammad Saberian | cb4a196 | 2012-06-05 10:40:41 -0700 | [diff] [blame^] | 115 | Log.i(TAG, "getModel: " + bytes); |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 116 | return bytes; |
| 117 | } catch (IOException e) { |
| 118 | throw new RuntimeException("Can't get model"); |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 119 | } |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 120 | } |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 121 | |
| 122 | public boolean setModel(final byte [] modelData) { |
| 123 | try { |
| 124 | ByteArrayInputStream input = new ByteArrayInputStream(modelData); |
| 125 | ObjectInputStream objStream = new ObjectInputStream(input); |
saberian | b019e89 | 2012-04-19 11:33:44 -0700 | [diff] [blame] | 126 | StochasticLinearRankerWithPrior.Model model = |
| 127 | (StochasticLinearRankerWithPrior.Model) objStream.readObject(); |
| 128 | if (mLearningSlRanker == null) |
| 129 | mLearningSlRanker = new StochasticLinearRankerWithPrior(); |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 130 | boolean res = mLearningSlRanker.loadModel(model); |
Mohammad Saberian | cb4a196 | 2012-06-05 10:40:41 -0700 | [diff] [blame^] | 131 | Log.i(TAG, "LoadModel: " + modelData); |
Wei Hua | 1dd8ef5 | 2012-03-30 15:15:12 -0700 | [diff] [blame] | 132 | return res; |
| 133 | } catch (IOException e) { |
| 134 | throw new RuntimeException("Can't load model"); |
| 135 | } catch (ClassNotFoundException e) { |
| 136 | throw new RuntimeException("Learning class not found"); |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | public IBinder getBinder() { |
| 141 | return this; |
| 142 | } |
| 143 | |
| 144 | public void setModelChangeCallback(ModelChangeCallback callback) { |
| 145 | modelChangeCallback = callback; |
| 146 | } |
| 147 | // End of IBordeauxLearner Interface implemenation |
Wei Hua | 6b4eebc | 2012-03-09 10:24:16 -0800 | [diff] [blame] | 148 | } |