Sync with google3 version.

Test: Tested that it works on device.

Bug: 36838725
Change-Id: I61747fc619bd7ee298e828d2fd9705f3531a233f
diff --git a/tests/token-feature-extractor_test.cc b/tests/token-feature-extractor_test.cc
index 7f6ba18..58097bc 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/tests/token-feature-extractor_test.cc
@@ -22,23 +22,48 @@
 namespace libtextclassifier {
 namespace {
 
-TEST(TokenFeatureExtractorTest, Extract) {
+class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
+ public:
+  using TokenFeatureExtractor::TokenFeatureExtractor;
+  using TokenFeatureExtractor::HashToken;
+};
+
+TEST(TokenFeatureExtractorTest, ExtractAscii) {
   TokenFeatureExtractorOptions options;
-  options.num_buckets = 10;
-  options.chargram_orders = std::vector<int>{1, 2};
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2, 3};
   options.extract_case_feature = true;
+  options.unicode_aware_features = false;
   options.extract_selection_mask_feature = true;
-  TokenFeatureExtractor extractor(options);
+  TestingTokenFeatureExtractor extractor(options);
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
 
-  extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+  extractor.Extract(Token{"Hello", 0, 5, true}, &sparse_features,
                     &dense_features);
 
-  EXPECT_THAT(
-      sparse_features,
-      testing::ElementsAreArray({8, 6, 0, 1, 1, 4, 7, 8, 8, 1, 4, 2, 7, 0, 4}));
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({
+                  // clang-format off
+                  extractor.HashToken("H"),
+                  extractor.HashToken("e"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("o"),
+                  extractor.HashToken("^H"),
+                  extractor.HashToken("He"),
+                  extractor.HashToken("el"),
+                  extractor.HashToken("ll"),
+                  extractor.HashToken("lo"),
+                  extractor.HashToken("o$"),
+                  extractor.HashToken("^He"),
+                  extractor.HashToken("Hel"),
+                  extractor.HashToken("ell"),
+                  extractor.HashToken("llo"),
+                  extractor.HashToken("lo$")
+                  // clang-format on
+              }));
   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
 
   sparse_features.clear();
@@ -46,26 +71,274 @@
   extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
                     &dense_features);
 
-  EXPECT_THAT(sparse_features, testing::ElementsAreArray(
-                                   {9, 3, 3, 1, 5, 6, 7, 3, 5, 5, 2, 3, 7}));
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({
+                  // clang-format off
+                  extractor.HashToken("w"),
+                  extractor.HashToken("o"),
+                  extractor.HashToken("r"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("d"),
+                  extractor.HashToken("!"),
+                  extractor.HashToken("^w"),
+                  extractor.HashToken("wo"),
+                  extractor.HashToken("or"),
+                  extractor.HashToken("rl"),
+                  extractor.HashToken("ld"),
+                  extractor.HashToken("d!"),
+                  extractor.HashToken("!$"),
+                  extractor.HashToken("^wo"),
+                  extractor.HashToken("wor"),
+                  extractor.HashToken("orl"),
+                  extractor.HashToken("rld"),
+                  extractor.HashToken("ld!"),
+                  extractor.HashToken("d!$"),
+                  // clang-format on
+              }));
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
 }
 
