Fixes utf8 handling, datetime model and flight number model, makes
models smaller by compressing the regex rules, and adds i18n models.
(sync from google3)
Test: bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Test: bit CtsViewTestCases:android.view.textclassifier.cts.TextClassificationManagerTest
Bug: 64929062
Bug: 77223425
Change-Id: I04472e6b247e824bf2b745077c50fcde4269aefc
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 74534e2..440cedf 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -104,6 +104,9 @@
FirstResult(classifier->ClassifyText("", {0, 0})));
EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
+ // Test invalid utf8 input.
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "\xf0\x9f\x98\x8b\x8b", {0, 0})));
}
TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
@@ -150,6 +153,41 @@
IsEmpty());
}
+TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone classification
+ unpacked_model->output_options->filtered_collections_classification.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // Check that the address classification still passes.
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+}
+
std::unique_ptr<RegexModel_::PatternT> MakePattern(
const std::string& collection_name, const std::string& pattern,
const bool enabled_for_classification, const bool enabled_for_selection,
@@ -225,7 +263,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -260,7 +297,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -294,7 +330,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -328,7 +363,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
@@ -469,6 +503,45 @@
IsEmpty());
}
+TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone selection
+ unpacked_model->output_options->filtered_collections_selection.push_back(
+ "phone");
+ // We need to force this for filtering.
+ unpacked_model->selection_options->always_classify_suggested_selection = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ // Address selection should still work.
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ std::make_pair(0, 27));
+}
+
TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
@@ -548,6 +621,91 @@
std::make_pair(-10, -1));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
std::make_pair(100, 17));
+
+ // Try passing invalid utf8.
+ EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
+ std::make_pair(-1, -1));
+}
+
+TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) {
+ CREATE_UNILIB_FOR_TESTING;
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
+ std::make_pair(10, 11));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
+ {14, 17}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
+ std::make_pair(14, 40));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
+ std::make_pair(4, 5));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
+ std::make_pair(7, 8));
+
+ // With a punctuation around the selected whitespace.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
+ std::make_pair(14, 41));
+
+ // When all's whitespace, should return the original indices.
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
+ std::make_pair(0, 1));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
+ std::make_pair(0, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
+ std::make_pair(2, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
+ std::make_pair(5, 6));
+}
+
+TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) {
+ CREATE_UNILIB_FOR_TESTING;
+ UnicodeText text;
+
+ text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(3, 4));
+ text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(3, 4));
+
+ // Nothing on the left.
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
+ std::make_pair(0, 1));
+
+ // Whitespace only.
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib),
+ std::make_pair(2, 3));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
+ std::make_pair(0, 1));
}
TEST_P(TextClassifierTest, Annotate) {
@@ -572,6 +730,11 @@
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+
+ // Try passing invalid utf8.
+ EXPECT_TRUE(
+ classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
+ .empty());
}
TEST_P(TextClassifierTest, AnnotateSmallBatches) {
@@ -686,6 +849,104 @@
EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
}
+TEST_P(TextClassifierTest, AnnotateFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone annotation
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // We add a custom annotator that wins against the phone classification
+ // below and that we subsequently suppress.
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "suppress");
+
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "suppress", "(\\d{3} ?\\d{4})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(19, 24, "date"),
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+#endif
+
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
std::unique_ptr<TextClassifier> classifier =
@@ -734,9 +995,43 @@
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
- result.clear();
}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+TEST_P(TextClassifierTest, ClassifyTextDatePriorities) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam());
+ EXPECT_TRUE(classifier);
+
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ result = classifier->ClassifyText("03/05", {0, 5}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-GB,en-US";
+ result = classifier->ClassifyText("03/05", {0, 5}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -789,7 +1084,7 @@
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -807,7 +1102,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
@@ -823,7 +1118,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
@@ -839,7 +1134,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
@@ -857,7 +1152,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
@@ -916,5 +1211,43 @@
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, MinAddressTokenLength) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<TextClassifier> classifier;
+
+ // With unrestricted number of address tokens should behave normally.
+ unpacked_model->classification_options->address_min_num_tokens = 0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise number of address tokens to suppress the address classification.
+ unpacked_model->classification_options->address_min_num_tokens = 5;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
} // namespace
} // namespace libtextclassifier2