Fixes crashes by making native library thread-safe, makes Annotate calls much faster by
re-using tokens, fixes default values in enums in FlatBuffer schema.
Test: bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Test: bit CtsViewTestCases:android.view.textclassifier.cts.TextClassificationManagerTest
Bug: 74193987
Bug: 68239358
Change-Id: Ic5ca42b628280bece59d31203748072084ac452c
(cherry picked from commit 2191547d7109587d73077f9d4818c691f7d7dafb)
Merged-In: Ic5ca42b628280bece59d31203748072084ac452c
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 1145ac5..74534e2 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -30,6 +30,7 @@
namespace {
using testing::ElementsAreArray;
+using testing::IsEmpty;
using testing::Pair;
using testing::Values;
@@ -105,6 +106,50 @@
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
}
+TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->classification_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+
+ // The classification model is still needed for selection scores.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(TextClassifierTest, ClassifyTextDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes =
+ ModeFlag_ANNOTATION_AND_SELECTION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(
+ classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
+ IsEmpty());
+}
+
std::unique_ptr<RegexModel_::PatternT> MakePattern(
const std::string& collection_name, const std::string& pattern,
const bool enabled_for_classification, const bool enabled_for_selection,
@@ -112,9 +157,12 @@
std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
result->collection_name = collection_name;
result->pattern = pattern;
- result->enabled_for_selection = enabled_for_selection;
- result->enabled_for_classification = enabled_for_classification;
- result->enabled_for_annotation = enabled_for_annotation;
+ // We cannot directly operate with |= on the flag, so use an int here.
+ int enabled_modes = ModeFlag_NONE;
+ if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
+ if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
+ if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
+ result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
result->target_classification_score = score;
result->priority_score = score;
return result;
@@ -171,7 +219,6 @@
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
-
TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -308,7 +355,6 @@
IsAnnotatedSpan(79, 91, "phone"),
}));
}
-
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, PhoneFiltering) {
@@ -371,6 +417,58 @@
std::make_pair(11, 12));
}
+TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ // Selection model needs to be present for annotation.
+ ASSERT_FALSE(classifier);
+}
+
+TEST_P(TextClassifierTest, SuggestSelectionDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the selection model.
+ unpacked_model->selection_model.clear();
+ unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "call me at (800) 123-456 today", {11, 24})));
+
+ EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
+ IsEmpty());
+}
+
TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
@@ -510,13 +608,14 @@
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
- // Add test thresholds.
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
+ // Add test threshold.
unpacked_model->triggering_options->min_annotate_confidence =
2.f; // Discards all results.
flatbuffers::FlatBufferBuilder builder;
@@ -531,8 +630,10 @@
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
- EXPECT_TRUE(classifier->Annotate(test_string).empty());
+
+ EXPECT_EQ(classifier->Annotate(test_string).size(), 1);
}
+#endif
TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
CREATE_UNILIB_FOR_TESTING;
@@ -543,6 +644,7 @@
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->min_annotate_confidence =
0.f; // Keeps all results.
+ unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
@@ -563,6 +665,27 @@
#endif
}
+TEST_P(TextClassifierTest, AnnotateDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the model for annotation.
+ unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+ EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
+}
+
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
std::unique_ptr<TextClassifier> classifier =
@@ -613,6 +736,32 @@
DatetimeGranularity::GRANULARITY_DAY);
result.clear();
}
+
+TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Disable the patterns for selection.
+ for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
+ unpacked_model->datetime_model->patterns[i]->enabled_modes =
+ ModeFlag_ANNOTATION_AND_CLASSIFICATION;
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+ EXPECT_EQ("date",
+ FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
+ EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
+ std::make_pair(0, 7));
+ EXPECT_THAT(classifier->Annotate("january 1, 2017"),
+ ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
+}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
class TestingTextClassifier : public TextClassifier {
@@ -640,7 +789,8 @@
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -657,7 +807,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
@@ -672,7 +823,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
@@ -687,7 +839,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
@@ -704,7 +857,8 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"", &chosen);
+ classifier.ResolveConflicts(candidates, /*context=*/"",
+ /*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
@@ -740,5 +894,27 @@
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+// These coarse tests are there only to make sure the execution happens in
+// reasonable amount of time.
+TEST_P(TextClassifierTest, LongInputNoResultCheck) {
+ CREATE_UNILIB_FOR_TESTING;
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ for (const std::string& value :
+ std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
+ const std::string input_100k =
+ std::string(50000, ' ') + value + std::string(50000, ' ');
+ const int value_length = value.size();
+
+ classifier->Annotate(input_100k);
+ classifier->SuggestSelection(input_100k, {50000, 50001});
+ classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
+ }
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
} // namespace
} // namespace libtextclassifier2