blob: 440cedfb3f438e7efefee1f2f36a583ac4659b3a [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "text-classifier.h"
18
19#include <fstream>
20#include <iostream>
21#include <memory>
22#include <string>
23
Lukas Zilkab23e2122018-02-09 10:25:19 +010024#include "model_generated.h"
25#include "types-test-util.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010026#include "gmock/gmock.h"
27#include "gtest/gtest.h"
28
29namespace libtextclassifier2 {
30namespace {
31
32using testing::ElementsAreArray;
Lukas Zilkaba849e72018-03-08 14:48:21 +010033using testing::IsEmpty;
Lukas Zilka21d8c982018-01-24 11:11:20 +010034using testing::Pair;
Lukas Zilkab23e2122018-02-09 10:25:19 +010035using testing::Values;
Lukas Zilka21d8c982018-01-24 11:11:20 +010036
Lukas Zilkab23e2122018-02-09 10:25:19 +010037std::string FirstResult(const std::vector<ClassificationResult>& results) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010038 if (results.empty()) {
39 return "<INVALID RESULTS>";
40 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010041 return results[0].collection;
Lukas Zilka21d8c982018-01-24 11:11:20 +010042}
43
44MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
45 return testing::Value(arg.span, Pair(start, end)) &&
46 testing::Value(FirstResult(arg.classification), best_class);
47}
48
49std::string ReadFile(const std::string& file_name) {
50 std::ifstream file_stream(file_name);
51 return std::string(std::istreambuf_iterator<char>(file_stream), {});
52}
53
54std::string GetModelPath() {
55 return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
56}
57
58TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
Lukas Zilkab23e2122018-02-09 10:25:19 +010059 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +010060 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +010061 TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +010062 EXPECT_FALSE(classifier);
63}
64
Lukas Zilkab23e2122018-02-09 10:25:19 +010065class TextClassifierTest : public ::testing::TestWithParam<const char*> {};
66
67INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest,
68 Values("test_model_cc.fb"));
69INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest,
70 Values("test_model.fb"));
71
72TEST_P(TextClassifierTest, ClassifyText) {
73 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +010074 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +010075 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +010076 ASSERT_TRUE(classifier);
77
78 EXPECT_EQ("other",
79 FirstResult(classifier->ClassifyText(
80 "this afternoon Barack Obama gave a speech at", {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010081 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
82 "Call me at (800) 123-456 today", {11, 24})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010083
84 // More lines.
85 EXPECT_EQ("other",
86 FirstResult(classifier->ClassifyText(
87 "this afternoon Barack Obama gave a speech at|Visit "
88 "www.google.com every today!|Call me at (800) 123-456 today.",
89 {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010090 EXPECT_EQ("phone",
91 FirstResult(classifier->ClassifyText(
92 "this afternoon Barack Obama gave a speech at|Visit "
93 "www.google.com every today!|Call me at (800) 123-456 today.",
94 {90, 103})));
95
96 // Single word.
97 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
98 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
99 EXPECT_EQ("<INVALID RESULTS>",
100 FirstResult(classifier->ClassifyText("asdf", {0, 0})));
101
102 // Junk.
103 EXPECT_EQ("<INVALID RESULTS>",
104 FirstResult(classifier->ClassifyText("", {0, 0})));
105 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
106 "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200107 // Test invalid utf8 input.
108 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
109 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100110}
111
Lukas Zilkaba849e72018-03-08 14:48:21 +0100112TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
113 CREATE_UNILIB_FOR_TESTING;
114 const std::string test_model = ReadFile(GetModelPath() + GetParam());
115 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
116
117 unpacked_model->classification_model.clear();
118 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
119 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
120
121 flatbuffers::FlatBufferBuilder builder;
122 builder.Finish(Model::Pack(builder, unpacked_model.get()));
123
124 std::unique_ptr<TextClassifier> classifier =
125 TextClassifier::FromUnownedBuffer(
126 reinterpret_cast<const char*>(builder.GetBufferPointer()),
127 builder.GetSize(), &unilib);
128
129 // The classification model is still needed for selection scores.
130 ASSERT_FALSE(classifier);
131}
132
133TEST_P(TextClassifierTest, ClassifyTextDisabled) {
134 CREATE_UNILIB_FOR_TESTING;
135 const std::string test_model = ReadFile(GetModelPath() + GetParam());
136 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
137
138 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
139 unpacked_model->triggering_options->enabled_modes =
140 ModeFlag_ANNOTATION_AND_SELECTION;
141
142 flatbuffers::FlatBufferBuilder builder;
143 builder.Finish(Model::Pack(builder, unpacked_model.get()));
144
145 std::unique_ptr<TextClassifier> classifier =
146 TextClassifier::FromUnownedBuffer(
147 reinterpret_cast<const char*>(builder.GetBufferPointer()),
148 builder.GetSize(), &unilib);
149 ASSERT_TRUE(classifier);
150
151 EXPECT_THAT(
152 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
153 IsEmpty());
154}
155
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200156TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) {
157 CREATE_UNILIB_FOR_TESTING;
158 const std::string test_model = ReadFile(GetModelPath() + GetParam());
159
160 std::unique_ptr<TextClassifier> classifier =
161 TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
162 &unilib);
163 ASSERT_TRUE(classifier);
164
165 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
166 "Call me at (800) 123-456 today", {11, 24})));
167
168 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
169 unpacked_model->output_options.reset(new OutputOptionsT);
170
171 // Disable phone classification
172 unpacked_model->output_options->filtered_collections_classification.push_back(
173 "phone");
174
175 flatbuffers::FlatBufferBuilder builder;
176 builder.Finish(Model::Pack(builder, unpacked_model.get()));
177
178 classifier = TextClassifier::FromUnownedBuffer(
179 reinterpret_cast<const char*>(builder.GetBufferPointer()),
180 builder.GetSize(), &unilib);
181 ASSERT_TRUE(classifier);
182
183 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
184 "Call me at (800) 123-456 today", {11, 24})));
185
186 // Check that the address classification still passes.
187 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
188 "350 Third Street, Cambridge", {0, 27})));
189}
190
Lukas Zilkab23e2122018-02-09 10:25:19 +0100191std::unique_ptr<RegexModel_::PatternT> MakePattern(
192 const std::string& collection_name, const std::string& pattern,
193 const bool enabled_for_classification, const bool enabled_for_selection,
194 const bool enabled_for_annotation, const float score) {
195 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
196 result->collection_name = collection_name;
197 result->pattern = pattern;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100198 // We cannot directly operate with |= on the flag, so use an int here.
199 int enabled_modes = ModeFlag_NONE;
200 if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
201 if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
202 if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
203 result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100204 result->target_classification_score = score;
205 result->priority_score = score;
206 return result;
207}
208
209#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
210TEST_P(TextClassifierTest, ClassifyTextRegularExpression) {
211 CREATE_UNILIB_FOR_TESTING;
212 const std::string test_model = ReadFile(GetModelPath() + GetParam());
213 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
214
215 // Add test regex models.
216 unpacked_model->regex_model->patterns.push_back(MakePattern(
217 "person", "Barack Obama", /*enabled_for_classification=*/true,
218 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
219 unpacked_model->regex_model->patterns.push_back(MakePattern(
220 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
221 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
222
223 flatbuffers::FlatBufferBuilder builder;
224 builder.Finish(Model::Pack(builder, unpacked_model.get()));
225
Lukas Zilka21d8c982018-01-24 11:11:20 +0100226 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100227 TextClassifier::FromUnownedBuffer(
228 reinterpret_cast<const char*>(builder.GetBufferPointer()),
229 builder.GetSize(), &unilib);
230 ASSERT_TRUE(classifier);
231
232 EXPECT_EQ("flight",
233 FirstResult(classifier->ClassifyText(
234 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
235 EXPECT_EQ("person",
236 FirstResult(classifier->ClassifyText(
237 "this afternoon Barack Obama gave a speech at", {15, 27})));
238 EXPECT_EQ("email",
239 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
240 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
241 "Contact me at you@android.com", {14, 29})));
242
243 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
244 "Visit www.google.com every today!", {6, 20})));
245
246 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
247 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
248 {7, 12})));
249
250 // More lines.
251 EXPECT_EQ("url",
252 FirstResult(classifier->ClassifyText(
253 "this afternoon Barack Obama gave a speech at|Visit "
254 "www.google.com every today!|Call me at (800) 123-456 today.",
255 {51, 65})));
256}
257#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
258
259#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100260TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
261 CREATE_UNILIB_FOR_TESTING;
262 const std::string test_model = ReadFile(GetModelPath() + GetParam());
263 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
264
265 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100266 unpacked_model->regex_model->patterns.push_back(MakePattern(
267 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
268 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
269 unpacked_model->regex_model->patterns.push_back(MakePattern(
270 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
271 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
272 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
273
274 flatbuffers::FlatBufferBuilder builder;
275 builder.Finish(Model::Pack(builder, unpacked_model.get()));
276
277 std::unique_ptr<TextClassifier> classifier =
278 TextClassifier::FromUnownedBuffer(
279 reinterpret_cast<const char*>(builder.GetBufferPointer()),
280 builder.GetSize(), &unilib);
281 ASSERT_TRUE(classifier);
282
283 // Check regular expression selection.
284 EXPECT_EQ(classifier->SuggestSelection(
285 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
286 std::make_pair(12, 19));
287 EXPECT_EQ(classifier->SuggestSelection(
288 "this afternoon Barack Obama gave a speech at", {15, 21}),
289 std::make_pair(15, 27));
290}
291#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
292
293#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
294TEST_P(TextClassifierTest,
295 SuggestSelectionRegularExpressionConflictsModelWins) {
296 const std::string test_model = ReadFile(GetModelPath() + GetParam());
297 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
298
299 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100300 unpacked_model->regex_model->patterns.push_back(MakePattern(
301 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
302 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
303 unpacked_model->regex_model->patterns.push_back(MakePattern(
304 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
305 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
306 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
307
308 flatbuffers::FlatBufferBuilder builder;
309 builder.Finish(Model::Pack(builder, unpacked_model.get()));
310
311 std::unique_ptr<TextClassifier> classifier =
312 TextClassifier::FromUnownedBuffer(
313 reinterpret_cast<const char*>(builder.GetBufferPointer()),
314 builder.GetSize());
315 ASSERT_TRUE(classifier);
316
317 // Check conflict resolution.
318 EXPECT_EQ(
319 classifier->SuggestSelection(
320 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
321 {55, 57}),
322 std::make_pair(26, 62));
323}
324#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
325
326#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
327TEST_P(TextClassifierTest,
328 SuggestSelectionRegularExpressionConflictsRegexWins) {
329 const std::string test_model = ReadFile(GetModelPath() + GetParam());
330 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
331
332 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100333 unpacked_model->regex_model->patterns.push_back(MakePattern(
334 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
335 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
336 unpacked_model->regex_model->patterns.push_back(MakePattern(
337 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
338 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
339 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
340
341 flatbuffers::FlatBufferBuilder builder;
342 builder.Finish(Model::Pack(builder, unpacked_model.get()));
343
344 std::unique_ptr<TextClassifier> classifier =
345 TextClassifier::FromUnownedBuffer(
346 reinterpret_cast<const char*>(builder.GetBufferPointer()),
347 builder.GetSize());
348 ASSERT_TRUE(classifier);
349
350 // Check conflict resolution.
351 EXPECT_EQ(
352 classifier->SuggestSelection(
353 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
354 {55, 57}),
355 std::make_pair(55, 62));
356}
357#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
358
359#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
360TEST_P(TextClassifierTest, AnnotateRegex) {
361 CREATE_UNILIB_FOR_TESTING;
362 const std::string test_model = ReadFile(GetModelPath() + GetParam());
363 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
364
365 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100366 unpacked_model->regex_model->patterns.push_back(MakePattern(
367 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
368 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
369 unpacked_model->regex_model->patterns.push_back(MakePattern(
370 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
371 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
372 flatbuffers::FlatBufferBuilder builder;
373 builder.Finish(Model::Pack(builder, unpacked_model.get()));
374
375 std::unique_ptr<TextClassifier> classifier =
376 TextClassifier::FromUnownedBuffer(
377 reinterpret_cast<const char*>(builder.GetBufferPointer()),
378 builder.GetSize(), &unilib);
379 ASSERT_TRUE(classifier);
380
381 const std::string test_string =
382 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
383 "number is 853 225 3556";
384 EXPECT_THAT(classifier->Annotate(test_string),
385 ElementsAreArray({
386 IsAnnotatedSpan(6, 18, "person"),
387 IsAnnotatedSpan(19, 24, "date"),
388 IsAnnotatedSpan(28, 55, "address"),
389 IsAnnotatedSpan(79, 91, "phone"),
390 }));
391}
Lukas Zilkab23e2122018-02-09 10:25:19 +0100392#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
393
394TEST_P(TextClassifierTest, PhoneFiltering) {
395 CREATE_UNILIB_FOR_TESTING;
396 std::unique_ptr<TextClassifier> classifier =
397 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100398 ASSERT_TRUE(classifier);
399
400 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
401 "phone: (123) 456 789", {7, 20})));
402 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
403 "phone: (123) 456 789,0001112", {7, 25})));
404 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
405 "phone: (123) 456 789,0001112", {7, 28})));
406}
407
Lukas Zilkab23e2122018-02-09 10:25:19 +0100408TEST_P(TextClassifierTest, SuggestSelection) {
409 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100410 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100411 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100412 ASSERT_TRUE(classifier);
413
414 EXPECT_EQ(classifier->SuggestSelection(
415 "this afternoon Barack Obama gave a speech at", {15, 21}),
416 std::make_pair(15, 21));
417
418 // Try passing whole string.
419 // If more than 1 token is specified, we should return back what entered.
420 EXPECT_EQ(
421 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
422 std::make_pair(0, 27));
423
424 // Single letter.
425 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
426
427 // Single word.
428 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
429
430 EXPECT_EQ(
431 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
432 std::make_pair(11, 23));
433
434 // Unpaired bracket stripping.
435 EXPECT_EQ(
436 classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
437 std::make_pair(11, 25));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100438 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
439 std::make_pair(12, 15));
440 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
441 std::make_pair(11, 15));
442 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
443 std::make_pair(12, 15));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100444
445 // If the resulting selection would be empty, the original span is returned.
446 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
447 std::make_pair(11, 13));
448 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
449 std::make_pair(11, 12));
450 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
451 std::make_pair(11, 12));
452}
453
Lukas Zilkaba849e72018-03-08 14:48:21 +0100454TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) {
455 CREATE_UNILIB_FOR_TESTING;
456 const std::string test_model = ReadFile(GetModelPath() + GetParam());
457 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
458
459 // Disable the selection model.
460 unpacked_model->selection_model.clear();
461 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
462 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
463
464 flatbuffers::FlatBufferBuilder builder;
465 builder.Finish(Model::Pack(builder, unpacked_model.get()));
466
467 std::unique_ptr<TextClassifier> classifier =
468 TextClassifier::FromUnownedBuffer(
469 reinterpret_cast<const char*>(builder.GetBufferPointer()),
470 builder.GetSize(), &unilib);
471 // Selection model needs to be present for annotation.
472 ASSERT_FALSE(classifier);
473}
474
475TEST_P(TextClassifierTest, SuggestSelectionDisabled) {
476 CREATE_UNILIB_FOR_TESTING;
477 const std::string test_model = ReadFile(GetModelPath() + GetParam());
478 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
479
480 // Disable the selection model.
481 unpacked_model->selection_model.clear();
482 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
483 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
484 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
485
486 flatbuffers::FlatBufferBuilder builder;
487 builder.Finish(Model::Pack(builder, unpacked_model.get()));
488
489 std::unique_ptr<TextClassifier> classifier =
490 TextClassifier::FromUnownedBuffer(
491 reinterpret_cast<const char*>(builder.GetBufferPointer()),
492 builder.GetSize(), &unilib);
493 ASSERT_TRUE(classifier);
494
495 EXPECT_EQ(
496 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
497 std::make_pair(11, 14));
498
499 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
500 "call me at (800) 123-456 today", {11, 24})));
501
502 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
503 IsEmpty());
504}
505
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200506TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) {
507 CREATE_UNILIB_FOR_TESTING;
508 const std::string test_model = ReadFile(GetModelPath() + GetParam());
509
510 std::unique_ptr<TextClassifier> classifier =
511 TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
512 &unilib);
513 ASSERT_TRUE(classifier);
514
515 EXPECT_EQ(
516 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
517 std::make_pair(11, 23));
518
519 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
520 unpacked_model->output_options.reset(new OutputOptionsT);
521
522 // Disable phone selection
523 unpacked_model->output_options->filtered_collections_selection.push_back(
524 "phone");
525 // We need to force this for filtering.
526 unpacked_model->selection_options->always_classify_suggested_selection = true;
527
528 flatbuffers::FlatBufferBuilder builder;
529 builder.Finish(Model::Pack(builder, unpacked_model.get()));
530
531 classifier = TextClassifier::FromUnownedBuffer(
532 reinterpret_cast<const char*>(builder.GetBufferPointer()),
533 builder.GetSize(), &unilib);
534 ASSERT_TRUE(classifier);
535
536 EXPECT_EQ(
537 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
538 std::make_pair(11, 14));
539
540 // Address selection should still work.
541 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
542 std::make_pair(0, 27));
543}
544
Lukas Zilkab23e2122018-02-09 10:25:19 +0100545TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
546 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100547 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100548 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100549 ASSERT_TRUE(classifier);
550
551 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
552 std::make_pair(0, 27));
553 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
554 std::make_pair(0, 27));
555 EXPECT_EQ(
556 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
557 std::make_pair(0, 27));
558 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
559 {16, 22}),
560 std::make_pair(6, 33));
561}
562
Lukas Zilkab23e2122018-02-09 10:25:19 +0100563TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) {
564 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100565 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100566 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100567 ASSERT_TRUE(classifier);
568
569 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
570 std::make_pair(4, 16));
571 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
572 std::make_pair(0, 12));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100573
574 SelectionOptions options;
575 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
576 std::make_pair(0, 7));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100577}
578
Lukas Zilkab23e2122018-02-09 10:25:19 +0100579TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) {
580 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100581 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100582 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100583 ASSERT_TRUE(classifier);
584
585 // From the right.
586 EXPECT_EQ(classifier->SuggestSelection(
587 "this afternoon BarackObama, gave a speech at", {15, 26}),
588 std::make_pair(15, 26));
589
590 // From the right multiple.
591 EXPECT_EQ(classifier->SuggestSelection(
592 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
593 std::make_pair(15, 26));
594
595 // From the left multiple.
596 EXPECT_EQ(classifier->SuggestSelection(
597 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
598 std::make_pair(21, 32));
599
600 // From both sides.
601 EXPECT_EQ(classifier->SuggestSelection(
602 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
603 std::make_pair(16, 27));
604}
605
Lukas Zilkab23e2122018-02-09 10:25:19 +0100606TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
607 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100608 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100609 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100610 ASSERT_TRUE(classifier);
611
612 // Try passing in bunch of invalid selections.
613 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
614 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
615 std::make_pair(-10, 27));
616 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
617 std::make_pair(0, 27));
618 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
619 std::make_pair(-30, 300));
620 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
621 std::make_pair(-10, -1));
622 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
623 std::make_pair(100, 17));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200624
625 // Try passing invalid utf8.
626 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
627 std::make_pair(-1, -1));
628}
629
630TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) {
631 CREATE_UNILIB_FOR_TESTING;
632 std::unique_ptr<TextClassifier> classifier =
633 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
634 ASSERT_TRUE(classifier);
635
636 EXPECT_EQ(
637 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
638 std::make_pair(11, 23));
639 EXPECT_EQ(
640 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
641 std::make_pair(10, 11));
642 EXPECT_EQ(
643 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
644 std::make_pair(23, 24));
645 EXPECT_EQ(
646 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
647 std::make_pair(23, 24));
648 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
649 {14, 17}),
650 std::make_pair(11, 25));
651 EXPECT_EQ(
652 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
653 std::make_pair(11, 23));
654 EXPECT_EQ(
655 classifier->SuggestSelection(
656 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
657 std::make_pair(14, 40));
658 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
659 std::make_pair(4, 5));
660 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
661 std::make_pair(7, 8));
662
663 // With a punctuation around the selected whitespace.
664 EXPECT_EQ(
665 classifier->SuggestSelection(
666 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
667 std::make_pair(14, 41));
668
669 // When all's whitespace, should return the original indices.
670 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
671 std::make_pair(0, 1));
672 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
673 std::make_pair(0, 3));
674 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
675 std::make_pair(2, 3));
676 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
677 std::make_pair(5, 6));
678}
679
680TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) {
681 CREATE_UNILIB_FOR_TESTING;
682 UnicodeText text;
683
684 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
685 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
686 std::make_pair(3, 4));
687 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
688 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
689 std::make_pair(3, 4));
690
691 // Nothing on the left.
692 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
693 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
694 std::make_pair(4, 5));
695 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
696 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
697 std::make_pair(0, 1));
698
699 // Whitespace only.
700 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
701 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib),
702 std::make_pair(2, 3));
703 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
704 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
705 std::make_pair(4, 5));
706 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
707 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
708 std::make_pair(0, 1));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100709}
710
Lukas Zilkab23e2122018-02-09 10:25:19 +0100711TEST_P(TextClassifierTest, Annotate) {
712 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100713 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100714 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100715 ASSERT_TRUE(classifier);
716
717 const std::string test_string =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100718 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
719 "number is 853 225 3556";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100720 EXPECT_THAT(classifier->Annotate(test_string),
721 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100722#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
723 IsAnnotatedSpan(19, 24, "date"),
724#endif
725 IsAnnotatedSpan(28, 55, "address"),
726 IsAnnotatedSpan(79, 91, "phone"),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100727 }));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100728
729 AnnotationOptions options;
730 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
731 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
732 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200733
734 // Try passing invalid utf8.
735 EXPECT_TRUE(
736 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
737 .empty());
Lukas Zilka21d8c982018-01-24 11:11:20 +0100738}
739
Lukas Zilkab23e2122018-02-09 10:25:19 +0100740TEST_P(TextClassifierTest, AnnotateSmallBatches) {
741 CREATE_UNILIB_FOR_TESTING;
742 const std::string test_model = ReadFile(GetModelPath() + GetParam());
743 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
744
745 // Set the batch size.
746 unpacked_model->selection_options->batch_size = 4;
747 flatbuffers::FlatBufferBuilder builder;
748 builder.Finish(Model::Pack(builder, unpacked_model.get()));
749
750 std::unique_ptr<TextClassifier> classifier =
751 TextClassifier::FromUnownedBuffer(
752 reinterpret_cast<const char*>(builder.GetBufferPointer()),
753 builder.GetSize(), &unilib);
754 ASSERT_TRUE(classifier);
755
756 const std::string test_string =
757 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
758 "number is 853 225 3556";
759 EXPECT_THAT(classifier->Annotate(test_string),
760 ElementsAreArray({
761#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
762 IsAnnotatedSpan(19, 24, "date"),
763#endif
764 IsAnnotatedSpan(28, 55, "address"),
765 IsAnnotatedSpan(79, 91, "phone"),
766 }));
767
768 AnnotationOptions options;
769 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
770 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
771 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
772}
773
Lukas Zilkaba849e72018-03-08 14:48:21 +0100774#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100775TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
776 CREATE_UNILIB_FOR_TESTING;
777 const std::string test_model = ReadFile(GetModelPath() + GetParam());
778 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
779
Lukas Zilkab23e2122018-02-09 10:25:19 +0100780 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100781 // Add test threshold.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100782 unpacked_model->triggering_options->min_annotate_confidence =
783 2.f; // Discards all results.
784 flatbuffers::FlatBufferBuilder builder;
785 builder.Finish(Model::Pack(builder, unpacked_model.get()));
786
787 std::unique_ptr<TextClassifier> classifier =
788 TextClassifier::FromUnownedBuffer(
789 reinterpret_cast<const char*>(builder.GetBufferPointer()),
790 builder.GetSize(), &unilib);
791 ASSERT_TRUE(classifier);
792
793 const std::string test_string =
794 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
795 "number is 853 225 3556";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100796
797 EXPECT_EQ(classifier->Annotate(test_string).size(), 1);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100798}
Lukas Zilkaba849e72018-03-08 14:48:21 +0100799#endif
Lukas Zilkab23e2122018-02-09 10:25:19 +0100800
801TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
802 CREATE_UNILIB_FOR_TESTING;
803 const std::string test_model = ReadFile(GetModelPath() + GetParam());
804 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
805
806 // Add test thresholds.
807 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
808 unpacked_model->triggering_options->min_annotate_confidence =
809 0.f; // Keeps all results.
Lukas Zilkaba849e72018-03-08 14:48:21 +0100810 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100811 flatbuffers::FlatBufferBuilder builder;
812 builder.Finish(Model::Pack(builder, unpacked_model.get()));
813
814 std::unique_ptr<TextClassifier> classifier =
815 TextClassifier::FromUnownedBuffer(
816 reinterpret_cast<const char*>(builder.GetBufferPointer()),
817 builder.GetSize(), &unilib);
818 ASSERT_TRUE(classifier);
819
820 const std::string test_string =
821 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
822 "number is 853 225 3556";
823#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
824 EXPECT_EQ(classifier->Annotate(test_string).size(), 3);
825#else
826 // In non-ICU mode there is no "date" result.
827 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
828#endif
829}
830
Lukas Zilkaba849e72018-03-08 14:48:21 +0100831TEST_P(TextClassifierTest, AnnotateDisabled) {
832 CREATE_UNILIB_FOR_TESTING;
833 const std::string test_model = ReadFile(GetModelPath() + GetParam());
834 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
835
836 // Disable the model for annotation.
837 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
838 flatbuffers::FlatBufferBuilder builder;
839 builder.Finish(Model::Pack(builder, unpacked_model.get()));
840
841 std::unique_ptr<TextClassifier> classifier =
842 TextClassifier::FromUnownedBuffer(
843 reinterpret_cast<const char*>(builder.GetBufferPointer()),
844 builder.GetSize(), &unilib);
845 ASSERT_TRUE(classifier);
846 const std::string test_string =
847 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
848 "number is 853 225 3556";
849 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
850}
851
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200852TEST_P(TextClassifierTest, AnnotateFilteredCollections) {
853 CREATE_UNILIB_FOR_TESTING;
854 const std::string test_model = ReadFile(GetModelPath() + GetParam());
855
856 std::unique_ptr<TextClassifier> classifier =
857 TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
858 &unilib);
859 ASSERT_TRUE(classifier);
860
861 const std::string test_string =
862 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
863 "number is 853 225 3556";
864
865 EXPECT_THAT(classifier->Annotate(test_string),
866 ElementsAreArray({
867#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
868 IsAnnotatedSpan(19, 24, "date"),
869#endif
870 IsAnnotatedSpan(28, 55, "address"),
871 IsAnnotatedSpan(79, 91, "phone"),
872 }));
873
874 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
875 unpacked_model->output_options.reset(new OutputOptionsT);
876
877 // Disable phone annotation
878 unpacked_model->output_options->filtered_collections_annotation.push_back(
879 "phone");
880
881 flatbuffers::FlatBufferBuilder builder;
882 builder.Finish(Model::Pack(builder, unpacked_model.get()));
883
884 classifier = TextClassifier::FromUnownedBuffer(
885 reinterpret_cast<const char*>(builder.GetBufferPointer()),
886 builder.GetSize(), &unilib);
887 ASSERT_TRUE(classifier);
888
889 EXPECT_THAT(classifier->Annotate(test_string),
890 ElementsAreArray({
891#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
892 IsAnnotatedSpan(19, 24, "date"),
893#endif
894 IsAnnotatedSpan(28, 55, "address"),
895 }));
896}
897
898#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
899TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) {
900 CREATE_UNILIB_FOR_TESTING;
901 const std::string test_model = ReadFile(GetModelPath() + GetParam());
902
903 std::unique_ptr<TextClassifier> classifier =
904 TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
905 &unilib);
906 ASSERT_TRUE(classifier);
907
908 const std::string test_string =
909 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
910 "number is 853 225 3556";
911
912 EXPECT_THAT(classifier->Annotate(test_string),
913 ElementsAreArray({
914#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
915 IsAnnotatedSpan(19, 24, "date"),
916#endif
917 IsAnnotatedSpan(28, 55, "address"),
918 IsAnnotatedSpan(79, 91, "phone"),
919 }));
920
921 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
922 unpacked_model->output_options.reset(new OutputOptionsT);
923
924 // We add a custom annotator that wins against the phone classification
925 // below and that we subsequently suppress.
926 unpacked_model->output_options->filtered_collections_annotation.push_back(
927 "suppress");
928
929 unpacked_model->regex_model->patterns.push_back(MakePattern(
930 "suppress", "(\\d{3} ?\\d{4})",
931 /*enabled_for_classification=*/false,
932 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
933
934 flatbuffers::FlatBufferBuilder builder;
935 builder.Finish(Model::Pack(builder, unpacked_model.get()));
936
937 classifier = TextClassifier::FromUnownedBuffer(
938 reinterpret_cast<const char*>(builder.GetBufferPointer()),
939 builder.GetSize(), &unilib);
940 ASSERT_TRUE(classifier);
941
942 EXPECT_THAT(classifier->Annotate(test_string),
943 ElementsAreArray({
944 IsAnnotatedSpan(19, 24, "date"),
945 IsAnnotatedSpan(28, 55, "address"),
946 }));
947}
948#endif
949
Lukas Zilkab23e2122018-02-09 10:25:19 +0100950#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
951TEST_P(TextClassifierTest, ClassifyTextDate) {
952 std::unique_ptr<TextClassifier> classifier =
953 TextClassifier::FromPath(GetModelPath() + GetParam());
954 EXPECT_TRUE(classifier);
955
956 std::vector<ClassificationResult> result;
957 ClassificationOptions options;
958
959 options.reference_timezone = "Europe/Zurich";
960 result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
961
962 ASSERT_EQ(result.size(), 1);
963 EXPECT_THAT(result[0].collection, "date");
964 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
965 EXPECT_EQ(result[0].datetime_parse_result.granularity,
966 DatetimeGranularity::GRANULARITY_DAY);
967 result.clear();
968
969 options.reference_timezone = "America/Los_Angeles";
970 result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
971 ASSERT_EQ(result.size(), 1);
972 EXPECT_THAT(result[0].collection, "date");
973 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
974 EXPECT_EQ(result[0].datetime_parse_result.granularity,
975 DatetimeGranularity::GRANULARITY_DAY);
976 result.clear();
977
978 options.reference_timezone = "America/Los_Angeles";
979 result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
980 ASSERT_EQ(result.size(), 1);
981 EXPECT_THAT(result[0].collection, "date");
982 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
983 EXPECT_EQ(result[0].datetime_parse_result.granularity,
984 DatetimeGranularity::GRANULARITY_SECOND);
985 result.clear();
986
987 // Date on another line.
988 options.reference_timezone = "Europe/Zurich";
989 result = classifier->ClassifyText(
990 "hello world this is the first line\n"
991 "january 1, 2017",
992 {35, 50}, options);
993 ASSERT_EQ(result.size(), 1);
994 EXPECT_THAT(result[0].collection, "date");
995 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
996 EXPECT_EQ(result[0].datetime_parse_result.granularity,
997 DatetimeGranularity::GRANULARITY_DAY);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100998}
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200999#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001000
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001001#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
1002TEST_P(TextClassifierTest, ClassifyTextDatePriorities) {
1003 std::unique_ptr<TextClassifier> classifier =
1004 TextClassifier::FromPath(GetModelPath() + GetParam());
1005 EXPECT_TRUE(classifier);
1006
1007 std::vector<ClassificationResult> result;
1008 ClassificationOptions options;
1009
1010 result.clear();
1011 options.reference_timezone = "Europe/Zurich";
1012 options.locales = "en-US";
1013 result = classifier->ClassifyText("03/05", {0, 5}, options);
1014
1015 ASSERT_EQ(result.size(), 1);
1016 EXPECT_THAT(result[0].collection, "date");
1017 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
1018 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1019 DatetimeGranularity::GRANULARITY_DAY);
1020
1021 result.clear();
1022 options.reference_timezone = "Europe/Zurich";
1023 options.locales = "en-GB,en-US";
1024 result = classifier->ClassifyText("03/05", {0, 5}, options);
1025
1026 ASSERT_EQ(result.size(), 1);
1027 EXPECT_THAT(result[0].collection, "date");
1028 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
1029 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1030 DatetimeGranularity::GRANULARITY_DAY);
1031}
1032#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
1033
1034#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001035TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
1036 CREATE_UNILIB_FOR_TESTING;
1037 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1038 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1039
1040 // Disable the patterns for selection.
1041 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1042 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1043 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1044 }
1045 flatbuffers::FlatBufferBuilder builder;
1046 builder.Finish(Model::Pack(builder, unpacked_model.get()));
1047
1048 std::unique_ptr<TextClassifier> classifier =
1049 TextClassifier::FromUnownedBuffer(
1050 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1051 builder.GetSize(), &unilib);
1052 ASSERT_TRUE(classifier);
1053 EXPECT_EQ("date",
1054 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1055 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1056 std::make_pair(0, 7));
1057 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1058 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1059}
Lukas Zilkab23e2122018-02-09 10:25:19 +01001060#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
1061
1062class TestingTextClassifier : public TextClassifier {
1063 public:
1064 TestingTextClassifier(const std::string& model, const UniLib* unilib)
1065 : TextClassifier(ViewModel(model.data(), model.size()), unilib) {}
1066
1067 using TextClassifier::ResolveConflicts;
1068};
1069
1070AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
1071 const std::string& collection,
1072 const float score) {
1073 AnnotatedSpan result;
1074 result.span = span;
1075 result.classification.push_back({collection, score});
1076 return result;
1077}
1078
1079TEST(TextClassifierTest, ResolveConflictsTrivial) {
1080 CREATE_UNILIB_FOR_TESTING;
1081 TestingTextClassifier classifier("", &unilib);
1082
1083 std::vector<AnnotatedSpan> candidates{
1084 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1085
1086 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001087 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001088 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001089 EXPECT_THAT(chosen, ElementsAreArray({0}));
1090}
1091
1092TEST(TextClassifierTest, ResolveConflictsSequence) {
1093 CREATE_UNILIB_FOR_TESTING;
1094 TestingTextClassifier classifier("", &unilib);
1095
1096 std::vector<AnnotatedSpan> candidates{{
1097 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1098 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1099 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1100 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1101 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1102 }};
1103
1104 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001105 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001106 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001107 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1108}
1109
1110TEST(TextClassifierTest, ResolveConflictsThreeSpans) {
1111 CREATE_UNILIB_FOR_TESTING;
1112 TestingTextClassifier classifier("", &unilib);
1113
1114 std::vector<AnnotatedSpan> candidates{{
1115 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1116 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1117 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1118 }};
1119
1120 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001121 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001122 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001123 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1124}
1125
1126TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) {
1127 CREATE_UNILIB_FOR_TESTING;
1128 TestingTextClassifier classifier("", &unilib);
1129
1130 std::vector<AnnotatedSpan> candidates{{
1131 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1132 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1133 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1134 }};
1135
1136 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001137 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001138 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001139 EXPECT_THAT(chosen, ElementsAreArray({1}));
1140}
1141
1142TEST(TextClassifierTest, ResolveConflictsFiveSpans) {
1143 CREATE_UNILIB_FOR_TESTING;
1144 TestingTextClassifier classifier("", &unilib);
1145
1146 std::vector<AnnotatedSpan> candidates{{
1147 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1148 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
1149 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1150 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
1151 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1152 }};
1153
1154 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001155 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001156 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001157 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
1158}
Lukas Zilka21d8c982018-01-24 11:11:20 +01001159
Lukas Zilkadf710db2018-02-27 12:44:09 +01001160#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
1161TEST_P(TextClassifierTest, LongInput) {
1162 CREATE_UNILIB_FOR_TESTING;
1163 std::unique_ptr<TextClassifier> classifier =
1164 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
1165 ASSERT_TRUE(classifier);
1166
1167 for (const auto& type_value_pair :
1168 std::vector<std::pair<std::string, std::string>>{
1169 {"address", "350 Third Street, Cambridge"},
1170 {"phone", "123 456-7890"},
1171 {"url", "www.google.com"},
1172 {"email", "someone@gmail.com"},
1173 {"flight", "LX 38"},
1174 {"date", "September 1, 2018"}}) {
1175 const std::string input_100k = std::string(50000, ' ') +
1176 type_value_pair.second +
1177 std::string(50000, ' ');
1178 const int value_length = type_value_pair.second.size();
1179
1180 EXPECT_THAT(classifier->Annotate(input_100k),
1181 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
1182 type_value_pair.first)}));
1183 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
1184 std::make_pair(50000, 50000 + value_length));
1185 EXPECT_EQ(type_value_pair.first,
1186 FirstResult(classifier->ClassifyText(
1187 input_100k, {50000, 50000 + value_length})));
1188 }
1189}
1190#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
1191
Lukas Zilkaba849e72018-03-08 14:48:21 +01001192#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
1193// These coarse tests are there only to make sure the execution happens in
1194// reasonable amount of time.
1195TEST_P(TextClassifierTest, LongInputNoResultCheck) {
1196 CREATE_UNILIB_FOR_TESTING;
1197 std::unique_ptr<TextClassifier> classifier =
1198 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
1199 ASSERT_TRUE(classifier);
1200
1201 for (const std::string& value :
1202 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
1203 const std::string input_100k =
1204 std::string(50000, ' ') + value + std::string(50000, ' ');
1205 const int value_length = value.size();
1206
1207 classifier->Annotate(input_100k);
1208 classifier->SuggestSelection(input_100k, {50000, 50001});
1209 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
1210 }
1211}
1212#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
1213
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001214#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
1215TEST_P(TextClassifierTest, MinAddressTokenLength) {
1216 CREATE_UNILIB_FOR_TESTING;
1217 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1218 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1219
1220 std::unique_ptr<TextClassifier> classifier;
1221
1222 // With unrestricted number of address tokens should behave normally.
1223 unpacked_model->classification_options->address_min_num_tokens = 0;
1224
1225 flatbuffers::FlatBufferBuilder builder;
1226 builder.Finish(Model::Pack(builder, unpacked_model.get()));
1227 classifier = TextClassifier::FromUnownedBuffer(
1228 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1229 builder.GetSize(), &unilib);
1230 ASSERT_TRUE(classifier);
1231
1232 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1233 "I live at 350 Third Street, Cambridge.", {10, 37})),
1234 "address");
1235
1236 // Raise number of address tokens to suppress the address classification.
1237 unpacked_model->classification_options->address_min_num_tokens = 5;
1238
1239 flatbuffers::FlatBufferBuilder builder2;
1240 builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
1241 classifier = TextClassifier::FromUnownedBuffer(
1242 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
1243 builder2.GetSize(), &unilib);
1244 ASSERT_TRUE(classifier);
1245
1246 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1247 "I live at 350 Third Street, Cambridge.", {10, 37})),
1248 "other");
1249}
1250#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
1251
Lukas Zilka21d8c982018-01-24 11:11:20 +01001252} // namespace
1253} // namespace libtextclassifier2