Import libtextclassifier
Test: atest atest framework/base/core/tests/coretests/src/android/view/textclassifier/
Change-Id: I4255dcb44bdef06448d436c4166483eba46cf264
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
index b5198d4..d807ad8 100644
--- a/annotator/annotator_test.cc
+++ b/annotator/annotator_test.cc
@@ -55,6 +55,56 @@
return TC3_TEST_DATA_DIR;
}
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model) {
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("first_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("is_alive"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("last_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8),
+ };
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ unpacked_model->entity_data_schema.assign(
+ schema_builder.GetBufferPointer(),
+ schema_builder.GetBufferPointer() + schema_builder.GetSize());
+}
+
class AnnotatorTest : public ::testing::TestWithParam<const char*> {
protected:
AnnotatorTest()
@@ -70,8 +120,6 @@
EXPECT_FALSE(classifier);
}
-INSTANTIATE_TEST_SUITE_P(ClickContext, AnnotatorTest,
- Values("test_model_cc.fb"));
INSTANTIATE_TEST_SUITE_P(BoundsSensitive, AnnotatorTest,
Values("test_model.fb"));
@@ -266,6 +314,73 @@
"www.google.com every today!|Call me at (800) 123-456 today.",
{51, 65})));
}
+
+TEST_P(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "(Barack) (Obama)", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+
+ // Use meta data to generate custom serialized entity data.
+ ReflectiveFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ unpacked_model->entity_data_schema.data()));
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("is_alive", true);
+
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->serialized_entity_data = entity_data->Serialize();
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ // Group 0 is the full match, capturing groups starting at 1.
+ pattern->capturing_group[1]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[1]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
+ "first_name";
+ pattern->capturing_group[2]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[2]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
+ "last_name";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ auto classifications = classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+}
#endif // TC3_UNILIB_ICU
#ifdef TC3_UNILIB_ICU
@@ -626,7 +741,7 @@
SelectionOptions options;
EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
- std::make_pair(0, 7));
+ std::make_pair(0, 12));
}
TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
@@ -774,8 +889,8 @@
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
-
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
// Try passing invalid utf8.
EXPECT_TRUE(
classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
@@ -809,7 +924,8 @@
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
}
#ifdef TC3_UNILIB_ICU