blob: e9084842f300bde2811169b8b0ecf37a9b619fe9 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <fstream>
20#include <iostream>
21#include <memory>
22#include <string>
23
Tony Mak6c4cc672018-09-17 11:48:50 +010024#include "annotator/model_generated.h"
25#include "annotator/types-test-util.h"
Tony Mak378c1f52019-03-04 15:58:11 +000026#include "utils/testing/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010027#include "gmock/gmock.h"
28#include "gtest/gtest.h"
29
Tony Mak6c4cc672018-09-17 11:48:50 +010030namespace libtextclassifier3 {
Lukas Zilka21d8c982018-01-24 11:11:20 +010031namespace {
32
Tony Mak378c1f52019-03-04 15:58:11 +000033using testing::ElementsAre;
Lukas Zilka21d8c982018-01-24 11:11:20 +010034using testing::ElementsAreArray;
Lukas Zilkaba849e72018-03-08 14:48:21 +010035using testing::IsEmpty;
Lukas Zilka21d8c982018-01-24 11:11:20 +010036using testing::Pair;
Lukas Zilkab23e2122018-02-09 10:25:19 +010037using testing::Values;
Lukas Zilka21d8c982018-01-24 11:11:20 +010038
Lukas Zilkab23e2122018-02-09 10:25:19 +010039std::string FirstResult(const std::vector<ClassificationResult>& results) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010040 if (results.empty()) {
41 return "<INVALID RESULTS>";
42 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010043 return results[0].collection;
Lukas Zilka21d8c982018-01-24 11:11:20 +010044}
45
46MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
47 return testing::Value(arg.span, Pair(start, end)) &&
48 testing::Value(FirstResult(arg.classification), best_class);
49}
50
Tony Mak378c1f52019-03-04 15:58:11 +000051MATCHER_P2(IsDateResult, time_ms_utc, granularity, "") {
52 return testing::Value(arg.collection, "date") &&
53 testing::Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
54 testing::Value(arg.datetime_parse_result.granularity, granularity);
55}
56
57MATCHER_P2(IsDatetimeResult, time_ms_utc, granularity, "") {
58 return testing::Value(arg.collection, "datetime") &&
59 testing::Value(arg.datetime_parse_result.time_ms_utc, time_ms_utc) &&
60 testing::Value(arg.datetime_parse_result.granularity, granularity);
61}
62
Lukas Zilka21d8c982018-01-24 11:11:20 +010063std::string ReadFile(const std::string& file_name) {
64 std::ifstream file_stream(file_name);
65 return std::string(std::istreambuf_iterator<char>(file_stream), {});
66}
67
68std::string GetModelPath() {
Tony Maka0f598b2018-11-20 20:39:04 +000069 return TC3_TEST_DATA_DIR;
Lukas Zilka21d8c982018-01-24 11:11:20 +010070}
71
Tony Mak378c1f52019-03-04 15:58:11 +000072std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
73
Tony Makd9446602019-02-20 18:25:39 +000074// Create fake entity data schema meta data.
75void AddTestEntitySchemaData(ModelT* unpacked_model) {
76 // Cannot use object oriented API here as that is not available for the
77 // reflection schema.
78 flatbuffers::FlatBufferBuilder schema_builder;
79 std::vector<flatbuffers::Offset<reflection::Field>> fields = {
80 reflection::CreateField(
81 schema_builder,
82 /*name=*/schema_builder.CreateString("first_name"),
83 /*type=*/
84 reflection::CreateType(schema_builder,
85 /*base_type=*/reflection::String),
86 /*id=*/0,
87 /*offset=*/4),
88 reflection::CreateField(
89 schema_builder,
90 /*name=*/schema_builder.CreateString("is_alive"),
91 /*type=*/
92 reflection::CreateType(schema_builder,
93 /*base_type=*/reflection::Bool),
94 /*id=*/1,
95 /*offset=*/6),
96 reflection::CreateField(
97 schema_builder,
98 /*name=*/schema_builder.CreateString("last_name"),
99 /*type=*/
100 reflection::CreateType(schema_builder,
101 /*base_type=*/reflection::String),
102 /*id=*/2,
103 /*offset=*/8),
Tony Mak378c1f52019-03-04 15:58:11 +0000104 reflection::CreateField(
105 schema_builder,
106 /*name=*/schema_builder.CreateString("age"),
107 /*type=*/
108 reflection::CreateType(schema_builder,
109 /*base_type=*/reflection::Int),
110 /*id=*/3,
111 /*offset=*/10),
Tony Makd9446602019-02-20 18:25:39 +0000112 };
113 std::vector<flatbuffers::Offset<reflection::Enum>> enums;
114 std::vector<flatbuffers::Offset<reflection::Object>> objects = {
115 reflection::CreateObject(
116 schema_builder,
117 /*name=*/schema_builder.CreateString("EntityData"),
118 /*fields=*/
119 schema_builder.CreateVectorOfSortedTables(&fields))};
120 schema_builder.Finish(reflection::CreateSchema(
121 schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
122 schema_builder.CreateVectorOfSortedTables(&enums),
123 /*(unused) file_ident=*/0,
124 /*(unused) file_ext=*/0,
125 /*root_table*/ objects[0]));
126
127 unpacked_model->entity_data_schema.assign(
128 schema_builder.GetBufferPointer(),
129 schema_builder.GetBufferPointer() + schema_builder.GetSize());
130}
131
Tony Mak6c4cc672018-09-17 11:48:50 +0100132class AnnotatorTest : public ::testing::TestWithParam<const char*> {
133 protected:
134 AnnotatorTest()
135 : INIT_UNILIB_FOR_TESTING(unilib_),
136 INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
137 UniLib unilib_;
138 CalendarLib calendarlib_;
139};
140
141TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
142 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
143 GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100144 EXPECT_FALSE(classifier);
145}
146
Tony Mak378c1f52019-03-04 15:58:11 +0000147#ifdef TC3_UNILIB_ICU
148TEST_F(AnnotatorTest, ClassifyText) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100149 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000150 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100151 ASSERT_TRUE(classifier);
152
153 EXPECT_EQ("other",
154 FirstResult(classifier->ClassifyText(
155 "this afternoon Barack Obama gave a speech at", {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100156 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
157 "Call me at (800) 123-456 today", {11, 24})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100158
159 // More lines.
160 EXPECT_EQ("other",
161 FirstResult(classifier->ClassifyText(
162 "this afternoon Barack Obama gave a speech at|Visit "
163 "www.google.com every today!|Call me at (800) 123-456 today.",
164 {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100165 EXPECT_EQ("phone",
166 FirstResult(classifier->ClassifyText(
167 "this afternoon Barack Obama gave a speech at|Visit "
168 "www.google.com every today!|Call me at (800) 123-456 today.",
169 {90, 103})));
170
171 // Single word.
172 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
173 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100174
Tony Mak378c1f52019-03-04 15:58:11 +0000175 // Junk. These should not crash the test.
176 classifier->ClassifyText("", {0, 0});
177 classifier->ClassifyText("asdf", {0, 0});
178 classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200179 // Test invalid utf8 input.
180 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
181 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100182}
Tony Mak378c1f52019-03-04 15:58:11 +0000183#endif
Lukas Zilka21d8c982018-01-24 11:11:20 +0100184
Tony Mak378c1f52019-03-04 15:58:11 +0000185#ifdef TC3_UNILIB_ICU
186TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
187 std::unique_ptr<Annotator> classifier =
188 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
189 ASSERT_TRUE(classifier);
190
191 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 6})));
192
193 ClassificationOptions classification_options;
194 classification_options.detected_text_language_tags = "en";
195 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
196 "isotope", {0, 6}, classification_options)));
197
198 classification_options.detected_text_language_tags = "uz";
199 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
200 "isotope", {0, 6}, classification_options)));
201}
202#endif
203
204#ifdef TC3_UNILIB_ICU
205TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
206 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100207 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
208
209 unpacked_model->classification_model.clear();
210 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
211 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
212
213 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000214 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100215
Tony Mak6c4cc672018-09-17 11:48:50 +0100216 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
217 reinterpret_cast<const char*>(builder.GetBufferPointer()),
218 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100219
220 // The classification model is still needed for selection scores.
221 ASSERT_FALSE(classifier);
222}
Tony Mak378c1f52019-03-04 15:58:11 +0000223#endif
Lukas Zilkaba849e72018-03-08 14:48:21 +0100224
Tony Mak378c1f52019-03-04 15:58:11 +0000225#ifdef TC3_UNILIB_ICU
226TEST_F(AnnotatorTest, ClassifyTextDisabled) {
227 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100228 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
229
Tony Mak378c1f52019-03-04 15:58:11 +0000230 unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100231
232 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000233 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100234
Tony Mak6c4cc672018-09-17 11:48:50 +0100235 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
236 reinterpret_cast<const char*>(builder.GetBufferPointer()),
237 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100238 ASSERT_TRUE(classifier);
239
240 EXPECT_THAT(
241 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
242 IsEmpty());
243}
Tony Mak378c1f52019-03-04 15:58:11 +0000244#endif
Lukas Zilkaba849e72018-03-08 14:48:21 +0100245
Tony Mak378c1f52019-03-04 15:58:11 +0000246#ifdef TC3_UNILIB_ICU
247TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
248 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200249
Tony Mak6c4cc672018-09-17 11:48:50 +0100250 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
251 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200252 ASSERT_TRUE(classifier);
253
254 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
255 "Call me at (800) 123-456 today", {11, 24})));
256
257 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
258 unpacked_model->output_options.reset(new OutputOptionsT);
259
260 // Disable phone classification
261 unpacked_model->output_options->filtered_collections_classification.push_back(
262 "phone");
263
264 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000265 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200266
Tony Mak6c4cc672018-09-17 11:48:50 +0100267 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200268 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100269 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200270 ASSERT_TRUE(classifier);
271
272 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
273 "Call me at (800) 123-456 today", {11, 24})));
274
275 // Check that the address classification still passes.
276 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
277 "350 Third Street, Cambridge", {0, 27})));
278}
Tony Mak378c1f52019-03-04 15:58:11 +0000279#endif
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200280
Tony Mak378c1f52019-03-04 15:58:11 +0000281#ifdef TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100282std::unique_ptr<RegexModel_::PatternT> MakePattern(
283 const std::string& collection_name, const std::string& pattern,
284 const bool enabled_for_classification, const bool enabled_for_selection,
Tony Mak378c1f52019-03-04 15:58:11 +0000285 const bool enabled_for_annotation, const float score,
286 const float priority_score) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100287 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
288 result->collection_name = collection_name;
289 result->pattern = pattern;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100290 // We cannot directly operate with |= on the flag, so use an int here.
291 int enabled_modes = ModeFlag_NONE;
292 if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
293 if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
294 if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
295 result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100296 result->target_classification_score = score;
Tony Mak378c1f52019-03-04 15:58:11 +0000297 result->priority_score = priority_score;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100298 return result;
299}
300
Tony Mak378c1f52019-03-04 15:58:11 +0000301// Shortcut function that doesn't need to specify the priority score.
302std::unique_ptr<RegexModel_::PatternT> MakePattern(
303 const std::string& collection_name, const std::string& pattern,
304 const bool enabled_for_classification, const bool enabled_for_selection,
305 const bool enabled_for_annotation, const float score) {
306 return MakePattern(collection_name, pattern, enabled_for_classification,
307 enabled_for_selection, enabled_for_annotation,
308 /*score=*/score,
309 /*priority_score=*/score);
310}
311#endif // TC3_UNILIB_ICU
312
Tony Maka0f598b2018-11-20 20:39:04 +0000313#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +0000314TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
315 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100316 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
317
318 // Add test regex models.
319 unpacked_model->regex_model->patterns.push_back(MakePattern(
320 "person", "Barack Obama", /*enabled_for_classification=*/true,
321 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
322 unpacked_model->regex_model->patterns.push_back(MakePattern(
323 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
324 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100325 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
326 MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
327 /*enabled_for_classification=*/true,
328 /*enabled_for_selection=*/false,
329 /*enabled_for_annotation=*/false, 1.0);
330 verified_pattern->verification_options.reset(new VerificationOptionsT);
331 verified_pattern->verification_options->verify_luhn_checksum = true;
332 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100333
334 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000335 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100336
Tony Mak6c4cc672018-09-17 11:48:50 +0100337 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
338 reinterpret_cast<const char*>(builder.GetBufferPointer()),
339 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100340 ASSERT_TRUE(classifier);
341
342 EXPECT_EQ("flight",
343 FirstResult(classifier->ClassifyText(
344 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
345 EXPECT_EQ("person",
346 FirstResult(classifier->ClassifyText(
347 "this afternoon Barack Obama gave a speech at", {15, 27})));
348 EXPECT_EQ("email",
349 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
350 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
351 "Contact me at you@android.com", {14, 29})));
352
353 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
354 "Visit www.google.com every today!", {6, 20})));
355
356 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
357 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
358 {7, 12})));
Tony Mak6c4cc672018-09-17 11:48:50 +0100359 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
360 "cc: 4012 8888 8888 1881", {4, 23})));
361 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
362 "2221 0067 4735 6281", {0, 19})));
363 // Luhn check fails.
364 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
365 {0, 19})));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100366
367 // More lines.
368 EXPECT_EQ("url",
369 FirstResult(classifier->ClassifyText(
370 "this afternoon Barack Obama gave a speech at|Visit "
371 "www.google.com every today!|Call me at (800) 123-456 today.",
372 {51, 65})));
373}
Tony Mak378c1f52019-03-04 15:58:11 +0000374#endif // TC3_UNILIB_ICU
Tony Makd9446602019-02-20 18:25:39 +0000375
Tony Mak378c1f52019-03-04 15:58:11 +0000376#ifdef TC3_UNILIB_ICU
377TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
378 const std::string test_model = ReadFile(GetTestModelPath());
Tony Makd9446602019-02-20 18:25:39 +0000379 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
380
381 // Add fake entity schema metadata.
382 AddTestEntitySchemaData(unpacked_model.get());
383
384 // Add test regex models.
385 unpacked_model->regex_model->patterns.push_back(MakePattern(
Tony Mak378c1f52019-03-04 15:58:11 +0000386 "person_with_age", "(Barack) (Obama) is (\\d+) years old",
387 /*enabled_for_classification=*/true,
Tony Makd9446602019-02-20 18:25:39 +0000388 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
389
390 // Use meta data to generate custom serialized entity data.
391 ReflectiveFlatbufferBuilder entity_data_builder(
392 flatbuffers::GetRoot<reflection::Schema>(
393 unpacked_model->entity_data_schema.data()));
394 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
395 entity_data_builder.NewRoot();
396 entity_data->Set("is_alive", true);
397
398 RegexModel_::PatternT* pattern =
399 unpacked_model->regex_model->patterns.back().get();
400 pattern->serialized_entity_data = entity_data->Serialize();
401 pattern->capturing_group.emplace_back(
402 new RegexModel_::Pattern_::CapturingGroupT);
403 pattern->capturing_group.emplace_back(
404 new RegexModel_::Pattern_::CapturingGroupT);
405 pattern->capturing_group.emplace_back(
406 new RegexModel_::Pattern_::CapturingGroupT);
Tony Mak378c1f52019-03-04 15:58:11 +0000407 pattern->capturing_group.emplace_back(
408 new RegexModel_::Pattern_::CapturingGroupT);
Tony Makd9446602019-02-20 18:25:39 +0000409 // Group 0 is the full match, capturing groups starting at 1.
410 pattern->capturing_group[1]->entity_field_path.reset(
411 new FlatbufferFieldPathT);
412 pattern->capturing_group[1]->entity_field_path->field.emplace_back(
413 new FlatbufferFieldT);
414 pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
415 "first_name";
416 pattern->capturing_group[2]->entity_field_path.reset(
417 new FlatbufferFieldPathT);
418 pattern->capturing_group[2]->entity_field_path->field.emplace_back(
419 new FlatbufferFieldT);
420 pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
421 "last_name";
Tony Mak378c1f52019-03-04 15:58:11 +0000422 pattern->capturing_group[3]->entity_field_path.reset(
423 new FlatbufferFieldPathT);
424 pattern->capturing_group[3]->entity_field_path->field.emplace_back(
425 new FlatbufferFieldT);
426 pattern->capturing_group[3]->entity_field_path->field.back()->field_name =
427 "age";
Tony Makd9446602019-02-20 18:25:39 +0000428
429 flatbuffers::FlatBufferBuilder builder;
430 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
431
432 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
433 reinterpret_cast<const char*>(builder.GetBufferPointer()),
434 builder.GetSize(), &unilib_, &calendarlib_);
435 ASSERT_TRUE(classifier);
436
Tony Mak378c1f52019-03-04 15:58:11 +0000437 auto classifications =
438 classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
Tony Makd9446602019-02-20 18:25:39 +0000439 EXPECT_EQ(1, classifications.size());
Tony Mak378c1f52019-03-04 15:58:11 +0000440 EXPECT_EQ("person_with_age", classifications[0].collection);
Tony Makd9446602019-02-20 18:25:39 +0000441
442 // Check entity data.
443 const flatbuffers::Table* entity =
444 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
445 classifications[0].serialized_entity_data.data()));
446 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
447 "Barack");
448 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
449 "Obama");
Tony Mak378c1f52019-03-04 15:58:11 +0000450 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
Tony Makd9446602019-02-20 18:25:39 +0000451 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
452}
Tony Maka0f598b2018-11-20 20:39:04 +0000453#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100454
Tony Maka0f598b2018-11-20 20:39:04 +0000455#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +0000456TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
457 const std::string test_model = ReadFile(GetTestModelPath());
458 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
459 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
460 // Add test regex models.
461 unpacked_model->regex_model->patterns.clear();
462 unpacked_model->regex_model->patterns.push_back(MakePattern(
463 "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
464 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
465 /*score=*/1.0, /*priority_score=*/1.0));
466 unpacked_model->regex_model->patterns.push_back(MakePattern(
467 "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
468 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
469 /*score=*/1.0, /*priority_score=*/0.0));
470
471 {
472 flatbuffers::FlatBufferBuilder builder;
473 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
474 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
475 reinterpret_cast<const char*>(builder.GetBufferPointer()),
476 builder.GetSize(), &unilib_, &calendarlib_);
477 ASSERT_TRUE(classifier);
478
479 EXPECT_EQ("flight1",
480 FirstResult(classifier->ClassifyText(
481 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
482 }
483
484 unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
485 {
486 flatbuffers::FlatBufferBuilder builder;
487 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
488 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
489 reinterpret_cast<const char*>(builder.GetBufferPointer()),
490 builder.GetSize(), &unilib_, &calendarlib_);
491 ASSERT_TRUE(classifier);
492
493 EXPECT_EQ("flight2",
494 FirstResult(classifier->ClassifyText(
495 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
496 }
497}
498#endif // TC3_UNILIB_ICU
499
500#ifdef TC3_UNILIB_ICU
501TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
502 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100503 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
504
505 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100506 unpacked_model->regex_model->patterns.push_back(MakePattern(
507 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
508 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
509 unpacked_model->regex_model->patterns.push_back(MakePattern(
510 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
511 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
512 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
Tony Mak6c4cc672018-09-17 11:48:50 +0100513 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
514 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
515 /*enabled_for_classification=*/false,
516 /*enabled_for_selection=*/true,
517 /*enabled_for_annotation=*/false, 1.0);
518 verified_pattern->verification_options.reset(new VerificationOptionsT);
519 verified_pattern->verification_options->verify_luhn_checksum = true;
520 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100521
522 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000523 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100524
Tony Mak6c4cc672018-09-17 11:48:50 +0100525 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
526 reinterpret_cast<const char*>(builder.GetBufferPointer()),
527 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100528 ASSERT_TRUE(classifier);
529
530 // Check regular expression selection.
531 EXPECT_EQ(classifier->SuggestSelection(
532 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
533 std::make_pair(12, 19));
534 EXPECT_EQ(classifier->SuggestSelection(
535 "this afternoon Barack Obama gave a speech at", {15, 21}),
536 std::make_pair(15, 27));
Tony Mak6c4cc672018-09-17 11:48:50 +0100537 EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
538 std::make_pair(4, 23));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100539}
Tony Mak854015a2019-01-16 15:56:48 +0000540
Tony Mak378c1f52019-03-04 15:58:11 +0000541TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
542 const std::string test_model = ReadFile(GetTestModelPath());
Tony Mak854015a2019-01-16 15:56:48 +0000543 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
544
545 // Add test regex models.
546 std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
547 MakePattern("date_range",
548 "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
549 "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
550 /*enabled_for_classification=*/false,
551 /*enabled_for_selection=*/true,
552 /*enabled_for_annotation=*/false, 1.0);
553 custom_selection_bounds_pattern->capturing_group.emplace_back(
554 new RegexModel_::Pattern_::CapturingGroupT);
555 custom_selection_bounds_pattern->capturing_group.emplace_back(
556 new RegexModel_::Pattern_::CapturingGroupT);
557 custom_selection_bounds_pattern->capturing_group.emplace_back(
558 new RegexModel_::Pattern_::CapturingGroupT);
559 custom_selection_bounds_pattern->capturing_group.emplace_back(
560 new RegexModel_::Pattern_::CapturingGroupT);
561 custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
562 custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
563 custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
564 custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
565 unpacked_model->regex_model->patterns.push_back(
566 std::move(custom_selection_bounds_pattern));
567
568 flatbuffers::FlatBufferBuilder builder;
569 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
570
571 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
572 reinterpret_cast<const char*>(builder.GetBufferPointer()),
573 builder.GetSize(), &unilib_, &calendarlib_);
574 ASSERT_TRUE(classifier);
575
576 // Check regular expression selection.
577 EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
578 {21, 23}),
579 std::make_pair(10, 34));
580 EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
581 std::make_pair(9, 17));
582}
Tony Maka0f598b2018-11-20 20:39:04 +0000583#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100584
Tony Maka0f598b2018-11-20 20:39:04 +0000585#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +0000586TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
587 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100588 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
589
590 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100591 unpacked_model->regex_model->patterns.push_back(MakePattern(
592 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
593 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
594 unpacked_model->regex_model->patterns.push_back(MakePattern(
595 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
596 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
597 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
598
599 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000600 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100601
Tony Mak6c4cc672018-09-17 11:48:50 +0100602 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
603 reinterpret_cast<const char*>(builder.GetBufferPointer()),
604 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100605 ASSERT_TRUE(classifier);
606
607 // Check conflict resolution.
608 EXPECT_EQ(
609 classifier->SuggestSelection(
610 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
611 {55, 57}),
612 std::make_pair(26, 62));
613}
Tony Maka0f598b2018-11-20 20:39:04 +0000614#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100615
Tony Maka0f598b2018-11-20 20:39:04 +0000616#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +0000617TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
618 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100619 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
620
621 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100622 unpacked_model->regex_model->patterns.push_back(MakePattern(
623 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
624 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
625 unpacked_model->regex_model->patterns.push_back(MakePattern(
626 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
627 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
628 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
629
630 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000631 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100632
Tony Mak6c4cc672018-09-17 11:48:50 +0100633 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
634 reinterpret_cast<const char*>(builder.GetBufferPointer()),
635 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100636 ASSERT_TRUE(classifier);
637
638 // Check conflict resolution.
639 EXPECT_EQ(
640 classifier->SuggestSelection(
641 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
642 {55, 57}),
643 std::make_pair(55, 62));
644}
Tony Maka0f598b2018-11-20 20:39:04 +0000645#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100646
Tony Maka0f598b2018-11-20 20:39:04 +0000647#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +0000648TEST_F(AnnotatorTest, AnnotateRegex) {
649 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100650 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
651
652 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100653 unpacked_model->regex_model->patterns.push_back(MakePattern(
654 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
655 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
656 unpacked_model->regex_model->patterns.push_back(MakePattern(
657 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
658 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100659 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
660 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
661 /*enabled_for_classification=*/false,
662 /*enabled_for_selection=*/false,
663 /*enabled_for_annotation=*/true, 1.0);
664 verified_pattern->verification_options.reset(new VerificationOptionsT);
665 verified_pattern->verification_options->verify_luhn_checksum = true;
666 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100667 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000668 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100669
Tony Mak6c4cc672018-09-17 11:48:50 +0100670 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
671 reinterpret_cast<const char*>(builder.GetBufferPointer()),
672 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100673 ASSERT_TRUE(classifier);
674
675 const std::string test_string =
676 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
Tony Mak6c4cc672018-09-17 11:48:50 +0100677 "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100678 EXPECT_THAT(classifier->Annotate(test_string),
Tony Mak6c4cc672018-09-17 11:48:50 +0100679 ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
680 IsAnnotatedSpan(28, 55, "address"),
681 IsAnnotatedSpan(79, 91, "phone"),
682 IsAnnotatedSpan(107, 126, "payment_card")}));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100683}
Tony Maka0f598b2018-11-20 20:39:04 +0000684#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100685
Tony Mak378c1f52019-03-04 15:58:11 +0000686#ifdef TC3_UNILIB_ICU
687TEST_F(AnnotatorTest, PhoneFiltering) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100688 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000689 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100690 ASSERT_TRUE(classifier);
691
692 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
693 "phone: (123) 456 789", {7, 20})));
694 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
695 "phone: (123) 456 789,0001112", {7, 25})));
696 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
697 "phone: (123) 456 789,0001112", {7, 28})));
698}
Tony Mak378c1f52019-03-04 15:58:11 +0000699#endif // TC3_UNILIB_ICU
Lukas Zilka21d8c982018-01-24 11:11:20 +0100700
Tony Mak378c1f52019-03-04 15:58:11 +0000701TEST_F(AnnotatorTest, SuggestSelection) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100702 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000703 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100704 ASSERT_TRUE(classifier);
705
706 EXPECT_EQ(classifier->SuggestSelection(
707 "this afternoon Barack Obama gave a speech at", {15, 21}),
708 std::make_pair(15, 21));
709
710 // Try passing whole string.
711 // If more than 1 token is specified, we should return back what entered.
712 EXPECT_EQ(
713 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
714 std::make_pair(0, 27));
715
716 // Single letter.
717 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
718
719 // Single word.
720 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
721
722 EXPECT_EQ(
723 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
724 std::make_pair(11, 23));
725
726 // Unpaired bracket stripping.
727 EXPECT_EQ(
728 classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
729 std::make_pair(11, 25));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100730 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
731 std::make_pair(12, 15));
732 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
733 std::make_pair(11, 15));
734 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
735 std::make_pair(12, 15));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100736
737 // If the resulting selection would be empty, the original span is returned.
738 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
739 std::make_pair(11, 13));
740 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
741 std::make_pair(11, 12));
742 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
743 std::make_pair(11, 12));
744}
745
Tony Mak378c1f52019-03-04 15:58:11 +0000746TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
747 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100748 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
749
750 // Disable the selection model.
751 unpacked_model->selection_model.clear();
752 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
753 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
754
755 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000756 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100757
Tony Mak6c4cc672018-09-17 11:48:50 +0100758 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
759 reinterpret_cast<const char*>(builder.GetBufferPointer()),
760 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100761 // Selection model needs to be present for annotation.
762 ASSERT_FALSE(classifier);
763}
764
Tony Mak378c1f52019-03-04 15:58:11 +0000765TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
766 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +0100767 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
768
769 // Disable the selection model.
770 unpacked_model->selection_model.clear();
771 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
772 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
773 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
774
775 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000776 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100777
Tony Mak6c4cc672018-09-17 11:48:50 +0100778 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
779 reinterpret_cast<const char*>(builder.GetBufferPointer()),
780 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100781 ASSERT_TRUE(classifier);
782
783 EXPECT_EQ(
784 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
785 std::make_pair(11, 14));
786
Tony Mak378c1f52019-03-04 15:58:11 +0000787#ifdef TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +0100788 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
789 "call me at (800) 123-456 today", {11, 24})));
Tony Mak378c1f52019-03-04 15:58:11 +0000790#endif
Lukas Zilkaba849e72018-03-08 14:48:21 +0100791
792 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
793 IsEmpty());
794}
795
Tony Mak378c1f52019-03-04 15:58:11 +0000796TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
797 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200798
Tony Mak6c4cc672018-09-17 11:48:50 +0100799 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
800 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200801 ASSERT_TRUE(classifier);
802
803 EXPECT_EQ(
804 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
805 std::make_pair(11, 23));
806
807 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
808 unpacked_model->output_options.reset(new OutputOptionsT);
809
810 // Disable phone selection
811 unpacked_model->output_options->filtered_collections_selection.push_back(
812 "phone");
813 // We need to force this for filtering.
814 unpacked_model->selection_options->always_classify_suggested_selection = true;
815
816 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000817 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200818
Tony Mak6c4cc672018-09-17 11:48:50 +0100819 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200820 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100821 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200822 ASSERT_TRUE(classifier);
823
824 EXPECT_EQ(
825 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
826 std::make_pair(11, 14));
827
828 // Address selection should still work.
829 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
830 std::make_pair(0, 27));
831}
832
Tony Mak378c1f52019-03-04 15:58:11 +0000833TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100834 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000835 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100836 ASSERT_TRUE(classifier);
837
838 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
839 std::make_pair(0, 27));
840 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
841 std::make_pair(0, 27));
842 EXPECT_EQ(
843 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
844 std::make_pair(0, 27));
845 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
846 {16, 22}),
847 std::make_pair(6, 33));
848}
849
Tony Mak378c1f52019-03-04 15:58:11 +0000850TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100851 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000852 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100853 ASSERT_TRUE(classifier);
854
855 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
856 std::make_pair(4, 16));
857 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
858 std::make_pair(0, 12));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100859
860 SelectionOptions options;
861 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
Tony Makd9446602019-02-20 18:25:39 +0000862 std::make_pair(0, 12));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100863}
864
Tony Mak378c1f52019-03-04 15:58:11 +0000865TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100866 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000867 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100868 ASSERT_TRUE(classifier);
869
870 // From the right.
871 EXPECT_EQ(classifier->SuggestSelection(
872 "this afternoon BarackObama, gave a speech at", {15, 26}),
873 std::make_pair(15, 26));
874
875 // From the right multiple.
876 EXPECT_EQ(classifier->SuggestSelection(
877 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
878 std::make_pair(15, 26));
879
880 // From the left multiple.
881 EXPECT_EQ(classifier->SuggestSelection(
882 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
883 std::make_pair(21, 32));
884
885 // From both sides.
886 EXPECT_EQ(classifier->SuggestSelection(
887 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
888 std::make_pair(16, 27));
889}
890
Tony Mak378c1f52019-03-04 15:58:11 +0000891TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100892 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000893 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100894 ASSERT_TRUE(classifier);
895
896 // Try passing in bunch of invalid selections.
897 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
898 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
899 std::make_pair(-10, 27));
900 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
901 std::make_pair(0, 27));
902 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
903 std::make_pair(-30, 300));
904 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
905 std::make_pair(-10, -1));
906 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
907 std::make_pair(100, 17));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200908
909 // Try passing invalid utf8.
910 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
911 std::make_pair(-1, -1));
912}
913
Tony Mak378c1f52019-03-04 15:58:11 +0000914TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100915 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000916 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200917 ASSERT_TRUE(classifier);
918
919 EXPECT_EQ(
920 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
921 std::make_pair(11, 23));
922 EXPECT_EQ(
923 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
924 std::make_pair(10, 11));
925 EXPECT_EQ(
926 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
927 std::make_pair(23, 24));
928 EXPECT_EQ(
929 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
930 std::make_pair(23, 24));
931 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
932 {14, 17}),
933 std::make_pair(11, 25));
934 EXPECT_EQ(
935 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
936 std::make_pair(11, 23));
937 EXPECT_EQ(
938 classifier->SuggestSelection(
939 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
940 std::make_pair(14, 40));
941 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
942 std::make_pair(4, 5));
943 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
944 std::make_pair(7, 8));
945
946 // With a punctuation around the selected whitespace.
947 EXPECT_EQ(
948 classifier->SuggestSelection(
949 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
950 std::make_pair(14, 41));
951
952 // When all's whitespace, should return the original indices.
953 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
954 std::make_pair(0, 1));
955 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
956 std::make_pair(0, 3));
957 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
958 std::make_pair(2, 3));
959 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
960 std::make_pair(5, 6));
961}
962
Tony Mak6c4cc672018-09-17 11:48:50 +0100963TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200964 UnicodeText text;
965
966 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100967 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200968 std::make_pair(3, 4));
969 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100970 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200971 std::make_pair(3, 4));
972
973 // Nothing on the left.
974 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100975 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200976 std::make_pair(4, 5));
977 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100978 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200979 std::make_pair(0, 1));
980
981 // Whitespace only.
982 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100983 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200984 std::make_pair(2, 3));
985 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100986 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200987 std::make_pair(4, 5));
988 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100989 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200990 std::make_pair(0, 1));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100991}
992
Tony Mak378c1f52019-03-04 15:58:11 +0000993TEST_F(AnnotatorTest, Annotate) {
Tony Mak6c4cc672018-09-17 11:48:50 +0100994 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +0000995 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100996 ASSERT_TRUE(classifier);
997
998 const std::string test_string =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100999 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1000 "number is 853 225 3556";
Lukas Zilka21d8c982018-01-24 11:11:20 +01001001 EXPECT_THAT(classifier->Annotate(test_string),
1002 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +01001003 IsAnnotatedSpan(28, 55, "address"),
1004 IsAnnotatedSpan(79, 91, "phone"),
Lukas Zilka21d8c982018-01-24 11:11:20 +01001005 }));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001006
1007 AnnotationOptions options;
1008 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1009 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Tony Makd9446602019-02-20 18:25:39 +00001010 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1011 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001012 // Try passing invalid utf8.
1013 EXPECT_TRUE(
1014 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
1015 .empty());
Lukas Zilka21d8c982018-01-24 11:11:20 +01001016}
1017
Tony Mak378c1f52019-03-04 15:58:11 +00001018TEST_F(AnnotatorTest, AnnotateAnnotationsSuppressNumbers) {
1019 std::unique_ptr<Annotator> classifier =
1020 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
1021 ASSERT_TRUE(classifier);
1022 AnnotationOptions options;
1023 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
Tony Maka0f598b2018-11-20 20:39:04 +00001024
Tony Mak378c1f52019-03-04 15:58:11 +00001025 // Number annotator.
1026 EXPECT_THAT(
1027 classifier->Annotate("853 225 3556 and then turn it up 99%", options),
1028 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
1029 IsAnnotatedSpan(33, 35, "number")}));
1030}
1031
1032TEST_F(AnnotatorTest, AnnotateSplitLines) {
1033 std::string model_buffer = ReadFile(GetTestModelPath());
1034 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1035 model->selection_feature_options->only_use_line_with_click = true;
1036 });
1037 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1038 model_buffer.data(), model_buffer.size(), &unilib_, &calendarlib_);
1039
1040 ASSERT_TRUE(classifier);
1041
1042 const std::string str1 =
1043 "hey, sorry, just finished up. i didn't hear back from you in time.";
1044 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1045
1046 const int kAnnotationLength = 26;
1047 EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
1048 EXPECT_THAT(
1049 classifier->Annotate(str2),
1050 ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
1051
1052 const std::string str3 = str1 + "\n" + str2;
1053 EXPECT_THAT(
1054 classifier->Annotate(str3),
1055 ElementsAreArray({IsAnnotatedSpan(
1056 str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
1057}
1058
1059
1060TEST_F(AnnotatorTest, AnnotateSmallBatches) {
1061 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001062 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1063
1064 // Set the batch size.
1065 unpacked_model->selection_options->batch_size = 4;
1066 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001067 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001068
Tony Mak6c4cc672018-09-17 11:48:50 +01001069 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1070 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1071 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001072 ASSERT_TRUE(classifier);
1073
1074 const std::string test_string =
1075 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1076 "number is 853 225 3556";
1077 EXPECT_THAT(classifier->Annotate(test_string),
1078 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +01001079 IsAnnotatedSpan(28, 55, "address"),
1080 IsAnnotatedSpan(79, 91, "phone"),
1081 }));
1082
1083 AnnotationOptions options;
1084 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1085 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Tony Makd9446602019-02-20 18:25:39 +00001086 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1087 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001088}
1089
Tony Maka0f598b2018-11-20 20:39:04 +00001090#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001091TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
1092 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001093 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1094
Lukas Zilkab23e2122018-02-09 10:25:19 +01001095 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001096 // Add test threshold.
Lukas Zilkab23e2122018-02-09 10:25:19 +01001097 unpacked_model->triggering_options->min_annotate_confidence =
1098 2.f; // Discards all results.
1099 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001100 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001101
Tony Mak6c4cc672018-09-17 11:48:50 +01001102 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1103 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1104 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001105 ASSERT_TRUE(classifier);
1106
1107 const std::string test_string =
1108 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1109 "number is 853 225 3556";
Lukas Zilkaba849e72018-03-08 14:48:21 +01001110
Tony Mak6c4cc672018-09-17 11:48:50 +01001111 EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001112}
Tony Maka0f598b2018-11-20 20:39:04 +00001113#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +01001114
Tony Mak378c1f52019-03-04 15:58:11 +00001115TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
1116 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001117 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1118
1119 // Add test thresholds.
1120 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1121 unpacked_model->triggering_options->min_annotate_confidence =
1122 0.f; // Keeps all results.
Lukas Zilkaba849e72018-03-08 14:48:21 +01001123 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001124 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001125 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001126
Tony Mak6c4cc672018-09-17 11:48:50 +01001127 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1128 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1129 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001130 ASSERT_TRUE(classifier);
1131
1132 const std::string test_string =
1133 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1134 "number is 853 225 3556";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001135 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001136}
1137
Tony Mak378c1f52019-03-04 15:58:11 +00001138TEST_F(AnnotatorTest, AnnotateDisabled) {
1139 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001140 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1141
1142 // Disable the model for annotation.
1143 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
1144 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001145 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +01001146
Tony Mak6c4cc672018-09-17 11:48:50 +01001147 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1148 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1149 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001150 ASSERT_TRUE(classifier);
1151 const std::string test_string =
1152 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1153 "number is 853 225 3556";
1154 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
1155}
1156
Tony Mak378c1f52019-03-04 15:58:11 +00001157TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
1158 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001159
Tony Mak6c4cc672018-09-17 11:48:50 +01001160 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1161 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001162 ASSERT_TRUE(classifier);
1163
1164 const std::string test_string =
1165 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1166 "number is 853 225 3556";
1167
1168 EXPECT_THAT(classifier->Annotate(test_string),
1169 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001170 IsAnnotatedSpan(28, 55, "address"),
1171 IsAnnotatedSpan(79, 91, "phone"),
1172 }));
1173
1174 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1175 unpacked_model->output_options.reset(new OutputOptionsT);
1176
1177 // Disable phone annotation
1178 unpacked_model->output_options->filtered_collections_annotation.push_back(
1179 "phone");
1180
1181 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001182 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001183
Tony Mak6c4cc672018-09-17 11:48:50 +01001184 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001185 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001186 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001187 ASSERT_TRUE(classifier);
1188
1189 EXPECT_THAT(classifier->Annotate(test_string),
1190 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001191 IsAnnotatedSpan(28, 55, "address"),
1192 }));
1193}
1194
Tony Maka0f598b2018-11-20 20:39:04 +00001195#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001196TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
1197 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001198
Tony Mak6c4cc672018-09-17 11:48:50 +01001199 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1200 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001201 ASSERT_TRUE(classifier);
1202
1203 const std::string test_string =
1204 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1205 "number is 853 225 3556";
1206
1207 EXPECT_THAT(classifier->Annotate(test_string),
1208 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001209 IsAnnotatedSpan(28, 55, "address"),
1210 IsAnnotatedSpan(79, 91, "phone"),
1211 }));
1212
1213 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1214 unpacked_model->output_options.reset(new OutputOptionsT);
1215
1216 // We add a custom annotator that wins against the phone classification
1217 // below and that we subsequently suppress.
1218 unpacked_model->output_options->filtered_collections_annotation.push_back(
1219 "suppress");
1220
1221 unpacked_model->regex_model->patterns.push_back(MakePattern(
1222 "suppress", "(\\d{3} ?\\d{4})",
1223 /*enabled_for_classification=*/false,
1224 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1225
1226 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001227 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001228
Tony Mak6c4cc672018-09-17 11:48:50 +01001229 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001230 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001231 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001232 ASSERT_TRUE(classifier);
1233
1234 EXPECT_THAT(classifier->Annotate(test_string),
1235 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001236 IsAnnotatedSpan(28, 55, "address"),
1237 }));
1238}
Tony Maka0f598b2018-11-20 20:39:04 +00001239#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001240
Tony Maka0f598b2018-11-20 20:39:04 +00001241#ifdef TC3_CALENDAR_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001242TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001243 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +00001244 Annotator::FromPath(GetTestModelPath());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001245 EXPECT_TRUE(classifier);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001246 ClassificationOptions options;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001247 options.reference_timezone = "Europe/Zurich";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001248
Tony Mak378c1f52019-03-04 15:58:11 +00001249 std::vector<ClassificationResult> result =
1250 classifier->ClassifyText("january 1, 2017", {0, 15}, options);
1251
1252 EXPECT_THAT(result,
1253 ElementsAre(IsDateResult(1483225200000,
1254 DatetimeGranularity::GRANULARITY_DAY)));
1255}
1256#endif
1257
1258#ifdef TC3_CALENDAR_ICU
1259TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
1260 std::unique_ptr<Annotator> classifier =
1261 Annotator::FromPath(GetTestModelPath());
1262 EXPECT_TRUE(classifier);
1263 ClassificationOptions options;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001264 options.reference_timezone = "America/Los_Angeles";
Lukas Zilkab23e2122018-02-09 10:25:19 +01001265
Tony Mak378c1f52019-03-04 15:58:11 +00001266 std::vector<ClassificationResult> result =
1267 classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1268
1269 EXPECT_THAT(result,
1270 ElementsAre(IsDateResult(1488355200000,
1271 DatetimeGranularity::GRANULARITY_DAY)));
1272}
1273#endif // TC3_UNILIB_ICU
1274
1275#ifdef TC3_CALENDAR_ICU
1276TEST_F(AnnotatorTest, ClassifyTextDateTimeInLATimezone) {
1277 std::unique_ptr<Annotator> classifier =
1278 Annotator::FromPath(GetTestModelPath());
1279 EXPECT_TRUE(classifier);
1280 ClassificationOptions options;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001281 options.reference_timezone = "America/Los_Angeles";
Tony Mak854015a2019-01-16 15:56:48 +00001282
Tony Mak378c1f52019-03-04 15:58:11 +00001283 std::vector<ClassificationResult> result =
1284 classifier->ClassifyText("2018/01/01 22:00", {0, 16}, options);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001285
Tony Mak378c1f52019-03-04 15:58:11 +00001286 EXPECT_THAT(result,
1287 ElementsAre(IsDatetimeResult(
1288 1514872800000, DatetimeGranularity::GRANULARITY_MINUTE)));
1289}
1290#endif // TC3_UNILIB_ICU
1291
1292#ifdef TC3_CALENDAR_ICU
1293TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
1294 std::unique_ptr<Annotator> classifier =
1295 Annotator::FromPath(GetTestModelPath());
1296 EXPECT_TRUE(classifier);
1297 ClassificationOptions options;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001298 options.reference_timezone = "Europe/Zurich";
Tony Mak378c1f52019-03-04 15:58:11 +00001299
1300 std::vector<ClassificationResult> result = classifier->ClassifyText(
Lukas Zilkab23e2122018-02-09 10:25:19 +01001301 "hello world this is the first line\n"
1302 "january 1, 2017",
1303 {35, 50}, options);
Tony Mak378c1f52019-03-04 15:58:11 +00001304
1305 EXPECT_THAT(result,
1306 ElementsAre(IsDateResult(1483225200000,
1307 DatetimeGranularity::GRANULARITY_DAY)));
Lukas Zilkab23e2122018-02-09 10:25:19 +01001308}
Tony Maka0f598b2018-11-20 20:39:04 +00001309#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001310
Tony Maka0f598b2018-11-20 20:39:04 +00001311#ifdef TC3_CALENDAR_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001312TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001313 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +00001314 Annotator::FromPath(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001315 EXPECT_TRUE(classifier);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001316 std::vector<ClassificationResult> result;
1317 ClassificationOptions options;
1318
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001319 options.reference_timezone = "Europe/Zurich";
1320 options.locales = "en-US";
Tony Mak378c1f52019-03-04 15:58:11 +00001321 result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001322
Tony Mak378c1f52019-03-04 15:58:11 +00001323 // In US, the date should be interpreted as <month>.<day>.
1324 EXPECT_THAT(result,
1325 ElementsAre(IsDatetimeResult(
1326 5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001327}
Tony Maka0f598b2018-11-20 20:39:04 +00001328#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001329
Tony Maka0f598b2018-11-20 20:39:04 +00001330#ifdef TC3_CALENDAR_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001331TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
1332 std::unique_ptr<Annotator> classifier =
1333 Annotator::FromPath(GetTestModelPath());
1334 EXPECT_TRUE(classifier);
1335 std::vector<ClassificationResult> result;
1336 ClassificationOptions options;
1337
1338 options.reference_timezone = "Europe/Zurich";
1339 options.locales = "de";
1340 result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
1341
1342 // In Germany, the date should be interpreted as <day>.<month>.
1343 EXPECT_THAT(result,
1344 ElementsAre(IsDatetimeResult(
1345 10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
1346}
1347#endif // TC3_UNILIB_ICU
1348
1349#ifdef TC3_CALENDAR_ICU
1350TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
1351 std::unique_ptr<Annotator> classifier =
1352 Annotator::FromPath(GetTestModelPath());
1353 EXPECT_TRUE(classifier);
1354 ClassificationOptions options;
1355 options.reference_timezone = "Europe/Zurich";
1356 options.locales = "en-US";
1357 const std::vector<ClassificationResult> result =
1358 classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
1359
1360 EXPECT_THAT(
1361 result,
1362 ElementsAre(
1363 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1364 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1365}
1366#endif // TC3_UNILIB_ICU
1367
1368#ifdef TC3_CALENDAR_ICU
1369TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
1370 std::unique_ptr<Annotator> classifier =
1371 Annotator::FromPath(GetTestModelPath());
1372 EXPECT_TRUE(classifier);
1373 AnnotationOptions options;
1374 options.reference_timezone = "Europe/Zurich";
1375 options.locales = "en-US";
1376 const std::vector<AnnotatedSpan> spans =
1377 classifier->Annotate("set an alarm for 10:30", options);
1378
1379 ASSERT_EQ(spans.size(), 1);
1380 const std::vector<ClassificationResult> result = spans[0].classification;
1381 EXPECT_THAT(
1382 result,
1383 ElementsAre(
1384 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1385 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1386}
1387#endif // TC3_UNILIB_ICU
1388
1389#ifdef TC3_CALENDAR_ICU
1390TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
1391 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkaba849e72018-03-08 14:48:21 +01001392 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1393
1394 // Disable the patterns for selection.
1395 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1396 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1397 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1398 }
1399 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001400 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +01001401
Tony Mak6c4cc672018-09-17 11:48:50 +01001402 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1403 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1404 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001405 ASSERT_TRUE(classifier);
1406 EXPECT_EQ("date",
1407 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1408 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1409 std::make_pair(0, 7));
1410 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1411 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1412}
Tony Maka0f598b2018-11-20 20:39:04 +00001413#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +01001414
Tony Mak6c4cc672018-09-17 11:48:50 +01001415class TestingAnnotator : public Annotator {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001416 public:
Tony Mak6c4cc672018-09-17 11:48:50 +01001417 TestingAnnotator(const std::string& model, const UniLib* unilib,
1418 const CalendarLib* calendarlib)
Tony Mak854015a2019-01-16 15:56:48 +00001419 : Annotator(libtextclassifier3::ViewModel(model.data(), model.size()),
1420 unilib, calendarlib) {}
Lukas Zilkab23e2122018-02-09 10:25:19 +01001421
Tony Mak6c4cc672018-09-17 11:48:50 +01001422 using Annotator::ResolveConflicts;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001423};
1424
Tony Mak378c1f52019-03-04 15:58:11 +00001425AnnotatedSpan MakeAnnotatedSpan(
1426 CodepointSpan span, const std::string& collection, const float score,
1427 AnnotatedSpan::Source source = AnnotatedSpan::Source::OTHER) {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001428 AnnotatedSpan result;
1429 result.span = span;
1430 result.classification.push_back({collection, score});
Tony Mak378c1f52019-03-04 15:58:11 +00001431 result.source = source;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001432 return result;
1433}
1434
Tony Mak6c4cc672018-09-17 11:48:50 +01001435TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1436 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001437
1438 std::vector<AnnotatedSpan> candidates{
1439 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
Tony Mak378c1f52019-03-04 15:58:11 +00001440 std::vector<Locale> locales = {Locale::FromBCP47("en")};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001441
1442 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001443 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Tony Mak378c1f52019-03-04 15:58:11 +00001444 locales,
1445 AnnotationUsecase_ANNOTATION_USECASE_SMART,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001446 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001447 EXPECT_THAT(chosen, ElementsAreArray({0}));
1448}
1449
Tony Mak6c4cc672018-09-17 11:48:50 +01001450TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1451 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001452
1453 std::vector<AnnotatedSpan> candidates{{
1454 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1455 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1456 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1457 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1458 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1459 }};
Tony Mak378c1f52019-03-04 15:58:11 +00001460 std::vector<Locale> locales = {Locale::FromBCP47("en")};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001461
1462 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001463 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Tony Mak378c1f52019-03-04 15:58:11 +00001464 locales,
1465 AnnotationUsecase_ANNOTATION_USECASE_SMART,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001466 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001467 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1468}
1469
Tony Mak6c4cc672018-09-17 11:48:50 +01001470TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1471 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001472
1473 std::vector<AnnotatedSpan> candidates{{
1474 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1475 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1476 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1477 }};
Tony Mak378c1f52019-03-04 15:58:11 +00001478 std::vector<Locale> locales = {Locale::FromBCP47("en")};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001479
1480 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001481 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Tony Mak378c1f52019-03-04 15:58:11 +00001482 locales,
1483 AnnotationUsecase_ANNOTATION_USECASE_SMART,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001484 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001485 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1486}
1487
Tony Mak6c4cc672018-09-17 11:48:50 +01001488TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1489 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001490
1491 std::vector<AnnotatedSpan> candidates{{
1492 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1493 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1494 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1495 }};
Tony Mak378c1f52019-03-04 15:58:11 +00001496 std::vector<Locale> locales = {Locale::FromBCP47("en")};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001497
1498 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001499 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Tony Mak378c1f52019-03-04 15:58:11 +00001500 locales,
1501 AnnotationUsecase_ANNOTATION_USECASE_SMART,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001502 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001503 EXPECT_THAT(chosen, ElementsAreArray({1}));
1504}
1505
Tony Mak6c4cc672018-09-17 11:48:50 +01001506TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
1507 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001508
1509 std::vector<AnnotatedSpan> candidates{{
1510 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1511 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
1512 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1513 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
1514 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1515 }};
Tony Mak378c1f52019-03-04 15:58:11 +00001516 std::vector<Locale> locales = {Locale::FromBCP47("en")};
Lukas Zilkab23e2122018-02-09 10:25:19 +01001517
1518 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001519 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Tony Mak378c1f52019-03-04 15:58:11 +00001520 locales,
1521 AnnotationUsecase_ANNOTATION_USECASE_SMART,
Lukas Zilkaba849e72018-03-08 14:48:21 +01001522 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001523 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
1524}
Lukas Zilka21d8c982018-01-24 11:11:20 +01001525
Tony Mak378c1f52019-03-04 15:58:11 +00001526TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
1527 TestingAnnotator classifier("", &unilib_, &calendarlib_);
1528
1529 std::vector<AnnotatedSpan> candidates{{
1530 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
1531 AnnotatedSpan::Source::KNOWLEDGE),
1532 MakeAnnotatedSpan({5, 10}, "address", 0.6),
1533 }};
1534 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1535
1536 std::vector<int> chosen;
1537 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1538 locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
1539 /*interpreter_manager=*/nullptr, &chosen);
1540 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
1541}
1542
1543TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
1544 TestingAnnotator classifier("", &unilib_, &calendarlib_);
1545
1546 std::vector<AnnotatedSpan> candidates{{
1547 MakeAnnotatedSpan({0, 15}, "address", 0.7),
1548 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
1549 AnnotatedSpan::Source::KNOWLEDGE),
1550 }};
1551 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1552
1553 std::vector<int> chosen;
1554 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1555 locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
1556 /*interpreter_manager=*/nullptr, &chosen);
1557 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
1558}
1559
1560TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
1561 TestingAnnotator classifier("", &unilib_, &calendarlib_);
1562
1563 std::vector<AnnotatedSpan> candidates{{
1564 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
1565 AnnotatedSpan::Source::KNOWLEDGE),
1566 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
1567 AnnotatedSpan::Source::KNOWLEDGE),
1568 }};
1569 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1570
1571 std::vector<int> chosen;
1572 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1573 locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
1574 /*interpreter_manager=*/nullptr, &chosen);
1575 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
1576}
1577
1578TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
1579 TestingAnnotator classifier("", &unilib_, &calendarlib_);
1580
1581 std::vector<AnnotatedSpan> candidates{{
1582 MakeAnnotatedSpan({0, 15}, "address", 0.7),
1583 MakeAnnotatedSpan({5, 10}, "date", 0.6),
1584 }};
1585 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1586
1587 std::vector<int> chosen;
1588 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1589 locales, AnnotationUsecase_ANNOTATION_USECASE_RAW,
1590 /*interpreter_manager=*/nullptr, &chosen);
1591 EXPECT_THAT(chosen, ElementsAreArray({0}));
1592}
1593
Tony Maka0f598b2018-11-20 20:39:04 +00001594#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001595TEST_F(AnnotatorTest, LongInput) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001596 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +00001597 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilkadf710db2018-02-27 12:44:09 +01001598 ASSERT_TRUE(classifier);
1599
1600 for (const auto& type_value_pair :
1601 std::vector<std::pair<std::string, std::string>>{
1602 {"address", "350 Third Street, Cambridge"},
1603 {"phone", "123 456-7890"},
1604 {"url", "www.google.com"},
1605 {"email", "someone@gmail.com"},
1606 {"flight", "LX 38"},
1607 {"date", "September 1, 2018"}}) {
1608 const std::string input_100k = std::string(50000, ' ') +
1609 type_value_pair.second +
1610 std::string(50000, ' ');
1611 const int value_length = type_value_pair.second.size();
1612
1613 EXPECT_THAT(classifier->Annotate(input_100k),
1614 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
1615 type_value_pair.first)}));
1616 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
1617 std::make_pair(50000, 50000 + value_length));
1618 EXPECT_EQ(type_value_pair.first,
1619 FirstResult(classifier->ClassifyText(
1620 input_100k, {50000, 50000 + value_length})));
1621 }
1622}
Tony Maka0f598b2018-11-20 20:39:04 +00001623#endif // TC3_UNILIB_ICU
Lukas Zilkadf710db2018-02-27 12:44:09 +01001624
Tony Maka0f598b2018-11-20 20:39:04 +00001625#ifdef TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001626// These coarse tests are there only to make sure the execution happens in
1627// reasonable amount of time.
Tony Mak378c1f52019-03-04 15:58:11 +00001628TEST_F(AnnotatorTest, LongInputNoResultCheck) {
Tony Mak6c4cc672018-09-17 11:48:50 +01001629 std::unique_ptr<Annotator> classifier =
Tony Mak378c1f52019-03-04 15:58:11 +00001630 Annotator::FromPath(GetTestModelPath(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001631 ASSERT_TRUE(classifier);
1632
1633 for (const std::string& value :
1634 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
1635 const std::string input_100k =
1636 std::string(50000, ' ') + value + std::string(50000, ' ');
1637 const int value_length = value.size();
1638
1639 classifier->Annotate(input_100k);
1640 classifier->SuggestSelection(input_100k, {50000, 50001});
1641 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
1642 }
1643}
Tony Maka0f598b2018-11-20 20:39:04 +00001644#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001645
Tony Maka0f598b2018-11-20 20:39:04 +00001646#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001647TEST_F(AnnotatorTest, MaxTokenLength) {
1648 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilka434442d2018-04-25 11:38:51 +02001649 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1650
Tony Mak6c4cc672018-09-17 11:48:50 +01001651 std::unique_ptr<Annotator> classifier;
Lukas Zilka434442d2018-04-25 11:38:51 +02001652
1653 // With unrestricted number of tokens should behave normally.
1654 unpacked_model->classification_options->max_num_tokens = -1;
1655
1656 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001657 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001658 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001659 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001660 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001661 ASSERT_TRUE(classifier);
1662
1663 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1664 "I live at 350 Third Street, Cambridge.", {10, 37})),
1665 "address");
1666
1667 // Raise the maximum number of tokens to suppress the classification.
1668 unpacked_model->classification_options->max_num_tokens = 3;
1669
1670 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001671 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001672 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001673 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001674 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001675 ASSERT_TRUE(classifier);
1676
1677 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1678 "I live at 350 Third Street, Cambridge.", {10, 37})),
1679 "other");
1680}
Tony Maka0f598b2018-11-20 20:39:04 +00001681#endif // TC3_UNILIB_ICU
Lukas Zilka434442d2018-04-25 11:38:51 +02001682
Tony Maka0f598b2018-11-20 20:39:04 +00001683#ifdef TC3_UNILIB_ICU
Tony Mak378c1f52019-03-04 15:58:11 +00001684TEST_F(AnnotatorTest, MinAddressTokenLength) {
1685 const std::string test_model = ReadFile(GetTestModelPath());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001686 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1687
Tony Mak6c4cc672018-09-17 11:48:50 +01001688 std::unique_ptr<Annotator> classifier;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001689
1690 // With unrestricted number of address tokens should behave normally.
1691 unpacked_model->classification_options->address_min_num_tokens = 0;
1692
1693 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001694 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001695 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001696 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001697 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001698 ASSERT_TRUE(classifier);
1699
1700 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1701 "I live at 350 Third Street, Cambridge.", {10, 37})),
1702 "address");
1703
1704 // Raise number of address tokens to suppress the address classification.
1705 unpacked_model->classification_options->address_min_num_tokens = 5;
1706
1707 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001708 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001709 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001710 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001711 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001712 ASSERT_TRUE(classifier);
1713
1714 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1715 "I live at 350 Third Street, Cambridge.", {10, 37})),
1716 "other");
1717}
Tony Maka0f598b2018-11-20 20:39:04 +00001718#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001719
Tony Mak854015a2019-01-16 15:56:48 +00001720TEST_F(AnnotatorTest, VisitAnnotatorModel) {
Tony Mak378c1f52019-03-04 15:58:11 +00001721 EXPECT_TRUE(
1722 VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
1723 if (model == nullptr) {
1724 return false;
1725 }
1726 return true;
1727 }));
Tony Mak854015a2019-01-16 15:56:48 +00001728 EXPECT_FALSE(VisitAnnotatorModel<bool>(
1729 GetModelPath() + "non_existing_model.fb", [](const Model* model) {
1730 if (model == nullptr) {
1731 return false;
1732 }
1733 return true;
1734 }));
1735}
1736
Lukas Zilka21d8c982018-01-24 11:11:20 +01001737} // namespace
Tony Mak6c4cc672018-09-17 11:48:50 +01001738} // namespace libtextclassifier3