blob: d807ad81b28f600b25d9679f1b0f68b327ac6459 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Lukas Zilka21d8c982018-01-24 11:11:20 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/annotator.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010018
19#include <fstream>
20#include <iostream>
21#include <memory>
22#include <string>
23
Tony Mak6c4cc672018-09-17 11:48:50 +010024#include "annotator/model_generated.h"
25#include "annotator/types-test-util.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010026#include "gmock/gmock.h"
27#include "gtest/gtest.h"
28
Tony Mak6c4cc672018-09-17 11:48:50 +010029namespace libtextclassifier3 {
Lukas Zilka21d8c982018-01-24 11:11:20 +010030namespace {
31
32using testing::ElementsAreArray;
Lukas Zilkaba849e72018-03-08 14:48:21 +010033using testing::IsEmpty;
Lukas Zilka21d8c982018-01-24 11:11:20 +010034using testing::Pair;
Lukas Zilkab23e2122018-02-09 10:25:19 +010035using testing::Values;
Lukas Zilka21d8c982018-01-24 11:11:20 +010036
Lukas Zilkab23e2122018-02-09 10:25:19 +010037std::string FirstResult(const std::vector<ClassificationResult>& results) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010038 if (results.empty()) {
39 return "<INVALID RESULTS>";
40 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010041 return results[0].collection;
Lukas Zilka21d8c982018-01-24 11:11:20 +010042}
43
44MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
45 return testing::Value(arg.span, Pair(start, end)) &&
46 testing::Value(FirstResult(arg.classification), best_class);
47}
48
49std::string ReadFile(const std::string& file_name) {
50 std::ifstream file_stream(file_name);
51 return std::string(std::istreambuf_iterator<char>(file_stream), {});
52}
53
54std::string GetModelPath() {
Tony Maka0f598b2018-11-20 20:39:04 +000055 return TC3_TEST_DATA_DIR;
Lukas Zilka21d8c982018-01-24 11:11:20 +010056}
57
Tony Makd9446602019-02-20 18:25:39 +000058// Create fake entity data schema meta data.
59void AddTestEntitySchemaData(ModelT* unpacked_model) {
60 // Cannot use object oriented API here as that is not available for the
61 // reflection schema.
62 flatbuffers::FlatBufferBuilder schema_builder;
63 std::vector<flatbuffers::Offset<reflection::Field>> fields = {
64 reflection::CreateField(
65 schema_builder,
66 /*name=*/schema_builder.CreateString("first_name"),
67 /*type=*/
68 reflection::CreateType(schema_builder,
69 /*base_type=*/reflection::String),
70 /*id=*/0,
71 /*offset=*/4),
72 reflection::CreateField(
73 schema_builder,
74 /*name=*/schema_builder.CreateString("is_alive"),
75 /*type=*/
76 reflection::CreateType(schema_builder,
77 /*base_type=*/reflection::Bool),
78 /*id=*/1,
79 /*offset=*/6),
80 reflection::CreateField(
81 schema_builder,
82 /*name=*/schema_builder.CreateString("last_name"),
83 /*type=*/
84 reflection::CreateType(schema_builder,
85 /*base_type=*/reflection::String),
86 /*id=*/2,
87 /*offset=*/8),
88 };
89 std::vector<flatbuffers::Offset<reflection::Enum>> enums;
90 std::vector<flatbuffers::Offset<reflection::Object>> objects = {
91 reflection::CreateObject(
92 schema_builder,
93 /*name=*/schema_builder.CreateString("EntityData"),
94 /*fields=*/
95 schema_builder.CreateVectorOfSortedTables(&fields))};
96 schema_builder.Finish(reflection::CreateSchema(
97 schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
98 schema_builder.CreateVectorOfSortedTables(&enums),
99 /*(unused) file_ident=*/0,
100 /*(unused) file_ext=*/0,
101 /*root_table*/ objects[0]));
102
103 unpacked_model->entity_data_schema.assign(
104 schema_builder.GetBufferPointer(),
105 schema_builder.GetBufferPointer() + schema_builder.GetSize());
106}
107
Tony Mak6c4cc672018-09-17 11:48:50 +0100108class AnnotatorTest : public ::testing::TestWithParam<const char*> {
109 protected:
110 AnnotatorTest()
111 : INIT_UNILIB_FOR_TESTING(unilib_),
112 INIT_CALENDARLIB_FOR_TESTING(calendarlib_) {}
113 UniLib unilib_;
114 CalendarLib calendarlib_;
115};
116
117TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
118 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
119 GetModelPath() + "wrong_embeddings.fb", &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100120 EXPECT_FALSE(classifier);
121}
122
Tony Mak5dc5e112019-02-01 14:52:10 +0000123INSTANTIATE_TEST_SUITE_P(BoundsSensitive, AnnotatorTest,
124 Values("test_model.fb"));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100125
Tony Mak6c4cc672018-09-17 11:48:50 +0100126TEST_P(AnnotatorTest, ClassifyText) {
127 std::unique_ptr<Annotator> classifier =
128 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100129 ASSERT_TRUE(classifier);
130
131 EXPECT_EQ("other",
132 FirstResult(classifier->ClassifyText(
133 "this afternoon Barack Obama gave a speech at", {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100134 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
135 "Call me at (800) 123-456 today", {11, 24})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100136
137 // More lines.
138 EXPECT_EQ("other",
139 FirstResult(classifier->ClassifyText(
140 "this afternoon Barack Obama gave a speech at|Visit "
141 "www.google.com every today!|Call me at (800) 123-456 today.",
142 {15, 27})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100143 EXPECT_EQ("phone",
144 FirstResult(classifier->ClassifyText(
145 "this afternoon Barack Obama gave a speech at|Visit "
146 "www.google.com every today!|Call me at (800) 123-456 today.",
147 {90, 103})));
148
149 // Single word.
150 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
151 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
152 EXPECT_EQ("<INVALID RESULTS>",
153 FirstResult(classifier->ClassifyText("asdf", {0, 0})));
154
155 // Junk.
156 EXPECT_EQ("<INVALID RESULTS>",
157 FirstResult(classifier->ClassifyText("", {0, 0})));
158 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
159 "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200160 // Test invalid utf8 input.
161 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
162 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100163}
164
Tony Mak6c4cc672018-09-17 11:48:50 +0100165TEST_P(AnnotatorTest, ClassifyTextDisabledFail) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100166 const std::string test_model = ReadFile(GetModelPath() + GetParam());
167 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
168
169 unpacked_model->classification_model.clear();
170 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
171 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
172
173 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000174 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100175
Tony Mak6c4cc672018-09-17 11:48:50 +0100176 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
177 reinterpret_cast<const char*>(builder.GetBufferPointer()),
178 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100179
180 // The classification model is still needed for selection scores.
181 ASSERT_FALSE(classifier);
182}
183
Tony Mak6c4cc672018-09-17 11:48:50 +0100184TEST_P(AnnotatorTest, ClassifyTextDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100185 const std::string test_model = ReadFile(GetModelPath() + GetParam());
186 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
187
188 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
189 unpacked_model->triggering_options->enabled_modes =
190 ModeFlag_ANNOTATION_AND_SELECTION;
191
192 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000193 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100194
Tony Mak6c4cc672018-09-17 11:48:50 +0100195 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
196 reinterpret_cast<const char*>(builder.GetBufferPointer()),
197 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100198 ASSERT_TRUE(classifier);
199
200 EXPECT_THAT(
201 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
202 IsEmpty());
203}
204
Tony Mak6c4cc672018-09-17 11:48:50 +0100205TEST_P(AnnotatorTest, ClassifyTextFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200206 const std::string test_model = ReadFile(GetModelPath() + GetParam());
207
Tony Mak6c4cc672018-09-17 11:48:50 +0100208 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
209 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200210 ASSERT_TRUE(classifier);
211
212 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
213 "Call me at (800) 123-456 today", {11, 24})));
214
215 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
216 unpacked_model->output_options.reset(new OutputOptionsT);
217
218 // Disable phone classification
219 unpacked_model->output_options->filtered_collections_classification.push_back(
220 "phone");
221
222 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000223 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200224
Tony Mak6c4cc672018-09-17 11:48:50 +0100225 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200226 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100227 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200228 ASSERT_TRUE(classifier);
229
230 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
231 "Call me at (800) 123-456 today", {11, 24})));
232
233 // Check that the address classification still passes.
234 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
235 "350 Third Street, Cambridge", {0, 27})));
236}
237
Lukas Zilkab23e2122018-02-09 10:25:19 +0100238std::unique_ptr<RegexModel_::PatternT> MakePattern(
239 const std::string& collection_name, const std::string& pattern,
240 const bool enabled_for_classification, const bool enabled_for_selection,
241 const bool enabled_for_annotation, const float score) {
242 std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
243 result->collection_name = collection_name;
244 result->pattern = pattern;
Lukas Zilkaba849e72018-03-08 14:48:21 +0100245 // We cannot directly operate with |= on the flag, so use an int here.
246 int enabled_modes = ModeFlag_NONE;
247 if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
248 if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
249 if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
250 result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100251 result->target_classification_score = score;
252 result->priority_score = score;
253 return result;
254}
255
Tony Maka0f598b2018-11-20 20:39:04 +0000256#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100257TEST_P(AnnotatorTest, ClassifyTextRegularExpression) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100258 const std::string test_model = ReadFile(GetModelPath() + GetParam());
259 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
260
261 // Add test regex models.
262 unpacked_model->regex_model->patterns.push_back(MakePattern(
263 "person", "Barack Obama", /*enabled_for_classification=*/true,
264 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
265 unpacked_model->regex_model->patterns.push_back(MakePattern(
266 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
267 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100268 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
269 MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
270 /*enabled_for_classification=*/true,
271 /*enabled_for_selection=*/false,
272 /*enabled_for_annotation=*/false, 1.0);
273 verified_pattern->verification_options.reset(new VerificationOptionsT);
274 verified_pattern->verification_options->verify_luhn_checksum = true;
275 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100276
277 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000278 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100279
Tony Mak6c4cc672018-09-17 11:48:50 +0100280 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
281 reinterpret_cast<const char*>(builder.GetBufferPointer()),
282 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100283 ASSERT_TRUE(classifier);
284
285 EXPECT_EQ("flight",
286 FirstResult(classifier->ClassifyText(
287 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
288 EXPECT_EQ("person",
289 FirstResult(classifier->ClassifyText(
290 "this afternoon Barack Obama gave a speech at", {15, 27})));
291 EXPECT_EQ("email",
292 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
293 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
294 "Contact me at you@android.com", {14, 29})));
295
296 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
297 "Visit www.google.com every today!", {6, 20})));
298
299 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
300 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
301 {7, 12})));
Tony Mak6c4cc672018-09-17 11:48:50 +0100302 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
303 "cc: 4012 8888 8888 1881", {4, 23})));
304 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
305 "2221 0067 4735 6281", {0, 19})));
306 // Luhn check fails.
307 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
308 {0, 19})));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100309
310 // More lines.
311 EXPECT_EQ("url",
312 FirstResult(classifier->ClassifyText(
313 "this afternoon Barack Obama gave a speech at|Visit "
314 "www.google.com every today!|Call me at (800) 123-456 today.",
315 {51, 65})));
316}
Tony Makd9446602019-02-20 18:25:39 +0000317
318TEST_P(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
319 const std::string test_model = ReadFile(GetModelPath() + GetParam());
320 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
321
322 // Add fake entity schema metadata.
323 AddTestEntitySchemaData(unpacked_model.get());
324
325 // Add test regex models.
326 unpacked_model->regex_model->patterns.push_back(MakePattern(
327 "person", "(Barack) (Obama)", /*enabled_for_classification=*/true,
328 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
329
330 // Use meta data to generate custom serialized entity data.
331 ReflectiveFlatbufferBuilder entity_data_builder(
332 flatbuffers::GetRoot<reflection::Schema>(
333 unpacked_model->entity_data_schema.data()));
334 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
335 entity_data_builder.NewRoot();
336 entity_data->Set("is_alive", true);
337
338 RegexModel_::PatternT* pattern =
339 unpacked_model->regex_model->patterns.back().get();
340 pattern->serialized_entity_data = entity_data->Serialize();
341 pattern->capturing_group.emplace_back(
342 new RegexModel_::Pattern_::CapturingGroupT);
343 pattern->capturing_group.emplace_back(
344 new RegexModel_::Pattern_::CapturingGroupT);
345 pattern->capturing_group.emplace_back(
346 new RegexModel_::Pattern_::CapturingGroupT);
347 // Group 0 is the full match, capturing groups starting at 1.
348 pattern->capturing_group[1]->entity_field_path.reset(
349 new FlatbufferFieldPathT);
350 pattern->capturing_group[1]->entity_field_path->field.emplace_back(
351 new FlatbufferFieldT);
352 pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
353 "first_name";
354 pattern->capturing_group[2]->entity_field_path.reset(
355 new FlatbufferFieldPathT);
356 pattern->capturing_group[2]->entity_field_path->field.emplace_back(
357 new FlatbufferFieldT);
358 pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
359 "last_name";
360
361 flatbuffers::FlatBufferBuilder builder;
362 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
363
364 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
365 reinterpret_cast<const char*>(builder.GetBufferPointer()),
366 builder.GetSize(), &unilib_, &calendarlib_);
367 ASSERT_TRUE(classifier);
368
369 auto classifications = classifier->ClassifyText(
370 "this afternoon Barack Obama gave a speech at", {15, 27});
371 EXPECT_EQ(1, classifications.size());
372 EXPECT_EQ("person", classifications[0].collection);
373
374 // Check entity data.
375 const flatbuffers::Table* entity =
376 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
377 classifications[0].serialized_entity_data.data()));
378 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
379 "Barack");
380 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
381 "Obama");
382 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
383}
Tony Maka0f598b2018-11-20 20:39:04 +0000384#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100385
Tony Maka0f598b2018-11-20 20:39:04 +0000386#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100387TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100388 const std::string test_model = ReadFile(GetModelPath() + GetParam());
389 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
390
391 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100392 unpacked_model->regex_model->patterns.push_back(MakePattern(
393 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
394 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
395 unpacked_model->regex_model->patterns.push_back(MakePattern(
396 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
397 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
398 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
Tony Mak6c4cc672018-09-17 11:48:50 +0100399 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
400 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
401 /*enabled_for_classification=*/false,
402 /*enabled_for_selection=*/true,
403 /*enabled_for_annotation=*/false, 1.0);
404 verified_pattern->verification_options.reset(new VerificationOptionsT);
405 verified_pattern->verification_options->verify_luhn_checksum = true;
406 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100407
408 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000409 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100410
Tony Mak6c4cc672018-09-17 11:48:50 +0100411 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
412 reinterpret_cast<const char*>(builder.GetBufferPointer()),
413 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100414 ASSERT_TRUE(classifier);
415
416 // Check regular expression selection.
417 EXPECT_EQ(classifier->SuggestSelection(
418 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
419 std::make_pair(12, 19));
420 EXPECT_EQ(classifier->SuggestSelection(
421 "this afternoon Barack Obama gave a speech at", {15, 21}),
422 std::make_pair(15, 27));
Tony Mak6c4cc672018-09-17 11:48:50 +0100423 EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
424 std::make_pair(4, 23));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100425}
Tony Mak854015a2019-01-16 15:56:48 +0000426
427TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
428 const std::string test_model = ReadFile(GetModelPath() + GetParam());
429 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
430
431 // Add test regex models.
432 std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
433 MakePattern("date_range",
434 "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
435 "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
436 /*enabled_for_classification=*/false,
437 /*enabled_for_selection=*/true,
438 /*enabled_for_annotation=*/false, 1.0);
439 custom_selection_bounds_pattern->capturing_group.emplace_back(
440 new RegexModel_::Pattern_::CapturingGroupT);
441 custom_selection_bounds_pattern->capturing_group.emplace_back(
442 new RegexModel_::Pattern_::CapturingGroupT);
443 custom_selection_bounds_pattern->capturing_group.emplace_back(
444 new RegexModel_::Pattern_::CapturingGroupT);
445 custom_selection_bounds_pattern->capturing_group.emplace_back(
446 new RegexModel_::Pattern_::CapturingGroupT);
447 custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
448 custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
449 custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
450 custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
451 unpacked_model->regex_model->patterns.push_back(
452 std::move(custom_selection_bounds_pattern));
453
454 flatbuffers::FlatBufferBuilder builder;
455 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
456
457 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
458 reinterpret_cast<const char*>(builder.GetBufferPointer()),
459 builder.GetSize(), &unilib_, &calendarlib_);
460 ASSERT_TRUE(classifier);
461
462 // Check regular expression selection.
463 EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
464 {21, 23}),
465 std::make_pair(10, 34));
466 EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
467 std::make_pair(9, 17));
468}
Tony Maka0f598b2018-11-20 20:39:04 +0000469#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100470
Tony Maka0f598b2018-11-20 20:39:04 +0000471#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100472TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100473 const std::string test_model = ReadFile(GetModelPath() + GetParam());
474 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
475
476 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100477 unpacked_model->regex_model->patterns.push_back(MakePattern(
478 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
479 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
480 unpacked_model->regex_model->patterns.push_back(MakePattern(
481 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
482 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
483 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
484
485 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000486 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100487
Tony Mak6c4cc672018-09-17 11:48:50 +0100488 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
489 reinterpret_cast<const char*>(builder.GetBufferPointer()),
490 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100491 ASSERT_TRUE(classifier);
492
493 // Check conflict resolution.
494 EXPECT_EQ(
495 classifier->SuggestSelection(
496 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
497 {55, 57}),
498 std::make_pair(26, 62));
499}
Tony Maka0f598b2018-11-20 20:39:04 +0000500#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100501
Tony Maka0f598b2018-11-20 20:39:04 +0000502#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100503TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100504 const std::string test_model = ReadFile(GetModelPath() + GetParam());
505 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
506
507 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100508 unpacked_model->regex_model->patterns.push_back(MakePattern(
509 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
510 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
511 unpacked_model->regex_model->patterns.push_back(MakePattern(
512 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
513 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
514 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
515
516 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000517 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100518
Tony Mak6c4cc672018-09-17 11:48:50 +0100519 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
520 reinterpret_cast<const char*>(builder.GetBufferPointer()),
521 builder.GetSize());
Lukas Zilkab23e2122018-02-09 10:25:19 +0100522 ASSERT_TRUE(classifier);
523
524 // Check conflict resolution.
525 EXPECT_EQ(
526 classifier->SuggestSelection(
527 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
528 {55, 57}),
529 std::make_pair(55, 62));
530}
Tony Maka0f598b2018-11-20 20:39:04 +0000531#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100532
Tony Maka0f598b2018-11-20 20:39:04 +0000533#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100534TEST_P(AnnotatorTest, AnnotateRegex) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100535 const std::string test_model = ReadFile(GetModelPath() + GetParam());
536 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
537
538 // Add test regex models.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100539 unpacked_model->regex_model->patterns.push_back(MakePattern(
540 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
541 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
542 unpacked_model->regex_model->patterns.push_back(MakePattern(
543 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
544 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
Tony Mak6c4cc672018-09-17 11:48:50 +0100545 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
546 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
547 /*enabled_for_classification=*/false,
548 /*enabled_for_selection=*/false,
549 /*enabled_for_annotation=*/true, 1.0);
550 verified_pattern->verification_options.reset(new VerificationOptionsT);
551 verified_pattern->verification_options->verify_luhn_checksum = true;
552 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100553 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000554 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100555
Tony Mak6c4cc672018-09-17 11:48:50 +0100556 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
557 reinterpret_cast<const char*>(builder.GetBufferPointer()),
558 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100559 ASSERT_TRUE(classifier);
560
561 const std::string test_string =
562 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
Tony Mak6c4cc672018-09-17 11:48:50 +0100563 "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100564 EXPECT_THAT(classifier->Annotate(test_string),
Tony Mak6c4cc672018-09-17 11:48:50 +0100565 ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
566 IsAnnotatedSpan(28, 55, "address"),
567 IsAnnotatedSpan(79, 91, "phone"),
568 IsAnnotatedSpan(107, 126, "payment_card")}));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100569}
Tony Maka0f598b2018-11-20 20:39:04 +0000570#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100571
Tony Mak6c4cc672018-09-17 11:48:50 +0100572TEST_P(AnnotatorTest, PhoneFiltering) {
573 std::unique_ptr<Annotator> classifier =
574 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100575 ASSERT_TRUE(classifier);
576
577 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
578 "phone: (123) 456 789", {7, 20})));
579 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
580 "phone: (123) 456 789,0001112", {7, 25})));
581 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
582 "phone: (123) 456 789,0001112", {7, 28})));
583}
584
Tony Mak6c4cc672018-09-17 11:48:50 +0100585TEST_P(AnnotatorTest, SuggestSelection) {
586 std::unique_ptr<Annotator> classifier =
587 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100588 ASSERT_TRUE(classifier);
589
590 EXPECT_EQ(classifier->SuggestSelection(
591 "this afternoon Barack Obama gave a speech at", {15, 21}),
592 std::make_pair(15, 21));
593
594 // Try passing whole string.
595 // If more than 1 token is specified, we should return back what entered.
596 EXPECT_EQ(
597 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
598 std::make_pair(0, 27));
599
600 // Single letter.
601 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
602
603 // Single word.
604 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
605
606 EXPECT_EQ(
607 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
608 std::make_pair(11, 23));
609
610 // Unpaired bracket stripping.
611 EXPECT_EQ(
612 classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
613 std::make_pair(11, 25));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100614 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
615 std::make_pair(12, 15));
616 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
617 std::make_pair(11, 15));
618 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
619 std::make_pair(12, 15));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100620
621 // If the resulting selection would be empty, the original span is returned.
622 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
623 std::make_pair(11, 13));
624 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
625 std::make_pair(11, 12));
626 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
627 std::make_pair(11, 12));
628}
629
Tony Mak6c4cc672018-09-17 11:48:50 +0100630TEST_P(AnnotatorTest, SuggestSelectionDisabledFail) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100631 const std::string test_model = ReadFile(GetModelPath() + GetParam());
632 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
633
634 // Disable the selection model.
635 unpacked_model->selection_model.clear();
636 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
637 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
638
639 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000640 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100641
Tony Mak6c4cc672018-09-17 11:48:50 +0100642 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
643 reinterpret_cast<const char*>(builder.GetBufferPointer()),
644 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100645 // Selection model needs to be present for annotation.
646 ASSERT_FALSE(classifier);
647}
648
Tony Mak6c4cc672018-09-17 11:48:50 +0100649TEST_P(AnnotatorTest, SuggestSelectionDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100650 const std::string test_model = ReadFile(GetModelPath() + GetParam());
651 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
652
653 // Disable the selection model.
654 unpacked_model->selection_model.clear();
655 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
656 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
657 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
658
659 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000660 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100661
Tony Mak6c4cc672018-09-17 11:48:50 +0100662 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
663 reinterpret_cast<const char*>(builder.GetBufferPointer()),
664 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100665 ASSERT_TRUE(classifier);
666
667 EXPECT_EQ(
668 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
669 std::make_pair(11, 14));
670
671 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
672 "call me at (800) 123-456 today", {11, 24})));
673
674 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
675 IsEmpty());
676}
677
Tony Mak6c4cc672018-09-17 11:48:50 +0100678TEST_P(AnnotatorTest, SuggestSelectionFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200679 const std::string test_model = ReadFile(GetModelPath() + GetParam());
680
Tony Mak6c4cc672018-09-17 11:48:50 +0100681 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
682 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200683 ASSERT_TRUE(classifier);
684
685 EXPECT_EQ(
686 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
687 std::make_pair(11, 23));
688
689 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
690 unpacked_model->output_options.reset(new OutputOptionsT);
691
692 // Disable phone selection
693 unpacked_model->output_options->filtered_collections_selection.push_back(
694 "phone");
695 // We need to force this for filtering.
696 unpacked_model->selection_options->always_classify_suggested_selection = true;
697
698 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000699 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200700
Tony Mak6c4cc672018-09-17 11:48:50 +0100701 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200702 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +0100703 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200704 ASSERT_TRUE(classifier);
705
706 EXPECT_EQ(
707 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
708 std::make_pair(11, 14));
709
710 // Address selection should still work.
711 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
712 std::make_pair(0, 27));
713}
714
Tony Mak6c4cc672018-09-17 11:48:50 +0100715TEST_P(AnnotatorTest, SuggestSelectionsAreSymmetric) {
716 std::unique_ptr<Annotator> classifier =
717 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100718 ASSERT_TRUE(classifier);
719
720 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
721 std::make_pair(0, 27));
722 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
723 std::make_pair(0, 27));
724 EXPECT_EQ(
725 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
726 std::make_pair(0, 27));
727 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
728 {16, 22}),
729 std::make_pair(6, 33));
730}
731
Tony Mak6c4cc672018-09-17 11:48:50 +0100732TEST_P(AnnotatorTest, SuggestSelectionWithNewLine) {
733 std::unique_ptr<Annotator> classifier =
734 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100735 ASSERT_TRUE(classifier);
736
737 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
738 std::make_pair(4, 16));
739 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
740 std::make_pair(0, 12));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100741
742 SelectionOptions options;
743 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
Tony Makd9446602019-02-20 18:25:39 +0000744 std::make_pair(0, 12));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100745}
746
Tony Mak6c4cc672018-09-17 11:48:50 +0100747TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
748 std::unique_ptr<Annotator> classifier =
749 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100750 ASSERT_TRUE(classifier);
751
752 // From the right.
753 EXPECT_EQ(classifier->SuggestSelection(
754 "this afternoon BarackObama, gave a speech at", {15, 26}),
755 std::make_pair(15, 26));
756
757 // From the right multiple.
758 EXPECT_EQ(classifier->SuggestSelection(
759 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
760 std::make_pair(15, 26));
761
762 // From the left multiple.
763 EXPECT_EQ(classifier->SuggestSelection(
764 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
765 std::make_pair(21, 32));
766
767 // From both sides.
768 EXPECT_EQ(classifier->SuggestSelection(
769 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
770 std::make_pair(16, 27));
771}
772
Tony Mak6c4cc672018-09-17 11:48:50 +0100773TEST_P(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
774 std::unique_ptr<Annotator> classifier =
775 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100776 ASSERT_TRUE(classifier);
777
778 // Try passing in bunch of invalid selections.
779 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
780 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
781 std::make_pair(-10, 27));
782 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
783 std::make_pair(0, 27));
784 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
785 std::make_pair(-30, 300));
786 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
787 std::make_pair(-10, -1));
788 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
789 std::make_pair(100, 17));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200790
791 // Try passing invalid utf8.
792 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
793 std::make_pair(-1, -1));
794}
795
Tony Mak6c4cc672018-09-17 11:48:50 +0100796TEST_P(AnnotatorTest, SuggestSelectionSelectSpace) {
797 std::unique_ptr<Annotator> classifier =
798 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200799 ASSERT_TRUE(classifier);
800
801 EXPECT_EQ(
802 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
803 std::make_pair(11, 23));
804 EXPECT_EQ(
805 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
806 std::make_pair(10, 11));
807 EXPECT_EQ(
808 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
809 std::make_pair(23, 24));
810 EXPECT_EQ(
811 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
812 std::make_pair(23, 24));
813 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
814 {14, 17}),
815 std::make_pair(11, 25));
816 EXPECT_EQ(
817 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
818 std::make_pair(11, 23));
819 EXPECT_EQ(
820 classifier->SuggestSelection(
821 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
822 std::make_pair(14, 40));
823 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
824 std::make_pair(4, 5));
825 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
826 std::make_pair(7, 8));
827
828 // With a punctuation around the selected whitespace.
829 EXPECT_EQ(
830 classifier->SuggestSelection(
831 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
832 std::make_pair(14, 41));
833
834 // When all's whitespace, should return the original indices.
835 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
836 std::make_pair(0, 1));
837 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
838 std::make_pair(0, 3));
839 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
840 std::make_pair(2, 3));
841 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
842 std::make_pair(5, 6));
843}
844
Tony Mak6c4cc672018-09-17 11:48:50 +0100845TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200846 UnicodeText text;
847
848 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100849 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200850 std::make_pair(3, 4));
851 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100852 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200853 std::make_pair(3, 4));
854
855 // Nothing on the left.
856 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100857 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200858 std::make_pair(4, 5));
859 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100860 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200861 std::make_pair(0, 1));
862
863 // Whitespace only.
864 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100865 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200866 std::make_pair(2, 3));
867 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100868 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200869 std::make_pair(4, 5));
870 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
Tony Mak6c4cc672018-09-17 11:48:50 +0100871 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib_),
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200872 std::make_pair(0, 1));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100873}
874
Tony Mak6c4cc672018-09-17 11:48:50 +0100875TEST_P(AnnotatorTest, Annotate) {
876 std::unique_ptr<Annotator> classifier =
877 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100878 ASSERT_TRUE(classifier);
879
880 const std::string test_string =
Lukas Zilkab23e2122018-02-09 10:25:19 +0100881 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
882 "number is 853 225 3556";
Lukas Zilka21d8c982018-01-24 11:11:20 +0100883 EXPECT_THAT(classifier->Annotate(test_string),
884 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100885 IsAnnotatedSpan(28, 55, "address"),
886 IsAnnotatedSpan(79, 91, "phone"),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100887 }));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100888
889 AnnotationOptions options;
890 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
891 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Tony Makd9446602019-02-20 18:25:39 +0000892 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
893 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200894 // Try passing invalid utf8.
895 EXPECT_TRUE(
896 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
897 .empty());
Lukas Zilka21d8c982018-01-24 11:11:20 +0100898}
899
Tony Maka0f598b2018-11-20 20:39:04 +0000900
Tony Mak6c4cc672018-09-17 11:48:50 +0100901TEST_P(AnnotatorTest, AnnotateSmallBatches) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100902 const std::string test_model = ReadFile(GetModelPath() + GetParam());
903 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
904
905 // Set the batch size.
906 unpacked_model->selection_options->batch_size = 4;
907 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000908 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100909
Tony Mak6c4cc672018-09-17 11:48:50 +0100910 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
911 reinterpret_cast<const char*>(builder.GetBufferPointer()),
912 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100913 ASSERT_TRUE(classifier);
914
915 const std::string test_string =
916 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
917 "number is 853 225 3556";
918 EXPECT_THAT(classifier->Annotate(test_string),
919 ElementsAreArray({
Lukas Zilkab23e2122018-02-09 10:25:19 +0100920 IsAnnotatedSpan(28, 55, "address"),
921 IsAnnotatedSpan(79, 91, "phone"),
922 }));
923
924 AnnotationOptions options;
925 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
926 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Tony Makd9446602019-02-20 18:25:39 +0000927 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
928 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100929}
930
Tony Maka0f598b2018-11-20 20:39:04 +0000931#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +0100932TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100933 const std::string test_model = ReadFile(GetModelPath() + GetParam());
934 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
935
Lukas Zilkab23e2122018-02-09 10:25:19 +0100936 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100937 // Add test threshold.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100938 unpacked_model->triggering_options->min_annotate_confidence =
939 2.f; // Discards all results.
940 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000941 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100942
Tony Mak6c4cc672018-09-17 11:48:50 +0100943 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
944 reinterpret_cast<const char*>(builder.GetBufferPointer()),
945 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100946 ASSERT_TRUE(classifier);
947
948 const std::string test_string =
949 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
950 "number is 853 225 3556";
Lukas Zilkaba849e72018-03-08 14:48:21 +0100951
Tony Mak6c4cc672018-09-17 11:48:50 +0100952 EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100953}
Tony Maka0f598b2018-11-20 20:39:04 +0000954#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +0100955
Tony Mak6c4cc672018-09-17 11:48:50 +0100956TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) {
Lukas Zilkab23e2122018-02-09 10:25:19 +0100957 const std::string test_model = ReadFile(GetModelPath() + GetParam());
958 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
959
960 // Add test thresholds.
961 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
962 unpacked_model->triggering_options->min_annotate_confidence =
963 0.f; // Keeps all results.
Lukas Zilkaba849e72018-03-08 14:48:21 +0100964 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100965 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000966 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkab23e2122018-02-09 10:25:19 +0100967
Tony Mak6c4cc672018-09-17 11:48:50 +0100968 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
969 reinterpret_cast<const char*>(builder.GetBufferPointer()),
970 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100971 ASSERT_TRUE(classifier);
972
973 const std::string test_string =
974 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
975 "number is 853 225 3556";
Lukas Zilkab23e2122018-02-09 10:25:19 +0100976 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100977}
978
Tony Mak6c4cc672018-09-17 11:48:50 +0100979TEST_P(AnnotatorTest, AnnotateDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +0100980 const std::string test_model = ReadFile(GetModelPath() + GetParam());
981 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
982
983 // Disable the model for annotation.
984 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
985 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +0000986 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +0100987
Tony Mak6c4cc672018-09-17 11:48:50 +0100988 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
989 reinterpret_cast<const char*>(builder.GetBufferPointer()),
990 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +0100991 ASSERT_TRUE(classifier);
992 const std::string test_string =
993 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
994 "number is 853 225 3556";
995 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
996}
997
Tony Mak6c4cc672018-09-17 11:48:50 +0100998TEST_P(AnnotatorTest, AnnotateFilteredCollections) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +0200999 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1000
Tony Mak6c4cc672018-09-17 11:48:50 +01001001 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1002 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001003 ASSERT_TRUE(classifier);
1004
1005 const std::string test_string =
1006 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1007 "number is 853 225 3556";
1008
1009 EXPECT_THAT(classifier->Annotate(test_string),
1010 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001011 IsAnnotatedSpan(28, 55, "address"),
1012 IsAnnotatedSpan(79, 91, "phone"),
1013 }));
1014
1015 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1016 unpacked_model->output_options.reset(new OutputOptionsT);
1017
1018 // Disable phone annotation
1019 unpacked_model->output_options->filtered_collections_annotation.push_back(
1020 "phone");
1021
1022 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001023 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001024
Tony Mak6c4cc672018-09-17 11:48:50 +01001025 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001026 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001027 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001028 ASSERT_TRUE(classifier);
1029
1030 EXPECT_THAT(classifier->Annotate(test_string),
1031 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001032 IsAnnotatedSpan(28, 55, "address"),
1033 }));
1034}
1035
Tony Maka0f598b2018-11-20 20:39:04 +00001036#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001037TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001038 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1039
Tony Mak6c4cc672018-09-17 11:48:50 +01001040 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1041 test_model.c_str(), test_model.size(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001042 ASSERT_TRUE(classifier);
1043
1044 const std::string test_string =
1045 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1046 "number is 853 225 3556";
1047
1048 EXPECT_THAT(classifier->Annotate(test_string),
1049 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001050 IsAnnotatedSpan(28, 55, "address"),
1051 IsAnnotatedSpan(79, 91, "phone"),
1052 }));
1053
1054 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1055 unpacked_model->output_options.reset(new OutputOptionsT);
1056
1057 // We add a custom annotator that wins against the phone classification
1058 // below and that we subsequently suppress.
1059 unpacked_model->output_options->filtered_collections_annotation.push_back(
1060 "suppress");
1061
1062 unpacked_model->regex_model->patterns.push_back(MakePattern(
1063 "suppress", "(\\d{3} ?\\d{4})",
1064 /*enabled_for_classification=*/false,
1065 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1066
1067 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001068 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001069
Tony Mak6c4cc672018-09-17 11:48:50 +01001070 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001071 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001072 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001073 ASSERT_TRUE(classifier);
1074
1075 EXPECT_THAT(classifier->Annotate(test_string),
1076 ElementsAreArray({
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001077 IsAnnotatedSpan(28, 55, "address"),
1078 }));
1079}
Tony Maka0f598b2018-11-20 20:39:04 +00001080#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001081
Tony Maka0f598b2018-11-20 20:39:04 +00001082#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001083TEST_P(AnnotatorTest, ClassifyTextDate) {
1084 std::unique_ptr<Annotator> classifier =
1085 Annotator::FromPath(GetModelPath() + GetParam());
Lukas Zilkab23e2122018-02-09 10:25:19 +01001086 EXPECT_TRUE(classifier);
1087
1088 std::vector<ClassificationResult> result;
1089 ClassificationOptions options;
1090
1091 options.reference_timezone = "Europe/Zurich";
1092 result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001093 ASSERT_EQ(result.size(), 1);
1094 EXPECT_THAT(result[0].collection, "date");
1095 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
1096 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1097 DatetimeGranularity::GRANULARITY_DAY);
1098 result.clear();
1099
1100 options.reference_timezone = "America/Los_Angeles";
1101 result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1102 ASSERT_EQ(result.size(), 1);
1103 EXPECT_THAT(result[0].collection, "date");
1104 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
1105 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1106 DatetimeGranularity::GRANULARITY_DAY);
1107 result.clear();
1108
1109 options.reference_timezone = "America/Los_Angeles";
1110 result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
Tony Mak854015a2019-01-16 15:56:48 +00001111 ASSERT_EQ(result.size(), 2); // Has 2 interpretations - a.m. or p.m.
Tony Mak5dc5e112019-02-01 14:52:10 +00001112 EXPECT_THAT(result[0].collection, "datetime");
Lukas Zilkab23e2122018-02-09 10:25:19 +01001113 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
1114 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1115 DatetimeGranularity::GRANULARITY_SECOND);
Tony Mak5dc5e112019-02-01 14:52:10 +00001116 EXPECT_THAT(result[1].collection, "datetime");
Tony Mak854015a2019-01-16 15:56:48 +00001117 EXPECT_EQ(result[1].datetime_parse_result.time_ms_utc, 1514874620000);
1118 EXPECT_EQ(result[1].datetime_parse_result.granularity,
1119 DatetimeGranularity::GRANULARITY_SECOND);
1120 result.clear();
1121
1122 options.reference_timezone = "America/Los_Angeles";
1123 result = classifier->ClassifyText("2018/01/01 22:00", {0, 16}, options);
1124 ASSERT_EQ(result.size(), 1); // Has only 1 interpretation - 10 p.m.
Tony Mak5dc5e112019-02-01 14:52:10 +00001125 EXPECT_THAT(result[0].collection, "datetime");
Tony Mak854015a2019-01-16 15:56:48 +00001126 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514872800000);
1127 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1128 DatetimeGranularity::GRANULARITY_MINUTE);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001129 result.clear();
1130
1131 // Date on another line.
1132 options.reference_timezone = "Europe/Zurich";
1133 result = classifier->ClassifyText(
1134 "hello world this is the first line\n"
1135 "january 1, 2017",
1136 {35, 50}, options);
1137 ASSERT_EQ(result.size(), 1);
1138 EXPECT_THAT(result[0].collection, "date");
1139 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
1140 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1141 DatetimeGranularity::GRANULARITY_DAY);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001142}
Tony Maka0f598b2018-11-20 20:39:04 +00001143#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001144
Tony Maka0f598b2018-11-20 20:39:04 +00001145#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001146TEST_P(AnnotatorTest, ClassifyTextDatePriorities) {
1147 std::unique_ptr<Annotator> classifier =
1148 Annotator::FromPath(GetModelPath() + GetParam());
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001149 EXPECT_TRUE(classifier);
1150
1151 std::vector<ClassificationResult> result;
1152 ClassificationOptions options;
1153
1154 result.clear();
1155 options.reference_timezone = "Europe/Zurich";
1156 options.locales = "en-US";
Lukas Zilka434442d2018-04-25 11:38:51 +02001157 result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001158
1159 ASSERT_EQ(result.size(), 1);
1160 EXPECT_THAT(result[0].collection, "date");
1161 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
1162 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1163 DatetimeGranularity::GRANULARITY_DAY);
1164
1165 result.clear();
1166 options.reference_timezone = "Europe/Zurich";
Lukas Zilka434442d2018-04-25 11:38:51 +02001167 options.locales = "de";
1168 result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001169
1170 ASSERT_EQ(result.size(), 1);
1171 EXPECT_THAT(result[0].collection, "date");
1172 EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
1173 EXPECT_EQ(result[0].datetime_parse_result.granularity,
1174 DatetimeGranularity::GRANULARITY_DAY);
1175}
Tony Maka0f598b2018-11-20 20:39:04 +00001176#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001177
Tony Maka0f598b2018-11-20 20:39:04 +00001178#ifdef TC3_CALENDAR_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001179TEST_P(AnnotatorTest, SuggestTextDateDisabled) {
Lukas Zilkaba849e72018-03-08 14:48:21 +01001180 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1181 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1182
1183 // Disable the patterns for selection.
1184 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1185 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1186 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1187 }
1188 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001189 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Lukas Zilkaba849e72018-03-08 14:48:21 +01001190
Tony Mak6c4cc672018-09-17 11:48:50 +01001191 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1192 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1193 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001194 ASSERT_TRUE(classifier);
1195 EXPECT_EQ("date",
1196 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1197 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1198 std::make_pair(0, 7));
1199 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1200 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1201}
Tony Maka0f598b2018-11-20 20:39:04 +00001202#endif // TC3_UNILIB_ICU
Lukas Zilkab23e2122018-02-09 10:25:19 +01001203
Tony Mak6c4cc672018-09-17 11:48:50 +01001204class TestingAnnotator : public Annotator {
Lukas Zilkab23e2122018-02-09 10:25:19 +01001205 public:
Tony Mak6c4cc672018-09-17 11:48:50 +01001206 TestingAnnotator(const std::string& model, const UniLib* unilib,
1207 const CalendarLib* calendarlib)
Tony Mak854015a2019-01-16 15:56:48 +00001208 : Annotator(libtextclassifier3::ViewModel(model.data(), model.size()),
1209 unilib, calendarlib) {}
Lukas Zilkab23e2122018-02-09 10:25:19 +01001210
Tony Mak6c4cc672018-09-17 11:48:50 +01001211 using Annotator::ResolveConflicts;
Lukas Zilkab23e2122018-02-09 10:25:19 +01001212};
1213
1214AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
1215 const std::string& collection,
1216 const float score) {
1217 AnnotatedSpan result;
1218 result.span = span;
1219 result.classification.push_back({collection, score});
1220 return result;
1221}
1222
Tony Mak6c4cc672018-09-17 11:48:50 +01001223TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1224 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001225
1226 std::vector<AnnotatedSpan> candidates{
1227 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1228
1229 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001230 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001231 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001232 EXPECT_THAT(chosen, ElementsAreArray({0}));
1233}
1234
Tony Mak6c4cc672018-09-17 11:48:50 +01001235TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1236 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001237
1238 std::vector<AnnotatedSpan> candidates{{
1239 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1240 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1241 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1242 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1243 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1244 }};
1245
1246 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001247 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001248 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001249 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1250}
1251
Tony Mak6c4cc672018-09-17 11:48:50 +01001252TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1253 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001254
1255 std::vector<AnnotatedSpan> candidates{{
1256 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1257 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1258 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1259 }};
1260
1261 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001262 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001263 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001264 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1265}
1266
Tony Mak6c4cc672018-09-17 11:48:50 +01001267TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1268 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001269
1270 std::vector<AnnotatedSpan> candidates{{
1271 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1272 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1273 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1274 }};
1275
1276 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001277 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001278 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001279 EXPECT_THAT(chosen, ElementsAreArray({1}));
1280}
1281
Tony Mak6c4cc672018-09-17 11:48:50 +01001282TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
1283 TestingAnnotator classifier("", &unilib_, &calendarlib_);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001284
1285 std::vector<AnnotatedSpan> candidates{{
1286 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1287 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
1288 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1289 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
1290 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1291 }};
1292
1293 std::vector<int> chosen;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001294 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
Lukas Zilkaba849e72018-03-08 14:48:21 +01001295 /*interpreter_manager=*/nullptr, &chosen);
Lukas Zilkab23e2122018-02-09 10:25:19 +01001296 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
1297}
Lukas Zilka21d8c982018-01-24 11:11:20 +01001298
Tony Maka0f598b2018-11-20 20:39:04 +00001299#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001300TEST_P(AnnotatorTest, LongInput) {
1301 std::unique_ptr<Annotator> classifier =
1302 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkadf710db2018-02-27 12:44:09 +01001303 ASSERT_TRUE(classifier);
1304
1305 for (const auto& type_value_pair :
1306 std::vector<std::pair<std::string, std::string>>{
1307 {"address", "350 Third Street, Cambridge"},
1308 {"phone", "123 456-7890"},
1309 {"url", "www.google.com"},
1310 {"email", "someone@gmail.com"},
1311 {"flight", "LX 38"},
1312 {"date", "September 1, 2018"}}) {
1313 const std::string input_100k = std::string(50000, ' ') +
1314 type_value_pair.second +
1315 std::string(50000, ' ');
1316 const int value_length = type_value_pair.second.size();
1317
1318 EXPECT_THAT(classifier->Annotate(input_100k),
1319 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
1320 type_value_pair.first)}));
1321 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
1322 std::make_pair(50000, 50000 + value_length));
1323 EXPECT_EQ(type_value_pair.first,
1324 FirstResult(classifier->ClassifyText(
1325 input_100k, {50000, 50000 + value_length})));
1326 }
1327}
Tony Maka0f598b2018-11-20 20:39:04 +00001328#endif // TC3_UNILIB_ICU
Lukas Zilkadf710db2018-02-27 12:44:09 +01001329
Tony Maka0f598b2018-11-20 20:39:04 +00001330#ifdef TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001331// These coarse tests are there only to make sure the execution happens in
1332// reasonable amount of time.
Tony Mak6c4cc672018-09-17 11:48:50 +01001333TEST_P(AnnotatorTest, LongInputNoResultCheck) {
1334 std::unique_ptr<Annotator> classifier =
1335 Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
Lukas Zilkaba849e72018-03-08 14:48:21 +01001336 ASSERT_TRUE(classifier);
1337
1338 for (const std::string& value :
1339 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
1340 const std::string input_100k =
1341 std::string(50000, ' ') + value + std::string(50000, ' ');
1342 const int value_length = value.size();
1343
1344 classifier->Annotate(input_100k);
1345 classifier->SuggestSelection(input_100k, {50000, 50001});
1346 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
1347 }
1348}
Tony Maka0f598b2018-11-20 20:39:04 +00001349#endif // TC3_UNILIB_ICU
Lukas Zilkaba849e72018-03-08 14:48:21 +01001350
Tony Maka0f598b2018-11-20 20:39:04 +00001351#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001352TEST_P(AnnotatorTest, MaxTokenLength) {
Lukas Zilka434442d2018-04-25 11:38:51 +02001353 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1354 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1355
Tony Mak6c4cc672018-09-17 11:48:50 +01001356 std::unique_ptr<Annotator> classifier;
Lukas Zilka434442d2018-04-25 11:38:51 +02001357
1358 // With unrestricted number of tokens should behave normally.
1359 unpacked_model->classification_options->max_num_tokens = -1;
1360
1361 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001362 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001363 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001364 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001365 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001366 ASSERT_TRUE(classifier);
1367
1368 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1369 "I live at 350 Third Street, Cambridge.", {10, 37})),
1370 "address");
1371
1372 // Raise the maximum number of tokens to suppress the classification.
1373 unpacked_model->classification_options->max_num_tokens = 3;
1374
1375 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001376 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001377 classifier = Annotator::FromUnownedBuffer(
Lukas Zilka434442d2018-04-25 11:38:51 +02001378 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001379 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilka434442d2018-04-25 11:38:51 +02001380 ASSERT_TRUE(classifier);
1381
1382 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1383 "I live at 350 Third Street, Cambridge.", {10, 37})),
1384 "other");
1385}
Tony Maka0f598b2018-11-20 20:39:04 +00001386#endif // TC3_UNILIB_ICU
Lukas Zilka434442d2018-04-25 11:38:51 +02001387
Tony Maka0f598b2018-11-20 20:39:04 +00001388#ifdef TC3_UNILIB_ICU
Tony Mak6c4cc672018-09-17 11:48:50 +01001389TEST_P(AnnotatorTest, MinAddressTokenLength) {
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001390 const std::string test_model = ReadFile(GetModelPath() + GetParam());
1391 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1392
Tony Mak6c4cc672018-09-17 11:48:50 +01001393 std::unique_ptr<Annotator> classifier;
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001394
1395 // With unrestricted number of address tokens should behave normally.
1396 unpacked_model->classification_options->address_min_num_tokens = 0;
1397
1398 flatbuffers::FlatBufferBuilder builder;
Tony Mak51a9e542018-11-02 13:36:22 +00001399 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001400 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001401 reinterpret_cast<const char*>(builder.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001402 builder.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001403 ASSERT_TRUE(classifier);
1404
1405 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1406 "I live at 350 Third Street, Cambridge.", {10, 37})),
1407 "address");
1408
1409 // Raise number of address tokens to suppress the address classification.
1410 unpacked_model->classification_options->address_min_num_tokens = 5;
1411
1412 flatbuffers::FlatBufferBuilder builder2;
Tony Mak51a9e542018-11-02 13:36:22 +00001413 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
Tony Mak6c4cc672018-09-17 11:48:50 +01001414 classifier = Annotator::FromUnownedBuffer(
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001415 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
Tony Mak6c4cc672018-09-17 11:48:50 +01001416 builder2.GetSize(), &unilib_, &calendarlib_);
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001417 ASSERT_TRUE(classifier);
1418
1419 EXPECT_EQ(FirstResult(classifier->ClassifyText(
1420 "I live at 350 Third Street, Cambridge.", {10, 37})),
1421 "other");
1422}
Tony Maka0f598b2018-11-20 20:39:04 +00001423#endif // TC3_UNILIB_ICU
Lukas Zilkae7962cc2018-03-28 18:09:48 +02001424
Tony Mak854015a2019-01-16 15:56:48 +00001425TEST_F(AnnotatorTest, VisitAnnotatorModel) {
1426 EXPECT_TRUE(VisitAnnotatorModel<bool>(GetModelPath() + "test_model.fb",
1427 [](const Model* model) {
1428 if (model == nullptr) {
1429 return false;
1430 }
1431 return true;
1432 }));
1433 EXPECT_FALSE(VisitAnnotatorModel<bool>(
1434 GetModelPath() + "non_existing_model.fb", [](const Model* model) {
1435 if (model == nullptr) {
1436 return false;
1437 }
1438 return true;
1439 }));
1440}
1441
Lukas Zilka21d8c982018-01-24 11:11:20 +01001442} // namespace
Tony Mak6c4cc672018-09-17 11:48:50 +01001443} // namespace libtextclassifier3