blob: fd645813aec8965326c1c0cf8f051928a59824bd [file] [log] [blame]
Tony Mak608b1ae2019-08-13 20:02:00 +01001/*
Tony Mak8cd7ba62019-10-15 15:29:22 +01002 * Copyright (C) 2018 The Android Open Source Project
Tony Mak608b1ae2019-08-13 20:02:00 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.textclassifier;
18
19import android.app.RemoteAction;
20import android.content.Intent;
Tony Mak608b1ae2019-08-13 20:02:00 +010021import android.os.Bundle;
22import android.view.textclassifier.TextClassification;
23import android.view.textclassifier.TextClassifier;
24import android.view.textclassifier.TextLinks;
Tony Mak21460022020-03-12 18:29:35 +000025import androidx.core.util.Pair;
Tony Mak608b1ae2019-08-13 20:02:00 +010026import com.google.android.textclassifier.AnnotatorModel;
Tony Mak8cd7ba62019-10-15 15:29:22 +010027import com.google.common.annotations.VisibleForTesting;
Tony Mak608b1ae2019-08-13 20:02:00 +010028import java.util.ArrayList;
29import java.util.List;
Tony Mak8cd7ba62019-10-15 15:29:22 +010030import javax.annotation.Nullable;
Tony Mak608b1ae2019-08-13 20:02:00 +010031
32/** Utility class for inserting and retrieving data in TextClassifier request/response extras. */
33// TODO: Make this a TestApi for CTS testing.
34public final class ExtrasUtils {
35
Tony Mak8cd7ba62019-10-15 15:29:22 +010036 // Keys for response objects.
37 private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
38 private static final String ENTITIES_EXTRAS = "entities-extras";
39 private static final String ACTION_INTENT = "action-intent";
40 private static final String ACTIONS_INTENTS = "actions-intents";
41 private static final String FOREIGN_LANGUAGE = "foreign-language";
42 private static final String ENTITY_TYPE = "entity-type";
43 private static final String SCORE = "score";
44 private static final String MODEL_VERSION = "model-version";
45 private static final String MODEL_NAME = "model-name";
46 private static final String TEXT_LANGUAGES = "text-languages";
47 private static final String ENTITIES = "entities";
Tony Mak608b1ae2019-08-13 20:02:00 +010048
Tony Mak8cd7ba62019-10-15 15:29:22 +010049 // Keys for request objects.
50 private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
51 "is-serialized-entity-data-enabled";
Tony Mak608b1ae2019-08-13 20:02:00 +010052
Tony Mak8cd7ba62019-10-15 15:29:22 +010053 private ExtrasUtils() {}
Tony Mak608b1ae2019-08-13 20:02:00 +010054
Tony Mak8cd7ba62019-10-15 15:29:22 +010055 /** Bundles and returns foreign language detection information for TextClassifier responses. */
56 static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
57 final Bundle bundle = new Bundle();
58 bundle.putString(ENTITY_TYPE, language);
59 bundle.putFloat(SCORE, score);
60 bundle.putInt(MODEL_VERSION, modelVersion);
61 bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
62 return bundle;
63 }
64
65 /**
66 * Stores {@code extra} as foreign language information in TextClassifier response object's extras
67 * {@code container}.
68 *
69 * @see #getForeignLanguageExtra(TextClassification)
70 */
71 static void putForeignLanguageExtra(Bundle container, Bundle extra) {
72 container.putParcelable(FOREIGN_LANGUAGE, extra);
73 }
74
75 /**
76 * Returns foreign language detection information contained in the TextClassification object.
77 * responses.
78 *
79 * @see #putForeignLanguageExtra(Bundle, Bundle)
80 */
81 @Nullable
82 @VisibleForTesting
83 public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
84 if (classification == null) {
85 return null;
Tony Mak608b1ae2019-08-13 20:02:00 +010086 }
Tony Mak8cd7ba62019-10-15 15:29:22 +010087 return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
88 }
Tony Mak608b1ae2019-08-13 20:02:00 +010089
Tony Mak8cd7ba62019-10-15 15:29:22 +010090 /** @see #getTopLanguage(Intent) */
91 static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
92 final int maxSize = Math.min(3, languageScores.getEntities().size());
93 final String[] languages =
94 languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
95 final float[] scores = new float[languages.length];
96 for (int i = 0; i < languages.length; i++) {
97 scores[i] = languageScores.getConfidenceScore(languages[i]);
Tony Mak608b1ae2019-08-13 20:02:00 +010098 }
Tony Mak8cd7ba62019-10-15 15:29:22 +010099 container.putStringArray(ENTITY_TYPE, languages);
100 container.putFloatArray(SCORE, scores);
101 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100102
Tony Mak8cd7ba62019-10-15 15:29:22 +0100103 /** See {@link #putTopLanguageScores(Bundle, EntityConfidence)}. */
104 @Nullable
Tony Mak21460022020-03-12 18:29:35 +0000105 static Pair<String, Float> getTopLanguage(@Nullable Intent intent) {
Tony Mak8cd7ba62019-10-15 15:29:22 +0100106 if (intent == null) {
107 return null;
108 }
109 final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
110 if (tcBundle == null) {
111 return null;
112 }
113 final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
114 if (textLanguagesExtra == null) {
115 return null;
116 }
117 final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
118 final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
119 if (languages == null
120 || scores == null
121 || languages.length == 0
122 || languages.length != scores.length) {
123 return null;
124 }
125 int highestScoringIndex = 0;
126 for (int i = 1; i < languages.length; i++) {
127 if (scores[highestScoringIndex] < scores[i]) {
128 highestScoringIndex = i;
129 }
130 }
Tony Mak21460022020-03-12 18:29:35 +0000131 return Pair.create(languages[highestScoringIndex], scores[highestScoringIndex]);
Tony Mak8cd7ba62019-10-15 15:29:22 +0100132 }
133
134 public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
135 container.putBundle(TEXT_LANGUAGES, extra);
136 }
137
138 /**
139 * Stores {@code actionsIntents} information in TextClassifier response object's extras {@code
140 * container}.
141 */
142 static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
143 container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
144 }
145
146 /**
147 * Stores {@code actionIntent} information in TextClassifier response object's extras {@code
148 * container}.
149 */
150 public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
151 container.putParcelable(ACTION_INTENT, actionIntent);
152 }
153
154 /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
155 @Nullable
156 public static Intent getActionIntent(Bundle container) {
157 return container.getParcelable(ACTION_INTENT);
158 }
159
160 /**
161 * Stores serialized entity data information in TextClassifier response object's extras {@code
162 * container}.
163 */
164 public static void putSerializedEntityData(
165 Bundle container, @Nullable byte[] serializedEntityData) {
166 container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
167 }
168
169 /** Returns serialized entity data information contained in a TextClassifier response object. */
170 @Nullable
171 public static byte[] getSerializedEntityData(Bundle container) {
172 return container.getByteArray(SERIALIZED_ENTITIES_DATA);
173 }
174
175 /**
176 * Stores {@code entities} information in TextClassifier response object's extras {@code
177 * container}.
178 *
179 * @see {@link #getCopyText(Bundle)}
180 */
181 public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
182 container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
183 }
184
185 /**
186 * Returns {@code entities} information contained in a TextClassifier response object.
187 *
188 * @see {@link #putEntitiesExtras(Bundle, Bundle)}
189 */
190 @Nullable
191 public static String getCopyText(Bundle container) {
192 Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
193 if (entitiesExtras == null) {
194 return null;
195 }
196 return entitiesExtras.getString("text");
197 }
198
199 /** Returns {@code actionIntents} information contained in the TextClassification object. */
200 @Nullable
201 public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
202 if (classification == null) {
203 return null;
204 }
205 return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
206 }
207
208 /**
209 * Returns the first action found in the {@code classification} object with an intent action
210 * string, {@code intentAction}.
211 */
212 @Nullable
213 @VisibleForTesting
214 public static RemoteAction findAction(
215 @Nullable TextClassification classification, @Nullable String intentAction) {
216 if (classification == null || intentAction == null) {
217 return null;
218 }
219 final ArrayList<Intent> actionIntents = getActionsIntents(classification);
220 if (actionIntents != null) {
221 final int size = actionIntents.size();
222 for (int i = 0; i < size; i++) {
223 final Intent intent = actionIntents.get(i);
224 if (intent != null && intentAction.equals(intent.getAction())) {
225 return classification.getActions().get(i);
Tony Mak608b1ae2019-08-13 20:02:00 +0100226 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100227 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100228 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100229 return null;
230 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100231
Tony Mak8cd7ba62019-10-15 15:29:22 +0100232 /** Returns the first "translate" action found in the {@code classification} object. */
233 @Nullable
234 @VisibleForTesting
235 public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
236 return findAction(classification, Intent.ACTION_TRANSLATE);
237 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100238
Tony Mak8cd7ba62019-10-15 15:29:22 +0100239 /** Returns the entity type contained in the {@code extra}. */
240 @Nullable
241 @VisibleForTesting
242 public static String getEntityType(@Nullable Bundle extra) {
243 if (extra == null) {
244 return null;
Tony Mak608b1ae2019-08-13 20:02:00 +0100245 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100246 return extra.getString(ENTITY_TYPE);
247 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100248
Tony Mak8cd7ba62019-10-15 15:29:22 +0100249 /** Returns the score contained in the {@code extra}. */
250 @VisibleForTesting
251 public static float getScore(Bundle extra) {
252 final int defaultValue = -1;
253 if (extra == null) {
254 return defaultValue;
Tony Mak608b1ae2019-08-13 20:02:00 +0100255 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100256 return extra.getFloat(SCORE, defaultValue);
257 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100258
Tony Mak8cd7ba62019-10-15 15:29:22 +0100259 /** Returns the model name contained in the {@code extra}. */
260 @Nullable
261 public static String getModelName(@Nullable Bundle extra) {
262 if (extra == null) {
263 return null;
Tony Mak608b1ae2019-08-13 20:02:00 +0100264 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100265 return extra.getString(MODEL_NAME);
266 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100267
Tony Mak8cd7ba62019-10-15 15:29:22 +0100268 /** Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}. */
269 public static void putEntities(
270 Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
271 if (classifications == null || classifications.length == 0) {
272 return;
Tony Mak608b1ae2019-08-13 20:02:00 +0100273 }
Tony Mak8cd7ba62019-10-15 15:29:22 +0100274 ArrayList<Bundle> entitiesBundle = new ArrayList<>();
275 for (AnnotatorModel.ClassificationResult classification : classifications) {
276 if (classification == null) {
277 continue;
278 }
279 Bundle entityBundle = new Bundle();
280 entityBundle.putString(ENTITY_TYPE, classification.getCollection());
281 entityBundle.putByteArray(SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
282 entitiesBundle.add(entityBundle);
283 }
284 if (!entitiesBundle.isEmpty()) {
285 container.putParcelableArrayList(ENTITIES, entitiesBundle);
286 }
287 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100288
Tony Mak8cd7ba62019-10-15 15:29:22 +0100289 /** Returns a list of entities contained in the {@code extra}. */
290 @Nullable
291 @VisibleForTesting
292 public static List<Bundle> getEntities(Bundle container) {
293 return container.getParcelableArrayList(ENTITIES);
294 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100295
Tony Mak8cd7ba62019-10-15 15:29:22 +0100296 /** Whether the annotator should populate serialized entity data into the result object. */
297 public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
298 return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
299 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100300
Tony Mak8cd7ba62019-10-15 15:29:22 +0100301 /**
302 * To indicate whether the annotator should populate serialized entity data in the result object.
303 */
304 @VisibleForTesting
305 public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
306 bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
307 }
Tony Mak608b1ae2019-08-13 20:02:00 +0100308}