Sync of libtextclassifier from Google3.
Exported by: knowledge/cerebra/sense/text_classifier/lib/export_to_aosp.sh
Bug: 67618889
Test: Builds. Tested also with oc-mr1 and tested that smartselect/sharing features work.
Change-Id: I25ad82cdd5eed20c60e83e7eb94dae6ab08b3690
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
index 6490748..990233c 100644
--- a/smartselect/cached-features.h
+++ b/smartselect/cached-features.h
@@ -20,7 +20,6 @@
#include <memory>
#include <vector>
-#include "base.h"
#include "common/vector-span.h"
#include "smartselect/types.h"
diff --git a/smartselect/cached-features_test.cc b/smartselect/cached-features_test.cc
new file mode 100644
index 0000000..b456816
--- /dev/null
+++ b/smartselect/cached-features_test.cc
@@ -0,0 +1,149 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/cached-features.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+class TestingCachedFeatures : public CachedFeatures {
+ public:
+ using CachedFeatures::CachedFeatures;
+ using CachedFeatures::RemapV0FeatureVector;
+};
+
+TEST(CachedFeaturesTest, Simple) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ for (int i = 0; i < sparse_features.size(); ++i) {
+ sparse_features[i].push_back(i);
+ }
+ std::vector<std::vector<float>> dense_features(tokens.size());
+ for (int i = 0; i < dense_features.size(); ++i) {
+ dense_features[i].push_back(-i);
+ }
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features, float* features) {
+ features[0] = sparse_features[0];
+ features[1] = sparse_features[0];
+ features[2] = dense_features[0];
+ features[3] = dense_features[0];
+ features[4] = 123;
+ return true;
+ },
+ 5);
+
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+ for (int i = 0; i < 5; i++) {
+ EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
+ }
+}
+
+TEST(CachedFeaturesTest, InvalidInput) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ std::vector<std::vector<float>> dense_features(tokens.size());
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* features) { return true; },
+ /*feature_vector_size=*/5);
+
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
+ EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+ EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
+}
+
+TEST(CachedFeaturesTest, RemapV0FeatureVector) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ std::vector<std::vector<float>> dense_features(tokens.size());
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* features) { return true; },
+ /*feature_vector_size=*/5);
+
+ std::vector<float> features_orig(5 * 5);
+ for (int i = 0; i < features_orig.size(); i++) {
+ features_orig[i] = i;
+ }
+ VectorSpan<float> features;
+
+ feature_extractor.SetV0FeatureMode(0);
+ features = VectorSpan<float>(features_orig);
+ feature_extractor.RemapV0FeatureVector(&features);
+ EXPECT_EQ(
+ std::vector<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
+ std::vector<float>(features.begin(), features.end()));
+
+ feature_extractor.SetV0FeatureMode(2);
+ features = VectorSpan<float>(features_orig);
+ feature_extractor.RemapV0FeatureVector(&features);
+ EXPECT_EQ(std::vector<float>({0, 1, 5, 6, 10, 11, 15, 16, 20, 21, 2, 3, 4,
+ 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
+ std::vector<float>(features.begin(), features.end()));
+}
+
+} // namespace
+} // namespace libtextclassifier
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 1b15982..08f18ea 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -24,9 +24,11 @@
#include "util/base/logging.h"
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -51,6 +53,10 @@
extractor_options.remap_digits = options.remap_digits();
extractor_options.lowercase_tokens = options.lowercase_tokens();
+ for (const auto& chargram : options.allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram);
+ }
+
return extractor_options;
}
@@ -173,8 +179,10 @@
} // namespace internal
std::string FeatureProcessor::GetDefaultCollection() const {
- if (options_.default_collection() >= options_.collections_size()) {
- TC_LOG(ERROR) << "No collections specified. Returning empty string.";
+ if (options_.default_collection() < 0 ||
+ options_.default_collection() >= options_.collections_size()) {
+ TC_LOG(ERROR)
+ << "Invalid or missing default collection. Returning empty string.";
return "";
}
return options_.collections(options_.default_collection());
@@ -217,18 +225,32 @@
return false;
}
- const int result_begin_token = token_span.first;
- const int result_begin_codepoint =
- tokens[options_.context_size() - result_begin_token].start;
- const int result_end_token = token_span.second;
- const int result_end_codepoint =
- tokens[options_.context_size() + result_end_token].end;
+ const int result_begin_token_index = token_span.first;
+ const Token& result_begin_token =
+ tokens[options_.context_size() - result_begin_token_index];
+ const int result_begin_codepoint = result_begin_token.start;
+ const int result_end_token_index = token_span.second;
+ const Token& result_end_token =
+ tokens[options_.context_size() + result_end_token_index];
+ const int result_end_codepoint = result_end_token.end;
if (result_begin_codepoint == kInvalidIndex ||
result_end_codepoint == kInvalidIndex) {
*span = CodepointSpan({kInvalidIndex, kInvalidIndex});
} else {
- *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
+ const UnicodeText token_begin_unicode =
+ UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
+ const UnicodeText token_end_unicode =
+ UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_end = token_end_unicode.end();
+
+ const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true);
+ const int end_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_end_unicode.begin(), token_end, /*count_from_beginning=*/false);
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
}
return true;
}
@@ -274,14 +296,28 @@
// Check that the spanned tokens cover the whole span.
bool tokens_match_span;
+ const CodepointIndex tokens_start = tokens[click_position - span_left].start;
+ const CodepointIndex tokens_end = tokens[click_position + span_right].end;
if (options_.snap_label_span_boundaries_to_containing_tokens()) {
- tokens_match_span =
- tokens[click_position - span_left].start <= span.first &&
- tokens[click_position + span_right].end >= span.second;
+ tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
} else {
- tokens_match_span =
- tokens[click_position - span_left].start == span.first &&
- tokens[click_position + span_right].end == span.second;
+ const UnicodeText token_left_unicode = UTF8ToUnicodeText(
+ tokens[click_position - span_left].value, /*do_copy=*/false);
+ const UnicodeText token_right_unicode = UTF8ToUnicodeText(
+ tokens[click_position + span_right].value, /*do_copy=*/false);
+
+ UnicodeText::const_iterator span_begin = token_left_unicode.begin();
+ UnicodeText::const_iterator span_end = token_right_unicode.end();
+
+ const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
+ const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
+ token_right_unicode.begin(), span_end, /*count_from_beginning=*/false);
+
+ tokens_match_span = tokens_start <= span.first &&
+ tokens_start + num_punctuation_start >= span.first &&
+ tokens_end >= span.second &&
+ tokens_end - num_punctuation_end <= span.second;
}
if (tokens_match_span) {
@@ -453,6 +489,77 @@
});
}
+void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
+ for (const int codepoint : options_.ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const {
+ if (span_start == span_end) {
+ return 0;
+ }
+
+ UnicodeText::const_iterator it;
+ UnicodeText::const_iterator it_last;
+ if (count_from_beginning) {
+ it = span_start;
+ it_last = span_end;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it_last;
+ } else {
+ it = span_end;
+ it_last = span_start;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it;
+ }
+
+ // Move until we encounter a non-ignored character.
+ int num_ignored = 0;
+ while (ignored_span_boundary_codepoints_.find(*it) !=
+ ignored_span_boundary_codepoints_.end()) {
+ ++num_ignored;
+
+ if (it == it_last) {
+ break;
+ }
+
+ if (count_from_beginning) {
+ ++it;
+ } else {
+ --it;
+ }
+ }
+
+ return num_ignored;
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ UnicodeText::const_iterator span_begin = context_unicode.begin();
+ std::advance(span_begin, span.first);
+ UnicodeText::const_iterator span_end = context_unicode.begin();
+ std::advance(span_end, span.second);
+
+ const int start_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/true);
+ const int end_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/false);
+
+ if (span.first + start_offset < span.second - end_offset) {
+ return {span.first + start_offset, span.second - end_offset};
+ } else {
+ return {span.first, span.first};
+ }
+}
+
float FeatureProcessor::SupportedCodepointsRatio(
int click_pos, const std::vector<Token>& tokens) const {
int num_supported = 0;
@@ -614,6 +721,10 @@
}
}
+ if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ relative_click_span = {tokens->size() - 1, tokens->size() - 1};
+ }
+
internal::StripOrPadTokens(relative_click_span, options_.context_size(),
tokens, click_pos);
@@ -621,8 +732,8 @@
const float supported_codepoint_ratio =
SupportedCodepointsRatio(*click_pos, *tokens);
if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
- TC_LOG(INFO) << "Not enough supported codepoints in the context: "
- << supported_codepoint_ratio;
+ TC_VLOG(1) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
return false;
}
}
@@ -658,6 +769,7 @@
bool FeatureProcessor::ICUTokenize(const std::string& context,
std::vector<Token>* result) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
icu::ErrorCode status;
icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
std::unique_ptr<icu::BreakIterator> break_iterator(
@@ -699,6 +811,10 @@
}
return true;
+#else
+ TC_LOG(WARNING) << "Can't tokenize, ICU not supported";
+ return false;
+#endif
}
void FeatureProcessor::InternalRetokenize(const std::string& context,
@@ -758,6 +874,8 @@
// Run the tokenizer and update the token bounds to reflect the offset of the
// substring.
std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
for (Token& token : tokens) {
token.start += span.first;
token.end += span.first;
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2c64b67..a39a789 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -20,6 +20,7 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
#include <memory>
+#include <set>
#include <string>
#include <vector>
@@ -104,6 +105,7 @@
{options.internal_tokenizer_codepoint_ranges().begin(),
options.internal_tokenizer_codepoint_ranges().end()},
&internal_tokenizer_codepoint_ranges_);
+ PrepareIgnoredSpanBoundaryCodepoints();
}
explicit FeatureProcessor(const std::string& serialized_options)
@@ -137,6 +139,8 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
+ // When relative_click_span == {kInvalidIndex, kInvalidIndex} then all tokens
+ // extracted from context will be considered.
bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
TokenSpan relative_click_span,
const FeatureVectorFn& feature_vector_fn,
@@ -155,6 +159,12 @@
return feature_extractor_.DenseFeaturesCount();
}
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -207,6 +217,18 @@
bool IsCodepointInRanges(
int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
// Finds the center token index in tokens vector, using the method defined
// in options_.
int FindCenterToken(CodepointSpan span,
@@ -240,6 +262,10 @@
std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
const FeatureProcessorOptions options_;
// Mapping between token selection spans and labels ids.
diff --git a/smartselect/feature-processor_test.cc b/smartselect/feature-processor_test.cc
new file mode 100644
index 0000000..1a9b9da
--- /dev/null
+++ b/smartselect/feature-processor_test.cc
@@ -0,0 +1,786 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/feature-processor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+
+TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař", 9, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař", 6, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
+
+ // clang-format off
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Hě", 0, 2),
+ Token("lló", 2, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {0, 5};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
+}
+
+TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST(FeatureProcessorTest, KeepLineWithClickThird) {
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {24, 33};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {18, 22};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 11, 17),
+ Token("Lině", 18, 22),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
+}
+
+TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
+ const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+ const CodepointSpan span = {5, 23};
+ // clang-format off
+ std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+ Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23),
+ Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28),
+ Token("Lině", 29, 33)};
+ // clang-format on
+
+ // Keeps the first line.
+ internal::StripTokensFromOtherLines(context, span, &tokens);
+ EXPECT_THAT(tokens, ElementsAreArray(
+ {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+ Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+ Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
+}
+
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::SupportedCodepointsRatio;
+ using FeatureProcessor::IsCodepointInRanges;
+ using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::supported_codepoint_ranges_;
+};
+
+TEST(FeatureProcessorTest, SpanToLabel) {
+ FeatureProcessorOptions options;
+ options.set_context_size(1);
+ options.set_max_selection_span(1);
+ options.set_snap_label_span_boundaries_to_containing_tokens(false);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.set_snap_label_span_boundaries_to_containing_tokens(true);
+ TestingFeatureProcessor feature_processor2(options);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.set_context_size(2);
+ options.set_max_selection_span(2);
+ TestingFeatureProcessor feature_processor3(options);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
+ FeatureProcessorOptions options;
+ options.set_context_size(1);
+ options.set_max_selection_span(1);
+ options.set_snap_label_span_boundaries_to_containing_tokens(false);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
+ ASSERT_EQ(3, tokens.size());
+ int label;
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
+ EXPECT_EQ(kInvalidLabel, label);
+ ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
+ EXPECT_NE(kInvalidLabel, label);
+ TokenSpan token_span;
+ feature_processor.LabelToTokenSpan(label, &token_span);
+ EXPECT_EQ(0, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ // Reconfigure with snapping enabled.
+ options.set_snap_label_span_boundaries_to_containing_tokens(true);
+ TestingFeatureProcessor feature_processor2(options);
+ int label2;
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
+ EXPECT_EQ(label, label2);
+
+ // Cross a token boundary.
+ ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+ ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
+ EXPECT_EQ(kInvalidLabel, label2);
+
+ // Multiple tokens.
+ options.set_context_size(2);
+ options.set_max_selection_span(2);
+ TestingFeatureProcessor feature_processor3(options);
+ tokens = feature_processor3.Tokenize("zero, one, two, three, four");
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
+ EXPECT_NE(kInvalidLabel, label2);
+ feature_processor3.LabelToTokenSpan(label2, &token_span);
+ EXPECT_EQ(1, token_span.first);
+ EXPECT_EQ(0, token_span.second);
+
+ int label3;
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+ ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
+ EXPECT_EQ(label2, label3);
+}
+
+TEST(FeatureProcessorTest, CenterTokenFromClick) {
+ int token_index;
+
+ // Exactly aligned indices.
+ token_index = internal::CenterTokenFromClick(
+ {6, 11},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 1);
+
+ // Click is contained in a token.
+ token_index = internal::CenterTokenFromClick(
+ {13, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, 2);
+
+ // Click spans two tokens.
+ token_index = internal::CenterTokenFromClick(
+ {6, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+ int token_index;
+
+ // Selection of length 3. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {7, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 2);
+
+ // Selection of length 1 token. Exactly aligned indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {21, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 3);
+
+ // Selection marks sub-token range, with no tokens in it.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {29, 33},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, kInvalidIndex);
+
+ // Selection of length 2. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {3, 25},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 1);
+
+ // Selection of length 1. Sub-token indices.
+ token_index = internal::CenterTokenFromMiddleOfSelection(
+ {22, 34},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
+ EXPECT_EQ(token_index, 4);
+
+ // Some invalid ones.
+ token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
+ EXPECT_EQ(token_index, -1);
+}
+
+TEST(FeatureProcessorTest, SupportedCodepointsRatio) {
+ FeatureProcessorOptions options;
+ options.set_context_size(2);
+ options.set_max_selection_span(2);
+ options.set_snap_label_span_boundaries_to_containing_tokens(false);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ FeatureProcessorOptions::CodepointRange* range;
+ range = options.add_supported_codepoint_ranges();
+ range->set_start(0);
+ range->set_end(128);
+
+ range = options.add_supported_codepoint_ranges();
+ range->set_start(10000);
+ range->set_end(10001);
+
+ range = options.add_supported_codepoint_ranges();
+ range->set_start(20000);
+ range->set_end(30000);
+
+ TestingFeatureProcessor feature_processor(options);
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ 1, feature_processor.Tokenize("aaa bbb ccc")),
+ FloatEq(1.0));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ 1, feature_processor.Tokenize("aaa bbb ěěě")),
+ FloatEq(2.0 / 3));
+ EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
+ 1, feature_processor.Tokenize("ěěě řřř ěěě")),
+ FloatEq(0.0));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ -1, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 0, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 127, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 128, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 9999, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 10000, feature_processor.supported_codepoint_ranges_));
+ EXPECT_FALSE(feature_processor.IsCodepointInRanges(
+ 10001, feature_processor.supported_codepoint_ranges_));
+ EXPECT_TRUE(feature_processor.IsCodepointInRanges(
+ 25000, feature_processor.supported_codepoint_ranges_));
+
+ std::vector<Token> tokens;
+ int click_pos;
+ std::vector<float> extra_features;
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ auto feature_fn = [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* embedding) { return true; };
+
+ options.set_min_supported_codepoint_ratio(0.0);
+ TestingFeatureProcessor feature_processor2(options);
+ EXPECT_TRUE(feature_processor2.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+ feature_fn, 2, &tokens,
+ &click_pos, &cached_features));
+
+ options.set_min_supported_codepoint_ratio(0.2);
+ TestingFeatureProcessor feature_processor3(options);
+ EXPECT_TRUE(feature_processor3.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+ feature_fn, 2, &tokens,
+ &click_pos, &cached_features));
+
+ options.set_min_supported_codepoint_ratio(0.5);
+ TestingFeatureProcessor feature_processor4(options);
+ EXPECT_FALSE(feature_processor4.ExtractFeatures(
+ "ěěě řřř eee", {4, 7}, {0, 0}, feature_fn, 2, &tokens, &click_pos,
+ &cached_features));
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the second token nothing should get padded.
+ tokens = tokens_orig;
+ click_index = 2;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the last token tokens should get padded from the right.
+ tokens = tokens_orig;
+ click_index = 12;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left to maximum
+ // context_size.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // Clicking to the middle with enough context should not produce any padding.
+ tokens = tokens_orig;
+ click_index = 6;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0),
+ Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+
+ // Clicking at the end should pad right to maximum context_size.
+ tokens = tokens_orig;
+ click_index = 11;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0),
+ Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+}
+
+TEST(FeatureProcessorTest, ICUTokenize) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::ICU);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize("พระบาทสมเด็จพระปรมิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token("สมเด็จ", 6, 12),
+ Token("พระ", 12, 15),
+ Token("ปร", 15, 17),
+ Token("มิ", 17, 19)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::ICU);
+ options.set_icu_preserve_whitespace_tokens(true);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens =
+ feature_processor.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token(" ", 6, 7),
+ Token("สมเด็จ", 7, 13),
+ Token(" ", 13, 14),
+ Token("พระ", 14, 17),
+ Token(" ", 17, 18),
+ Token("ปร", 18, 20),
+ Token(" ", 20, 21),
+ Token("มิ", 21, 23)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, MixedTokenize) {
+ FeatureProcessorOptions options;
+ options.set_tokenization_type(
+ libtextclassifier::FeatureProcessorOptions::MIXED);
+
+ TokenizationCodepointRange* config =
+ options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ FeatureProcessorOptions::CodepointRange* range;
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(0);
+ range->set_end(128);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(128);
+ range->set_end(256);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(256);
+ range->set_end(384);
+
+ range = options.add_internal_tokenizer_codepoint_ranges();
+ range->set_start(384);
+ range->set_end(592);
+
+ TestingFeatureProcessor feature_processor(options);
+ std::vector<Token> tokens = feature_processor.Tokenize(
+ "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
+ ASSERT_EQ(tokens,
+ // clang-format off
+ std::vector<Token>({Token("こんにちは", 0, 5),
+ Token("Japanese-ląnguagę", 5, 22),
+ Token("text", 23, 27),
+ Token("世界", 28, 30),
+ Token("http://www.google.com/", 31, 53)}));
+ // clang-format on
+}
+
+TEST(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
+ FeatureProcessorOptions options;
+ options.add_ignored_span_boundary_codepoints('.');
+ options.add_ignored_span_boundary_codepoints(',');
+ options.add_ignored_span_boundary_codepoints('[');
+ options.add_ignored_span_boundary_codepoints(']');
+
+ TestingFeatureProcessor feature_processor(options);
+
+ const std::string text1_utf8 = "ěščř";
+ const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text1.begin(), text1.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text2_utf8 = ".,abčd";
+ const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text2.begin(), text2.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text3_utf8 = ".,abčd[]";
+ const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text3.begin(), text3.end(),
+ /*count_from_beginning=*/false),
+ 2);
+
+ const std::string text4_utf8 = "[abčd]";
+ const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/true),
+ 1);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text4.begin(), text4.end(),
+ /*count_from_beginning=*/false),
+ 1);
+
+ const std::string text5_utf8 = "";
+ const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text5.begin(), text5.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text6_utf8 = "012345ěščř";
+ const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text6_begin = text6.begin();
+ std::advance(text6_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/true),
+ 0);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text6_begin, text6.end(),
+ /*count_from_beginning=*/false),
+ 0);
+
+ const std::string text7_utf8 = "012345.,ěščř";
+ const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
+ UnicodeText::const_iterator text7_begin = text7.begin();
+ std::advance(text7_begin, 6);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7_begin, text7.end(),
+ /*count_from_beginning=*/true),
+ 2);
+ UnicodeText::const_iterator text7_end = text7.begin();
+ std::advance(text7_end, 8);
+ EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
+ text7.begin(), text7_end,
+ /*count_from_beginning=*/false),
+ 2);
+
+ // Test not stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {0, 24}),
+ std::make_pair(0, 24));
+ // Test basic stripping.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
+ "Hello [[[Wořld]] or not?", {6, 16}),
+ std::make_pair(9, 14));
+ // Test stripping when everything is stripped.
+ EXPECT_EQ(
+ feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
+ std::make_pair(6, 6));
+ // Test stripping empty string.
+ EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
+ std::make_pair(0, 0));
+}
+
+} // namespace
+} // namespace libtextclassifier
diff --git a/smartselect/model-params.cc b/smartselect/model-params.cc
index 9a31bab..65c4f93 100644
--- a/smartselect/model-params.cc
+++ b/smartselect/model-params.cc
@@ -43,6 +43,7 @@
reader.trimmed_proto().GetExtension(feature_processor_extension_id);
// If no tokenization codepoint config is present, tokenize on space.
+ // TODO(zilka): Remove the default config.
if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
TokenizationCodepointRange* config;
// New line character.
@@ -67,42 +68,16 @@
if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
selection_options =
reader.trimmed_proto().GetExtension(selection_options_extension_id);
- } else {
- // Default values when SelectionModelOptions is not present.
- for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
- {{33, 35}, {37, 39}, {42, 42}, {44, 47},
- {58, 59}, {63, 64}, {91, 93}, {95, 95},
- {123, 123}, {125, 125}, {161, 161}, {171, 171},
- {183, 183}, {187, 187}, {191, 191}, {894, 894},
- {903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470},
- {1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524},
- {1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645},
- {1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416},
- {3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858},
- {3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175},
- {4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788},
- {5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106},
- {6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687},
- {8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286},
- {8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142},
- {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
- {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
- {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
- {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
- {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
- {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
- {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
- {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
- {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
- {68176, 68184}})) {
- for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
- selection_options.add_punctuation_to_strip(i);
- }
- selection_options.set_strip_punctuation(true);
- selection_options.set_enforce_symmetry(true);
- selection_options.set_symmetry_context_size(
- feature_processor_options.context_size() * 2);
+
+ // For backward compatibility with the current models.
+ if (!feature_processor_options.ignored_span_boundary_codepoints_size()) {
+ *feature_processor_options.mutable_ignored_span_boundary_codepoints() =
+ selection_options.deprecated_punctuation_to_strip();
}
+ } else {
+ selection_options.set_enforce_symmetry(true);
+ selection_options.set_symmetry_context_size(
+ feature_processor_options.context_size() * 2);
}
SharingModelOptions sharing_options;
diff --git a/smartselect/model-parser.cc b/smartselect/model-parser.cc
new file mode 100644
index 0000000..0cf05e3
--- /dev/null
+++ b/smartselect/model-parser.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/model-parser.h"
+#include "util/base/endian.h"
+
+namespace libtextclassifier {
+namespace {
+
+// Small helper class for parsing the merged model format.
+// The merged model consists of interleaved <int32 data_size, char* data>
+// segments.
+class MergedModelParser {
+ public:
+ MergedModelParser(const void* addr, const int size)
+ : addr_(reinterpret_cast<const char*>(addr)), size_(size), pos_(addr_) {}
+
+ bool ReadBytesAndAdvance(int num_bytes, const char** result) {
+ const char* read_addr = pos_;
+ if (Advance(num_bytes)) {
+ *result = read_addr;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool ReadInt32AndAdvance(int* result) {
+ const char* read_addr = pos_;
+ if (Advance(sizeof(int))) {
+ *result =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(read_addr));
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool IsDone() { return pos_ == addr_ + size_; }
+
+ private:
+ bool Advance(int num_bytes) {
+ pos_ += num_bytes;
+ return pos_ <= addr_ + size_;
+ }
+
+ const char* addr_;
+ const int size_;
+ const char* pos_;
+};
+
+} // namespace
+
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length) {
+ MergedModelParser parser(addr, size);
+
+ if (!parser.ReadInt32AndAdvance(selection_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*selection_model_length, selection_model)) {
+ return false;
+ }
+
+ if (!parser.ReadInt32AndAdvance(sharing_model_length)) {
+ return false;
+ }
+
+ if (!parser.ReadBytesAndAdvance(*sharing_model_length, sharing_model)) {
+ return false;
+ }
+
+ return parser.IsDone();
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/model-parser.h b/smartselect/model-parser.h
new file mode 100644
index 0000000..801262f
--- /dev/null
+++ b/smartselect/model-parser.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
+
+namespace libtextclassifier {
+
+// Parse a merged model image.
+bool ParseMergedModel(const void* addr, const int size,
+ const char** selection_model, int* selection_model_length,
+ const char** sharing_model, int* sharing_model_length);
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARSER_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index dee8f8b..3e5068d 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -16,6 +16,7 @@
#include "smartselect/text-classification-model.h"
+#include <cctype>
#include <cmath>
#include <iterator>
#include <numeric>
@@ -26,10 +27,14 @@
#include "common/memory_image/memory-image-reader.h"
#include "common/mmap.h"
#include "common/softmax.h"
+#include "smartselect/model-parser.h"
#include "smartselect/text-classification-model.pb.h"
#include "util/base/logging.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -49,65 +54,60 @@
const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
if (i >= selection_indices.first && i < selection_indices.second &&
- u_isdigit(*it)) {
+ isdigit(*it)) {
++count;
}
}
return count;
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) {
+ const icu::UnicodeString unicode_context(context.c_str(), context.size(),
+ "utf-8");
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexMatcher> matcher(
+ regex->matcher(unicode_context, status));
+ return matcher->matches(0 /* start */, status);
+}
+#endif
+
} // namespace
-CodepointSpan TextClassificationModel::StripPunctuation(
- CodepointSpan selection, const std::string& context) const {
- UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
- int context_length =
- std::distance(context_unicode.begin(), context_unicode.end());
-
- // Check that the indices are valid.
- if (selection.first < 0 || selection.first > context_length ||
- selection.second < 0 || selection.second > context_length) {
- return selection;
- }
-
- // Move the left border until we encounter a non-punctuation character.
- UnicodeText::const_iterator it_from_begin = context_unicode.begin();
- std::advance(it_from_begin, selection.first);
- for (; punctuation_to_strip_.find(*it_from_begin) !=
- punctuation_to_strip_.end();
- ++it_from_begin, ++selection.first) {
- }
-
- // Unless we are already at the end, move the right border until we encounter
- // a non-punctuation character.
- UnicodeText::const_iterator it_from_end = context_unicode.begin();
- std::advance(it_from_end, selection.second);
- if (it_from_begin != it_from_end) {
- --it_from_end;
- for (; punctuation_to_strip_.find(*it_from_end) !=
- punctuation_to_strip_.end();
- --it_from_end, --selection.second) {
- }
- return selection;
- } else {
- // When the token is all punctuation.
- return {0, 0};
- }
+TextClassificationModel::TextClassificationModel(const std::string& path)
+ : mmap_(new nlp_core::ScopedMmap(path)) {
+ InitFromMmap();
}
-TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
- initialized_ = LoadModels(mmap_.handle());
+TextClassificationModel::TextClassificationModel(int fd)
+ : mmap_(new nlp_core::ScopedMmap(fd)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(int fd, int offset, int size)
+ : mmap_(new nlp_core::ScopedMmap(fd, offset, size)) {
+ InitFromMmap();
+}
+
+TextClassificationModel::TextClassificationModel(const void* addr, int size) {
+ initialized_ = LoadModels(addr, size);
if (!initialized_) {
TC_LOG(ERROR) << "Failed to load models";
return;
}
+}
- selection_options_ = selection_params_->GetSelectionModelOptions();
- for (const int codepoint : selection_options_.punctuation_to_strip()) {
- punctuation_to_strip_.insert(codepoint);
+void TextClassificationModel::InitFromMmap() {
+ if (!mmap_->handle().ok()) {
+ return;
}
- sharing_options_ = selection_params_->GetSharingModelOptions();
+ initialized_ =
+ LoadModels(mmap_->handle().start(), mmap_->handle().num_bytes());
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Failed to load models";
+ return;
+ }
}
namespace {
@@ -151,40 +151,23 @@
};
}
-void ParseMergedModel(const MmapHandle& mmap_handle,
- const char** selection_model, int* selection_model_length,
- const char** sharing_model, int* sharing_model_length) {
- // Read the length of the selection model.
- const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
- *selection_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*selection_model_length);
- *selection_model = model_data;
- model_data += *selection_model_length;
-
- *sharing_model_length =
- LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
- model_data += sizeof(*sharing_model_length);
- *sharing_model = model_data;
-}
-
} // namespace
-bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
- if (!mmap_handle.ok()) {
- return false;
- }
-
+bool TextClassificationModel::LoadModels(const void* addr, int size) {
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(addr, size, &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse the model.";
+ return false;
+ }
selection_params_.reset(
ModelParamsBuilder(selection_model, selection_model_length, nullptr));
if (!selection_params_.get()) {
return false;
}
+ selection_options_ = selection_params_->GetSelectionModelOptions();
selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
selection_feature_processor_.reset(
new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
@@ -197,12 +180,35 @@
if (!sharing_params_.get()) {
return false;
}
+ sharing_options_ = selection_params_->GetSharingModelOptions();
sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
sharing_feature_processor_.reset(
new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
sharing_feature_fn_ = CreateFeatureVectorFn(
*sharing_network_, sharing_network_->EmbeddingSize(0));
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ // Initialize pattern recognizers.
+ for (const auto& regex_pattern : sharing_options_.regex_pattern()) {
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexPattern> compiled_pattern(
+ icu::RegexPattern::compile(
+ icu::UnicodeString(regex_pattern.pattern().c_str(),
+ regex_pattern.pattern().size(), "utf-8"),
+ 0 /* flags */, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
+ } else {
+ regex_patterns_.push_back(
+ {regex_pattern.collection_name(), std::move(compiled_pattern)});
+ }
+ }
+#else
+ if (sharing_options_.regex_pattern_size() > 0) {
+ TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
+ }
+#endif
+
return true;
}
@@ -215,8 +221,12 @@
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
- ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
- &sharing_model, &sharing_model_length);
+ if (!ParseMergedModel(mmap.handle().start(), mmap.handle().num_bytes(),
+ &selection_model, &selection_model_length,
+ &sharing_model, &sharing_model_length)) {
+ TC_LOG(ERROR) << "Couldn't parse merged model.";
+ return false;
+ }
MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
selection_model_length);
@@ -245,14 +255,14 @@
CreateFeatureVectorFn(network, embedding_size),
embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
&click_pos, &cached_features)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
VectorSpan<float> features;
VectorSpan<Token> output_tokens;
if (!cached_features->Get(click_pos, &features, &output_tokens)) {
- TC_LOG(ERROR) << "Could not extract features.";
+ TC_VLOG(1) << "Could not extract features.";
return {};
}
@@ -277,9 +287,9 @@
}
if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
- TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
- << std::get<0>(click_indices) << " "
- << std::get<1>(click_indices);
+ TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices:"
+ << std::get<0>(click_indices) << " "
+ << std::get<1>(click_indices);
return click_indices;
}
@@ -300,28 +310,32 @@
std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
}
- if (selection_options_.strip_punctuation()) {
- result = StripPunctuation(result, context);
- }
-
return result;
}
namespace {
-std::pair<CodepointSpan, float> BestSelectionSpan(
- CodepointSpan original_click_indices, const std::vector<float>& scores,
- const std::vector<CodepointSpan>& selection_label_spans) {
+int BestPrediction(const std::vector<float>& scores) {
if (!scores.empty()) {
const int prediction =
std::max_element(scores.begin(), scores.end()) - scores.begin();
+ return prediction;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+std::pair<CodepointSpan, float> BestSelectionSpan(
+ CodepointSpan original_click_indices, const std::vector<float>& scores,
+ const std::vector<CodepointSpan>& selection_label_spans) {
+ const int prediction = BestPrediction(scores);
+ if (prediction != kInvalidLabel) {
std::pair<CodepointIndex, CodepointIndex> selection =
selection_label_spans[prediction];
if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
- TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
- << prediction << " " << selection.first << " "
- << selection.second;
+ TC_VLOG(1) << "Invalid indices predicted, returning input: " << prediction
+ << " " << selection.first << " " << selection.second;
return {original_click_indices, -1.0};
}
@@ -367,86 +381,17 @@
CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
const std::string& context, CodepointSpan click_indices) const {
const int symmetry_context_size = selection_options_.symmetry_context_size();
- std::vector<Token> tokens;
- std::unique_ptr<CachedFeatures> cached_features;
- int click_index;
- int embedding_size = selection_network_->EmbeddingSize(0);
- if (!selection_feature_processor_->ExtractFeatures(
- context, click_indices, /*relative_click_span=*/
- {symmetry_context_size, symmetry_context_size + 1},
- selection_feature_fn_,
- embedding_size + selection_feature_processor_->DenseFeaturesCount(),
- &tokens, &click_index, &cached_features)) {
- TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
- return click_indices;
- }
-
- // Scan in the symmetry context for selection span proposals.
- std::vector<std::pair<CodepointSpan, float>> proposals;
-
- for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
- const int token_index = click_index + i;
- if (token_index >= 0 && token_index < tokens.size() &&
- !tokens[token_index].is_padding) {
- float score;
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
-
- CodepointSpan span;
- if (cached_features->Get(token_index, &features, &output_tokens)) {
- std::vector<float> scores;
- selection_network_->ComputeLogits(features, &scores);
-
- std::vector<CodepointSpan> selection_label_spans;
- if (selection_feature_processor_->SelectionLabelSpans(
- output_tokens, &selection_label_spans)) {
- scores = nlp_core::ComputeSoftmax(scores);
- std::tie(span, score) =
- BestSelectionSpan(click_indices, scores, selection_label_spans);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
- score >= 0) {
- proposals.push_back({span, score});
- }
- }
- }
+ std::vector<CodepointSpan> chunks = Chunk(
+ context, click_indices, {symmetry_context_size, symmetry_context_size});
+ for (const CodepointSpan& chunk : chunks) {
+ // If chunk and click indices have an overlap, return the chunk.
+ if (!(click_indices.first >= chunk.second ||
+ click_indices.second <= chunk.first)) {
+ return chunk;
}
}
- // Sort selection span proposals by their respective probabilities.
- std::sort(
- proposals.begin(), proposals.end(),
- [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
- return a.second > b.second;
- });
-
- // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
- // claimed by the higher-scoring selection proposals, so that the
- // lower-scoring ones cannot use them. Returns the selection proposal if it
- // contains the clicked token.
- std::vector<int> used_tokens(tokens.size(), 0);
- for (auto span_result : proposals) {
- TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
- bool feasible = true;
- for (int i = span.first; i < span.second; i++) {
- if (used_tokens[i] != 0) {
- feasible = false;
- break;
- }
- }
-
- if (feasible) {
- if (span.first <= click_index && span.second > click_index) {
- return {span_result.first.first, span_result.first.second};
- }
- for (int i = span.first; i < span.second; i++) {
- used_tokens[i] = 1;
- }
- }
- }
- }
-
- return {click_indices.first, click_indices.second};
+ return click_indices;
}
std::vector<std::pair<std::string, float>>
@@ -459,9 +404,9 @@
}
if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
- TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
+ TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
return {};
}
@@ -475,21 +420,29 @@
return {{kEmailHintCollection, 1.0}};
}
+ // Check whether any of the regular expressions match.
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
+ if (MatchesRegex(regex_pattern.pattern.get(), context)) {
+ return {{regex_pattern.collection_name, 1.0}};
+ }
+ }
+#endif
+
EmbeddingNetwork::Vector scores =
InferInternal(context, selection_indices, *sharing_feature_processor_,
*sharing_network_, sharing_feature_fn_, nullptr);
if (scores.empty() ||
scores.size() != sharing_feature_processor_->NumCollections()) {
- TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
+ TC_VLOG(1) << "Using default class: scores.size() = " << scores.size();
return {};
}
scores = nlp_core::ComputeSoftmax(scores);
- std::vector<std::pair<std::string, float>> result;
+ std::vector<std::pair<std::string, float>> result(scores.size());
for (int i = 0; i < scores.size(); i++) {
- result.push_back(
- {sharing_feature_processor_->LabelToCollection(i), scores[i]});
+ result[i] = {sharing_feature_processor_->LabelToCollection(i), scores[i]};
}
std::sort(result.begin(), result.end(),
[](const std::pair<std::string, float>& a,
@@ -509,4 +462,147 @@
return result;
}
+std::vector<CodepointSpan> TextClassificationModel::Chunk(
+ const std::string& context, CodepointSpan click_span,
+ TokenSpan relative_click_span) const {
+ std::unique_ptr<CachedFeatures> cached_features;
+ std::vector<Token> tokens;
+ int click_index;
+
+ int embedding_size = selection_network_->EmbeddingSize(0);
+ // TODO(zilka): Refactor the ExtractFeatures API to smoothly support the
+ // different usecases. Now it's a lot click-centric.
+ if (!selection_feature_processor_->ExtractFeatures(
+ context, click_span, relative_click_span, selection_feature_fn_,
+ embedding_size + selection_feature_processor_->DenseFeaturesCount(),
+ &tokens, &click_index, &cached_features)) {
+ TC_VLOG(1) << "Couldn't ExtractFeatures.";
+ return {};
+ }
+
+ if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ relative_click_span = {tokens.size() - 1, tokens.size() - 1};
+ }
+
+ struct SelectionProposal {
+ int label;
+ int token_index;
+ CodepointSpan span;
+ float score;
+ };
+
+ // Scan in the symmetry context for selection span proposals.
+ std::vector<SelectionProposal> proposals;
+
+ for (int i = -relative_click_span.first; i < relative_click_span.second + 1;
+ ++i) {
+ const int token_index = click_index + i;
+ if (token_index >= 0 && token_index < tokens.size() &&
+ !tokens[token_index].is_padding) {
+ float score;
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+
+ if (tokens[token_index].is_padding) {
+ continue;
+ }
+
+ std::vector<CodepointSpan> selection_label_spans;
+ CodepointSpan span;
+ if (cached_features->Get(token_index, &features, &output_tokens) &&
+ selection_feature_processor_->SelectionLabelSpans(
+ output_tokens, &selection_label_spans)) {
+ // Add an implicit proposal for each token to be by itself. Every
+ // token should be now represented in the results.
+ proposals.push_back(
+ SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
+
+ std::vector<float> scores;
+ selection_network_->ComputeLogits(features, &scores);
+
+ scores = nlp_core::ComputeSoftmax(scores);
+ std::tie(span, score) = BestSelectionSpan(
+ {kInvalidIndex, kInvalidIndex}, scores, selection_label_spans);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+ score >= 0) {
+ const int prediction = BestPrediction(scores);
+ proposals.push_back(
+ SelectionProposal{prediction, token_index, span, score});
+ }
+ } else {
+ // Add an implicit proposal for each token to be by itself. Every token
+ // should be now represented in the results.
+ proposals.push_back(SelectionProposal{
+ 0,
+ token_index,
+ {tokens[token_index].start, tokens[token_index].end},
+ 0.0});
+ }
+ }
+ }
+
+ // Sort selection span proposals by their respective probabilities.
+ std::sort(proposals.begin(), proposals.end(),
+ [](const SelectionProposal& a, const SelectionProposal& b) {
+ return a.score > b.score;
+ });
+
+ // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
+ // claimed by the higher-scoring selection proposals, so that the
+ // lower-scoring ones cannot use them. Returns the selection proposal if it
+ // contains the clicked token.
+ std::vector<CodepointSpan> result;
+ std::vector<bool> token_used(tokens.size(), false);
+ for (const SelectionProposal& proposal : proposals) {
+ const int predicted_label = proposal.label;
+ TokenSpan relative_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(predicted_label,
+ &relative_span)) {
+ continue;
+ }
+ TokenSpan span;
+ span.first = proposal.token_index - relative_span.first;
+ span.second = proposal.token_index + relative_span.second + 1;
+
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
+ bool feasible = true;
+ for (int i = span.first; i < span.second; i++) {
+ if (token_used[i]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (feasible) {
+ result.push_back(proposal.span);
+ for (int i = span.first; i < span.second; i++) {
+ token_used[i] = true;
+ }
+ }
+ }
+ }
+
+ std::sort(result.begin(), result.end(),
+ [](const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.first;
+ });
+
+ return result;
+}
+
+std::vector<TextClassificationModel::AnnotatedSpan>
+TextClassificationModel::Annotate(const std::string& context) const {
+ std::vector<CodepointSpan> chunks =
+ Chunk(context, /*click_span=*/{0, 1},
+ /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
+
+ std::vector<TextClassificationModel::AnnotatedSpan> result;
+ for (const CodepointSpan& chunk : chunks) {
+ result.emplace_back();
+ result.back().span = chunk;
+ result.back().classification = ClassifyText(context, chunk);
+ }
+ return result;
+}
+
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 522372c..5b58d89 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -23,7 +23,6 @@
#include <set>
#include <string>
-#include "base.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
@@ -38,10 +37,33 @@
// SmartSelection/Sharing feed-forward model.
class TextClassificationModel {
public:
+ // Represents a result of Annotate call.
+ struct AnnotatedSpan {
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<std::pair<std::string, float>> classification;
+ };
+
// Loads TextClassificationModel from given file given by an int
// file descriptor.
+ // Offset is byte a position in the file to the beginning of the model data.
+ TextClassificationModel(int fd, int offset, int size);
+
+ // Same as above but the whole file is mapped and it is assumed the model
+ // starts at offset 0.
explicit TextClassificationModel(int fd);
+ // Loads TextClassificationModel from given file.
+ explicit TextClassificationModel(const std::string& path);
+
+ // Loads TextClassificationModel from given location in memory.
+ TextClassificationModel(const void* addr, int size);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
// Bit flags for the input selection.
enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
@@ -63,11 +85,26 @@
const std::string& context, CodepointSpan click_indices,
int input_flags = 0) const;
+ // Annotates given input text. The annotations should cover the whole input
+ // context except for whitespaces, and are sorted by their position in the
+ // context string.
+ std::vector<AnnotatedSpan> Annotate(const std::string& context) const;
+
protected:
- // Removes punctuation from the beginning and end of the selection and returns
- // the new selection span.
- CodepointSpan StripPunctuation(CodepointSpan selection,
- const std::string& context) const;
+ // Initializes the model from mmap_ file.
+ void InitFromMmap();
+
+ // Extracts chunks from the context. The extraction proceeds from the center
+ // token determined by click_span and looks at relative_click_span tokens
+ // left and right around the click position.
+ // If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole
+ // context is considered, regardless of the click_span (which should point to
+ // the beginning {0, 1}.
+ // Returns the chunks sorted by their position in the context string.
+ // TODO(zilka): Tidy up the interface.
+ std::vector<CodepointSpan> Chunk(const std::string& context,
+ CodepointSpan click_span,
+ TokenSpan relative_click_span) const;
// During evaluation we need access to the feature processor.
FeatureProcessor* SelectionFeatureProcessor() const {
@@ -90,7 +127,14 @@
SharingModelOptions sharing_options_;
private:
- bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ struct CompiledRegexPattern {
+ std::string collection_name;
+ std::unique_ptr<icu::RegexPattern> pattern;
+ };
+#endif
+
+ bool LoadModels(const void* addr, int size);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
@@ -108,8 +152,8 @@
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
CodepointSpan click_indices) const;
- bool initialized_;
- nlp_core::ScopedMmap mmap_;
+ bool initialized_ = false;
+ std::unique_ptr<nlp_core::ScopedMmap> mmap_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
@@ -118,14 +162,28 @@
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
FeatureVectorFn sharing_feature_fn_;
-
- std::set<int> punctuation_to_strip_;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ std::vector<CompiledRegexPattern> regex_patterns_;
+#endif
};
// Parses the merged image given as a file descriptor, and reads
// the ModelOptions proto from the selection model.
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
+// Pretty-printing function for TextClassificationModel::AnnotatedSpan.
+inline std::ostream& operator<<(
+ std::ostream& os, const TextClassificationModel::AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].first;
+ best_score = span.classification[0].second;
+ }
+ return os << "Span(" << span.span.first << ", " << span.span.second << ", "
+ << best_class << ", " << best_score << ")";
+}
+
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index b5b0287..ca10a0e 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -37,10 +37,7 @@
message SelectionModelOptions {
// A list of Unicode codepoints to strip from predicted selections.
- repeated int32 punctuation_to_strip = 1;
-
- // Whether to strip punctuation after the selection is made.
- optional bool strip_punctuation = 2;
+ repeated int32 deprecated_punctuation_to_strip = 1;
// Enforce symmetrical selections.
optional bool enforce_symmetry = 3;
@@ -48,6 +45,8 @@
// Number of inferences made around the click position (to one side), for
// enforcing symmetry.
optional int32 symmetry_context_size = 4;
+
+ reserved 2;
}
message SharingModelOptions {
@@ -60,8 +59,19 @@
// Limits for phone numbers.
optional int32 phone_min_num_digits = 3 [default = 7];
optional int32 phone_max_num_digits = 4 [default = 15];
+
+ // List of regular expression matchers to check.
+ message RegexPattern {
+ // The name of the collection of a match.
+ optional string collection_name = 1;
+
+ // The pattern to check.
+ optional string pattern = 2;
+ }
+ repeated RegexPattern regex_pattern = 5;
}
+// Next ID: 39
message FeatureProcessorOptions {
// Number of buckets used for hashing charactergrams.
optional int32 num_buckets = 1 [default = -1];
@@ -193,7 +203,18 @@
[default = INTERNAL_TOKENIZER];
optional bool icu_preserve_whitespace_tokens = 31 [default = false];
- reserved 7, 11, 12, 17, 26, 27, 28, 29, 32;
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ repeated int32 ignored_span_boundary_codepoints = 36;
+
+ reserved 7, 11, 12, 26, 27, 28, 29, 32, 35;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ repeated bytes allowed_chargrams = 38;
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/smartselect/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc
new file mode 100644
index 0000000..490b395
--- /dev/null
+++ b/smartselect/text-classification-model_test.cc
@@ -0,0 +1,347 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/text-classification-model.h"
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+std::string GetModelPath() {
+ return TEST_DATA_DIR "smartselection.model";
+}
+
+TEST(TextClassificationModelTest, ReadModelOptions) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ ModelOptions model_options;
+ ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
+ close(fd);
+
+ EXPECT_EQ("en", model_options.language());
+ EXPECT_GT(model_options.version(), 0);
+}
+
+TEST(TextClassificationModelTest, SuggestSelection) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TextClassificationModel> model(
+ new TextClassificationModel(fd));
+ close(fd);
+
+ EXPECT_EQ(model->SuggestSelection(
+ "this afternoon Barack Obama gave a speech at", {15, 21}),
+ std::make_pair(15, 27));
+
+ // Try passing whole string.
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+ std::make_pair(0, 27));
+
+ // Single letter.
+ EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1}));
+
+ // Single word.
+ EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
+}
+
+TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TextClassificationModel> model(
+ new TextClassificationModel(fd));
+ close(fd);
+
+ EXPECT_EQ(std::make_pair(0, 27),
+ model->SuggestSelection("350 Third Street, Cambridge", {0, 3}));
+ EXPECT_EQ(std::make_pair(0, 27),
+ model->SuggestSelection("350 Third Street, Cambridge", {4, 9}));
+ EXPECT_EQ(std::make_pair(0, 27),
+ model->SuggestSelection("350 Third Street, Cambridge", {10, 16}));
+ EXPECT_EQ(std::make_pair(6, 33),
+ model->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
+ {16, 22}));
+}
+
+TEST(TextClassificationModelTest, SuggestSelectionWithNewLine) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TextClassificationModel> model(
+ new TextClassificationModel(fd));
+ close(fd);
+
+ std::tuple<int, int> selection;
+ selection = model->SuggestSelection("abc\nBarack Obama", {4, 10});
+ EXPECT_EQ(4, std::get<0>(selection));
+ EXPECT_EQ(16, std::get<1>(selection));
+
+ selection = model->SuggestSelection("Barack Obama\nabc", {0, 6});
+ EXPECT_EQ(0, std::get<0>(selection));
+ EXPECT_EQ(12, std::get<1>(selection));
+}
+
+TEST(TextClassificationModelTest, SuggestSelectionWithPunctuation) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TextClassificationModel> model(
+ new TextClassificationModel(fd));
+ close(fd);
+
+ std::tuple<int, int> selection;
+
+ // From the right.
+ selection = model->SuggestSelection(
+ "this afternoon Barack Obama, gave a speech at", {15, 21});
+ EXPECT_EQ(15, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ // From the right multiple.
+ selection = model->SuggestSelection(
+ "this afternoon Barack Obama,.,.,, gave a speech at", {15, 21});
+ EXPECT_EQ(15, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ // From the left multiple.
+ selection = model->SuggestSelection(
+ "this afternoon ,.,.,,Barack Obama gave a speech at", {21, 27});
+ EXPECT_EQ(21, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ // From both sides.
+ selection = model->SuggestSelection(
+ "this afternoon !Barack Obama,- gave a speech at", {16, 22});
+ EXPECT_EQ(16, std::get<0>(selection));
+ EXPECT_EQ(28, std::get<1>(selection));
+}
+
+class TestingTextClassificationModel
+ : public libtextclassifier::TextClassificationModel {
+ public:
+ explicit TestingTextClassificationModel(int fd)
+ : libtextclassifier::TextClassificationModel(fd) {}
+
+ void DisableClassificationHints() {
+ sharing_options_.set_always_accept_url_hint(false);
+ sharing_options_.set_always_accept_email_hint(false);
+ }
+};
+
+TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TextClassificationModel> ff_model(
+ new TextClassificationModel(fd));
+ close(fd);
+
+ std::tuple<int, int> selection;
+
+ // Try passing in bunch of invalid selections.
+ selection = ff_model->SuggestSelection("", {0, 27});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(0, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ selection = ff_model->SuggestSelection("", {-10, 27});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(-10, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {0, 27});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(0, std::get<0>(selection));
+ EXPECT_EQ(27, std::get<1>(selection));
+
+ selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-30, 300});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(-30, std::get<0>(selection));
+ EXPECT_EQ(300, std::get<1>(selection));
+
+ selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-10, -1});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(-10, std::get<0>(selection));
+ EXPECT_EQ(-1, std::get<1>(selection));
+
+ selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {100, 17});
+ // If more than 1 token is specified, we should return back what entered.
+ EXPECT_EQ(100, std::get<0>(selection));
+ EXPECT_EQ(17, std::get<1>(selection));
+}
+
+namespace {
+
+std::string FindBestResult(std::vector<std::pair<std::string, float>> results) {
+ if (results.empty()) {
+ return "<INVALID RESULTS>";
+ }
+
+ std::sort(results.begin(), results.end(),
+ [](const std::pair<std::string, float> a,
+ const std::pair<std::string, float> b) {
+ return a.second > b.second;
+ });
+ return results[0].first;
+}
+
+} // namespace
+
+TEST(TextClassificationModelTest, ClassifyText) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ model->DisableClassificationHints();
+ EXPECT_EQ("other",
+ FindBestResult(model->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27})));
+ EXPECT_EQ("other",
+ FindBestResult(model->ClassifyText("you@android.com", {0, 15})));
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText(
+ "Contact me at you@android.com", {14, 29})));
+ EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+
+ // More lines.
+ EXPECT_EQ("other",
+ FindBestResult(model->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {15, 27})));
+ EXPECT_EQ("other",
+ FindBestResult(model->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {51, 65})));
+ EXPECT_EQ("phone",
+ FindBestResult(model->ClassifyText(
+ "this afternoon Barack Obama gave a speech at|Visit "
+ "www.google.com every today!|Call me at (800) 123-456 today.",
+ {90, 103})));
+
+ // Single word.
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5})));
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
+ EXPECT_EQ("<INVALID RESULTS>",
+ FindBestResult(model->ClassifyText("asdf", {0, 0})));
+
+ // Junk.
+ EXPECT_EQ("<INVALID RESULTS>",
+ FindBestResult(model->ClassifyText("", {0, 0})));
+ EXPECT_EQ("<INVALID RESULTS>", FindBestResult(model->ClassifyText(
+ "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
+}
+
+TEST(TextClassificationModelTest, ClassifyTextWithHints) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ // When EMAIL hint is passed, the result should be email.
+ EXPECT_EQ("email",
+ FindBestResult(model->ClassifyText(
+ "x", {0, 1}, TextClassificationModel::SELECTION_IS_EMAIL)));
+ // When URL hint is passed, the result should be email.
+ EXPECT_EQ("url",
+ FindBestResult(model->ClassifyText(
+ "x", {0, 1}, TextClassificationModel::SELECTION_IS_URL)));
+ // When both hints are passed, the result should be url (as it's probably
+ // better to let Chrome handle this case).
+ EXPECT_EQ("url", FindBestResult(model->ClassifyText(
+ "x", {0, 1},
+ TextClassificationModel::SELECTION_IS_EMAIL |
+ TextClassificationModel::SELECTION_IS_URL)));
+
+ // With disabled hints, we should get the same prediction regardless of the
+ // hint.
+ model->DisableClassificationHints();
+ EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
+ model->ClassifyText("x", {0, 1},
+ TextClassificationModel::SELECTION_IS_EMAIL));
+
+ EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
+ model->ClassifyText("x", {0, 1},
+ TextClassificationModel::SELECTION_IS_URL));
+}
+
+TEST(TextClassificationModelTest, PhoneFiltering) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ EXPECT_EQ("phone", FindBestResult(model->ClassifyText("phone: (123) 456 789",
+ {7, 20}, 0)));
+ EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 25}, 0)));
+ EXPECT_EQ("other", FindBestResult(model->ClassifyText(
+ "phone: (123) 456 789,0001112", {7, 28}, 0)));
+}
+
+TEST(TextClassificationModelTest, Annotate) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ std::string test_string =
+ "I saw Barak Obama today at 350 Third Street, Cambridge";
+ std::vector<TextClassificationModel::AnnotatedSpan> result =
+ model->Annotate(test_string);
+
+ std::vector<TextClassificationModel::AnnotatedSpan> expected;
+ expected.emplace_back();
+ expected.back().span = {0, 1};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {2, 5};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {6, 17};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {18, 23};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {24, 26};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {27, 54};
+ expected.back().classification.push_back({"address", 1.0});
+
+ ASSERT_EQ(result.size(), expected.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ EXPECT_EQ(result[i].span, expected[i].span) << result[i];
+ EXPECT_EQ(result[i].classification[0].first,
+ expected[i].classification[0].first)
+ << result[i];
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 479be41..6afd951 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -16,14 +16,17 @@
#include "smartselect/token-feature-extractor.h"
+#include <cctype>
#include <string>
#include "util/base/logging.h"
#include "util/hash/farmhash.h"
#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
#include "unicode/uchar.h"
+#endif
namespace libtextclassifier {
@@ -47,6 +50,7 @@
return copy;
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
void RemapTokenUnicode(const std::string& token,
const TokenFeatureExtractorOptions& options,
UnicodeText* remapped) {
@@ -70,12 +74,14 @@
icu_string.toUTF8String(utf8_str);
remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
}
+#endif
} // namespace
TokenFeatureExtractor::TokenFeatureExtractor(
const TokenFeatureExtractorOptions& options)
: options_(options) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
UErrorCode status;
for (const std::string& pattern : options.regexp_features) {
status = U_ZERO_ERROR;
@@ -87,10 +93,44 @@
TC_LOG(WARNING) << "Failed to load pattern" << pattern;
}
}
+#else
+ bool found_unsupported_regexp_features = false;
+ for (const std::string& pattern : options.regexp_features) {
+ // A temporary solution to support this specific regexp pattern without
+ // adding too much binary size.
+ if (pattern == "^[^a-z]*$") {
+ enable_all_caps_feature_ = true;
+ } else {
+ found_unsupported_regexp_features = true;
+ }
+ }
+ if (found_unsupported_regexp_features) {
+ TC_LOG(WARNING) << "ICU not supported regexp features ignored.";
+ }
+#endif
}
int TokenFeatureExtractor::HashToken(StringPiece token) const {
- return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ if (options_.allowed_chargrams.empty()) {
+ return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+ } else {
+ // Padding and out-of-vocabulary tokens have extra buckets reserved because
+ // they are special and important tokens, and we don't want them to share
+ // embedding with other charactergrams.
+ // TODO(zilka): Experimentally verify.
+ const int kNumExtraBuckets = 2;
+ const std::string token_string = token.ToString();
+ if (token_string == "<PAD>") {
+ return 1;
+ } else if (options_.allowed_chargrams.find(token_string) ==
+ options_.allowed_chargrams.end()) {
+ return 0; // Out-of-vocabulary.
+ } else {
+ return (tcfarmhash::Fingerprint64(token) %
+ (options_.num_buckets - kNumExtraBuckets)) +
+ kNumExtraBuckets;
+ }
+ }
}
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
@@ -126,19 +166,23 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- if (chargram_order == 1) {
- for (int i = 1; i < feature_word.size() - 1; ++i) {
- result.push_back(
- HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
- }
- } else {
- for (int i = 0;
- i < static_cast<int>(feature_word.size()) - chargram_order + 1;
- ++i) {
- result.push_back(HashToken(
- StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(
+ HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
+ /*len=*/chargram_order)));
+ }
}
}
}
@@ -148,6 +192,7 @@
std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
const Token& token) const {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<int> result;
if (token.is_padding || token.value.empty()) {
result.push_back(HashToken("<PAD>"));
@@ -186,39 +231,47 @@
// Upper-bound the number of charactergram extracted to avoid resizing.
result.reserve(options_.chargram_orders.size() * feature_word.size());
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- UnicodeText::const_iterator it_start = feature_word_unicode.begin();
- UnicodeText::const_iterator it_end = feature_word_unicode.end();
- if (chargram_order == 1) {
- ++it_start;
- --it_end;
- }
-
- UnicodeText::const_iterator it_chargram_start = it_start;
- UnicodeText::const_iterator it_chargram_end = it_start;
- bool chargram_is_complete = true;
- for (int i = 0; i < chargram_order; ++i) {
- if (it_chargram_end == it_end) {
- chargram_is_complete = false;
- break;
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ UnicodeText::const_iterator it_start = feature_word_unicode.begin();
+ UnicodeText::const_iterator it_end = feature_word_unicode.end();
+ if (chargram_order == 1) {
+ ++it_start;
+ --it_end;
}
- ++it_chargram_end;
- }
- if (!chargram_is_complete) {
- continue;
- }
- for (; it_chargram_end <= it_end;
- ++it_chargram_start, ++it_chargram_end) {
- const int length_bytes =
- it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
- result.push_back(HashToken(
- StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ UnicodeText::const_iterator it_chargram_start = it_start;
+ UnicodeText::const_iterator it_chargram_end = it_start;
+ bool chargram_is_complete = true;
+ for (int i = 0; i < chargram_order; ++i) {
+ if (it_chargram_end == it_end) {
+ chargram_is_complete = false;
+ break;
+ }
+ ++it_chargram_end;
+ }
+ if (!chargram_is_complete) {
+ continue;
+ }
+
+ for (; it_chargram_end <= it_end;
+ ++it_chargram_start, ++it_chargram_end) {
+ const int length_bytes =
+ it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
+ result.push_back(HashToken(
+ StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ }
}
}
}
return result;
+#else
+ TC_LOG(WARNING) << "ICU not supported. No feature extracted.";
+ return {};
+#endif
}
bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
@@ -234,7 +287,14 @@
if (options_.unicode_aware_features) {
UnicodeText token_unicode =
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
+ bool is_upper;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ is_upper = u_isupper(*token_unicode.begin());
+#else
+ TC_LOG(WARNING) << "Using non-unicode isupper because ICU is disabled.";
+ is_upper = isupper(*token_unicode.begin());
+#endif
+ if (!token.value.empty() && is_upper) {
dense_features->push_back(1.0);
} else {
dense_features->push_back(-1.0);
@@ -260,6 +320,7 @@
}
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
// Add regexp features.
if (!regex_patterns_.empty()) {
icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
@@ -281,6 +342,23 @@
}
}
}
+#else
+ if (enable_all_caps_feature_) {
+ bool is_all_caps = true;
+ for (const char character_byte : token.value) {
+ if (islower(character_byte)) {
+ is_all_caps = false;
+ break;
+ }
+ }
+ if (is_all_caps) {
+ dense_features->push_back(1.0);
+ } else {
+ dense_features->push_back(-1.0);
+ }
+ }
+#endif
+
return true;
}
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index 8287fbd..5afeca4 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -18,12 +18,14 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
#include <memory>
+#include <unordered_set>
#include <vector>
-#include "base.h"
#include "smartselect/types.h"
#include "util/strings/stringpiece.h"
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
#include "unicode/regex.h"
+#endif
namespace libtextclassifier {
@@ -55,6 +57,12 @@
// Maximum length of a word.
int max_word_length = 20;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ std::unordered_set<std::string> allowed_chargrams;
};
class TokenFeatureExtractor {
@@ -73,8 +81,16 @@
std::vector<float>* dense_features) const;
int DenseFeaturesCount() const {
- return options_.extract_case_feature +
- options_.extract_selection_mask_feature + regex_patterns_.size();
+ int feature_count =
+ options_.extract_case_feature + options_.extract_selection_mask_feature;
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ feature_count += regex_patterns_.size();
+#else
+ if (enable_all_caps_feature_) {
+ feature_count += 1;
+ }
+#endif
+ return feature_count;
}
protected:
@@ -94,8 +110,11 @@
private:
TokenFeatureExtractorOptions options_;
-
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
std::vector<std::unique_ptr<icu::RegexPattern>> regex_patterns_;
+#else
+ bool enable_all_caps_feature_ = false;
+#endif
};
} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor_test.cc b/smartselect/token-feature-extractor_test.cc
new file mode 100644
index 0000000..4b635fd
--- /dev/null
+++ b/smartselect/token-feature-extractor_test.cc
@@ -0,0 +1,543 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/token-feature-extractor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
+ public:
+ using TokenFeatureExtractor::TokenFeatureExtractor;
+ using TokenFeatureExtractor::HashToken;
+};
+
+TEST(TokenFeatureExtractorTest, ExtractAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("e"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("o"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("He"),
+ extractor.HashToken("el"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("lo"),
+ extractor.HashToken("o$"),
+ extractor.HashToken("^He"),
+ extractor.HashToken("Hel"),
+ extractor.HashToken("ell"),
+ extractor.HashToken("llo"),
+ extractor.HashToken("lo$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^world!$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("ě"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("ó"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("Hě"),
+ extractor.HashToken("ěl"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("ló"),
+ extractor.HashToken("ó$"),
+ extractor.HashToken("^Hě"),
+ extractor.HashToken("Hěl"),
+ extractor.HashToken("ěll"),
+ extractor.HashToken("lló"),
+ extractor.HashToken("ló$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ extractor.HashToken("^world!$"),
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = false;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, DigitRemapping) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = false;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = true;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST(TokenFeatureExtractorTest, LowercaseAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = false;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
+TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = true;
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
+TEST(TokenFeatureExtractorTest, RegexFeatures) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = false;
+ options.unicode_aware_features = false;
+ options.regexp_features.push_back("^[a-z]+$"); // all lower case.
+ options.regexp_features.push_back("^[0-9]+$"); // all digits.
+ TokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{22};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options);
+
+ // Test that this runs. ASAN should catch problems.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
+ &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
+ extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
+ // clang-format on
+ }));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor_unicode(options);
+
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor_ascii(options);
+
+ for (const std::string& input :
+ {"https://www.abcdefgh.com/in/xxxkkkvayio",
+ "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
+ "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
+ "x", "Hello", "Hey,", "Hi", ""}) {
+ std::vector<int> sparse_features_unicode;
+ std::vector<float> dense_features_unicode;
+ extractor_unicode.Extract(Token{input, 0, 0}, true,
+ &sparse_features_unicode,
+ &dense_features_unicode);
+
+ std::vector<int> sparse_features_ascii;
+ std::vector<float> dense_features_ascii;
+ extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
+ &dense_features_ascii);
+
+ EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
+ EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
+ }
+}
+
+TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token(), false, &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST(TokenFeatureExtractorTest, ExtractFiltered) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ options.allowed_chargrams.insert("^H");
+ options.allowed_chargrams.insert("ll");
+ options.allowed_chargrams.insert("llo");
+ options.allowed_chargrams.insert("w");
+ options.allowed_chargrams.insert("!");
+ options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
+
+ TestingTokenFeatureExtractor extractor(options);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ 0,
+ extractor.HashToken("\xc4"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("^H"),
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("ll"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("llo"),
+ 0
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("!"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+ EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
+}
+
+} // namespace
+} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
index 2093fde..2489a61 100644
--- a/smartselect/tokenizer.cc
+++ b/smartselect/tokenizer.cc
@@ -16,49 +16,42 @@
#include "smartselect/tokenizer.h"
+#include <algorithm>
+
#include "util/strings/utf8.h"
#include "util/utf8/unicodetext.h"
namespace libtextclassifier {
-void Tokenizer::PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- codepoint_ranges_.clear();
- codepoint_ranges_.reserve(codepoint_range_configs.size());
- for (const TokenizationCodepointRange& range : codepoint_range_configs) {
- codepoint_ranges_.push_back(
- CodepointRange(range.start(), range.end(), range.role()));
- }
-
+Tokenizer::Tokenizer(
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges)
+ : codepoint_ranges_(codepoint_ranges) {
std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const CodepointRange& a, const CodepointRange& b) {
- return a.start < b.start;
+ [](const TokenizationCodepointRange& a,
+ const TokenizationCodepointRange& b) {
+ return a.start() < b.start();
});
}
TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
int codepoint) const {
- auto it = std::lower_bound(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- codepoint,
- [](const CodepointRange& range, int codepoint) {
- // This function compares range with the
- // codepoint for the purpose of finding the first
- // greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when
- // range < codepoint; the first time it will
- // return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is
- // range.end <= codepoint here but when codepoint
- // == range.end it means it's actually just
- // outside of the range, thus the range is less
- // than the codepoint.
- return range.end <= codepoint;
- });
- if (it != codepoint_ranges_.end() && it->start <= codepoint &&
- it->end > codepoint) {
- return it->role;
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const TokenizationCodepointRange& range, int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range.end() <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && it->start() <= codepoint &&
+ it->end() > codepoint) {
+ return it->role();
} else {
return TokenizationCodepointRange::DEFAULT_ROLE;
}
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
index 897f7c4..4eb78f9 100644
--- a/smartselect/tokenizer.h
+++ b/smartselect/tokenizer.h
@@ -22,7 +22,6 @@
#include "smartselect/tokenizer.pb.h"
#include "smartselect/types.h"
-#include "util/base/integral_types.h"
namespace libtextclassifier {
@@ -31,29 +30,12 @@
class Tokenizer {
public:
explicit Tokenizer(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
- PrepareTokenizationCodepointRanges(codepoint_range_configs);
- }
+ const std::vector<TokenizationCodepointRange>& codepoint_ranges);
// Tokenizes the input string using the selected tokenization method.
std::vector<Token> Tokenize(const std::string& utf8_text) const;
protected:
- // Represents a codepoint range [start, end) with its role for tokenization.
- struct CodepointRange {
- int32 start;
- int32 end;
- TokenizationCodepointRange::Role role;
-
- CodepointRange(int32 arg_start, int32 arg_end,
- TokenizationCodepointRange::Role arg_role)
- : start(arg_start), end(arg_end), role(arg_role) {}
- };
-
- // Prepares tokenization codepoint ranges for use in tokenization.
- void PrepareTokenizationCodepointRanges(
- const std::vector<TokenizationCodepointRange>& codepoint_range_configs);
-
// Finds the tokenization role for given codepoint.
// If the character is not found returns DEFAULT_ROLE.
// Internally uses binary search so should be O(log(# of codepoint_ranges)).
@@ -62,7 +44,7 @@
private:
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- std::vector<CodepointRange> codepoint_ranges_;
+ std::vector<TokenizationCodepointRange> codepoint_ranges_;
};
} // namespace libtextclassifier
diff --git a/smartselect/tokenizer_test.cc b/smartselect/tokenizer_test.cc
new file mode 100644
index 0000000..cdb90a9
--- /dev/null
+++ b/smartselect/tokenizer_test.cc
@@ -0,0 +1,261 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/tokenizer.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingTokenizer : public Tokenizer {
+ public:
+ explicit TestingTokenizer(
+ const std::vector<TokenizationCodepointRange>& codepoint_range_configs)
+ : Tokenizer(codepoint_range_configs) {}
+
+ TokenizationCodepointRange::Role TestFindTokenizationRole(int c) const {
+ return FindTokenizationRole(c);
+ }
+};
+
+TEST(TokenizerTest, FindTokenizationRole) {
+ std::vector<TokenizationCodepointRange> configs;
+ TokenizationCodepointRange* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0);
+ config->set_end(10);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(1234);
+ config->set_end(12345);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+
+ TestingTokenizer tokenizer(configs);
+
+ // Test hits to the first group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
+ TokenizationCodepointRange::TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
+ TokenizationCodepointRange::TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+
+ // Test a hit to the second group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
+ TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+
+ // Test hits to the third group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
+ TokenizationCodepointRange::TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
+ TokenizationCodepointRange::TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+
+ // Test a hit outside.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
+ TokenizationCodepointRange::DEFAULT_ROLE);
+}
+
+TEST(TokenizerTest, TokenizeOnSpace) {
+ std::vector<TokenizationCodepointRange> configs;
+ TokenizationCodepointRange* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ // Space character.
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+ TestingTokenizer tokenizer(configs);
+ std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
+
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
+}
+
+TEST(TokenizerTest, TokenizeComplex) {
+ std::vector<TokenizationCodepointRange> configs;
+ TokenizationCodepointRange* config;
+
+ // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
+ // Latin - cyrilic.
+ // 0000..007F; Basic Latin
+ // 0080..00FF; Latin-1 Supplement
+ // 0100..017F; Latin Extended-A
+ // 0180..024F; Latin Extended-B
+ // 0250..02AF; IPA Extensions
+ // 02B0..02FF; Spacing Modifier Letters
+ // 0300..036F; Combining Diacritical Marks
+ // 0370..03FF; Greek and Coptic
+ // 0400..04FF; Cyrillic
+ // 0500..052F; Cyrillic Supplement
+ // 0530..058F; Armenian
+ // 0590..05FF; Hebrew
+ // 0600..06FF; Arabic
+ // 0700..074F; Syriac
+ // 0750..077F; Arabic Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0);
+ config->set_end(32);
+ config->set_role(TokenizationCodepointRange::DEFAULT_ROLE);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(33);
+ config->set_end(0x77F + 1);
+ config->set_role(TokenizationCodepointRange::DEFAULT_ROLE);
+
+ // CJK
+ // 2E80..2EFF; CJK Radicals Supplement
+ // 3000..303F; CJK Symbols and Punctuation
+ // 3040..309F; Hiragana
+ // 30A0..30FF; Katakana
+ // 3100..312F; Bopomofo
+ // 3130..318F; Hangul Compatibility Jamo
+ // 3190..319F; Kanbun
+ // 31A0..31BF; Bopomofo Extended
+ // 31C0..31EF; CJK Strokes
+ // 31F0..31FF; Katakana Phonetic Extensions
+ // 3200..32FF; Enclosed CJK Letters and Months
+ // 3300..33FF; CJK Compatibility
+ // 3400..4DBF; CJK Unified Ideographs Extension A
+ // 4DC0..4DFF; Yijing Hexagram Symbols
+ // 4E00..9FFF; CJK Unified Ideographs
+ // A000..A48F; Yi Syllables
+ // A490..A4CF; Yi Radicals
+ // A4D0..A4FF; Lisu
+ // A500..A63F; Vai
+ // F900..FAFF; CJK Compatibility Ideographs
+ // FE30..FE4F; CJK Compatibility Forms
+ // 20000..2A6DF; CJK Unified Ideographs Extension B
+ // 2A700..2B73F; CJK Unified Ideographs Extension C
+ // 2B740..2B81F; CJK Unified Ideographs Extension D
+ // 2B820..2CEAF; CJK Unified Ideographs Extension E
+ // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
+ // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2E80);
+ config->set_end(0x2EFF + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x3000);
+ config->set_end(0xA63F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0xF900);
+ config->set_end(0xFAFF + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0xFE30);
+ config->set_end(0xFE4F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x20000);
+ config->set_end(0x2A6DF + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2A700);
+ config->set_end(0x2B73F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2B740);
+ config->set_end(0x2B81F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2B820);
+ config->set_end(0x2CEAF + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2CEB0);
+ config->set_end(0x2EBEF + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x2F800);
+ config->set_end(0x2FA1F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+
+ // Thai.
+ // 0E00..0E7F; Thai
+ configs.emplace_back();
+ config = &configs.back();
+ config->set_start(0x0E00);
+ config->set_end(0x0E7F + 1);
+ config->set_role(TokenizationCodepointRange::TOKEN_SEPARATOR);
+
+ Tokenizer tokenizer(configs);
+ std::vector<Token> tokens;
+
+ tokens = tokenizer.Tokenize(
+ "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
+ EXPECT_EQ(tokens.size(), 30);
+
+ tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
+ // clang-format off
+ EXPECT_THAT(
+ tokens,
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
+ // clang-format on
+}
+
+} // namespace
+} // namespace libtextclassifier