blob: be71f84f6e5e2298d28e78b57afed3594d0c499c [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <fstream>
20#include <iostream>
21#include <memory>
22#include <string>
23
Tony Mak6c4cc672018-09-17 11:48:50 +010024#include "annotator/model_generated.h"
25#include "annotator/types-test-util.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010026#include "gmock/gmock.h"
27#include "gtest/gtest.h"
28
Tony Mak6c4cc672018-09-17 11:48:50 +010029namespace libtextclassifier3 {
Lukas Zilka21d8c982018-01-24 11:11:20 +010030namespace {
31
32using testing::ElementsAreArray;
Lukas Zilkaba849e72018-03-08 14:48:21 +010033using testing::IsEmpty;
Lukas Zilka21d8c982018-01-24 11:11:20 +010034using testing::Pair;
Lukas Zilkab23e2122018-02-09 10:25:19 +010035using testing::Values;
Lukas Zilka21d8c982018-01-24 11:11:20 +010036
Lukas Zilkab23e2122018-02-09 10:25:19 +010037std::string FirstResult(const std::vector<ClassificationResult>& results) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010038 if (results.empty()) {
39 return "<INVALID RESULTS>";
40 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010041 return results[0].collection;
Lukas Zilka21d8c982018-01-24 11:11:20 +010042}
43
44MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
45 return testing::Value(arg.span, Pair(start, end)) &&
46 testing::Value(FirstResult(arg.classification), best_class);
47}
48
49std::string ReadFile(const std::string& file_name) {
50 std::ifstream file_stream(file_name);
51 return std::string(std::istreambuf_iterator<char>(file_stream), {});
52}
53
54std::string GetModelPath() {
Tony Maka0f598b2018-11-20 20:39:04 +000055 return TC3_TEST_DATA_DIR;
Lukas Zilka21d8c982018-01-24 11:11:20 +010056}
57
Tony Mak6c4cc672018-09-17 11:48:50 +010058class AnnotatorTest : public ::testing::TestWithParam<const char*> {
59 protected:
60 AnnotatorTest()
61 : INIT_UNILIB_FOR_TESTING(unilib_),
62 INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
63 UniLib unilib_;
64 CalendarLib calendarlib_;
65};
66
67TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
68 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
69 GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +010070 EXPECT_FALSE(classifier);
71}
72
Tony Mak6c4cc672018-09-17 11:48:50 +010073INSTANTIATE_TEST_CASE_P(ClickContext, AnnotatorTest,
Lukas Zilkab23e2122018-02-09 10:25:19 +010074 Values("test_model_cc.fb"));
Tony Mak6c4cc672018-09-17 11:48:50 +010075INSTANTIATE_TEST_CASE_P(BoundsSensitive, AnnotatorTest,
Lukas Zilkab23e2122018-02-09 10:25:19 +010076 Values("test_model.fb"));
77
Tony Mak6c4cc672018-09-17 11:48:50 +010078TEST_P(AnnotatorTest, ClassifyText) {
79 std::unique_ptr<Annotator> classifier =
80 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +010081 ASSERT_TRUE(classifier);
82
83 EXPECT_EQ("other",
84 FirstResult(classifier->ClassifyText(
85 "this afternoon Barack Obama gave a speech at", {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010086 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
87 "Call me at (800) 123-456 today", {11, 24})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010088
89 // More lines.
90 EXPECT_EQ("other",
91 FirstResult(classifier->ClassifyText(
92 "this afternoon Barack Obama gave a speech at|Visit "
93 "www.google.com every today!|Call me at (800) 123-456 today.",
94 {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +010095 EXPECT_EQ("phone",
96 FirstResult(classifier->ClassifyText(
97 "this afternoon Barack Obama gave a speech at|Visit "
98 "www.google.com every today!|Call me at (800) 123-456 today.",
99 {90, 103})));
100
101 // Single word.
102 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
103 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
104 EXPECT_EQ("<INVALID RESULTS>",
105 FirstResult(classifier->ClassifyText("asdf", {0, 0})));
106
107 // Junk.
108 EXPECT_EQ("<INVALID RESULTS>",
109 FirstResult(classifier->ClassifyText("", {0, 0})));
110 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
111 "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200112 // Test invalid utf8 input.
113 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
114 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100115}
116
Tony Mak6c4cc672018-09-17 11:48:50 +0100117TEST_P(AnnotatorTest, ClassifyTextDisabledFail) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100118 const std::string test_model = ReadFile(GetModelPath() + GetParam());
119 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
120
121 unpacked_model->classification_model.clear();
122 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
123 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
124
125 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000126 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100127
Tony Mak6c4cc672018-09-17 11:48:50 +0100128 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
129 reinterpret_cast<const char*>(builder.GetBufferPointer()),
130 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100131
132 // The classification model is still needed for selection scores.
133 ASSERT_FALSE(classifier);
134}
135
Tony Mak6c4cc672018-09-17 11:48:50 +0100136TEST_P(AnnotatorTest, ClassifyTextDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100137 const std::string test_model = ReadFile(GetModelPath() + GetParam());
138 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
139
140 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
141 unpacked_model->triggering_options->enabled_modes =
142 ModeFlag_ANNOTATION_AND_SELECTION;
143
144 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000145 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100146
Tony Mak6c4cc672018-09-17 11:48:50 +0100147 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
148 reinterpret_cast<const char*>(builder.GetBufferPointer()),
149 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100150 ASSERT_TRUE(classifier);
151
152 EXPECT_THAT(
153 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
154 IsEmpty());
155}
156
Tony Mak6c4cc672018-09-17 11:48:50 +0100157TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200158 const std::string test_model = ReadFile(GetModelPath() + GetParam());
159
Tony Mak6c4cc672018-09-17 11:48:50 +0100160 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
161 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200162 ASSERT_TRUE(classifier);
163
164 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
165 "Call me at (800) 123-456 today", {11, 24})));
166
167 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
168 unpacked_model->output_options.reset(new OutputOptionsT);
169
170 // Disable phone classification
171 unpacked_model->output_options->filtered_collections_classification.push_back(
172 "phone");
173
174 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000175 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200176
Tony Mak6c4cc672018-09-17 11:48:50 +0100177 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200178 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100179 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200180 ASSERT_TRUE(classifier);
181
182 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
183 "Call me at (800) 123-456 today", {11, 24})));
184
185 // Check that the address classification still passes.
186 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
187 "350 Third Street, Cambridge", {0, 27})));
188}
189
Lukas Zilkab23e2122018-02-09 10:25:19 +0100190std::unique_ptr<RegexModel_::PatternT> MakePattern(
191 const std::string& collection_name, const std::string& pattern,
192 const bool enabled_for_classification, const bool enabled_for_selection,
193 const bool enabled_for_annotation, const float score) {
194 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
195 result->collection_name = collection_name;
196 result->pattern = pattern;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100197 // We cannot directly operate with |= on the flag, so use an int here.
198 int enabled_modes = ModeFlag_NONE;
199 if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
200 if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
201 if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
202 result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100203 result->target_classification_score = score;
204 result->priority_score = score;
205 return result;
206}
207
Tony Maka0f598b2018-11-20 20:39:04 +0000208#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100209TEST_P(AnnotatorTest, ClassifyTextRegularExpression) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100210 const std::string test_model = ReadFile(GetModelPath() + GetParam());
211 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
212
213 // Add test regex models.
214 unpacked_model->regex_model->patterns.push_back(MakePattern(
215 "person", "Barack Obama", /*enabled_for_classification=*/true,
216 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
217 unpacked_model->regex_model->patterns.push_back(MakePattern(
218 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
219 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100220 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
221 MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
222 /*enabled_for_classification=*/true,
223 /*enabled_for_selection=*/false,
224 /*enabled_for_annotation=*/false, 1.0);
225 verified_pattern->verification_options.reset(new VerificationOptionsT);
226 verified_pattern->verification_options->verify_luhn_checksum = true;
227 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100228
229 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000230 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100231
Tony Mak6c4cc672018-09-17 11:48:50 +0100232 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
233 reinterpret_cast<const char*>(builder.GetBufferPointer()),
234 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100235 ASSERT_TRUE(classifier);
236
237 EXPECT_EQ("flight",
238 FirstResult(classifier->ClassifyText(
239 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
240 EXPECT_EQ("person",
241 FirstResult(classifier->ClassifyText(
242 "this afternoon Barack Obama gave a speech at", {15, 27})));
243 EXPECT_EQ("email",
244 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
245 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
246 "Contact me at you@android.com", {14, 29})));
247
248 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
249 "Visit www.google.com every today!", {6, 20})));
250
251 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
252 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
253 {7, 12})));
Tony Mak6c4cc672018-09-17 11:48:50 +0100254 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
255 "cc: 4012 8888 8888 1881", {4, 23})));
256 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
257 "2221 0067 4735 6281", {0, 19})));
258 // Luhn check fails.
259 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
260 {0, 19})));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100261
262 // More lines.
263 EXPECT_EQ("url",
264 FirstResult(classifier->ClassifyText(
265 "this afternoon Barack Obama gave a speech at|Visit "
266 "www.google.com every today!|Call me at (800) 123-456 today.",
267 {51, 65})));
268}
Tony Maka0f598b2018-11-20 20:39:04 +0000269#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100270
Tony Maka0f598b2018-11-20 20:39:04 +0000271#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100272TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100273 const std::string test_model = ReadFile(GetModelPath() + GetParam());
274 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
275
276 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100277 unpacked_model->regex_model->patterns.push_back(MakePattern(
278 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
279 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
280 unpacked_model->regex_model->patterns.push_back(MakePattern(
281 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
282 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
283 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
Tony Mak6c4cc672018-09-17 11:48:50 +0100284 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
285 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
286 /*enabled_for_classification=*/false,
287 /*enabled_for_selection=*/true,
288 /*enabled_for_annotation=*/false, 1.0);
289 verified_pattern->verification_options.reset(new VerificationOptionsT);
290 verified_pattern->verification_options->verify_luhn_checksum = true;
291 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100292
293 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000294 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100295
Tony Mak6c4cc672018-09-17 11:48:50 +0100296 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
297 reinterpret_cast<const char*>(builder.GetBufferPointer()),
298 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100299 ASSERT_TRUE(classifier);
300
301 // Check regular expression selection.
302 EXPECT_EQ(classifier->SuggestSelection(
303 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
304 std::make_pair(12, 19));
305 EXPECT_EQ(classifier->SuggestSelection(
306 "this afternoon Barack Obama gave a speech at", {15, 21}),
307 std::make_pair(15, 27));
Tony Mak6c4cc672018-09-17 11:48:50 +0100308 EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
309 std::make_pair(4, 23));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100310}
Tony Mak854015a2019-01-16 15:56:48 +0000311
312TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
313 const std::string test_model = ReadFile(GetModelPath() + GetParam());
314 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
315
316 // Add test regex models.
317 std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
318 MakePattern("date_range",
319 "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
320 "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
321 /*enabled_for_classification=*/false,
322 /*enabled_for_selection=*/true,
323 /*enabled_for_annotation=*/false, 1.0);
324 custom_selection_bounds_pattern->capturing_group.emplace_back(
325 new RegexModel_::Pattern_::CapturingGroupT);
326 custom_selection_bounds_pattern->capturing_group.emplace_back(
327 new RegexModel_::Pattern_::CapturingGroupT);
328 custom_selection_bounds_pattern->capturing_group.emplace_back(
329 new RegexModel_::Pattern_::CapturingGroupT);
330 custom_selection_bounds_pattern->capturing_group.emplace_back(
331 new RegexModel_::Pattern_::CapturingGroupT);
332 custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
333 custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
334 custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
335 custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
336 unpacked_model->regex_model->patterns.push_back(
337 std::move(custom_selection_bounds_pattern));
338
339 flatbuffers::FlatBufferBuilder builder;
340 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
341
342 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
343 reinterpret_cast<const char*>(builder.GetBufferPointer()),
344 builder.GetSize(), &unilib_, &calendarlib_);
345 ASSERT_TRUE(classifier);
346
347 // Check regular expression selection.
348 EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
349 {21, 23}),
350 std::make_pair(10, 34));
351 EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
352 std::make_pair(9, 17));
353}
Tony Maka0f598b2018-11-20 20:39:04 +0000354#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100355
Tony Maka0f598b2018-11-20 20:39:04 +0000356#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100357TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100358 const std::string test_model = ReadFile(GetModelPath() + GetParam());
359 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
360
361 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100362 unpacked_model->regex_model->patterns.push_back(MakePattern(
363 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
364 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
365 unpacked_model->regex_model->patterns.push_back(MakePattern(
366 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
367 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
368 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
369
370 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000371 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100372
Tony Mak6c4cc672018-09-17 11:48:50 +0100373 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
374 reinterpret_cast<const char*>(builder.GetBufferPointer()),
375 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100376 ASSERT_TRUE(classifier);
377
378 // Check conflict resolution.
379 EXPECT_EQ(
380 classifier->SuggestSelection(
381 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
382 {55, 57}),
383 std::make_pair(26, 62));
384}
Tony Maka0f598b2018-11-20 20:39:04 +0000385#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100386
Tony Maka0f598b2018-11-20 20:39:04 +0000387#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100388TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100389 const std::string test_model = ReadFile(GetModelPath() + GetParam());
390 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
391
392 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100393 unpacked_model->regex_model->patterns.push_back(MakePattern(
394 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
395 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
396 unpacked_model->regex_model->patterns.push_back(MakePattern(
397 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
398 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
399 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
400
401 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000402 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100403
Tony Mak6c4cc672018-09-17 11:48:50 +0100404 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
405 reinterpret_cast<const char*>(builder.GetBufferPointer()),
406 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100407 ASSERT_TRUE(classifier);
408
409 // Check conflict resolution.
410 EXPECT_EQ(
411 classifier->SuggestSelection(
412 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
413 {55, 57}),
414 std::make_pair(55, 62));
415}
Tony Maka0f598b2018-11-20 20:39:04 +0000416#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100417
Tony Maka0f598b2018-11-20 20:39:04 +0000418#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100419TEST_P(AnnotatorTest, AnnotateRegex) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100420 const std::string test_model = ReadFile(GetModelPath() + GetParam());
421 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
422
423 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100424 unpacked_model->regex_model->patterns.push_back(MakePattern(
425 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
426 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
427 unpacked_model->regex_model->patterns.push_back(MakePattern(
428 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
429 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100430 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
431 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
432 /*enabled_for_classification=*/false,
433 /*enabled_for_selection=*/false,
434 /*enabled_for_annotation=*/true, 1.0);
435 verified_pattern->verification_options.reset(new VerificationOptionsT);
436 verified_pattern->verification_options->verify_luhn_checksum = true;
437 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100438 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000439 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100440
Tony Mak6c4cc672018-09-17 11:48:50 +0100441 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
442 reinterpret_cast<const char*>(builder.GetBufferPointer()),
443 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100444 ASSERT_TRUE(classifier);
445
446 const std::string test_string =
447 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
Tony Mak6c4cc672018-09-17 11:48:50 +0100448 "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100449 EXPECT_THAT(classifier->Annotate(test_string),
Tony Mak6c4cc672018-09-17 11:48:50 +0100450 ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
451 IsAnnotatedSpan(28, 55, "address"),
452 IsAnnotatedSpan(79, 91, "phone"),
453 IsAnnotatedSpan(107, 126, "payment_card")}));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100454}
Tony Maka0f598b2018-11-20 20:39:04 +0000455#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100456
Tony Mak6c4cc672018-09-17 11:48:50 +0100457TEST_P(AnnotatorTest, PhoneFiltering) {
458 std::unique_ptr<Annotator> classifier =
459 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100460 ASSERT_TRUE(classifier);
461
462 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
463 "phone: (123) 456 789", {7, 20})));
464 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
465 "phone: (123) 456 789,0001112", {7, 25})));
466 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
467 "phone: (123) 456 789,0001112", {7, 28})));
468}
469
Tony Mak6c4cc672018-09-17 11:48:50 +0100470TEST_P(AnnotatorTest, SuggestSelection) {
471 std::unique_ptr<Annotator> classifier =
472 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100473 ASSERT_TRUE(classifier);
474
475 EXPECT_EQ(classifier->SuggestSelection(
476 "this afternoon Barack Obama gave a speech at", {15, 21}),
477 std::make_pair(15, 21));
478
479 // Try passing whole string.
480 // If more than 1 token is specified, we should return back what entered.
481 EXPECT_EQ(
482 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
483 std::make_pair(0, 27));
484
485 // Single letter.
486 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
487
488 // Single word.
489 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
490
491 EXPECT_EQ(
492 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
493 std::make_pair(11, 23));
494
495 // Unpaired bracket stripping.
496 EXPECT_EQ(
497 classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
498 std::make_pair(11, 25));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100499 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
500 std::make_pair(12, 15));
501 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
502 std::make_pair(11, 15));
503 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
504 std::make_pair(12, 15));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100505
506 // If the resulting selection would be empty, the original span is returned.
507 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
508 std::make_pair(11, 13));
509 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
510 std::make_pair(11, 12));
511 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
512 std::make_pair(11, 12));
513}
514
Tony Mak6c4cc672018-09-17 11:48:50 +0100515TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100516 const std::string test_model = ReadFile(GetModelPath() + GetParam());
517 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
518
519 // Disable the selection model.
520 unpacked_model->selection_model.clear();
521 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
522 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
523
524 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000525 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100526
Tony Mak6c4cc672018-09-17 11:48:50 +0100527 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
528 reinterpret_cast<const char*>(builder.GetBufferPointer()),
529 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100530 // Selection model needs to be present for annotation.
531 ASSERT_FALSE(classifier);
532}
533
Tony Mak6c4cc672018-09-17 11:48:50 +0100534TEST_P(AnnotatorTest, SuggestSelectionDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100535 const std::string test_model = ReadFile(GetModelPath() + GetParam());
536 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
537
538 // Disable the selection model.
539 unpacked_model->selection_model.clear();
540 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
541 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
542 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
543
544 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000545 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100546
Tony Mak6c4cc672018-09-17 11:48:50 +0100547 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
548 reinterpret_cast<const char*>(builder.GetBufferPointer()),
549 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100550 ASSERT_TRUE(classifier);
551
552 EXPECT_EQ(
553 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
554 std::make_pair(11, 14));
555
556 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
557 "call me at (800) 123-456 today", {11, 24})));
558
559 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
560 IsEmpty());
561}
562
Tony Mak6c4cc672018-09-17 11:48:50 +0100563TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200564 const std::string test_model = ReadFile(GetModelPath() + GetParam());
565
Tony Mak6c4cc672018-09-17 11:48:50 +0100566 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
567 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200568 ASSERT_TRUE(classifier);
569
570 EXPECT_EQ(
571 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
572 std::make_pair(11, 23));
573
574 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
575 unpacked_model->output_options.reset(new OutputOptionsT);
576
577 // Disable phone selection
578 unpacked_model->output_options->filtered_collections_selection.push_back(
579 "phone");
580 // We need to force this for filtering.
581 unpacked_model->selection_options->always_classify_suggested_selection = true;
582
583 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000584 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200585
Tony Mak6c4cc672018-09-17 11:48:50 +0100586 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200587 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100588 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200589 ASSERT_TRUE(classifier);
590
591 EXPECT_EQ(
592 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
593 std::make_pair(11, 14));
594
595 // Address selection should still work.
596 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
597 std::make_pair(0, 27));
598}
599
Tony Mak6c4cc672018-09-17 11:48:50 +0100600TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) {
601 std::unique_ptr<Annotator> classifier =
602 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100603 ASSERT_TRUE(classifier);
604
605 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
606 std::make_pair(0, 27));
607 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
608 std::make_pair(0, 27));
609 EXPECT_EQ(
610 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
611 std::make_pair(0, 27));
612 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
613 {16, 22}),
614 std::make_pair(6, 33));
615}
616
Tony Mak6c4cc672018-09-17 11:48:50 +0100617TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) {
618 std::unique_ptr<Annotator> classifier =
619 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100620 ASSERT_TRUE(classifier);
621
622 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
623 std::make_pair(4, 16));
624 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
625 std::make_pair(0, 12));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100626
627 SelectionOptions options;
628 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
629 std::make_pair(0, 7));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100630}
631
Tony Mak6c4cc672018-09-17 11:48:50 +0100632TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
633 std::unique_ptr<Annotator> classifier =
634 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100635 ASSERT_TRUE(classifier);
636
637 // From the right.
638 EXPECT_EQ(classifier->SuggestSelection(
639 "this afternoon BarackObama, gave a speech at", {15, 26}),
640 std::make_pair(15, 26));
641
642 // From the right multiple.
643 EXPECT_EQ(classifier->SuggestSelection(
644 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
645 std::make_pair(15, 26));
646
647 // From the left multiple.
648 EXPECT_EQ(classifier->SuggestSelection(
649 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
650 std::make_pair(21, 32));
651
652 // From both sides.
653 EXPECT_EQ(classifier->SuggestSelection(
654 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
655 std::make_pair(16, 27));
656}
657
Tony Mak6c4cc672018-09-17 11:48:50 +0100658TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
659 std::unique_ptr<Annotator> classifier =
660 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100661 ASSERT_TRUE(classifier);
662
663 // Try passing in bunch of invalid selections.
664 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
665 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
666 std::make_pair(-10, 27));
667 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
668 std::make_pair(0, 27));
669 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
670 std::make_pair(-30, 300));
671 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
672 std::make_pair(-10, -1));
673 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
674 std::make_pair(100, 17));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200675
676 // Try passing invalid utf8.
677 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
678 std::make_pair(-1, -1));
679}
680
Tony Mak6c4cc672018-09-17 11:48:50 +0100681TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) {
682 std::unique_ptr<Annotator> classifier =
683 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200684 ASSERT_TRUE(classifier);
685
686 EXPECT_EQ(
687 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
688 std::make_pair(11, 23));
689 EXPECT_EQ(
690 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
691 std::make_pair(10, 11));
692 EXPECT_EQ(
693 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
694 std::make_pair(23, 24));
695 EXPECT_EQ(
696 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
697 std::make_pair(23, 24));
698 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
699 {14, 17}),
700 std::make_pair(11, 25));
701 EXPECT_EQ(
702 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
703 std::make_pair(11, 23));
704 EXPECT_EQ(
705 classifier->SuggestSelection(
706 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
707 std::make_pair(14, 40));
708 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
709 std::make_pair(4, 5));
710 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
711 std::make_pair(7, 8));
712
713 // With a punctuation around the selected whitespace.
714 EXPECT_EQ(
715 classifier->SuggestSelection(
716 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
717 std::make_pair(14, 41));
718
719 // When all's whitespace, should return the original indices.
720 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
721 std::make_pair(0, 1));
722 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
723 std::make_pair(0, 3));
724 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
725 std::make_pair(2, 3));
726 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
727 std::make_pair(5, 6));
728}
729
Tony Mak6c4cc672018-09-17 11:48:50 +0100730TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200731 UnicodeText text;
732
733 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100734 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200735 std::make_pair(3, 4));
736 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100737 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200738 std::make_pair(3, 4));
739
740 // Nothing on the left.
741 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100742 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200743 std::make_pair(4, 5));
744 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100745 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200746 std::make_pair(0, 1));
747
748 // Whitespace only.
749 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100750 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200751 std::make_pair(2, 3));
752 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100753 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200754 std::make_pair(4, 5));
755 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100756 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200757 std::make_pair(0, 1));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100758}
759
Tony Mak6c4cc672018-09-17 11:48:50 +0100760TEST_P(AnnotatorTest, Annotate) {
761 std::unique_ptr<Annotator> classifier =
762 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100763 ASSERT_TRUE(classifier);
764
765 const std::string test_string =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100766 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
767 "number is 853 225 3556";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100768 EXPECT_THAT(classifier->Annotate(test_string),
769 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100770 IsAnnotatedSpan(28, 55, "address"),
771 IsAnnotatedSpan(79, 91, "phone"),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100772 }));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100773
774 AnnotationOptions options;
775 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
776 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
777 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200778
779 // Try passing invalid utf8.
780 EXPECT_TRUE(
781 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
782 .empty());
Lukas Zilka21d8c982018-01-24 11:11:20 +0100783}
784
Tony Maka0f598b2018-11-20 20:39:04 +0000785
Tony Mak6c4cc672018-09-17 11:48:50 +0100786TEST_P(AnnotatorTest, AnnotateSmallBatches) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100787 const std::string test_model = ReadFile(GetModelPath() + GetParam());
788 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
789
790 // Set the batch size.
791 unpacked_model->selection_options->batch_size = 4;
792 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000793 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100794
Tony Mak6c4cc672018-09-17 11:48:50 +0100795 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
796 reinterpret_cast<const char*>(builder.GetBufferPointer()),
797 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100798 ASSERT_TRUE(classifier);
799
800 const std::string test_string =
801 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
802 "number is 853 225 3556";
803 EXPECT_THAT(classifier->Annotate(test_string),
804 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100805 IsAnnotatedSpan(28, 55, "address"),
806 IsAnnotatedSpan(79, 91, "phone"),
807 }));
808
809 AnnotationOptions options;
810 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
811 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
812 EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
813}
814
Tony Maka0f598b2018-11-20 20:39:04 +0000815#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100816TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100817 const std::string test_model = ReadFile(GetModelPath() + GetParam());
818 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
819
Lukas Zilkab23e2122018-02-09 10:25:19 +0100820 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100821 // Add test threshold.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100822 unpacked_model->triggering_options->min_annotate_confidence =
823 2.f; // Discards all results.
824 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000825 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100826
Tony Mak6c4cc672018-09-17 11:48:50 +0100827 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
828 reinterpret_cast<const char*>(builder.GetBufferPointer()),
829 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100830 ASSERT_TRUE(classifier);
831
832 const std::string test_string =
833 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
834 "number is 853 225 3556";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100835
Tony Mak6c4cc672018-09-17 11:48:50 +0100836 EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100837}
Tony Maka0f598b2018-11-20 20:39:04 +0000838#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100839
Tony Mak6c4cc672018-09-17 11:48:50 +0100840TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100841 const std::string test_model = ReadFile(GetModelPath() + GetParam());
842 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
843
844 // Add test thresholds.
845 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
846 unpacked_model->triggering_options->min_annotate_confidence =
847 0.f; // Keeps all results.
Lukas Zilkaba849e72018-03-08 14:48:21 +0100848 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100849 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000850 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100851
Tony Mak6c4cc672018-09-17 11:48:50 +0100852 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
853 reinterpret_cast<const char*>(builder.GetBufferPointer()),
854 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100855 ASSERT_TRUE(classifier);
856
857 const std::string test_string =
858 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
859 "number is 853 225 3556";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100860 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100861}
862
Tony Mak6c4cc672018-09-17 11:48:50 +0100863TEST_P(AnnotatorTest, AnnotateDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100864 const std::string test_model = ReadFile(GetModelPath() + GetParam());
865 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
866
867 // Disable the model for annotation.
868 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
869 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000870 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100871
Tony Mak6c4cc672018-09-17 11:48:50 +0100872 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
873 reinterpret_cast<const char*>(builder.GetBufferPointer()),
874 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100875 ASSERT_TRUE(classifier);
876 const std::string test_string =
877 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
878 "number is 853 225 3556";
879 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
880}
881
Tony Mak6c4cc672018-09-17 11:48:50 +0100882TEST_P(AnnotatorTest, AnnotateFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200883 const std::string test_model = ReadFile(GetModelPath() + GetParam());
884
Tony Mak6c4cc672018-09-17 11:48:50 +0100885 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
886 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200887 ASSERT_TRUE(classifier);
888
889 const std::string test_string =
890 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
891 "number is 853 225 3556";
892
893 EXPECT_THAT(classifier->Annotate(test_string),
894 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200895 IsAnnotatedSpan(28, 55, "address"),
896 IsAnnotatedSpan(79, 91, "phone"),
897 }));
898
899 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
900 unpacked_model->output_options.reset(new OutputOptionsT);
901
902 // Disable phone annotation
903 unpacked_model->output_options->filtered_collections_annotation.push_back(
904 "phone");
905
906 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000907 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200908
Tony Mak6c4cc672018-09-17 11:48:50 +0100909 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200910 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100911 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200912 ASSERT_TRUE(classifier);
913
914 EXPECT_THAT(classifier->Annotate(test_string),
915 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200916 IsAnnotatedSpan(28, 55, "address"),
917 }));
918}
919
Tony Maka0f598b2018-11-20 20:39:04 +0000920#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100921TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200922 const std::string test_model = ReadFile(GetModelPath() + GetParam());
923
Tony Mak6c4cc672018-09-17 11:48:50 +0100924 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
925 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200926 ASSERT_TRUE(classifier);
927
928 const std::string test_string =
929 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
930 "number is 853 225 3556";
931
932 EXPECT_THAT(classifier->Annotate(test_string),
933 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200934 IsAnnotatedSpan(28, 55, "address"),
935 IsAnnotatedSpan(79, 91, "phone"),
936 }));
937
938 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
939 unpacked_model->output_options.reset(new OutputOptionsT);
940
941 // We add a custom annotator that wins against the phone classification
942 // below and that we subsequently suppress.
943 unpacked_model->output_options->filtered_collections_annotation.push_back(
944 "suppress");
945
946 unpacked_model->regex_model->patterns.push_back(MakePattern(
947 "suppress", "(\\d{3} ?\\d{4})",
948 /*enabled_for_classification=*/false,
949 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
950
951 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000952 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200953
Tony Mak6c4cc672018-09-17 11:48:50 +0100954 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200955 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100956 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200957 ASSERT_TRUE(classifier);
958
959 EXPECT_THAT(classifier->Annotate(test_string),
960 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200961 IsAnnotatedSpan(28, 55, "address"),
962 }));
963}
Tony Maka0f598b2018-11-20 20:39:04 +0000964#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200965
Tony Maka0f598b2018-11-20 20:39:04 +0000966#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100967TEST_P(AnnotatorTest, ClassifyTextDate) {
968 std::unique_ptr<Annotator> classifier =
969 Annotator::FromPath(GetModelPath() + GetParam());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100970 EXPECT_TRUE(classifier);
971
972 std::vector<ClassificationResult> result;
973 ClassificationOptions options;
974
975 options.reference_timezone = "Europe/Zurich";
976 result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100977 ASSERT_EQ(result.size(), 1);
978 EXPECT_THAT(result[0].collection, "date");
979 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
980 EXPECT_EQ(result[0].datetime_parse_result.granularity,
981 DatetimeGranularity::GRANULARITY_DAY);
982 result.clear();
983
984 options.reference_timezone = "America/Los_Angeles";
985 result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
986 ASSERT_EQ(result.size(), 1);
987 EXPECT_THAT(result[0].collection, "date");
988 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
989 EXPECT_EQ(result[0].datetime_parse_result.granularity,
990 DatetimeGranularity::GRANULARITY_DAY);
991 result.clear();
992
993 options.reference_timezone = "America/Los_Angeles";
994 result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
Tony Mak854015a2019-01-16 15:56:48 +0000995 ASSERT_EQ(result.size(), 2); // Has 2 interpretations - a.m. or p.m.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100996 EXPECT_THAT(result[0].collection, "date");
997 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
998 EXPECT_EQ(result[0].datetime_parse_result.granularity,
999 DatetimeGranularity::GRANULARITY_SECOND);
Tony Mak854015a2019-01-16 15:56:48 +00001000 EXPECT_THAT(result[1].collection, "date");
1001 EXPECT_EQ(result[1].datetime_parse_result.time_ms_utc, 1514874620000);
1002 EXPECT_EQ(result[1].datetime_parse_result.granularity,
1003 DatetimeGranularity::GRANULARITY_SECOND);
1004 result.clear();
1005
1006 options.reference_timezone = "America/Los_Angeles";
1007 result = classifier->ClassifyText("2018/01/01 22:00", {0, 16}, options);
1008 ASSERT_EQ(result.size(), 1); // Has only 1 interpretation - 10 p.m.
1009 EXPECT_THAT(result[0].collection, "date");
1010 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514872800000);
1011 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1012 DatetimeGranularity::GRANULARITY_MINUTE);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001013 result.clear();
1014
1015 // Date on another line.
1016 options.reference_timezone = "Europe/Zurich";
1017 result = classifier->ClassifyText(
1018 "hello world this is the first line\n"
1019 "january 1, 2017",
1020 {35, 50}, options);
1021 ASSERT_EQ(result.size(), 1);
1022 EXPECT_THAT(result[0].collection, "date");
1023 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
1024 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1025 DatetimeGranularity::GRANULARITY_DAY);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001026}
Tony Maka0f598b2018-11-20 20:39:04 +00001027#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001028
Tony Maka0f598b2018-11-20 20:39:04 +00001029#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001030TEST_P(AnnotatorTest, ClassifyTextDatePriorities) {
1031 std::unique_ptr<Annotator> classifier =
1032 Annotator::FromPath(GetModelPath() + GetParam());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001033 EXPECT_TRUE(classifier);
1034
1035 std::vector<ClassificationResult> result;
1036 ClassificationOptions options;
1037
1038 result.clear();
1039 options.reference_timezone = "Europe/Zurich";
1040 options.locales = "en-US";
Lukas Zilka434442d2018-04-25 11:38:51 +02001041 result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001042
1043 ASSERT_EQ(result.size(), 1);
1044 EXPECT_THAT(result[0].collection, "date");
1045 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
1046 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1047 DatetimeGranularity::GRANULARITY_DAY);
1048
1049 result.clear();
1050 options.reference_timezone = "Europe/Zurich";
Lukas Zilka434442d2018-04-25 11:38:51 +02001051 options.locales = "de";
1052 result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001053
1054 ASSERT_EQ(result.size(), 1);
1055 EXPECT_THAT(result[0].collection, "date");
1056 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
1057 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1058 DatetimeGranularity::GRANULARITY_DAY);
1059}
Tony Maka0f598b2018-11-20 20:39:04 +00001060#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001061
Tony Maka0f598b2018-11-20 20:39:04 +00001062#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001063TEST_P(AnnotatorTest, SuggestTextDateDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001064 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1065 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1066
1067 // Disable the patterns for selection.
1068 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1069 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1070 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1071 }
1072 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001073 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +01001074
Tony Mak6c4cc672018-09-17 11:48:50 +01001075 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1076 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1077 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001078 ASSERT_TRUE(classifier);
1079 EXPECT_EQ("date",
1080 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1081 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1082 std::make_pair(0, 7));
1083 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1084 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1085}
Tony Maka0f598b2018-11-20 20:39:04 +00001086#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +01001087
Tony Mak6c4cc672018-09-17 11:48:50 +01001088class TestingAnnotator : public Annotator {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001089 public:
Tony Mak6c4cc672018-09-17 11:48:50 +01001090 TestingAnnotator(const std::string& model, const UniLib* unilib,
1091 const CalendarLib* calendarlib)
Tony Mak854015a2019-01-16 15:56:48 +00001092 : Annotator(libtextclassifier3::ViewModel(model.data(), model.size()),
1093 unilib, calendarlib) {}
Lukas Zilkab23e2122018-02-09 10:25:19 +01001094
Tony Mak6c4cc672018-09-17 11:48:50 +01001095 using Annotator::ResolveConflicts;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001096};
1097
1098AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
1099 const std::string& collection,
1100 const float score) {
1101 AnnotatedSpan result;
1102 result.span = span;
1103 result.classification.push_back({collection, score});
1104 return result;
1105}
1106
Tony Mak6c4cc672018-09-17 11:48:50 +01001107TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1108 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001109
1110 std::vector<AnnotatedSpan> candidates{
1111 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1112
1113 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001114 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001115 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001116 EXPECT_THAT(chosen, ElementsAreArray({0}));
1117}
1118
Tony Mak6c4cc672018-09-17 11:48:50 +01001119TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1120 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001121
1122 std::vector<AnnotatedSpan> candidates{{
1123 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1124 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1125 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1126 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1127 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1128 }};
1129
1130 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001131 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001132 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001133 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1134}
1135
Tony Mak6c4cc672018-09-17 11:48:50 +01001136TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1137 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001138
1139 std::vector<AnnotatedSpan> candidates{{
1140 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1141 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1142 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1143 }};
1144
1145 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001146 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001147 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001148 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1149}
1150
Tony Mak6c4cc672018-09-17 11:48:50 +01001151TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1152 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001153
1154 std::vector<AnnotatedSpan> candidates{{
1155 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1156 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1157 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1158 }};
1159
1160 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001161 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001162 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001163 EXPECT_THAT(chosen, ElementsAreArray({1}));
1164}
1165
Tony Mak6c4cc672018-09-17 11:48:50 +01001166TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
1167 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001168
1169 std::vector<AnnotatedSpan> candidates{{
1170 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1171 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
1172 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1173 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
1174 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1175 }};
1176
1177 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001178 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001179 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001180 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
1181}
Lukas Zilka21d8c982018-01-24 11:11:20 +01001182
Tony Maka0f598b2018-11-20 20:39:04 +00001183#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001184TEST_P(AnnotatorTest, LongInput) {
1185 std::unique_ptr<Annotator> classifier =
1186 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkadf710db2018-02-27 12:44:09 +01001187 ASSERT_TRUE(classifier);
1188
1189 for (const auto& type_value_pair :
1190 std::vector<std::pair<std::string, std::string>>{
1191 {"address", "350 Third Street, Cambridge"},
1192 {"phone", "123 456-7890"},
1193 {"url", "www.google.com"},
1194 {"email", "someone@gmail.com"},
1195 {"flight", "LX 38"},
1196 {"date", "September 1, 2018"}}) {
1197 const std::string input_100k = std::string(50000, ' ') +
1198 type_value_pair.second +
1199 std::string(50000, ' ');
1200 const int value_length = type_value_pair.second.size();
1201
1202 EXPECT_THAT(classifier->Annotate(input_100k),
1203 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
1204 type_value_pair.first)}));
1205 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
1206 std::make_pair(50000, 50000 + value_length));
1207 EXPECT_EQ(type_value_pair.first,
1208 FirstResult(classifier->ClassifyText(
1209 input_100k, {50000, 50000 + value_length})));
1210 }
1211}
Tony Maka0f598b2018-11-20 20:39:04 +00001212#endif // TC3_UNILIB_ICU
Lukas Zilkadf710db2018-02-27 12:44:09 +01001213
Tony Maka0f598b2018-11-20 20:39:04 +00001214#ifdef TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001215// These coarse tests are there only to make sure the execution happens in
1216// reasonable amount of time.
Tony Mak6c4cc672018-09-17 11:48:50 +01001217TEST_P(AnnotatorTest, LongInputNoResultCheck) {
1218 std::unique_ptr<Annotator> classifier =
1219 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001220 ASSERT_TRUE(classifier);
1221
1222 for (const std::string& value :
1223 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
1224 const std::string input_100k =
1225 std::string(50000, ' ') + value + std::string(50000, ' ');
1226 const int value_length = value.size();
1227
1228 classifier->Annotate(input_100k);
1229 classifier->SuggestSelection(input_100k, {50000, 50001});
1230 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
1231 }
1232}
Tony Maka0f598b2018-11-20 20:39:04 +00001233#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001234
Tony Maka0f598b2018-11-20 20:39:04 +00001235#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001236TEST_P(AnnotatorTest, MaxTokenLength) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001237 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1238 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1239
Tony Mak6c4cc672018-09-17 11:48:50 +01001240 std::unique_ptr<Annotator> classifier;
Lukas Zilka434442d2018-04-25 11:38:51 +02001241
1242 // With unrestricted number of tokens should behave normally.
1243 unpacked_model->classification_options->max_num_tokens = -1;
1244
1245 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001246 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001247 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001248 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001249 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001250 ASSERT_TRUE(classifier);
1251
1252 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1253 "I live at 350 Third Street, Cambridge.", {10, 37})),
1254 "address");
1255
1256 // Raise the maximum number of tokens to suppress the classification.
1257 unpacked_model->classification_options->max_num_tokens = 3;
1258
1259 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001260 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001261 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001262 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001263 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001264 ASSERT_TRUE(classifier);
1265
1266 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1267 "I live at 350 Third Street, Cambridge.", {10, 37})),
1268 "other");
1269}
Tony Maka0f598b2018-11-20 20:39:04 +00001270#endif // TC3_UNILIB_ICU
Lukas Zilka434442d2018-04-25 11:38:51 +02001271
Tony Maka0f598b2018-11-20 20:39:04 +00001272#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001273TEST_P(AnnotatorTest, MinAddressTokenLength) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001274 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1275 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1276
Tony Mak6c4cc672018-09-17 11:48:50 +01001277 std::unique_ptr<Annotator> classifier;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001278
1279 // With unrestricted number of address tokens should behave normally.
1280 unpacked_model->classification_options->address_min_num_tokens = 0;
1281
1282 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001283 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001284 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001285 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001286 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001287 ASSERT_TRUE(classifier);
1288
1289 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1290 "I live at 350 Third Street, Cambridge.", {10, 37})),
1291 "address");
1292
1293 // Raise number of address tokens to suppress the address classification.
1294 unpacked_model->classification_options->address_min_num_tokens = 5;
1295
1296 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001297 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001298 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001299 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001300 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001301 ASSERT_TRUE(classifier);
1302
1303 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1304 "I live at 350 Third Street, Cambridge.", {10, 37})),
1305 "other");
1306}
Tony Maka0f598b2018-11-20 20:39:04 +00001307#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001308
Tony Mak854015a2019-01-16 15:56:48 +00001309TEST_F(AnnotatorTest, VisitAnnotatorModel) {
1310 EXPECT_TRUE(VisitAnnotatorModel<bool>(GetModelPath() + "test_model.fb",
1311 [](const Model* model) {
1312 if (model == nullptr) {
1313 return false;
1314 }
1315 return true;
1316 }));
1317 EXPECT_FALSE(VisitAnnotatorModel<bool>(
1318 GetModelPath() + "non_existing_model.fb", [](const Model* model) {
1319 if (model == nullptr) {
1320 return false;
1321 }
1322 return true;
1323 }));
1324}
1325
Lukas Zilka21d8c982018-01-24 11:11:20 +01001326} // namespace
Tony Mak6c4cc672018-09-17 11:48:50 +01001327} // namespace libtextclassifier3