Import libtextclassifier changes from google3.
This includes an upgrade to the latest version of the model.
We now use a test data path constant in the unit tests. This is
consistent with another test project, minikin_tests.
Test: Tests pass on-device.
Bug: 34865247
Change-Id: Ia061888a0bba371c6b429ad7c5457af611ba12f2
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index ecdad16..652db84 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -106,113 +106,95 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
- SelectionWithContext selection;
- selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {0, 5};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
// Keeps the first line.
- selection.click_start = 0;
- selection.click_end = 5;
- selection.selection_start = 6;
- selection.selection_end = 10;
-
- SelectionWithContext line_selection;
- int shift;
- std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
- EXPECT_EQ(line_selection.context, "Fiřst Lině");
- EXPECT_EQ(line_selection.click_start, 0);
- EXPECT_EQ(line_selection.click_end, 5);
- EXPECT_EQ(line_selection.selection_start, 6);
- EXPECT_EQ(line_selection.selection_end, 10);
- EXPECT_EQ(shift, 0);
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
- SelectionWithContext selection;
- selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
- // Keeps the second line.
- selection.click_start = 11;
- selection.click_end = 17;
- selection.selection_start = 18;
- selection.selection_end = 22;
-
- SelectionWithContext line_selection;
- int shift;
- std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
- EXPECT_EQ(line_selection.context, "Sěcond Lině");
- EXPECT_EQ(line_selection.click_start, 0);
- EXPECT_EQ(line_selection.click_end, 6);
- EXPECT_EQ(line_selection.selection_start, 7);
- EXPECT_EQ(line_selection.selection_end, 11);
- EXPECT_EQ(shift, 11);
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
- SelectionWithContext selection;
- selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {24, 33};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
- // Keeps the third line.
- selection.click_start = 29;
- selection.click_end = 33;
- selection.selection_start = 23;
- selection.selection_end = 28;
-
- SelectionWithContext line_selection;
- int shift;
- std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
- EXPECT_EQ(line_selection.context, "Thiřd Lině");
- EXPECT_EQ(line_selection.click_start, 6);
- EXPECT_EQ(line_selection.click_end, 10);
- EXPECT_EQ(line_selection.selection_start, 0);
- EXPECT_EQ(line_selection.selection_end, 5);
- EXPECT_EQ(shift, 23);
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
- SelectionWithContext selection;
- selection.context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
- // Keeps the second line.
- selection.click_start = 11;
- selection.click_end = 17;
- selection.selection_start = 18;
- selection.selection_end = 22;
-
- SelectionWithContext line_selection;
- int shift;
- std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
- EXPECT_EQ(line_selection.context, "Sěcond Lině");
- EXPECT_EQ(line_selection.click_start, 0);
- EXPECT_EQ(line_selection.click_end, 6);
- EXPECT_EQ(line_selection.selection_start, 7);
- EXPECT_EQ(line_selection.selection_end, 11);
- EXPECT_EQ(shift, 11);
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
- SelectionWithContext selection;
- selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {5, 23};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23),
+ Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
- // Selects across lines, so KeepLine should not do any changes.
- selection.click_start = 6;
- selection.click_end = 17;
- selection.selection_start = 0;
- selection.selection_end = 22;
-
- SelectionWithContext line_selection;
- int shift;
- std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
- EXPECT_EQ(line_selection.context, "Fiřst Lině\nSěcond Lině\nThiřd Lině");
- EXPECT_EQ(line_selection.click_start, 6);
- EXPECT_EQ(line_selection.click_end, 17);
- EXPECT_EQ(line_selection.selection_start, 0);
- EXPECT_EQ(line_selection.selection_end, 22);
- EXPECT_EQ(shift, 0);
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
TEST(FeatureProcessorTest, GetFeaturesWithContextDropout) {
@@ -231,14 +213,6 @@
config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
FeatureProcessor feature_processor(options);
- SelectionWithContext selection_with_context;
- selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
- // Selection and click indices of the X in the middle:
- selection_with_context.selection_start = 20;
- selection_with_context.selection_end = 21;
- selection_with_context.click_start = 20;
- selection_with_context.click_end = 21;
-
// Test that two subsequent runs with random context dropout produce
// different features.
feature_processor.SetRandom(new std::mt19937);
@@ -251,13 +225,13 @@
CodepointSpan selection_codepoint_label;
int classification_label;
EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- selection_with_context, &features, &extra_features,
- &selection_label_spans, &selection_label, &selection_codepoint_label,
- &classification_label));
+ "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+ &features, &extra_features, &selection_label_spans, &selection_label,
+ &selection_codepoint_label, &classification_label));
EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- selection_with_context, &features2, &extra_features,
- &selection_label_spans, &selection_label, &selection_codepoint_label,
- &classification_label));
+ "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+ &features2, &extra_features, &selection_label_spans, &selection_label,
+ &selection_codepoint_label, &classification_label));
EXPECT_NE(features, features2);
}
@@ -276,14 +250,6 @@
config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
FeatureProcessor feature_processor(options);
- SelectionWithContext selection_with_context;
- selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
- // Selection and click indices of the X in the middle:
- selection_with_context.selection_start = 20;
- selection_with_context.selection_end = 21;
- selection_with_context.click_start = 20;
- selection_with_context.click_end = 21;
-
std::vector<std::vector<std::pair<int, float>>> features;
std::vector<float> extra_features;
std::vector<CodepointSpan> selection_label_spans;
@@ -291,19 +257,14 @@
CodepointSpan selection_codepoint_label;
int classification_label;
EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- selection_with_context, &features, &extra_features,
- &selection_label_spans, &selection_label, &selection_codepoint_label,
- &classification_label));
+ "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+ &features, &extra_features, &selection_label_spans, &selection_label,
+ &selection_codepoint_label, &classification_label));
EXPECT_EQ(19, features.size());
// Should pad the string.
- selection_with_context.context = "X";
- selection_with_context.selection_start = 0;
- selection_with_context.selection_end = 1;
- selection_with_context.click_start = 0;
- selection_with_context.click_end = 1;
EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- selection_with_context, &features, &extra_features,
+ "X", {0, 1}, {0, 1}, "", &features, &extra_features,
&selection_label_spans, &selection_label, &selection_codepoint_label,
&classification_label));
EXPECT_EQ(19, features.size());
@@ -432,5 +393,67 @@
}));
}
+TEST(FeatureProcessorTest, CenterTokenFromClick) {
+ int token_index;
+
+ // Exactly aligned indices.
+ token_index = internal::CenterTokenFromClick(
+ {6, 11}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+ Token("heře!", 12, 17, false)});
+ EXPECT_EQ(token_index, 1);
+
+ // Click is contained in a token.
+ token_index = internal::CenterTokenFromClick(
+ {13, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+ Token("heře!", 12, 17, false)});
+ EXPECT_EQ(token_index, 2);
+
+ // Click spans two tokens.
+ token_index = internal::CenterTokenFromClick(
+ {6, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+ Token("heře!", 12, 17, false)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+ SelectionWithContext selection;
+ int token_index;
+
+ // Selection of length 3. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {7, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+ Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+ Token("Token5", 28, 34, false)});
+ EXPECT_EQ(token_index, 2);
+
+ // Selection of length 1 token. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {21, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+ Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+ Token("Token5", 28, 34, false)});
+ EXPECT_EQ(token_index, 3);
+
+ // Selection marks sub-token range, with no tokens in it.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {29, 33}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+ Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+ Token("Token5", 28, 34, false)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+
+ // Selection of length 2. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {3, 25}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+ Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+ Token("Token5", 28, 34, false)});
+ EXPECT_EQ(token_index, 1);
+
+ // Selection of length 1. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {22, 34}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+ Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+ Token("Token5", 28, 34, false)});
+ EXPECT_EQ(token_index, 4);
+}
+
} // namespace
} // namespace libtextclassifier