blob: b6cf16a56032a0ff5fc549cad31f876c16f17cfa [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "text-classifier.h"
18
19#include <fstream>
20#include <iostream>
21#include <memory>
22#include <string>
23
Lukas Zilkab23e2122018-02-09 10:25:19 +010024#include "model_generated.h"
25#include "types-test-util.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010026#include "gmock/gmock.h"
27#include "gtest/gtest.h"
28
29namespace libtextclassifier2 {
30namespace {
31
32using testing::ElementsAreArray;
33using testing::Pair;
Lukas Zilkab23e2122018-02-09 10:25:19 +010034using testing::Values;
Lukas Zilka21d8c982018-01-24 11:11:20 +010035
Lukas Zilkab23e2122018-02-09 10:25:19 +010036std::string FirstResult(const std::vector<ClassificationResult>& results) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010037 if (results.empty()) {
38 return "<INVALID RESULTS>";
39 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010040 return results[0].collection;
Lukas Zilka21d8c982018-01-24 11:11:20 +010041}
42
43MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
44 return testing::Value(arg.span, Pair(start, end)) &&
45 testing::Value(FirstResult(arg.classification), best_class);
46}
47
48std::string ReadFile(const std::string& file_name) {
49 std::ifstream file_stream(file_name);
50 return std::string(std::istreambuf_iterator<char>(file_stream), {});
51}
52
53std::string GetModelPath() {
54 return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
55}
56
57TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
Lukas Zilkab23e2122018-02-09 10:25:19 +010058 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +010059 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +010060 TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +010061 EXPECT_FALSE(classifier);
62}
63
Lukas Zilkab23e2122018-02-09 10:25:19 +010064class TextClassifierTest : public ::testing::TestWithParam<const char*> {};
65
66INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest,
67 Values("test_model_cc.fb"));
68INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest,
69 Values("test_model.fb"));
70
71TEST_P(TextClassifierTest, ClassifyText) {
72 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +010073 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +010074 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +010075 ASSERT_TRUE(classifier);
76
77 EXPECT_EQ("other",
78 FirstResult(classifier->ClassifyText(
79 "this afternoon Barack Obama gave a speech at", {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010080 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
81 "Call me at (800) 123-456 today", {11, 24})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010082
83 // More lines.
84 EXPECT_EQ("other",
85 FirstResult(classifier->ClassifyText(
86 "this afternoon Barack Obama gave a speech at|Visit "
87 "www.google.com every today!|Call me at (800) 123-456 today.",
88 {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010089 EXPECT_EQ("phone",
90 FirstResult(classifier->ClassifyText(
91 "this afternoon Barack Obama gave a speech at|Visit "
92 "www.google.com every today!|Call me at (800) 123-456 today.",
93 {90, 103})));
94
95 // Single word.
96 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
97 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
98 EXPECT_EQ("<INVALID RESULTS>",
99 FirstResult(classifier->ClassifyText("asdf", {0, 0})));
100
101 // Junk.
102 EXPECT_EQ("<INVALID RESULTS>",
103 FirstResult(classifier->ClassifyText("", {0, 0})));
104 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
105 "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
106}
107
Lukas Zilkab23e2122018-02-09 10:25:19 +0100108std::unique_ptr<RegexModel_::PatternT> MakePattern(
109 const std::string& collection_name, const std::string& pattern,
110 const bool enabled_for_classification, const bool enabled_for_selection,
111 const bool enabled_for_annotation, const float score) {
112 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
113 result->collection_name = collection_name;
114 result->pattern = pattern;
115 result->enabled_for_selection = enabled_for_selection;
116 result->enabled_for_classification = enabled_for_classification;
117 result->enabled_for_annotation = enabled_for_annotation;
118 result->target_classification_score = score;
119 result->priority_score = score;
120 return result;
121}
122
123#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
124TEST_P(TextClassifierTest, ClassifyTextRegularExpression) {
125 CREATE_UNILIB_FOR_TESTING;
126 const std::string test_model = ReadFile(GetModelPath() + GetParam());
127 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
128
129 // Add test regex models.
130 unpacked_model->regex_model->patterns.push_back(MakePattern(
131 "person", "Barack Obama", /*enabled_for_classification=*/true,
132 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
133 unpacked_model->regex_model->patterns.push_back(MakePattern(
134 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
135 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
136
137 flatbuffers::FlatBufferBuilder builder;
138 builder.Finish(Model::Pack(builder, unpacked_model.get()));
139
Lukas Zilka21d8c982018-01-24 11:11:20 +0100140 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100141 TextClassifier::FromUnownedBuffer(
142 reinterpret_cast<const char*>(builder.GetBufferPointer()),
143 builder.GetSize(), &unilib);
144 ASSERT_TRUE(classifier);
145
146 EXPECT_EQ("flight",
147 FirstResult(classifier->ClassifyText(
148 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
149 EXPECT_EQ("person",
150 FirstResult(classifier->ClassifyText(
151 "this afternoon Barack Obama gave a speech at", {15, 27})));
152 EXPECT_EQ("email",
153 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
154 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
155 "Contact me at you@android.com", {14, 29})));
156
157 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
158 "Visit www.google.com every today!", {6, 20})));
159
160 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
161 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
162 {7, 12})));
163
164 // More lines.
165 EXPECT_EQ("url",
166 FirstResult(classifier->ClassifyText(
167 "this afternoon Barack Obama gave a speech at|Visit "
168 "www.google.com every today!|Call me at (800) 123-456 today.",
169 {51, 65})));
170}
171#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
172
173#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
174
175TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
176 CREATE_UNILIB_FOR_TESTING;
177 const std::string test_model = ReadFile(GetModelPath() + GetParam());
178 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
179
180 // Add test regex models.
181 unpacked_model->regex_model.reset(new RegexModelT);
182 unpacked_model->regex_model->patterns.push_back(MakePattern(
183 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
184 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
185 unpacked_model->regex_model->patterns.push_back(MakePattern(
186 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
187 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
188 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
189
190 flatbuffers::FlatBufferBuilder builder;
191 builder.Finish(Model::Pack(builder, unpacked_model.get()));
192
193 std::unique_ptr<TextClassifier> classifier =
194 TextClassifier::FromUnownedBuffer(
195 reinterpret_cast<const char*>(builder.GetBufferPointer()),
196 builder.GetSize(), &unilib);
197 ASSERT_TRUE(classifier);
198
199 // Check regular expression selection.
200 EXPECT_EQ(classifier->SuggestSelection(
201 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
202 std::make_pair(12, 19));
203 EXPECT_EQ(classifier->SuggestSelection(
204 "this afternoon Barack Obama gave a speech at", {15, 21}),
205 std::make_pair(15, 27));
206}
207#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
208
209#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
210TEST_P(TextClassifierTest,
211 SuggestSelectionRegularExpressionConflictsModelWins) {
212 const std::string test_model = ReadFile(GetModelPath() + GetParam());
213 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
214
215 // Add test regex models.
216 unpacked_model->regex_model.reset(new RegexModelT);
217 unpacked_model->regex_model->patterns.push_back(MakePattern(
218 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
219 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
220 unpacked_model->regex_model->patterns.push_back(MakePattern(
221 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
222 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
223 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
224
225 flatbuffers::FlatBufferBuilder builder;
226 builder.Finish(Model::Pack(builder, unpacked_model.get()));
227
228 std::unique_ptr<TextClassifier> classifier =
229 TextClassifier::FromUnownedBuffer(
230 reinterpret_cast<const char*>(builder.GetBufferPointer()),
231 builder.GetSize());
232 ASSERT_TRUE(classifier);
233
234 // Check conflict resolution.
235 EXPECT_EQ(
236 classifier->SuggestSelection(
237 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
238 {55, 57}),
239 std::make_pair(26, 62));
240}
241#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
242
243#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
244TEST_P(TextClassifierTest,
245 SuggestSelectionRegularExpressionConflictsRegexWins) {
246 const std::string test_model = ReadFile(GetModelPath() + GetParam());
247 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
248
249 // Add test regex models.
250 unpacked_model->regex_model.reset(new RegexModelT);
251 unpacked_model->regex_model->patterns.push_back(MakePattern(
252 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
253 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
254 unpacked_model->regex_model->patterns.push_back(MakePattern(
255 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
256 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
257 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
258
259 flatbuffers::FlatBufferBuilder builder;
260 builder.Finish(Model::Pack(builder, unpacked_model.get()));
261
262 std::unique_ptr<TextClassifier> classifier =
263 TextClassifier::FromUnownedBuffer(
264 reinterpret_cast<const char*>(builder.GetBufferPointer()),
265 builder.GetSize());
266 ASSERT_TRUE(classifier);
267
268 // Check conflict resolution.
269 EXPECT_EQ(
270 classifier->SuggestSelection(
271 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
272 {55, 57}),
273 std::make_pair(55, 62));
274}
275#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
276
277#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
278TEST_P(TextClassifierTest, AnnotateRegex) {
279 CREATE_UNILIB_FOR_TESTING;
280 const std::string test_model = ReadFile(GetModelPath() + GetParam());
281 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
282
283 // Add test regex models.
284 unpacked_model->regex_model.reset(new RegexModelT);
285 unpacked_model->regex_model->patterns.push_back(MakePattern(
286 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
287 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
288 unpacked_model->regex_model->patterns.push_back(MakePattern(
289 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
290 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
291 flatbuffers::FlatBufferBuilder builder;
292 builder.Finish(Model::Pack(builder, unpacked_model.get()));
293
294 std::unique_ptr<TextClassifier> classifier =
295 TextClassifier::FromUnownedBuffer(
296 reinterpret_cast<const char*>(builder.GetBufferPointer()),
297 builder.GetSize(), &unilib);
298 ASSERT_TRUE(classifier);
299
300 const std::string test_string =
301 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
302 "number is 853 225 3556";
303 EXPECT_THAT(classifier->Annotate(test_string),
304 ElementsAreArray({
305 IsAnnotatedSpan(6, 18, "person"),
306 IsAnnotatedSpan(19, 24, "date"),
307 IsAnnotatedSpan(28, 55, "address"),
308 IsAnnotatedSpan(79, 91, "phone"),
309 }));
310}
311
312#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
313
314TEST_P(TextClassifierTest, PhoneFiltering) {
315 CREATE_UNILIB_FOR_TESTING;
316 std::unique_ptr<TextClassifier> classifier =
317 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100318 ASSERT_TRUE(classifier);
319
320 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
321 "phone: (123) 456 789", {7, 20})));
322 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
323 "phone: (123) 456 789,0001112", {7, 25})));
324 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
325 "phone: (123) 456 789,0001112", {7, 28})));
326}
327
Lukas Zilkab23e2122018-02-09 10:25:19 +0100328TEST_P(TextClassifierTest, SuggestSelection) {
329 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100330 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100331 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100332 ASSERT_TRUE(classifier);
333
334 EXPECT_EQ(classifier->SuggestSelection(
335 "this afternoon Barack Obama gave a speech at", {15, 21}),
336 std::make_pair(15, 21));
337
338 // Try passing whole string.
339 // If more than 1 token is specified, we should return back what entered.
340 EXPECT_EQ(
341 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
342 std::make_pair(0, 27));
343
344 // Single letter.
345 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
346
347 // Single word.
348 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
349
350 EXPECT_EQ(
351 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
352 std::make_pair(11, 23));
353
354 // Unpaired bracket stripping.
355 EXPECT_EQ(
356 classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
357 std::make_pair(11, 25));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100358 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
359 std::make_pair(12, 15));
360 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
361 std::make_pair(11, 15));
362 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
363 std::make_pair(12, 15));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100364
365 // If the resulting selection would be empty, the original span is returned.
366 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
367 std::make_pair(11, 13));
368 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
369 std::make_pair(11, 12));
370 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
371 std::make_pair(11, 12));
372}
373
Lukas Zilkab23e2122018-02-09 10:25:19 +0100374TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
375 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100376 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100377 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100378 ASSERT_TRUE(classifier);
379
380 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
381 std::make_pair(0, 27));
382 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
383 std::make_pair(0, 27));
384 EXPECT_EQ(
385 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
386 std::make_pair(0, 27));
387 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
388 {16, 22}),
389 std::make_pair(6, 33));
390}
391
Lukas Zilkab23e2122018-02-09 10:25:19 +0100392TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) {
393 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100394 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100395 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100396 ASSERT_TRUE(classifier);
397
398 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
399 std::make_pair(4, 16));
400 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
401 std::make_pair(0, 12));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100402
403 SelectionOptions options;
404 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
405 std::make_pair(0, 7));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100406}
407
Lukas Zilkab23e2122018-02-09 10:25:19 +0100408TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) {
409 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100410 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100411 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100412 ASSERT_TRUE(classifier);
413
414 // From the right.
415 EXPECT_EQ(classifier->SuggestSelection(
416 "this afternoon BarackObama, gave a speech at", {15, 26}),
417 std::make_pair(15, 26));
418
419 // From the right multiple.
420 EXPECT_EQ(classifier->SuggestSelection(
421 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
422 std::make_pair(15, 26));
423
424 // From the left multiple.
425 EXPECT_EQ(classifier->SuggestSelection(
426 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
427 std::make_pair(21, 32));
428
429 // From both sides.
430 EXPECT_EQ(classifier->SuggestSelection(
431 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
432 std::make_pair(16, 27));
433}
434
Lukas Zilkab23e2122018-02-09 10:25:19 +0100435TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
436 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100437 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100438 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100439 ASSERT_TRUE(classifier);
440
441 // Try passing in bunch of invalid selections.
442 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
443 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
444 std::make_pair(-10, 27));
445 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
446 std::make_pair(0, 27));
447 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
448 std::make_pair(-30, 300));
449 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
450 std::make_pair(-10, -1));
451 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
452 std::make_pair(100, 17));
453}
454
Lukas Zilkab23e2122018-02-09 10:25:19 +0100455TEST_P(TextClassifierTest, Annotate) {
456 CREATE_UNILIB_FOR_TESTING;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100457 std::unique_ptr<TextClassifier> classifier =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100458 TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100459 ASSERT_TRUE(classifier);
460
461 const std::string test_string =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100462 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
463 "number is 853 225 3556";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100464 EXPECT_THAT(classifier->Annotate(test_string),
465 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100466#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
467 IsAnnotatedSpan(19, 24, "date"),
468#endif
469 IsAnnotatedSpan(28, 55, "address"),
470 IsAnnotatedSpan(79, 91, "phone"),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100471 }));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100472
473 AnnotationOptions options;
474 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
475 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
476 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
Lukas Zilka21d8c982018-01-24 11:11:20 +0100477}
478
Lukas Zilkab23e2122018-02-09 10:25:19 +0100479TEST_P(TextClassifierTest, AnnotateSmallBatches) {
480 CREATE_UNILIB_FOR_TESTING;
481 const std::string test_model = ReadFile(GetModelPath() + GetParam());
482 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
483
484 // Set the batch size.
485 unpacked_model->selection_options->batch_size = 4;
486 flatbuffers::FlatBufferBuilder builder;
487 builder.Finish(Model::Pack(builder, unpacked_model.get()));
488
489 std::unique_ptr<TextClassifier> classifier =
490 TextClassifier::FromUnownedBuffer(
491 reinterpret_cast<const char*>(builder.GetBufferPointer()),
492 builder.GetSize(), &unilib);
493 ASSERT_TRUE(classifier);
494
495 const std::string test_string =
496 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
497 "number is 853 225 3556";
498 EXPECT_THAT(classifier->Annotate(test_string),
499 ElementsAreArray({
500#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
501 IsAnnotatedSpan(19, 24, "date"),
502#endif
503 IsAnnotatedSpan(28, 55, "address"),
504 IsAnnotatedSpan(79, 91, "phone"),
505 }));
506
507 AnnotationOptions options;
508 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
509 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
510 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
511}
512
513TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
514 CREATE_UNILIB_FOR_TESTING;
515 const std::string test_model = ReadFile(GetModelPath() + GetParam());
516 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
517
518 // Add test thresholds.
519 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
520 unpacked_model->triggering_options->min_annotate_confidence =
521 2.f; // Discards all results.
522 flatbuffers::FlatBufferBuilder builder;
523 builder.Finish(Model::Pack(builder, unpacked_model.get()));
524
525 std::unique_ptr<TextClassifier> classifier =
526 TextClassifier::FromUnownedBuffer(
527 reinterpret_cast<const char*>(builder.GetBufferPointer()),
528 builder.GetSize(), &unilib);
529 ASSERT_TRUE(classifier);
530
531 const std::string test_string =
532 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
533 "number is 853 225 3556";
534 EXPECT_TRUE(classifier->Annotate(test_string).empty());
535}
536
537TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
538 CREATE_UNILIB_FOR_TESTING;
539 const std::string test_model = ReadFile(GetModelPath() + GetParam());
540 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
541
542 // Add test thresholds.
543 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
544 unpacked_model->triggering_options->min_annotate_confidence =
545 0.f; // Keeps all results.
546 flatbuffers::FlatBufferBuilder builder;
547 builder.Finish(Model::Pack(builder, unpacked_model.get()));
548
549 std::unique_ptr<TextClassifier> classifier =
550 TextClassifier::FromUnownedBuffer(
551 reinterpret_cast<const char*>(builder.GetBufferPointer()),
552 builder.GetSize(), &unilib);
553 ASSERT_TRUE(classifier);
554
555 const std::string test_string =
556 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
557 "number is 853 225 3556";
558#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
559 EXPECT_EQ(classifier->Annotate(test_string).size(), 3);
560#else
561 // In non-ICU mode there is no "date" result.
562 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
563#endif
564}
565
566#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
567TEST_P(TextClassifierTest, ClassifyTextDate) {
568 std::unique_ptr<TextClassifier> classifier =
569 TextClassifier::FromPath(GetModelPath() + GetParam());
570 EXPECT_TRUE(classifier);
571
572 std::vector<ClassificationResult> result;
573 ClassificationOptions options;
574
575 options.reference_timezone = "Europe/Zurich";
576 result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
577
578 ASSERT_EQ(result.size(), 1);
579 EXPECT_THAT(result[0].collection, "date");
580 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
581 EXPECT_EQ(result[0].datetime_parse_result.granularity,
582 DatetimeGranularity::GRANULARITY_DAY);
583 result.clear();
584
585 options.reference_timezone = "America/Los_Angeles";
586 result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
587 ASSERT_EQ(result.size(), 1);
588 EXPECT_THAT(result[0].collection, "date");
589 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
590 EXPECT_EQ(result[0].datetime_parse_result.granularity,
591 DatetimeGranularity::GRANULARITY_DAY);
592 result.clear();
593
594 options.reference_timezone = "America/Los_Angeles";
595 result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
596 ASSERT_EQ(result.size(), 1);
597 EXPECT_THAT(result[0].collection, "date");
598 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
599 EXPECT_EQ(result[0].datetime_parse_result.granularity,
600 DatetimeGranularity::GRANULARITY_SECOND);
601 result.clear();
602
603 // Date on another line.
604 options.reference_timezone = "Europe/Zurich";
605 result = classifier->ClassifyText(
606 "hello world this is the first line\n"
607 "january 1, 2017",
608 {35, 50}, options);
609 ASSERT_EQ(result.size(), 1);
610 EXPECT_THAT(result[0].collection, "date");
611 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
612 EXPECT_EQ(result[0].datetime_parse_result.granularity,
613 DatetimeGranularity::GRANULARITY_DAY);
614 result.clear();
615}
616#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
617
618class TestingTextClassifier : public TextClassifier {
619 public:
620 TestingTextClassifier(const std::string& model, const UniLib* unilib)
621 : TextClassifier(ViewModel(model.data(), model.size()), unilib) {}
622
623 using TextClassifier::ResolveConflicts;
624};
625
626AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
627 const std::string& collection,
628 const float score) {
629 AnnotatedSpan result;
630 result.span = span;
631 result.classification.push_back({collection, score});
632 return result;
633}
634
635TEST(TextClassifierTest, ResolveConflictsTrivial) {
636 CREATE_UNILIB_FOR_TESTING;
637 TestingTextClassifier classifier("", &unilib);
638
639 std::vector<AnnotatedSpan> candidates{
640 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
641
642 std::vector<int> chosen;
643 classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
644 EXPECT_THAT(chosen, ElementsAreArray({0}));
645}
646
647TEST(TextClassifierTest, ResolveConflictsSequence) {
648 CREATE_UNILIB_FOR_TESTING;
649 TestingTextClassifier classifier("", &unilib);
650
651 std::vector<AnnotatedSpan> candidates{{
652 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
653 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
654 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
655 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
656 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
657 }};
658
659 std::vector<int> chosen;
660 classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
661 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
662}
663
664TEST(TextClassifierTest, ResolveConflictsThreeSpans) {
665 CREATE_UNILIB_FOR_TESTING;
666 TestingTextClassifier classifier("", &unilib);
667
668 std::vector<AnnotatedSpan> candidates{{
669 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
670 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
671 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
672 }};
673
674 std::vector<int> chosen;
675 classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
676 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
677}
678
679TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) {
680 CREATE_UNILIB_FOR_TESTING;
681 TestingTextClassifier classifier("", &unilib);
682
683 std::vector<AnnotatedSpan> candidates{{
684 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
685 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
686 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
687 }};
688
689 std::vector<int> chosen;
690 classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
691 EXPECT_THAT(chosen, ElementsAreArray({1}));
692}
693
694TEST(TextClassifierTest, ResolveConflictsFiveSpans) {
695 CREATE_UNILIB_FOR_TESTING;
696 TestingTextClassifier classifier("", &unilib);
697
698 std::vector<AnnotatedSpan> candidates{{
699 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
700 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
701 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
702 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
703 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
704 }};
705
706 std::vector<int> chosen;
707 classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
708 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
709}
Lukas Zilka21d8c982018-01-24 11:11:20 +0100710
711} // namespace
712} // namespace libtextclassifier2