Snap for 4545621 from 065929754b11798dfb935448c238af1865d7c26f to pi-release
Change-Id: Ica7160f687f5e2a3ea082e5b95d5c1b8b9bf5ffc
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 315e2b4..7af0897 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 08f18ea..c1db95a 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -119,34 +119,14 @@
}
}
-void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
- std::vector<UnicodeTextRange>* ranges) {
- UnicodeText::const_iterator start = t.begin();
- UnicodeText::const_iterator curr = start;
- UnicodeText::const_iterator end = t.end();
- for (; curr != end; ++curr) {
- if (codepoints.find(*curr) != codepoints.end()) {
- if (start != curr) {
- ranges->push_back(std::make_pair(start, curr));
- }
- start = curr;
- ++start;
- }
- }
- if (start != end) {
- ranges->push_back(std::make_pair(start, end));
- }
-}
+} // namespace internal
-void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens) {
+void FeatureProcessor::StripTokensFromOtherLines(
+ const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const {
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
- std::vector<UnicodeTextRange> lines;
- std::set<char32> codepoints;
- codepoints.insert('\n');
- codepoints.insert('|');
- internal::FindSubstrings(context_unicode, codepoints, &lines);
+ std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
auto span_start = context_unicode.begin();
if (span.first > 0) {
@@ -176,8 +156,6 @@
}
}
-} // namespace internal
-
std::string FeatureProcessor::GetDefaultCollection() const {
if (options_.default_collection() < 0 ||
options_.default_collection() >= options_.collections_size()) {
@@ -249,8 +227,14 @@
token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true);
const int end_ignored = CountIgnoredSpanBoundaryCodepoints(
token_end_unicode.begin(), token_end, /*count_from_beginning=*/false);
- *span = CodepointSpan({result_begin_codepoint + begin_ignored,
- result_end_codepoint - end_ignored});
+ // In case everything would be stripped, set the span to the original
+ // beginning and zero length.
+ if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
+ *span = {result_begin_codepoint, result_begin_codepoint};
+ } else {
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
+ }
}
return true;
}
@@ -339,16 +323,23 @@
}
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span) {
+ CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens) {
const int codepoint_start = std::get<0>(codepoint_span);
const int codepoint_end = std::get<1>(codepoint_span);
TokenIndex start_token = kInvalidIndex;
TokenIndex end_token = kInvalidIndex;
for (int i = 0; i < selectable_tokens.size(); ++i) {
- if (codepoint_start <= selectable_tokens[i].start &&
- codepoint_end >= selectable_tokens[i].end &&
- !selectable_tokens[i].is_padding) {
+ bool is_token_in_span;
+ if (snap_boundaries_to_containing_tokens) {
+ is_token_in_span = codepoint_start < selectable_tokens[i].end &&
+ codepoint_end > selectable_tokens[i].start;
+ } else {
+ is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
+ codepoint_end >= selectable_tokens[i].end;
+ }
+ if (is_token_in_span && !selectable_tokens[i].is_padding) {
if (start_token == kInvalidIndex) {
start_token = i;
}
@@ -539,6 +530,43 @@
return num_ignored;
}
+namespace {
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+ std::vector<UnicodeTextRange>* ranges) {
+ UnicodeText::const_iterator start = t.begin();
+ UnicodeText::const_iterator curr = start;
+ UnicodeText::const_iterator end = t.end();
+ for (; curr != end; ++curr) {
+ if (codepoints.find(*curr) != codepoints.end()) {
+ if (start != curr) {
+ ranges->push_back(std::make_pair(start, curr));
+ }
+ start = curr;
+ ++start;
+ }
+ }
+ if (start != end) {
+ ranges->push_back(std::make_pair(start, end));
+ }
+}
+
+} // namespace
+
+std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
+ const UnicodeText& context_unicode) const {
+ if (options_.only_use_line_with_click()) {
+ std::vector<UnicodeTextRange> lines;
+ std::set<char32> codepoints;
+ codepoints.insert('\n');
+ codepoints.insert('|');
+ FindSubstrings(context_unicode, codepoints, &lines);
+ return lines;
+ } else {
+ return {{context_unicode.begin(), context_unicode.end()}};
+ }
+}
+
CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
const std::string& context, CodepointSpan span) const {
const UnicodeText context_unicode =
@@ -657,7 +685,7 @@
}
if (options_.only_use_line_with_click()) {
- internal::StripTokensFromOtherLines(context, input_span, tokens);
+ StripTokensFromOtherLines(context, input_span, tokens);
}
int local_click_pos;
@@ -712,17 +740,20 @@
std::unique_ptr<CachedFeatures>* cached_features) const {
TokenizeAndFindClick(context, input_span, tokens, click_pos);
- // If the default click method failed, let's try to do sub-token matching
- // before we fail.
- if (*click_pos == kInvalidIndex) {
- *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ if (input_span.first != kInvalidIndex && input_span.second != kInvalidIndex) {
+ // If the default click method failed, let's try to do sub-token matching
+ // before we fail.
if (*click_pos == kInvalidIndex) {
- return false;
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
+ return false;
+ }
}
- }
-
- if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
- relative_click_span = {tokens->size() - 1, tokens->size() - 1};
+ } else {
+ // If input_span is unspecified, click the first token and extract features
+ // from all tokens.
+ *click_pos = 0;
+ relative_click_span = {0, tokens->size()};
}
internal::StripOrPadTokens(relative_click_span, options_.context_size(),
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index a39a789..ef9a3df 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -53,11 +53,6 @@
TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
const FeatureProcessorOptions& options);
-// Removes tokens that are not part of a line of the context which contains
-// given span.
-void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens);
-
// Splits tokens that contain the selection boundary inside them.
// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
@@ -81,8 +76,12 @@
} // namespace internal
// Converts a codepoint span to a token span in the given list of tokens.
-TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span);
+// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
+// token to overlap with the codepoint range to be considered part of it.
+// Otherwise it must be fully included in the range.
+TokenSpan CodepointSpanToTokenSpan(
+ const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens = false);
// Converts a token span to a codepoint span in the given list of tokens.
CodepointSpan TokenSpanToCodepointSpan(
@@ -139,8 +138,8 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
- // When relative_click_span == {kInvalidIndex, kInvalidIndex} then all tokens
- // extracted from context will be considered.
+ // When input_span == {kInvalidIndex, kInvalidIndex} then, relative_click_span
+ // is ignored, and all tokens extracted from context will be considered.
bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
TokenSpan relative_click_span,
const FeatureVectorFn& feature_vector_fn,
@@ -159,6 +158,10 @@
return feature_extractor_.DenseFeaturesCount();
}
+ // Splits context to several segments according to configuration.
+ std::vector<UnicodeTextRange> SplitContext(
+ const UnicodeText& context_unicode) const;
+
// Strips boundary codepoints from the span in context and returns the new
// start and end indices. If the span comprises entirely of boundary
// codepoints, the first index of span is returned for both indices.
@@ -249,6 +252,11 @@
void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
std::vector<Token>* result) const;
+ // Removes all tokens from tokens that are not on a line (defined by calling
+ // SplitContext on the context) to which span points.
+ void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
const TokenFeatureExtractor feature_extractor_;
// Codepoint ranges that define what codepoints are supported by the model.
diff --git a/smartselect/feature-processor_test.cc b/smartselect/feature-processor_test.cc
index 1a9b9da..9bee67a 100644
--- a/smartselect/feature-processor_test.cc
+++ b/smartselect/feature-processor_test.cc
@@ -25,6 +25,18 @@
using testing::ElementsAreArray;
using testing::FloatEq;
+class TestingFeatureProcessor : public FeatureProcessor {
+ public:
+ using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
+ using FeatureProcessor::FeatureProcessor;
+ using FeatureProcessor::ICUTokenize;
+ using FeatureProcessor::IsCodepointInRanges;
+ using FeatureProcessor::SpanToLabel;
+ using FeatureProcessor::StripTokensFromOtherLines;
+ using FeatureProcessor::supported_codepoint_ranges_;
+ using FeatureProcessor::SupportedCodepointsRatio;
+};
+
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
std::vector<Token> tokens{Token("Hělló", 0, 5),
Token("fěěbař@google.com", 6, 23),
@@ -107,6 +119,10 @@
}
TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {0, 5};
// clang-format off
@@ -119,12 +135,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens,
ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
// clang-format off
@@ -137,12 +157,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickThird) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {24, 33};
// clang-format off
@@ -155,12 +179,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
const CodepointSpan span = {18, 22};
// clang-format off
@@ -173,12 +201,16 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
}
TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
+ FeatureProcessorOptions options;
+ options.set_only_use_line_with_click(true);
+ TestingFeatureProcessor feature_processor(options);
+
const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
const CodepointSpan span = {5, 23};
// clang-format off
@@ -191,24 +223,13 @@
// clang-format on
// Keeps the first line.
- internal::StripTokensFromOtherLines(context, span, &tokens);
+ feature_processor.StripTokensFromOtherLines(context, span, &tokens);
EXPECT_THAT(tokens, ElementsAreArray(
{Token("Fiřst", 0, 5), Token("Lině", 6, 10),
Token("Sěcond", 18, 23), Token("Lině", 19, 23),
Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
}
-class TestingFeatureProcessor : public FeatureProcessor {
- public:
- using FeatureProcessor::FeatureProcessor;
- using FeatureProcessor::SpanToLabel;
- using FeatureProcessor::SupportedCodepointsRatio;
- using FeatureProcessor::IsCodepointInRanges;
- using FeatureProcessor::ICUTokenize;
- using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
- using FeatureProcessor::supported_codepoint_ranges_;
-};
-
TEST(FeatureProcessorTest, SpanToLabel) {
FeatureProcessorOptions options;
options.set_context_size(1);
@@ -782,5 +803,35 @@
std::make_pair(0, 0));
}
+TEST(FeatureProcessorTest, CodepointSpanToTokenSpan) {
+ const std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
+
+ // Spans matching the tokens exactly.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
+
+ // Snapping to containing tokens has no effect.
+ EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
+ EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
+ EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
+
+ // Span boundaries inside tokens.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
+ EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
+
+ // Tokens adjacent to the span, but not overlapping.
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
+ EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index 3e5068d..e7ae09c 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -61,6 +61,17 @@
return count;
}
+std::string ExtractSelection(const std::string& context,
+ CodepointSpan selection_indices) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ auto selection_begin = context_unicode.begin();
+ std::advance(selection_begin, selection_indices.first);
+ auto selection_end = context_unicode.begin();
+ std::advance(selection_end, selection_indices.second);
+ return UnicodeText::UTF8Substring(selection_begin, selection_end);
+}
+
#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
bool MatchesRegex(const icu::RegexPattern* regex, const std::string& context) {
const icu::UnicodeString unicode_context(context.c_str(), context.size(),
@@ -153,6 +164,31 @@
} // namespace
+void TextClassificationModel::InitializeSharingRegexPatterns(
+ const std::vector<SharingModelOptions::RegexPattern>& patterns) {
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ // Initialize pattern recognizers.
+ for (const auto& regex_pattern : patterns) {
+ UErrorCode status = U_ZERO_ERROR;
+ std::unique_ptr<icu::RegexPattern> compiled_pattern(
+ icu::RegexPattern::compile(
+ icu::UnicodeString(regex_pattern.pattern().c_str(),
+ regex_pattern.pattern().size(), "utf-8"),
+ 0 /* flags */, status));
+ if (U_FAILURE(status)) {
+ TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
+ } else {
+ regex_patterns_.push_back(
+ {regex_pattern.collection_name(), std::move(compiled_pattern)});
+ }
+ }
+#else
+ if (!patterns.empty()) {
+ TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
+ }
+#endif
+}
+
bool TextClassificationModel::LoadModels(const void* addr, int size) {
const char *selection_model, *sharing_model;
int selection_model_length, sharing_model_length;
@@ -187,27 +223,9 @@
sharing_feature_fn_ = CreateFeatureVectorFn(
*sharing_network_, sharing_network_->EmbeddingSize(0));
-#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
- // Initialize pattern recognizers.
- for (const auto& regex_pattern : sharing_options_.regex_pattern()) {
- UErrorCode status = U_ZERO_ERROR;
- std::unique_ptr<icu::RegexPattern> compiled_pattern(
- icu::RegexPattern::compile(
- icu::UnicodeString(regex_pattern.pattern().c_str(),
- regex_pattern.pattern().size(), "utf-8"),
- 0 /* flags */, status));
- if (U_FAILURE(status)) {
- TC_LOG(WARNING) << "Failed to load pattern" << regex_pattern.pattern();
- } else {
- regex_patterns_.push_back(
- {regex_pattern.collection_name(), std::move(compiled_pattern)});
- }
- }
-#else
- if (sharing_options_.regex_pattern_size() > 0) {
- TC_LOG(WARNING) << "ICU not supported regexp matchers ignored.";
- }
-#endif
+ InitializeSharingRegexPatterns(std::vector<SharingModelOptions::RegexPattern>(
+ sharing_options_.regex_pattern().begin(),
+ sharing_options_.regex_pattern().end()));
return true;
}
@@ -279,6 +297,104 @@
return scores;
}
+namespace {
+
+// Returns true if given codepoint is contained in the given span in context.
+bool IsCodepointInSpan(const char32 codepoint, const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto begin_it = context_unicode.begin();
+ std::advance(begin_it, span.first);
+ auto end_it = context_unicode.begin();
+ std::advance(end_it, span.second);
+
+ return std::find(begin_it, end_it, codepoint) != end_it;
+}
+
+// Returns the first codepoint of the span.
+char32 FirstSpanCodepoint(const std::string& context,
+ const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.first);
+ return *it;
+}
+
+// Returns the last codepoint of the span.
+char32 LastSpanCodepoint(const std::string& context, const CodepointSpan span) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ auto it = context_unicode.begin();
+ std::advance(it, span.second - 1);
+ return *it;
+}
+
+} // namespace
+
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+
+namespace {
+
+bool IsOpenBracket(const char32 codepoint) {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_OPEN;
+}
+
+bool IsClosingBracket(const char32 codepoint) {
+ return u_getIntPropertyValue(codepoint, UCHAR_BIDI_PAIRED_BRACKET_TYPE) ==
+ U_BPT_CLOSE;
+}
+
+} // namespace
+
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span) {
+ if (context.empty()) {
+ return span;
+ }
+
+ const char32 begin_char = FirstSpanCodepoint(context, span);
+
+ const char32 paired_begin_char = u_getBidiPairedBracket(begin_char);
+ if (paired_begin_char != begin_char) {
+ if (!IsOpenBracket(begin_char) ||
+ !IsCodepointInSpan(paired_begin_char, context, span)) {
+ ++span.first;
+ }
+ }
+
+ if (span.first == span.second) {
+ return span;
+ }
+
+ const char32 end_char = LastSpanCodepoint(context, span);
+ const char32 paired_end_char = u_getBidiPairedBracket(end_char);
+ if (paired_end_char != end_char) {
+ if (!IsClosingBracket(end_char) ||
+ !IsCodepointInSpan(paired_end_char, context, span)) {
+ --span.second;
+ }
+ }
+
+ // Should not happen, but let's make sure.
+ if (span.first > span.second) {
+ TC_LOG(WARNING) << "Inverse indices result: " << span.first << ", "
+ << span.second;
+ span.second = span.first;
+ }
+
+ return span;
+}
+#endif
+
CodepointSpan TextClassificationModel::SuggestSelection(
const std::string& context, CodepointSpan click_indices) const {
if (!initialized_) {
@@ -286,19 +402,15 @@
return click_indices;
}
- if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
- TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices:"
- << std::get<0>(click_indices) << " "
- << std::get<1>(click_indices);
- return click_indices;
- }
+ const int context_codepoint_size =
+ UTF8ToUnicodeText(context, /*do_copy=*/false).size();
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- const int context_length =
- std::distance(context_unicode.begin(), context_unicode.end());
- if (std::get<0>(click_indices) >= context_length ||
- std::get<1>(click_indices) > context_length) {
+ if (click_indices.first < 0 || click_indices.second < 0 ||
+ click_indices.first >= context_codepoint_size ||
+ click_indices.second > context_codepoint_size ||
+ click_indices.first >= click_indices.second) {
+ TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
+ << click_indices.first << " " << click_indices.second;
return click_indices;
}
@@ -310,6 +422,16 @@
std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
}
+#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ if (selection_options_.strip_unpaired_brackets()) {
+ const CodepointSpan stripped_result =
+ StripUnpairedBrackets(context, result);
+ if (stripped_result.first != stripped_result.second) {
+ result = stripped_result;
+ }
+ }
+#endif
+
return result;
}
@@ -422,8 +544,10 @@
// Check whether any of the regular expressions match.
#ifndef LIBTEXTCLASSIFIER_DISABLE_ICU_SUPPORT
+ const std::string selection_text =
+ ExtractSelection(context, selection_indices);
for (const CompiledRegexPattern& regex_pattern : regex_patterns_) {
- if (MatchesRegex(regex_pattern.pattern.get(), context)) {
+ if (MatchesRegex(regex_pattern.pattern.get(), selection_text)) {
return {{regex_pattern.collection_name, 1.0}};
}
}
@@ -468,10 +592,7 @@
std::unique_ptr<CachedFeatures> cached_features;
std::vector<Token> tokens;
int click_index;
-
int embedding_size = selection_network_->EmbeddingSize(0);
- // TODO(zilka): Refactor the ExtractFeatures API to smoothly support the
- // different usecases. Now it's a lot click-centric.
if (!selection_feature_processor_->ExtractFeatures(
context, click_span, relative_click_span, selection_feature_fn_,
embedding_size + selection_feature_processor_->DenseFeaturesCount(),
@@ -480,8 +601,15 @@
return {};
}
- if (relative_click_span == std::make_pair(kInvalidIndex, kInvalidIndex)) {
- relative_click_span = {tokens.size() - 1, tokens.size() - 1};
+ int first_token;
+ int last_token;
+ if (relative_click_span.first == kInvalidIndex ||
+ relative_click_span.second == kInvalidIndex) {
+ first_token = 0;
+ last_token = tokens.size();
+ } else {
+ first_token = click_index - relative_click_span.first;
+ last_token = click_index + relative_click_span.second + 1;
}
struct SelectionProposal {
@@ -493,51 +621,45 @@
// Scan in the symmetry context for selection span proposals.
std::vector<SelectionProposal> proposals;
+ for (int token_index = first_token; token_index < last_token; ++token_index) {
+ if (token_index < 0 || token_index >= tokens.size() ||
+ tokens[token_index].is_padding) {
+ continue;
+ }
- for (int i = -relative_click_span.first; i < relative_click_span.second + 1;
- ++i) {
- const int token_index = click_index + i;
- if (token_index >= 0 && token_index < tokens.size() &&
- !tokens[token_index].is_padding) {
- float score;
- VectorSpan<float> features;
- VectorSpan<Token> output_tokens;
+ float score;
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ std::vector<CodepointSpan> selection_label_spans;
+ CodepointSpan span;
+ if (cached_features->Get(token_index, &features, &output_tokens) &&
+ selection_feature_processor_->SelectionLabelSpans(
+ output_tokens, &selection_label_spans)) {
+ // Add an implicit proposal for each token to be by itself. Every
+ // token should be now represented in the results.
+ proposals.push_back(
+ SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
- if (tokens[token_index].is_padding) {
- continue;
- }
+ std::vector<float> scores;
+ selection_network_->ComputeLogits(features, &scores);
- std::vector<CodepointSpan> selection_label_spans;
- CodepointSpan span;
- if (cached_features->Get(token_index, &features, &output_tokens) &&
- selection_feature_processor_->SelectionLabelSpans(
- output_tokens, &selection_label_spans)) {
- // Add an implicit proposal for each token to be by itself. Every
- // token should be now represented in the results.
+ scores = nlp_core::ComputeSoftmax(scores);
+ std::tie(span, score) = BestSelectionSpan({kInvalidIndex, kInvalidIndex},
+ scores, selection_label_spans);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+ score >= 0) {
+ const int prediction = BestPrediction(scores);
proposals.push_back(
- SelectionProposal{0, token_index, selection_label_spans[0], 0.0});
-
- std::vector<float> scores;
- selection_network_->ComputeLogits(features, &scores);
-
- scores = nlp_core::ComputeSoftmax(scores);
- std::tie(span, score) = BestSelectionSpan(
- {kInvalidIndex, kInvalidIndex}, scores, selection_label_spans);
- if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
- score >= 0) {
- const int prediction = BestPrediction(scores);
- proposals.push_back(
- SelectionProposal{prediction, token_index, span, score});
- }
- } else {
- // Add an implicit proposal for each token to be by itself. Every token
- // should be now represented in the results.
- proposals.push_back(SelectionProposal{
- 0,
- token_index,
- {tokens[token_index].start, tokens[token_index].end},
- 0.0});
+ SelectionProposal{prediction, token_index, span, score});
}
+ } else {
+ // Add an implicit proposal for each token to be by itself. Every token
+ // should be now represented in the results.
+ proposals.push_back(SelectionProposal{
+ 0,
+ token_index,
+ {tokens[token_index].start, tokens[token_index].end},
+ 0.0});
}
}
@@ -592,9 +714,20 @@
std::vector<TextClassificationModel::AnnotatedSpan>
TextClassificationModel::Annotate(const std::string& context) const {
- std::vector<CodepointSpan> chunks =
- Chunk(context, /*click_span=*/{0, 1},
- /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
+ std::vector<CodepointSpan> chunks;
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ for (const UnicodeTextRange& line :
+ selection_feature_processor_->SplitContext(context_unicode)) {
+ const std::vector<CodepointSpan> local_chunks =
+ Chunk(UnicodeText::UTF8Substring(line.first, line.second),
+ /*click_span=*/{kInvalidIndex, kInvalidIndex},
+ /*relative_click_span=*/{kInvalidIndex, kInvalidIndex});
+ const int offset = std::distance(context_unicode.begin(), line.first);
+ for (CodepointSpan chunk : local_chunks) {
+ chunks.push_back({chunk.first + offset, chunk.second + offset});
+ }
+ }
std::vector<TextClassificationModel::AnnotatedSpan> result;
for (const CodepointSpan& chunk : chunks) {
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 5b58d89..d0df193 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -98,10 +98,8 @@
// token determined by click_span and looks at relative_click_span tokens
// left and right around the click position.
// If relative_click_span == {kInvalidIndex, kInvalidIndex} then the whole
- // context is considered, regardless of the click_span (which should point to
- // the beginning {0, 1}.
+ // context is considered, regardless of the click_span.
// Returns the chunks sorted by their position in the context string.
- // TODO(zilka): Tidy up the interface.
std::vector<CodepointSpan> Chunk(const std::string& context,
CodepointSpan click_span,
TokenSpan relative_click_span) const;
@@ -111,6 +109,9 @@
return selection_feature_processor_.get();
}
+ void InitializeSharingRegexPatterns(
+ const std::vector<SharingModelOptions::RegexPattern>& patterns);
+
// Collection name when url hint is accepted.
const std::string kUrlHintCollection = "url";
@@ -167,6 +168,12 @@
#endif
};
+// If the first or the last codepoint of the given span is a bracket, the
+// bracket is stripped if the span does not contain its corresponding paired
+// version.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span);
+
// Parses the merged image given as a file descriptor, and reads
// the ModelOptions proto from the selection model.
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index ca10a0e..315e849 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -46,6 +46,12 @@
// enforcing symmetry.
optional int32 symmetry_context_size = 4;
+ // If true, before the selection is returned, the unpaired brackets contained
+ // in the predicted selection are stripped from the both selection ends.
+ // The bracket codepoints are defined in the Unicode standard:
+ // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
+ optional bool strip_unpaired_brackets = 5 [default = true];
+
reserved 2;
}
@@ -71,7 +77,7 @@
repeated RegexPattern regex_pattern = 5;
}
-// Next ID: 39
+// Next ID: 41
message FeatureProcessorOptions {
// Number of buckets used for hashing charactergrams.
optional int32 num_buckets = 1 [default = -1];
@@ -207,7 +213,7 @@
// predicted spans.
repeated int32 ignored_span_boundary_codepoints = 36;
- reserved 7, 11, 12, 26, 27, 28, 29, 32, 35;
+ reserved 7, 11, 12, 26, 27, 28, 29, 32, 35, 39, 40;
// List of allowed charactergrams. The extracted charactergrams are filtered
// using this list, and charactergrams that are not present are interpreted as
diff --git a/smartselect/text-classification-model_test.cc b/smartselect/text-classification-model_test.cc
index 490b395..5550e53 100644
--- a/smartselect/text-classification-model_test.cc
+++ b/smartselect/text-classification-model_test.cc
@@ -18,6 +18,8 @@
#include <fcntl.h>
#include <stdio.h>
+#include <fstream>
+#include <iostream>
#include <memory>
#include <string>
@@ -26,10 +28,31 @@
namespace libtextclassifier {
namespace {
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
std::string GetModelPath() {
return TEST_DATA_DIR "smartselection.model";
}
+std::string GetURLRegexPath() {
+ return TEST_DATA_DIR "regex_url.txt";
+}
+
+std::string GetEmailRegexPath() {
+ return TEST_DATA_DIR "regex_email.txt";
+}
+
+TEST(TextClassificationModelTest, StripUnpairedBrackets) {
+ // Stripping brackets strip brackets from length 1 bracket only selections.
+ EXPECT_EQ(StripUnpairedBrackets("call me at ) today", {11, 12}),
+ std::make_pair(12, 12));
+ EXPECT_EQ(StripUnpairedBrackets("call me at ( today", {11, 12}),
+ std::make_pair(12, 12));
+}
+
TEST(TextClassificationModelTest, ReadModelOptions) {
const std::string model_path = GetModelPath();
int fd = open(model_path.c_str(), O_RDONLY);
@@ -62,6 +85,29 @@
// Single word.
EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
+
+ EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ // Unpaired bracket stripping.
+ EXPECT_EQ(
+ model->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(model->SuggestSelection("call me at (857 225 3556 today", {11, 15}),
+ std::make_pair(12, 24));
+ EXPECT_EQ(model->SuggestSelection("call me at 857 225 3556) today", {11, 14}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ model->SuggestSelection("call me at )857 225 3556( today", {11, 15}),
+ std::make_pair(12, 24));
+
+ // If the resulting selection would be empty, the original span is returned.
+ EXPECT_EQ(model->SuggestSelection("call me at )( today", {11, 13}),
+ std::make_pair(11, 13));
+ EXPECT_EQ(model->SuggestSelection("call me at ( today", {11, 12}),
+ std::make_pair(11, 12));
+ EXPECT_EQ(model->SuggestSelection("call me at ) today", {11, 12}),
+ std::make_pair(11, 12));
}
TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
@@ -139,6 +185,8 @@
explicit TestingTextClassificationModel(int fd)
: libtextclassifier::TextClassificationModel(fd) {}
+ using TextClassificationModel::InitializeSharingRegexPatterns;
+
void DisableClassificationHints() {
sharing_options_.set_always_accept_url_hint(false);
sharing_options_.set_always_accept_email_hint(false);
@@ -310,14 +358,14 @@
close(fd);
std::string test_string =
- "I saw Barak Obama today at 350 Third Street, Cambridge";
+ "& saw Barak Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225-3556.";
std::vector<TextClassificationModel::AnnotatedSpan> result =
model->Annotate(test_string);
std::vector<TextClassificationModel::AnnotatedSpan> expected;
expected.emplace_back();
- expected.back().span = {0, 1};
- expected.back().classification.push_back({"other", 1.0});
+ expected.back().span = {0, 0};
expected.emplace_back();
expected.back().span = {2, 5};
expected.back().classification.push_back({"other", 1.0});
@@ -328,20 +376,65 @@
expected.back().span = {18, 23};
expected.back().classification.push_back({"other", 1.0});
expected.emplace_back();
- expected.back().span = {24, 26};
- expected.back().classification.push_back({"other", 1.0});
+ expected.back().span = {24, 24};
expected.emplace_back();
expected.back().span = {27, 54};
expected.back().classification.push_back({"address", 1.0});
+ expected.emplace_back();
+ expected.back().span = {55, 58};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {59, 61};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {62, 74};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {75, 77};
+ expected.back().classification.push_back({"other", 1.0});
+ expected.emplace_back();
+ expected.back().span = {78, 90};
+ expected.back().classification.push_back({"phone", 1.0});
- ASSERT_EQ(result.size(), expected.size());
+ EXPECT_EQ(result.size(), expected.size());
for (int i = 0; i < expected.size(); ++i) {
EXPECT_EQ(result[i].span, expected[i].span) << result[i];
- EXPECT_EQ(result[i].classification[0].first,
- expected[i].classification[0].first)
- << result[i];
+ if (!expected[i].classification.empty()) {
+ EXPECT_GT(result[i].classification.size(), 0);
+ EXPECT_EQ(result[i].classification[0].first,
+ expected[i].classification[0].first)
+ << result[i];
+ }
}
}
+TEST(TextClassificationModelTest, URLEmailRegex) {
+ const std::string model_path = GetModelPath();
+ int fd = open(model_path.c_str(), O_RDONLY);
+ std::unique_ptr<TestingTextClassificationModel> model(
+ new TestingTextClassificationModel(fd));
+ close(fd);
+
+ SharingModelOptions options;
+ SharingModelOptions::RegexPattern* email_pattern =
+ options.add_regex_pattern();
+ email_pattern->set_collection_name("email");
+ email_pattern->set_pattern(ReadFile(GetEmailRegexPath()));
+ SharingModelOptions::RegexPattern* url_pattern = options.add_regex_pattern();
+ url_pattern->set_collection_name("url");
+ url_pattern->set_pattern(ReadFile(GetURLRegexPath()));
+
+ // TODO(b/69538802): Modify directly the model image instead.
+ model->InitializeSharingRegexPatterns(
+ {options.regex_pattern().begin(), options.regex_pattern().end()});
+
+ EXPECT_EQ("url", FindBestResult(model->ClassifyText(
+ "Visit www.google.com every today!", {6, 20})));
+ EXPECT_EQ("email", FindBestResult(model->ClassifyText(
+ "My email: asdf@something.cz", {10, 27})));
+ EXPECT_EQ("url", FindBestResult(model->ClassifyText(
+ "Login: http://asdf@something.cz", {7, 31})));
+}
+
} // namespace
} // namespace libtextclassifier
diff --git a/util/base/endian.h b/util/base/endian.h
index 5813288..f319f65 100644
--- a/util/base/endian.h
+++ b/util/base/endian.h
@@ -24,13 +24,34 @@
#if defined OS_LINUX || defined OS_CYGWIN || defined OS_ANDROID || \
defined(__ANDROID__)
#include <endian.h>
+#elif defined(__APPLE__)
+#include <machine/endian.h>
+// Add linux style defines.
+#ifndef __BYTE_ORDER
+#define __BYTE_ORDER BYTE_ORDER
+#endif // __BYTE_ORDER
+#ifndef __LITTLE_ENDIAN
+#define __LITTLE_ENDIAN LITTLE_ENDIAN
+#endif // __LITTLE_ENDIAN
+#ifndef __BIG_ENDIAN
+#define __BIG_ENDIAN BIG_ENDIAN
+#endif // __BIG_ENDIAN
#endif
// The following guarantees declaration of the byte swap functions, and
// defines __BYTE_ORDER for MSVC
#if defined(__GLIBC__) || defined(__CYGWIN__)
#include <byteswap.h> // IWYU pragma: export
-
+// The following section defines the byte swap functions for OS X / iOS,
+// which does not ship with byteswap.h.
+#elif defined(__APPLE__)
+// Make sure that byte swap functions are not already defined.
+#if !defined(bswap_16)
+#include <libkern/OSByteOrder.h>
+#define bswap_16(x) OSSwapInt16(x)
+#define bswap_32(x) OSSwapInt32(x)
+#define bswap_64(x) OSSwapInt64(x)
+#endif // !defined(bswap_16)
#else
#define GG_LONGLONG(x) x##LL
#define GG_ULONGLONG(x) x##ULL
diff --git a/util/hash/farmhash.cc b/util/hash/farmhash.cc
index f4f2e84..673f45f 100644
--- a/util/hash/farmhash.cc
+++ b/util/hash/farmhash.cc
@@ -348,10 +348,7 @@
return x;
}
-} // namespace NAMESPACE_FOR_HASH_FUNCTIONS;
-
using namespace std;
-using namespace NAMESPACE_FOR_HASH_FUNCTIONS;
namespace farmhashna {
#undef Fetch
#define Fetch Fetch64
@@ -1407,7 +1404,6 @@
return CityHash128(s, len);
}
} // namespace farmhashcc
-namespace NAMESPACE_FOR_HASH_FUNCTIONS {
// BASIC STRING HASHING