blob: a9a9550d7909800d564e451ef0be6f37d452efd7 [file] [log] [blame]
Tony Mak0be540b2018-11-09 16:58:35 +00001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.view.textclassifier;
18
19import static org.hamcrest.CoreMatchers.not;
20import static org.junit.Assert.assertEquals;
Tony Mak0be540b2018-11-09 16:58:35 +000021import static org.junit.Assert.assertThat;
22import static org.junit.Assert.assertTrue;
23
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +000024import android.app.RemoteAction;
Tony Mak0be540b2018-11-09 16:58:35 +000025import android.content.Context;
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +000026import android.content.Intent;
Tony Makc12035e2019-02-26 17:45:34 +000027import android.net.Uri;
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +000028import android.os.Bundle;
Tony Mak0be540b2018-11-09 16:58:35 +000029import android.os.LocaleList;
Abodunrinwa Tokiadc19402018-11-22 17:10:25 +000030import android.text.Spannable;
31import android.text.SpannableString;
Tony Mak0be540b2018-11-09 16:58:35 +000032
Tadashi G. Takaokab4470f22019-01-15 18:29:15 +090033import androidx.test.InstrumentationRegistry;
34import androidx.test.filters.SmallTest;
Tony Mak72e17972019-03-16 10:28:42 +000035import androidx.test.runner.AndroidJUnit4;
Tadashi G. Takaokab4470f22019-01-15 18:29:15 +090036
Tony Makc12035e2019-02-26 17:45:34 +000037import com.google.common.truth.Truth;
38
Tony Mak0be540b2018-11-09 16:58:35 +000039import org.hamcrest.BaseMatcher;
40import org.hamcrest.Description;
41import org.hamcrest.Matcher;
42import org.junit.Before;
Abodunrinwa Toki0f1d77e2019-05-02 21:01:19 +010043import org.junit.Ignore;
Tony Mak0be540b2018-11-09 16:58:35 +000044import org.junit.Test;
45import org.junit.runner.RunWith;
Tony Mak0be540b2018-11-09 16:58:35 +000046
47import java.util.Arrays;
48import java.util.Collections;
49import java.util.List;
50
51/**
52 * Testing {@link TextClassifierTest} APIs on local and system textclassifier.
53 * <p>
54 * Tests are skipped if such a textclassifier does not exist.
55 */
56@SmallTest
Tony Mak72e17972019-03-16 10:28:42 +000057@RunWith(AndroidJUnit4.class)
Tony Mak0be540b2018-11-09 16:58:35 +000058public class TextClassifierTest {
Tony Mak0be540b2018-11-09 16:58:35 +000059
Tony Mak72e17972019-03-16 10:28:42 +000060 // TODO: Implement TextClassifierService testing.
Tony Mak0be540b2018-11-09 16:58:35 +000061
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +000062 private static final TextClassificationConstants TC_CONSTANTS =
Abodunrinwa Toki0634af32019-04-04 13:10:59 +010063 new TextClassificationConstants(() -> "");
Tony Mak0be540b2018-11-09 16:58:35 +000064 private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
65 private static final String NO_TYPE = null;
66
67 private Context mContext;
68 private TextClassificationManager mTcm;
69 private TextClassifier mClassifier;
70
71 @Before
72 public void setup() {
73 mContext = InstrumentationRegistry.getTargetContext();
74 mTcm = mContext.getSystemService(TextClassificationManager.class);
Tony Mak72e17972019-03-16 10:28:42 +000075 mClassifier = mTcm.getTextClassifier(TextClassifier.LOCAL);
Tony Mak0be540b2018-11-09 16:58:35 +000076 }
77
78 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +000079 public void testSuggestSelection() {
Tony Mak0be540b2018-11-09 16:58:35 +000080 if (isTextClassifierDisabled()) return;
81
82 String text = "Contact me at droid@android.com";
83 String selected = "droid";
84 String suggested = "droid@android.com";
85 int startIndex = text.indexOf(selected);
86 int endIndex = startIndex + selected.length();
87 int smartStartIndex = text.indexOf(suggested);
88 int smartEndIndex = smartStartIndex + suggested.length();
89 TextSelection.Request request = new TextSelection.Request.Builder(
90 text, startIndex, endIndex)
91 .setDefaultLocales(LOCALES)
92 .build();
93
94 TextSelection selection = mClassifier.suggestSelection(request);
95 assertThat(selection,
96 isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
97 }
98
99 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000100 public void testSuggestSelection_url() {
Tony Mak0be540b2018-11-09 16:58:35 +0000101 if (isTextClassifierDisabled()) return;
102
103 String text = "Visit http://www.android.com for more information";
104 String selected = "http";
105 String suggested = "http://www.android.com";
106 int startIndex = text.indexOf(selected);
107 int endIndex = startIndex + selected.length();
108 int smartStartIndex = text.indexOf(suggested);
109 int smartEndIndex = smartStartIndex + suggested.length();
110 TextSelection.Request request = new TextSelection.Request.Builder(
111 text, startIndex, endIndex)
112 .setDefaultLocales(LOCALES)
113 .build();
114
115 TextSelection selection = mClassifier.suggestSelection(request);
116 assertThat(selection,
117 isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
118 }
119
120 @Test
121 public void testSmartSelection_withEmoji() {
122 if (isTextClassifierDisabled()) return;
123
124 String text = "\uD83D\uDE02 Hello.";
125 String selected = "Hello";
126 int startIndex = text.indexOf(selected);
127 int endIndex = startIndex + selected.length();
128 TextSelection.Request request = new TextSelection.Request.Builder(
129 text, startIndex, endIndex)
130 .setDefaultLocales(LOCALES)
131 .build();
132
133 TextSelection selection = mClassifier.suggestSelection(request);
134 assertThat(selection,
135 isTextSelection(startIndex, endIndex, NO_TYPE));
136 }
137
138 @Test
139 public void testClassifyText() {
140 if (isTextClassifierDisabled()) return;
141
142 String text = "Contact me at droid@android.com";
143 String classifiedText = "droid@android.com";
144 int startIndex = text.indexOf(classifiedText);
145 int endIndex = startIndex + classifiedText.length();
146 TextClassification.Request request = new TextClassification.Request.Builder(
147 text, startIndex, endIndex)
148 .setDefaultLocales(LOCALES)
149 .build();
150
151 TextClassification classification = mClassifier.classifyText(request);
152 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
153 }
154
155 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000156 public void testClassifyText_url() {
Tony Mak0be540b2018-11-09 16:58:35 +0000157 if (isTextClassifierDisabled()) return;
158
159 String text = "Visit www.android.com for more information";
160 String classifiedText = "www.android.com";
161 int startIndex = text.indexOf(classifiedText);
162 int endIndex = startIndex + classifiedText.length();
163 TextClassification.Request request = new TextClassification.Request.Builder(
164 text, startIndex, endIndex)
165 .setDefaultLocales(LOCALES)
166 .build();
167
168 TextClassification classification = mClassifier.classifyText(request);
169 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
Abodunrinwa Tokic33fc772019-02-06 01:17:10 +0000170 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
Tony Mak0be540b2018-11-09 16:58:35 +0000171 }
172
173 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000174 public void testClassifyText_address() {
Tony Mak0be540b2018-11-09 16:58:35 +0000175 if (isTextClassifierDisabled()) return;
176
177 String text = "Brandschenkestrasse 110, Zürich, Switzerland";
178 TextClassification.Request request = new TextClassification.Request.Builder(
179 text, 0, text.length())
180 .setDefaultLocales(LOCALES)
181 .build();
182
183 TextClassification classification = mClassifier.classifyText(request);
184 assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
185 }
186
187 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000188 public void testClassifyText_url_inCaps() {
Tony Mak0be540b2018-11-09 16:58:35 +0000189 if (isTextClassifierDisabled()) return;
190
191 String text = "Visit HTTP://ANDROID.COM for more information";
192 String classifiedText = "HTTP://ANDROID.COM";
193 int startIndex = text.indexOf(classifiedText);
194 int endIndex = startIndex + classifiedText.length();
195 TextClassification.Request request = new TextClassification.Request.Builder(
196 text, startIndex, endIndex)
197 .setDefaultLocales(LOCALES)
198 .build();
199
200 TextClassification classification = mClassifier.classifyText(request);
201 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
Abodunrinwa Tokic33fc772019-02-06 01:17:10 +0000202 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
Tony Mak0be540b2018-11-09 16:58:35 +0000203 }
204
205 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000206 public void testClassifyText_date() {
Tony Mak0be540b2018-11-09 16:58:35 +0000207 if (isTextClassifierDisabled()) return;
208
209 String text = "Let's meet on January 9, 2018.";
210 String classifiedText = "January 9, 2018";
211 int startIndex = text.indexOf(classifiedText);
212 int endIndex = startIndex + classifiedText.length();
213 TextClassification.Request request = new TextClassification.Request.Builder(
214 text, startIndex, endIndex)
215 .setDefaultLocales(LOCALES)
216 .build();
217
218 TextClassification classification = mClassifier.classifyText(request);
219 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
Tony Makfdb35542019-03-22 12:01:50 +0000220 Bundle extras = classification.getExtras();
221 List<Bundle> entities = ExtrasUtils.getEntities(extras);
222 Truth.assertThat(entities).hasSize(1);
223 Bundle entity = entities.get(0);
224 Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
Tony Mak0be540b2018-11-09 16:58:35 +0000225 }
226
227 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000228 public void testClassifyText_datetime() {
Tony Mak0be540b2018-11-09 16:58:35 +0000229 if (isTextClassifierDisabled()) return;
230
231 String text = "Let's meet 2018/01/01 10:30:20.";
232 String classifiedText = "2018/01/01 10:30:20";
233 int startIndex = text.indexOf(classifiedText);
234 int endIndex = startIndex + classifiedText.length();
235 TextClassification.Request request = new TextClassification.Request.Builder(
236 text, startIndex, endIndex)
237 .setDefaultLocales(LOCALES)
238 .build();
239
240 TextClassification classification = mClassifier.classifyText(request);
241 assertThat(classification,
242 isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
243 }
244
245 @Test
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000246 public void testClassifyText_foreignText() {
247 LocaleList originalLocales = LocaleList.getDefault();
248 LocaleList.setDefault(LocaleList.forLanguageTags("en"));
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +0000249 String japaneseText = "これは日本語のテキストです";
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000250
251 Context context = new FakeContextBuilder()
252 .setIntentComponent(Intent.ACTION_TRANSLATE, FakeContextBuilder.DEFAULT_COMPONENT)
253 .build();
254 TextClassifier classifier = new TextClassifierImpl(context, TC_CONSTANTS);
255 TextClassification.Request request = new TextClassification.Request.Builder(
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +0000256 japaneseText, 0, japaneseText.length())
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000257 .setDefaultLocales(LOCALES)
258 .build();
259
260 TextClassification classification = classifier.classifyText(request);
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +0000261 RemoteAction translateAction = classification.getActions().get(0);
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000262 assertEquals(1, classification.getActions().size());
263 assertEquals(
264 context.getString(com.android.internal.R.string.translate),
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +0000265 translateAction.getTitle());
266
267 assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
268 Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
Abodunrinwa Toki385b10c2019-01-23 18:24:08 +0000269 assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
Abodunrinwa Toki520b2f82019-01-27 07:48:02 +0000270 Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
271 assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
272 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
273 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
Tony Mak72e17972019-03-16 10:28:42 +0000274 assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
275 assertEquals("ja", ExtrasUtils.getTopLanguage(intent).getLanguage());
Abodunrinwa Toki229f40a2018-11-27 20:08:00 +0000276
277 LocaleList.setDefault(originalLocales);
278 }
279
280 @Test
Tony Mak0be540b2018-11-09 16:58:35 +0000281 public void testGenerateLinks_phone() {
282 if (isTextClassifierDisabled()) return;
283 String text = "The number is +12122537077. See you tonight!";
284 TextLinks.Request request = new TextLinks.Request.Builder(text).build();
285 assertThat(mClassifier.generateLinks(request),
286 isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
287 }
288
289 @Test
290 public void testGenerateLinks_exclude() {
291 if (isTextClassifierDisabled()) return;
292 String text = "You want apple@banana.com. See you tonight!";
293 List<String> hints = Collections.EMPTY_LIST;
294 List<String> included = Collections.EMPTY_LIST;
295 List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
296 TextLinks.Request request = new TextLinks.Request.Builder(text)
297 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
298 .setDefaultLocales(LOCALES)
299 .build();
300 assertThat(mClassifier.generateLinks(request),
301 not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
302 }
303
304 @Test
305 public void testGenerateLinks_explicit_address() {
306 if (isTextClassifierDisabled()) return;
307 String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
308 List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
309 TextLinks.Request request = new TextLinks.Request.Builder(text)
310 .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
311 .setDefaultLocales(LOCALES)
312 .build();
313 assertThat(mClassifier.generateLinks(request),
314 isTextLinksContaining(text, "1600 Amphitheater Parkway, Mountain View, CA",
315 TextClassifier.TYPE_ADDRESS));
316 }
317
318 @Test
319 public void testGenerateLinks_exclude_override() {
320 if (isTextClassifierDisabled()) return;
321 String text = "You want apple@banana.com. See you tonight!";
322 List<String> hints = Collections.EMPTY_LIST;
323 List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
324 List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
325 TextLinks.Request request = new TextLinks.Request.Builder(text)
326 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
327 .setDefaultLocales(LOCALES)
328 .build();
329 assertThat(mClassifier.generateLinks(request),
330 not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
331 }
332
333 @Test
334 public void testGenerateLinks_maxLength() {
335 if (isTextClassifierDisabled()) return;
336 char[] manySpaces = new char[mClassifier.getMaxGenerateLinksTextLength()];
337 Arrays.fill(manySpaces, ' ');
338 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
339 TextLinks links = mClassifier.generateLinks(request);
340 assertTrue(links.getLinks().isEmpty());
341 }
342
Abodunrinwa Tokiadc19402018-11-22 17:10:25 +0000343 @Test
344 public void testApplyLinks_unsupportedCharacter() {
345 if (isTextClassifierDisabled()) return;
346 Spannable url = new SpannableString("\u202Emoc.diordna.com");
347 TextLinks.Request request = new TextLinks.Request.Builder(url).build();
348 assertEquals(
349 TextLinks.STATUS_UNSUPPORTED_CHARACTER,
350 mClassifier.generateLinks(request).apply(url, 0, null));
351 }
352
353
Tony Mak0be540b2018-11-09 16:58:35 +0000354 @Test(expected = IllegalArgumentException.class)
355 public void testGenerateLinks_tooLong() {
356 if (isTextClassifierDisabled()) {
357 throw new IllegalArgumentException("pass if disabled");
358 }
359 char[] manySpaces = new char[mClassifier.getMaxGenerateLinksTextLength() + 1];
360 Arrays.fill(manySpaces, ' ');
361 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
362 mClassifier.generateLinks(request);
363 }
364
365 @Test
Tony Makb6afd3c2019-04-05 15:45:18 +0100366 public void testGenerateLinks_entityData() {
367 if (isTextClassifierDisabled()) return;
368 String text = "The number is +12122537077.";
369 Bundle extras = new Bundle();
370 ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
371 TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
372
373 TextLinks textLinks = mClassifier.generateLinks(request);
374
375 Truth.assertThat(textLinks.getLinks()).hasSize(1);
376 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
377 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
378 Truth.assertThat(entities).hasSize(1);
379 Bundle entity = entities.get(0);
380 Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
381 }
382
383 @Test
384 public void testGenerateLinks_entityData_disabled() {
385 if (isTextClassifierDisabled()) return;
386 String text = "The number is +12122537077.";
387 TextLinks.Request request = new TextLinks.Request.Builder(text).build();
388
389 TextLinks textLinks = mClassifier.generateLinks(request);
390
391 Truth.assertThat(textLinks.getLinks()).hasSize(1);
392 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
393 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
394 Truth.assertThat(entities).isNull();
395 }
396
397 @Test
Tony Mak0be540b2018-11-09 16:58:35 +0000398 public void testDetectLanguage() {
399 if (isTextClassifierDisabled()) return;
400 String text = "This is English text";
401 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
402 TextLanguage textLanguage = mClassifier.detectLanguage(request);
403 assertThat(textLanguage, isTextLanguage("en"));
404 }
405
406 @Test
407 public void testDetectLanguage_japanese() {
408 if (isTextClassifierDisabled()) return;
409 String text = "これは日本語のテキストです";
410 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
411 TextLanguage textLanguage = mClassifier.detectLanguage(request);
412 assertThat(textLanguage, isTextLanguage("ja"));
413 }
414
Abodunrinwa Toki0f1d77e2019-05-02 21:01:19 +0100415 @Ignore // Doesn't work without a language-based model.
Tony Mak0be540b2018-11-09 16:58:35 +0000416 @Test
Tony Makaa496d02019-04-11 17:38:47 +0100417 public void testSuggestConversationActions_textReplyOnly_maxOne() {
Tony Mak0be540b2018-11-09 16:58:35 +0000418 if (isTextClassifierDisabled()) return;
419 ConversationActions.Message message =
Tony Mak82fa8d92018-12-07 17:37:43 +0000420 new ConversationActions.Message.Builder(
Tony Mak91daa152019-01-24 16:00:28 +0000421 ConversationActions.Message.PERSON_USER_OTHERS)
Tony Makc4359bf2018-12-11 19:38:53 +0800422 .setText("Where are you?")
Tony Mak82fa8d92018-12-07 17:37:43 +0000423 .build();
Tony Makae85aae2019-01-09 15:59:56 +0000424 TextClassifier.EntityConfig typeConfig =
425 new TextClassifier.EntityConfig.Builder().includeTypesFromTextClassifier(false)
Tony Mak0be540b2018-11-09 16:58:35 +0000426 .setIncludedTypes(
Tony Makae85aae2019-01-09 15:59:56 +0000427 Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
Tony Mak0be540b2018-11-09 16:58:35 +0000428 .build();
429 ConversationActions.Request request =
430 new ConversationActions.Request.Builder(Collections.singletonList(message))
Tony Makc4359bf2018-12-11 19:38:53 +0800431 .setMaxSuggestions(1)
Tony Mak0be540b2018-11-09 16:58:35 +0000432 .setTypeConfig(typeConfig)
433 .build();
434
435 ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
Tony Makaa496d02019-04-11 17:38:47 +0100436 Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
437 ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
438 Truth.assertThat(conversationAction.getType()).isEqualTo(
439 ConversationAction.TYPE_TEXT_REPLY);
440 Truth.assertThat(conversationAction.getTextReply()).isNotNull();
Tony Mak128a61d2019-01-28 18:57:26 +0000441 }
Tony Makc4359bf2018-12-11 19:38:53 +0800442
Abodunrinwa Toki0f1d77e2019-05-02 21:01:19 +0100443 @Ignore // Doesn't work without a language-based model.
Tony Makc4359bf2018-12-11 19:38:53 +0800444 @Test
445 public void testSuggestConversationActions_textReplyOnly_noMax() {
446 if (isTextClassifierDisabled()) return;
447 ConversationActions.Message message =
448 new ConversationActions.Message.Builder(
Tony Mak91daa152019-01-24 16:00:28 +0000449 ConversationActions.Message.PERSON_USER_OTHERS)
Tony Makc4359bf2018-12-11 19:38:53 +0800450 .setText("Where are you?")
451 .build();
Tony Makae85aae2019-01-09 15:59:56 +0000452 TextClassifier.EntityConfig typeConfig =
453 new TextClassifier.EntityConfig.Builder().includeTypesFromTextClassifier(false)
Tony Makc4359bf2018-12-11 19:38:53 +0800454 .setIncludedTypes(
Tony Makae85aae2019-01-09 15:59:56 +0000455 Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
Tony Makc4359bf2018-12-11 19:38:53 +0800456 .build();
457 ConversationActions.Request request =
458 new ConversationActions.Request.Builder(Collections.singletonList(message))
459 .setTypeConfig(typeConfig)
460 .build();
461
462 ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
463 assertTrue(conversationActions.getConversationActions().size() > 1);
Tony Makae85aae2019-01-09 15:59:56 +0000464 for (ConversationAction conversationAction :
Tony Makc4359bf2018-12-11 19:38:53 +0800465 conversationActions.getConversationActions()) {
466 assertThat(conversationAction,
Tony Makae85aae2019-01-09 15:59:56 +0000467 isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
Tony Mak0be540b2018-11-09 16:58:35 +0000468 }
469 }
470
Tony Makc12035e2019-02-26 17:45:34 +0000471 @Test
472 public void testSuggestConversationActions_openUrl() {
473 if (isTextClassifierDisabled()) return;
474 ConversationActions.Message message =
475 new ConversationActions.Message.Builder(
476 ConversationActions.Message.PERSON_USER_OTHERS)
477 .setText("Check this out: https://www.android.com")
478 .build();
479 TextClassifier.EntityConfig typeConfig =
480 new TextClassifier.EntityConfig.Builder().includeTypesFromTextClassifier(false)
481 .setIncludedTypes(
482 Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
483 .build();
484 ConversationActions.Request request =
485 new ConversationActions.Request.Builder(Collections.singletonList(message))
486 .setMaxSuggestions(1)
487 .setTypeConfig(typeConfig)
488 .build();
489
490 ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
491 Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
492 ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
493 Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
494 Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
495 Truth.assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
496 Truth.assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
497 }
498
Abodunrinwa Toki0f1d77e2019-05-02 21:01:19 +0100499 @Ignore // Doesn't work without a language-based model.
Tony Mak09214422019-03-01 18:25:23 +0000500 @Test
501 public void testSuggestConversationActions_copy() {
502 if (isTextClassifierDisabled()) return;
503 ConversationActions.Message message =
504 new ConversationActions.Message.Builder(
505 ConversationActions.Message.PERSON_USER_OTHERS)
506 .setText("Authentication code: 12345")
507 .build();
508 TextClassifier.EntityConfig typeConfig =
509 new TextClassifier.EntityConfig.Builder().includeTypesFromTextClassifier(false)
510 .setIncludedTypes(
511 Collections.singletonList(ConversationAction.TYPE_COPY))
512 .build();
513 ConversationActions.Request request =
514 new ConversationActions.Request.Builder(Collections.singletonList(message))
515 .setMaxSuggestions(1)
516 .setTypeConfig(typeConfig)
517 .build();
518
519 ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
520 Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
521 ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
522 Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_COPY);
523 Truth.assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
524 Truth.assertThat(conversationAction.getAction()).isNull();
525 String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
526 Truth.assertThat(code).isEqualTo("12345");
Tony Makfdb35542019-03-22 12:01:50 +0000527 Truth.assertThat(
528 ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
Tony Mak09214422019-03-01 18:25:23 +0000529 }
Tony Mak0be540b2018-11-09 16:58:35 +0000530
Tony Makaa496d02019-04-11 17:38:47 +0100531 @Test
Tony Makfbe85242019-11-27 17:52:46 +0000532 public void testSuggestConversationActions_deduplicate() {
533 Context context = new FakeContextBuilder()
534 .setIntentComponent(Intent.ACTION_SENDTO, FakeContextBuilder.DEFAULT_COMPONENT)
535 .build();
Tony Makaa496d02019-04-11 17:38:47 +0100536 ConversationActions.Message message =
537 new ConversationActions.Message.Builder(
538 ConversationActions.Message.PERSON_USER_OTHERS)
539 .setText("a@android.com b@android.com")
540 .build();
541 ConversationActions.Request request =
542 new ConversationActions.Request.Builder(Collections.singletonList(message))
543 .setMaxSuggestions(3)
544 .build();
545
Tony Makfbe85242019-11-27 17:52:46 +0000546 TextClassifier classifier = new TextClassifierImpl(context, TC_CONSTANTS);
547 ConversationActions conversationActions = classifier.suggestConversationActions(request);
Tony Makaa496d02019-04-11 17:38:47 +0100548
549 Truth.assertThat(conversationActions.getConversationActions()).isEmpty();
550 }
551
Tony Mak0be540b2018-11-09 16:58:35 +0000552 private boolean isTextClassifierDisabled() {
553 return mClassifier == null || mClassifier == TextClassifier.NO_OP;
554 }
555
556 private static Matcher<TextSelection> isTextSelection(
557 final int startIndex, final int endIndex, final String type) {
558 return new BaseMatcher<TextSelection>() {
559 @Override
560 public boolean matches(Object o) {
561 if (o instanceof TextSelection) {
562 TextSelection selection = (TextSelection) o;
563 return startIndex == selection.getSelectionStartIndex()
564 && endIndex == selection.getSelectionEndIndex()
565 && typeMatches(selection, type);
566 }
567 return false;
568 }
569
570 private boolean typeMatches(TextSelection selection, String type) {
571 return type == null
572 || (selection.getEntityCount() > 0
573 && type.trim().equalsIgnoreCase(selection.getEntity(0)));
574 }
575
576 @Override
577 public void describeTo(Description description) {
578 description.appendValue(
579 String.format("%d, %d, %s", startIndex, endIndex, type));
580 }
581 };
582 }
583
584 private static Matcher<TextLinks> isTextLinksContaining(
585 final String text, final String substring, final String type) {
586 return new BaseMatcher<TextLinks>() {
587
588 @Override
589 public void describeTo(Description description) {
590 description.appendText("text=").appendValue(text)
591 .appendText(", substring=").appendValue(substring)
592 .appendText(", type=").appendValue(type);
593 }
594
595 @Override
596 public boolean matches(Object o) {
597 if (o instanceof TextLinks) {
598 for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
599 if (text.subSequence(link.getStart(), link.getEnd()).equals(substring)) {
600 return type.equals(link.getEntity(0));
601 }
602 }
603 }
604 return false;
605 }
606 };
607 }
608
609 private static Matcher<TextClassification> isTextClassification(
610 final String text, final String type) {
611 return new BaseMatcher<TextClassification>() {
612 @Override
613 public boolean matches(Object o) {
614 if (o instanceof TextClassification) {
615 TextClassification result = (TextClassification) o;
616 return text.equals(result.getText())
617 && result.getEntityCount() > 0
618 && type.equals(result.getEntity(0));
619 }
620 return false;
621 }
622
623 @Override
624 public void describeTo(Description description) {
625 description.appendText("text=").appendValue(text)
626 .appendText(", type=").appendValue(type);
627 }
628 };
629 }
630
Abodunrinwa Tokic33fc772019-02-06 01:17:10 +0000631 private static Matcher<TextClassification> containsIntentWithAction(final String action) {
632 return new BaseMatcher<TextClassification>() {
633 @Override
634 public boolean matches(Object o) {
635 if (o instanceof TextClassification) {
636 TextClassification result = (TextClassification) o;
637 return ExtrasUtils.findAction(result, action) != null;
638 }
639 return false;
640 }
641
642 @Override
643 public void describeTo(Description description) {
644 description.appendText("intent action=").appendValue(action);
645 }
646 };
647 }
648
Tony Mak0be540b2018-11-09 16:58:35 +0000649 private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
650 return new BaseMatcher<TextLanguage>() {
651 @Override
652 public boolean matches(Object o) {
653 if (o instanceof TextLanguage) {
654 TextLanguage result = (TextLanguage) o;
655 return result.getLocaleHypothesisCount() > 0
656 && languageTag.equals(result.getLocale(0).toLanguageTag());
657 }
658 return false;
659 }
660
661 @Override
662 public void describeTo(Description description) {
663 description.appendText("locale=").appendValue(languageTag);
664 }
665 };
666 }
Tony Makc4359bf2018-12-11 19:38:53 +0800667
Tony Makae85aae2019-01-09 15:59:56 +0000668 private static Matcher<ConversationAction> isConversationAction(String actionType) {
669 return new BaseMatcher<ConversationAction>() {
Tony Makc4359bf2018-12-11 19:38:53 +0800670 @Override
671 public boolean matches(Object o) {
Tony Makae85aae2019-01-09 15:59:56 +0000672 if (!(o instanceof ConversationAction)) {
Tony Makc4359bf2018-12-11 19:38:53 +0800673 return false;
674 }
Tony Makae85aae2019-01-09 15:59:56 +0000675 ConversationAction conversationAction =
676 (ConversationAction) o;
Tony Makc4359bf2018-12-11 19:38:53 +0800677 if (!actionType.equals(conversationAction.getType())) {
678 return false;
679 }
Tony Makae85aae2019-01-09 15:59:56 +0000680 if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
Tony Makc4359bf2018-12-11 19:38:53 +0800681 if (conversationAction.getTextReply() == null) {
682 return false;
683 }
684 }
685 if (conversationAction.getConfidenceScore() < 0
686 || conversationAction.getConfidenceScore() > 1) {
687 return false;
688 }
689 return true;
690 }
691
692 @Override
693 public void describeTo(Description description) {
694 description.appendText("actionType=").appendValue(actionType);
695 }
696 };
697 }
Tony Mak0be540b2018-11-09 16:58:35 +0000698}