blob: 8d18e199865f7e6afa610a16e379bb1c491d6bed [file] [log] [blame]
Tony Mak608b1ae2019-08-13 20:02:00 +01001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.textclassifier;
18
19import android.app.RemoteAction;
20import android.content.Intent;
21import android.icu.util.ULocale;
22import android.os.Bundle;
23import android.view.textclassifier.TextClassification;
24import android.view.textclassifier.TextClassifier;
25import android.view.textclassifier.TextLinks;
26
27import androidx.annotation.Nullable;
28import androidx.annotation.VisibleForTesting;
29
30import com.google.android.textclassifier.AnnotatorModel;
31
32import java.util.ArrayList;
33import java.util.List;
34
35/** Utility class for inserting and retrieving data in TextClassifier request/response extras. */
36// TODO: Make this a TestApi for CTS testing.
37public final class ExtrasUtils {
38
39 // Keys for response objects.
40 private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
41 private static final String ENTITIES_EXTRAS = "entities-extras";
42 private static final String ACTION_INTENT = "action-intent";
43 private static final String ACTIONS_INTENTS = "actions-intents";
44 private static final String FOREIGN_LANGUAGE = "foreign-language";
45 private static final String ENTITY_TYPE = "entity-type";
46 private static final String SCORE = "score";
47 private static final String MODEL_VERSION = "model-version";
48 private static final String MODEL_NAME = "model-name";
49 private static final String TEXT_LANGUAGES = "text-languages";
50 private static final String ENTITIES = "entities";
51
52 // Keys for request objects.
53 private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
54 "is-serialized-entity-data-enabled";
55
56 private ExtrasUtils() {}
57
58 /** Bundles and returns foreign language detection information for TextClassifier responses. */
59 static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
60 final Bundle bundle = new Bundle();
61 bundle.putString(ENTITY_TYPE, language);
62 bundle.putFloat(SCORE, score);
63 bundle.putInt(MODEL_VERSION, modelVersion);
64 bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
65 return bundle;
66 }
67
68 /**
69 * Stores {@code extra} as foreign language information in TextClassifier response object's
70 * extras {@code container}.
71 *
72 * @see #getForeignLanguageExtra(TextClassification)
73 */
74 static void putForeignLanguageExtra(Bundle container, Bundle extra) {
75 container.putParcelable(FOREIGN_LANGUAGE, extra);
76 }
77
78 /**
79 * Returns foreign language detection information contained in the TextClassification object.
80 * responses.
81 *
82 * @see #putForeignLanguageExtra(Bundle, Bundle)
83 */
84 @Nullable
85 @VisibleForTesting
86 public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
87 if (classification == null) {
88 return null;
89 }
90 return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
91 }
92
93 /** @see #getTopLanguage(Intent) */
94 @VisibleForTesting
95 static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
96 final int maxSize = Math.min(3, languageScores.getEntities().size());
97 final String[] languages =
98 languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
99 final float[] scores = new float[languages.length];
100 for (int i = 0; i < languages.length; i++) {
101 scores[i] = languageScores.getConfidenceScore(languages[i]);
102 }
103 container.putStringArray(ENTITY_TYPE, languages);
104 container.putFloatArray(SCORE, scores);
105 }
106
107 /** @see #putTopLanguageScores(Bundle, EntityConfidence) */
108 @Nullable
109 @VisibleForTesting
110 public static ULocale getTopLanguage(@Nullable Intent intent) {
111 if (intent == null) {
112 return null;
113 }
114 final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
115 if (tcBundle == null) {
116 return null;
117 }
118 final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
119 if (textLanguagesExtra == null) {
120 return null;
121 }
122 final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
123 final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
124 if (languages == null
125 || scores == null
126 || languages.length == 0
127 || languages.length != scores.length) {
128 return null;
129 }
130 int highestScoringIndex = 0;
131 for (int i = 1; i < languages.length; i++) {
132 if (scores[highestScoringIndex] < scores[i]) {
133 highestScoringIndex = i;
134 }
135 }
136 return ULocale.forLanguageTag(languages[highestScoringIndex]);
137 }
138
139 public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
140 container.putBundle(TEXT_LANGUAGES, extra);
141 }
142
143 /**
144 * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
145 * container}.
146 */
147 static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
148 container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
149 }
150
151 /**
152 * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
153 * container}.
154 */
155 public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
156 container.putParcelable(ACTION_INTENT, actionIntent);
157 }
158
159 /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
160 @Nullable
161 public static Intent getActionIntent(Bundle container) {
162 return container.getParcelable(ACTION_INTENT);
163 }
164
165 /**
166 * Stores serialized entity data information in TextClassifier response object's extras {@code
167 * container}.
168 */
169 public static void putSerializedEntityData(
170 Bundle container, @Nullable byte[] serializedEntityData) {
171 container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
172 }
173
174 /** Returns serialized entity data information contained in a TextClassifier response object. */
175 @Nullable
176 public static byte[] getSerializedEntityData(Bundle container) {
177 return container.getByteArray(SERIALIZED_ENTITIES_DATA);
178 }
179
180 /**
181 * Stores {@code entities} information in TextClassifier response object's extras {@code
182 * container}.
183 *
184 * @see {@link #getCopyText(Bundle)}
185 */
186 public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
187 container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
188 }
189
190 /**
191 * Returns {@code entities} information contained in a TextClassifier response object.
192 *
193 * @see {@link #putEntitiesExtras(Bundle, Bundle)}
194 */
195 @Nullable
196 public static String getCopyText(Bundle container) {
197 Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
198 if (entitiesExtras == null) {
199 return null;
200 }
201 return entitiesExtras.getString("text");
202 }
203
204 /** Returns {@code actionIntents} information contained in the TextClassification object. */
205 @Nullable
206 public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
207 if (classification == null) {
208 return null;
209 }
210 return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
211 }
212
213 /**
214 * Returns the first action found in the {@code classification} object with an intent action
215 * string, {@code intentAction}.
216 */
217 @Nullable
218 @VisibleForTesting
219 public static RemoteAction findAction(
220 @Nullable TextClassification classification, @Nullable String intentAction) {
221 if (classification == null || intentAction == null) {
222 return null;
223 }
224 final ArrayList<Intent> actionIntents = getActionsIntents(classification);
225 if (actionIntents != null) {
226 final int size = actionIntents.size();
227 for (int i = 0; i < size; i++) {
228 final Intent intent = actionIntents.get(i);
229 if (intent != null && intentAction.equals(intent.getAction())) {
230 return classification.getActions().get(i);
231 }
232 }
233 }
234 return null;
235 }
236
237 /** Returns the first "translate" action found in the {@code classification} object. */
238 @Nullable
239 @VisibleForTesting
240 public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
241 return findAction(classification, Intent.ACTION_TRANSLATE);
242 }
243
244 /** Returns the entity type contained in the {@code extra}. */
245 @Nullable
246 @VisibleForTesting
247 public static String getEntityType(@Nullable Bundle extra) {
248 if (extra == null) {
249 return null;
250 }
251 return extra.getString(ENTITY_TYPE);
252 }
253
254 /** Returns the score contained in the {@code extra}. */
255 @VisibleForTesting
256 public static float getScore(Bundle extra) {
257 final int defaultValue = -1;
258 if (extra == null) {
259 return defaultValue;
260 }
261 return extra.getFloat(SCORE, defaultValue);
262 }
263
264 /** Returns the model name contained in the {@code extra}. */
265 @Nullable
266 public static String getModelName(@Nullable Bundle extra) {
267 if (extra == null) {
268 return null;
269 }
270 return extra.getString(MODEL_NAME);
271 }
272
273 /**
274 * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
275 */
276 public static void putEntities(
277 Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
278 if (classifications == null || classifications.length == 0) {
279 return;
280 }
281 ArrayList<Bundle> entitiesBundle = new ArrayList<>();
282 for (AnnotatorModel.ClassificationResult classification : classifications) {
283 if (classification == null) {
284 continue;
285 }
286 Bundle entityBundle = new Bundle();
287 entityBundle.putString(ENTITY_TYPE, classification.getCollection());
288 entityBundle.putByteArray(
289 SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
290 entitiesBundle.add(entityBundle);
291 }
292 if (!entitiesBundle.isEmpty()) {
293 container.putParcelableArrayList(ENTITIES, entitiesBundle);
294 }
295 }
296
297 /** Returns a list of entities contained in the {@code extra}. */
298 @Nullable
299 @VisibleForTesting
300 public static List<Bundle> getEntities(Bundle container) {
301 return container.getParcelableArrayList(ENTITIES);
302 }
303
304 /** Whether the annotator should populate serialized entity data into the result object. */
305 public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
306 return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
307 }
308
309 /**
310 * To indicate whether the annotator should populate serialized entity data in the result
311 * object.
312 */
313 @VisibleForTesting
314 public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
315 bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
316 }
317}