+TEST(TokenFeatureExtractorTest, ExtractUnicode) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2, 3};
+  options.extract_case_feature = true;
+  options.unicode_aware_features = true;
+  options.extract_selection_mask_feature = true;
+  TestingTokenFeatureExtractor extractor(options);
+
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+
+  extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+                    &dense_features);
+
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({
+                  // clang-format off
+                  extractor.HashToken("H"),
+                  extractor.HashToken("ě"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("ó"),
+                  extractor.HashToken("^H"),
+                  extractor.HashToken("Hě"),
+                  extractor.HashToken("ěl"),
+                  extractor.HashToken("ll"),
+                  extractor.HashToken("ló"),
+                  extractor.HashToken("ó$"),
+                  extractor.HashToken("^Hě"),
+                  extractor.HashToken("Hěl"),
+                  extractor.HashToken("ěll"),
+                  extractor.HashToken("lló"),
+                  extractor.HashToken("ló$")
+                  // clang-format on
+              }));
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+  sparse_features.clear();
+  dense_features.clear();
+  extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+                    &dense_features);
+
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({
+                  // clang-format off
+                  extractor.HashToken("w"),
+                  extractor.HashToken("o"),
+                  extractor.HashToken("r"),
+                  extractor.HashToken("l"),
+                  extractor.HashToken("d"),
+                  extractor.HashToken("!"),
+                  extractor.HashToken("^w"),
+                  extractor.HashToken("wo"),
+                  extractor.HashToken("or"),
+                  extractor.HashToken("rl"),
+                  extractor.HashToken("ld"),
+                  extractor.HashToken("d!"),
+                  extractor.HashToken("!$"),
+                  extractor.HashToken("^wo"),
+                  extractor.HashToken("wor"),
+                  extractor.HashToken("orl"),
+                  extractor.HashToken("rld"),
+                  extractor.HashToken("ld!"),
+                  extractor.HashToken("d!$"),
+                  // clang-format on
+              }));
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2};
+  options.extract_case_feature = true;
+  options.unicode_aware_features = true;
+  options.extract_selection_mask_feature = false;
+  TokenFeatureExtractor extractor(options);
+
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+  sparse_features.clear();
+  dense_features.clear();
+  extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+
+  sparse_features.clear();
+  dense_features.clear();
+  extractor.Extract(Token{"Ř", 23, 29, false}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+  sparse_features.clear();
+  dense_features.clear();
+  extractor.Extract(Token{"ř", 23, 29, false}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, DigitRemapping) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2};
+  options.remap_digits = true;
+  options.unicode_aware_features = false;
+  TokenFeatureExtractor extractor(options);
+
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+                    &dense_features);
+
+  std::vector<int> sparse_features2;
+  extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+                    &dense_features);
+  EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+  extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+                    &dense_features);
+  EXPECT_THAT(sparse_features,
+              testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2};
+  options.remap_digits = true;
+  options.unicode_aware_features = true;
+  TokenFeatureExtractor extractor(options);
+
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+                    &dense_features);
+
+  std::vector<int> sparse_features2;
+  extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+                    &dense_features);
+  EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+  extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+                    &dense_features);
+  EXPECT_THAT(sparse_features,
+              testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST(TokenFeatureExtractorTest, RegexFeatures) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2};
+  options.remap_digits = false;
+  options.unicode_aware_features = false;
+  options.regexp_features.push_back("^[a-z]+$");  // all lower case.
+  options.regexp_features.push_back("^[0-9]+$");  // all digits.
+  TokenFeatureExtractor extractor(options);
+
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"abCde", 0, 6, true}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+  dense_features.clear();
+  extractor.Extract(Token{"abcde", 0, 6, true}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
+
+  dense_features.clear();
+  extractor.Extract(Token{"12c45", 0, 6, true}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+  dense_features.clear();
+  extractor.Extract(Token{"12345", 0, 6, true}, &sparse_features,
+                    &dense_features);
+  EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractInvalidUTF8) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5, 100};
+  options.extract_case_feature = true;
+  options.unicode_aware_features = true;
+  options.extract_selection_mask_feature = true;
+  TestingTokenFeatureExtractor extractor(options);
+
+  // Test that this runs. ASAN should catch problems.
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"\xf0👶👶¾👶🏿\xf0", 0, 7, true},
+                    &sparse_features, &dense_features);
+}
+
+TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
+  TokenFeatureExtractorOptions options;
+  options.num_buckets = 1000;
+  options.chargram_orders = std::vector<int>{22};
+  options.extract_case_feature = true;
+  options.unicode_aware_features = true;
+  options.extract_selection_mask_feature = true;
+  TestingTokenFeatureExtractor extractor(options);
+
+  // Test that this runs. ASAN should catch problems.
+  std::vector<int> sparse_features;
+  std::vector<float> dense_features;
+  extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0, true},
+                    &sparse_features, &dense_features);
+
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({
+                  // clang-format off
+                  extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
+                  extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
+                  // clang-format on
+              }));
+}
+
 TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
   TokenFeatureExtractorOptions options;
-  options.num_buckets = 10;
+  options.num_buckets = 1000;
   options.chargram_orders = std::vector<int>{1, 2};
   options.extract_case_feature = true;
+  options.unicode_aware_features = false;
   options.extract_selection_mask_feature = true;
 
-  TokenFeatureExtractor extractor(options);
+  TestingTokenFeatureExtractor extractor(options);
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
 
   extractor.Extract(Token(), &sparse_features, &dense_features);
 
-  EXPECT_THAT(sparse_features, testing::ElementsAreArray({5}));
+  EXPECT_THAT(sparse_features,
+              testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
 }