Fixes utf8 handling, datetime model and flight number model, makes
models smaller by compressing the regex rules, and adds i18n models.
(sync from google3)
Test: bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Test: bit CtsViewTestCases:android.view.textclassifier.cts.TextClassificationManagerTest
Bug: 64929062
Bug: 77223425
Change-Id: I04472e6b247e824bf2b745077c50fcde4269aefc
diff --git a/Android.mk b/Android.mk
index 2317b83..760c4c8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -37,7 +37,8 @@
MY_LIBTEXTCLASSIFIER_CFLAGS := \
$(MY_LIBTEXTCLASSIFIER_WARNING_CFLAGS) \
-fvisibility=hidden \
- -DLIBTEXTCLASSIFIER_UNILIB_ICU
+ -DLIBTEXTCLASSIFIER_UNILIB_ICU \
+ -DZLIB_CONST
# Only enable debug logging in userdebug/eng builds.
ifneq (,$(filter userdebug eng, $(TARGET_BUILD_VARIANT)))
@@ -70,17 +71,23 @@
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc test-util.%,$(call all-subdir-cpp-files))
-LOCAL_C_INCLUDES := $(TOP)/external/tensorflow $(TOP)/external/flatbuffers/include
+
+LOCAL_C_INCLUDES := $(TOP)/external/zlib
+LOCAL_C_INCLUDES += $(TOP)/external/tensorflow
+LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include
LOCAL_SHARED_LIBRARIES += liblog
-LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
+LOCAL_SHARED_LIBRARIES += libicuuc
+LOCAL_SHARED_LIBRARIES += libicui18n
LOCAL_SHARED_LIBRARIES += libtflite
+LOCAL_SHARED_LIBRARIES += libz
+
LOCAL_STATIC_LIBRARIES += flatbuffers
+
LOCAL_REQUIRED_MODULES := textclassifier.en.model
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
-
LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
@@ -106,12 +113,19 @@
LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
LOCAL_SRC_FILES := $(call all-subdir-cpp-files)
-LOCAL_C_INCLUDES := $(TOP)/external/tensorflow $(TOP)/external/flatbuffers/include
+
+LOCAL_C_INCLUDES := $(TOP)/external/zlib
+LOCAL_C_INCLUDES += $(TOP)/external/tensorflow
+LOCAL_C_INCLUDES += $(TOP)/external/flatbuffers/include
LOCAL_STATIC_LIBRARIES += libgmock
LOCAL_SHARED_LIBRARIES += liblog
-LOCAL_SHARED_LIBRARIES += libicuuc libicui18n
+LOCAL_SHARED_LIBRARIES += libicuuc
+LOCAL_SHARED_LIBRARIES += libicui18n
LOCAL_SHARED_LIBRARIES += libtflite
+LOCAL_SHARED_LIBRARIES += libz
+
+LOCAL_STATIC_LIBRARIES += flatbuffers
include $(BUILD_NATIVE_TEST)
diff --git a/datetime/extractor.cc b/datetime/extractor.cc
index 8c6c3ff..f4ab8f4 100644
--- a/datetime/extractor.cc
+++ b/datetime/extractor.cc
@@ -20,156 +20,119 @@
namespace libtextclassifier2 {
-constexpr char const* kGroupYear = "YEAR";
-constexpr char const* kGroupMonth = "MONTH";
-constexpr char const* kGroupDay = "DAY";
-constexpr char const* kGroupHour = "HOUR";
-constexpr char const* kGroupMinute = "MINUTE";
-constexpr char const* kGroupSecond = "SECOND";
-constexpr char const* kGroupAmpm = "AMPM";
-constexpr char const* kGroupRelationDistance = "RELATIONDISTANCE";
-constexpr char const* kGroupRelation = "RELATION";
-constexpr char const* kGroupRelationType = "RELATIONTYPE";
-// Dummy groups serve just as an inflator of the selection. E.g. we might want
-// to select more text than was contained in an envelope of all extractor spans.
-constexpr char const* kGroupDummy1 = "DUMMY1";
-constexpr char const* kGroupDummy2 = "DUMMY2";
-
bool DatetimeExtractor::Extract(DateParseData* result,
CodepointSpan* result_span) const {
result->field_set_mask = 0;
*result_span = {kInvalidIndex, kInvalidIndex};
- UnicodeText group_text;
- if (GroupNotEmpty(kGroupYear, &group_text)) {
- result->field_set_mask |= DateParseData::YEAR_FIELD;
- if (!ParseYear(group_text, &(result->year))) {
- TC_LOG(ERROR) << "Couldn't extract YEAR.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupYear, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
+ if (rule_.regex->groups() == nullptr) {
+ return false;
}
- if (GroupNotEmpty(kGroupMonth, &group_text)) {
- result->field_set_mask |= DateParseData::MONTH_FIELD;
- if (!ParseMonth(group_text, &(result->month))) {
- TC_LOG(ERROR) << "Couldn't extract MONTH.";
+ for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
+ UnicodeText group_text;
+ const int group_type = rule_.regex->groups()->Get(group_id);
+ if (group_type == DatetimeGroupType_GROUP_UNUSED) {
+ continue;
+ }
+ if (!GroupTextFromMatch(group_id, &group_text)) {
+ TC_LOG(ERROR) << "Couldn't retrieve group.";
return false;
}
- if (!UpdateMatchSpan(kGroupMonth, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
+ // The pattern can have a group defined in a part that was not matched,
+ // e.g. an optional part. In this case we'll get an empty content here.
+ if (group_text.empty()) {
+ continue;
}
- }
-
- if (GroupNotEmpty(kGroupDay, &group_text)) {
- result->field_set_mask |= DateParseData::DAY_FIELD;
- if (!ParseDigits(group_text, &(result->day_of_month))) {
- TC_LOG(ERROR) << "Couldn't extract DAY.";
- return false;
+ switch (group_type) {
+ case DatetimeGroupType_GROUP_YEAR: {
+ if (!ParseYear(group_text, &(result->year))) {
+ TC_LOG(ERROR) << "Couldn't extract YEAR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::YEAR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MONTH: {
+ if (!ParseMonth(group_text, &(result->month))) {
+ TC_LOG(ERROR) << "Couldn't extract MONTH.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MONTH_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DAY: {
+ if (!ParseDigits(group_text, &(result->day_of_month))) {
+ TC_LOG(ERROR) << "Couldn't extract DAY.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::DAY_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_HOUR: {
+ if (!ParseDigits(group_text, &(result->hour))) {
+ TC_LOG(ERROR) << "Couldn't extract HOUR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::HOUR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MINUTE: {
+ if (!ParseDigits(group_text, &(result->minute))) {
+ TC_LOG(ERROR) << "Couldn't extract MINUTE.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MINUTE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_SECOND: {
+ if (!ParseDigits(group_text, &(result->second))) {
+ TC_LOG(ERROR) << "Couldn't extract SECOND.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::SECOND_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_AMPM: {
+ if (!ParseAMPM(group_text, &(result->ampm))) {
+ TC_LOG(ERROR) << "Couldn't extract AMPM.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::AMPM_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
+ if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATION: {
+ if (!ParseRelation(group_text, &(result->relation))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONTYPE: {
+ if (!ParseRelationType(group_text, &(result->relation_type))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DUMMY1:
+ case DatetimeGroupType_GROUP_DUMMY2:
+ break;
+ default:
+ TC_LOG(INFO) << "Unknown group type.";
+ continue;
}
- if (!UpdateMatchSpan(kGroupDay, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupHour, &group_text)) {
- result->field_set_mask |= DateParseData::HOUR_FIELD;
- if (!ParseDigits(group_text, &(result->hour))) {
- TC_LOG(ERROR) << "Couldn't extract HOUR.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupHour, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupMinute, &group_text)) {
- result->field_set_mask |= DateParseData::MINUTE_FIELD;
- if (!ParseDigits(group_text, &(result->minute))) {
- TC_LOG(ERROR) << "Couldn't extract MINUTE.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupMinute, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupSecond, &group_text)) {
- result->field_set_mask |= DateParseData::SECOND_FIELD;
- if (!ParseDigits(group_text, &(result->second))) {
- TC_LOG(ERROR) << "Couldn't extract SECOND.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupSecond, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupAmpm, &group_text)) {
- result->field_set_mask |= DateParseData::AMPM_FIELD;
- if (!ParseAMPM(group_text, &(result->ampm))) {
- TC_LOG(ERROR) << "Couldn't extract AMPM.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupAmpm, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelationDistance, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
- if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelationDistance, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelation, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_FIELD;
- if (!ParseRelation(group_text, &(result->relation))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelation, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelationType, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
- if (!ParseRelationType(group_text, &(result->relation_type))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelationType, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupDummy1, &group_text)) {
- if (!UpdateMatchSpan(kGroupDummy1, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupDummy2, &group_text)) {
- if (!UpdateMatchSpan(kGroupDummy2, result_span)) {
+ if (!UpdateMatchSpan(group_id, result_span)) {
TC_LOG(ERROR) << "Couldn't update span.";
return false;
}
@@ -226,24 +189,24 @@
return true;
}
-bool DatetimeExtractor::GroupNotEmpty(StringPiece name,
- UnicodeText* result) const {
+bool DatetimeExtractor::GroupTextFromMatch(int group_id,
+ UnicodeText* result) const {
int status;
- *result = matcher_.Group(name, &status);
+ *result = matcher_.Group(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- return !result->empty();
+ return true;
}
-bool DatetimeExtractor::UpdateMatchSpan(StringPiece name,
+bool DatetimeExtractor::UpdateMatchSpan(int group_id,
CodepointSpan* span) const {
int status;
- const int match_start = matcher_.Start(name, &status);
+ const int match_start = matcher_.Start(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- const int match_end = matcher_.End(name, &status);
+ const int match_end = matcher_.End(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
diff --git a/datetime/extractor.h b/datetime/extractor.h
index ceeb9cf..5c36ec4 100644
--- a/datetime/extractor.h
+++ b/datetime/extractor.h
@@ -29,18 +29,31 @@
namespace libtextclassifier2 {
+struct CompiledRule {
+ // The compiled regular expression.
+ std::unique_ptr<const UniLib::RegexPattern> compiled_regex;
+
+ // The uncompiled pattern and information about the pattern groups.
+ const DatetimeModelPattern_::Regex* regex;
+
+ // DatetimeModelPattern which 'regex' is part of and comes from.
+ const DatetimeModelPattern* pattern;
+};
+
// A helper class for DatetimeParser that extracts structured data
// (DateParseDate) from the current match of the passed RegexMatcher.
class DatetimeExtractor {
public:
DatetimeExtractor(
- const UniLib::RegexMatcher& matcher, int locale_id, const UniLib& unilib,
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int locale_id, const UniLib& unilib,
const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
extractor_rules,
const std::unordered_map<DatetimeExtractorType,
std::unordered_map<int, int>>&
type_and_locale_to_extractor_rule)
- : matcher_(matcher),
+ : rule_(rule),
+ matcher_(matcher),
locale_id_(locale_id),
unilib_(unilib),
rules_(extractor_rules),
@@ -57,10 +70,10 @@
DatetimeExtractorType extractor_type,
UnicodeText* match_result = nullptr) const;
- bool GroupNotEmpty(StringPiece name, UnicodeText* result) const;
+ bool GroupTextFromMatch(int group_id, UnicodeText* result) const;
// Updates the span to include the current match for the given group.
- bool UpdateMatchSpan(StringPiece group_name, CodepointSpan* span) const;
+ bool UpdateMatchSpan(int group_id, CodepointSpan* span) const;
// Returns true if any of the extractors from 'mapping' matched. If it did,
// will fill 'result' with the associated value from 'mapping'.
@@ -84,6 +97,7 @@
DateParseData::RelationType* parsed_relation_type) const;
bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+ const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
int locale_id_;
const UniLib& unilib_;
diff --git a/datetime/parser.cc b/datetime/parser.cc
index 8ad3d33..e9a6eb1 100644
--- a/datetime/parser.cc
+++ b/datetime/parser.cc
@@ -21,61 +21,78 @@
#include "datetime/extractor.h"
#include "util/calendar/calendar.h"
+#include "util/i18n/locale.h"
#include "util/strings/split.h"
namespace libtextclassifier2 {
std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
- const DatetimeModel* model, const UniLib& unilib) {
- std::unique_ptr<DatetimeParser> result(new DatetimeParser(model, unilib));
+ const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor) {
+ std::unique_ptr<DatetimeParser> result(
+ new DatetimeParser(model, unilib, decompressor));
if (!result->initialized_) {
result.reset();
}
return result;
}
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib)
+DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor)
: unilib_(unilib) {
initialized_ = false;
- for (int i = 0; i < model->patterns()->Length(); ++i) {
- const DatetimeModelPattern* pattern = model->patterns()->Get(i);
- for (int j = 0; j < pattern->regexes()->Length(); ++j) {
+
+ if (model == nullptr) {
+ return;
+ }
+
+ if (model->patterns() != nullptr) {
+ for (const DatetimeModelPattern* pattern : *model->patterns()) {
+ if (pattern->regexes()) {
+ for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(unilib, regex->pattern(),
+ regex->compressed_pattern(),
+ decompressor);
+ if (!regex_pattern) {
+ TC_LOG(ERROR) << "Couldn't create rule pattern.";
+ return;
+ }
+ rules_.push_back({std::move(regex_pattern), regex, pattern});
+ if (pattern->locales()) {
+ for (int locale : *pattern->locales()) {
+ locale_to_rules_[locale].push_back(rules_.size() - 1);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (model->extractors() != nullptr) {
+ for (const DatetimeModelExtractor* extractor : *model->extractors()) {
std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib.CreateRegexPattern(UTF8ToUnicodeText(
- pattern->regexes()->Get(j)->str(), /*do_copy=*/false));
+ UncompressMakeRegexPattern(unilib, extractor->pattern(),
+ extractor->compressed_pattern(),
+ decompressor);
if (!regex_pattern) {
- TC_LOG(ERROR) << "Couldn't create pattern: "
- << pattern->regexes()->Get(j)->str();
+ TC_LOG(ERROR) << "Couldn't create extractor pattern";
return;
}
- rules_.push_back(std::move(regex_pattern));
- rule_id_to_pattern_.push_back(pattern);
- for (int k = 0; k < pattern->locales()->Length(); ++k) {
- locale_to_rules_[pattern->locales()->Get(k)].push_back(rules_.size() -
- 1);
+ extractor_rules_.push_back(std::move(regex_pattern));
+
+ if (extractor->locales()) {
+ for (int locale : *extractor->locales()) {
+ type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
+ extractor_rules_.size() - 1;
+ }
}
}
}
- for (int i = 0; i < model->extractors()->Length(); ++i) {
- const DatetimeModelExtractor* extractor = model->extractors()->Get(i);
- std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib.CreateRegexPattern(
- UTF8ToUnicodeText(extractor->pattern()->str(), /*do_copy=*/false));
- if (!regex_pattern) {
- TC_LOG(ERROR) << "Couldn't create pattern: "
- << extractor->pattern()->str();
- return;
+ if (model->locales() != nullptr) {
+ for (int i = 0; i < model->locales()->Length(); ++i) {
+ locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
- extractor_rules_.push_back(std::move(regex_pattern));
-
- for (int j = 0; j < extractor->locales()->Length(); ++j) {
- type_and_locale_to_extractor_rule_[extractor->extractor()]
- [extractor->locales()->Get(j)] = i;
- }
- }
-
- for (int i = 0; i < model->locales()->Length(); ++i) {
- locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
use_extractors_for_locating_ = model->use_extractors_for_locating();
@@ -86,19 +103,21 @@
bool DatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locales, mode,
- results);
+ anchor_start_end, results);
}
bool DatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
std::vector<DatetimeParseResultSpan> found_spans;
std::unordered_set<int> executed_rules;
- for (const int locale_id : ParseLocales(locales)) {
+ for (const int locale_id : ParseAndExpandLocales(locales)) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
continue;
@@ -110,26 +129,45 @@
continue;
}
- if (!(rule_id_to_pattern_[rule_id]->enabled_modes() & mode)) {
+ if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
continue;
}
executed_rules.insert(rule_id);
- if (!ParseWithRule(*rules_[rule_id], rule_id_to_pattern_[rule_id], input,
- reference_time_ms_utc, reference_timezone, locale_id,
+ if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, locale_id, anchor_start_end,
&found_spans)) {
return false;
}
}
}
- // Resolve conflicts by always picking the longer span.
- std::sort(
- found_spans.begin(), found_spans.end(),
- [](const DatetimeParseResultSpan& a, const DatetimeParseResultSpan& b) {
- return (a.span.second - a.span.first) > (b.span.second - b.span.first);
- });
+ std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
+ int counter = 0;
+ for (const auto& found_span : found_spans) {
+ indexed_found_spans.push_back({found_span, counter});
+ counter++;
+ }
+
+ // Resolve conflicts by always picking the longer span and breaking ties by
+ // selecting the earlier entry in the list for a given locale.
+ std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
+
+ found_spans.clear();
+ for (auto& span_index_pair : indexed_found_spans) {
+ found_spans.push_back(span_index_pair.first);
+ }
std::set<int, std::function<bool(int, int)>> chosen_indices_set(
[&found_spans](int a, int b) {
@@ -145,60 +183,119 @@
return true;
}
-bool DatetimeParser::ParseWithRule(
- const UniLib::RegexPattern& regex, const DatetimeModelPattern* pattern,
- const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
- std::unique_ptr<UniLib::RegexMatcher> matcher = regex.Matcher(input);
-
+bool DatetimeParser::HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ int locale_id, std::vector<DatetimeParseResultSpan>* result) const {
int status = UniLib::RegexMatcher::kNoError;
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- const int start = matcher->Start(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
+ const int start = matcher.Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
- const int end = matcher->End(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
+ const int end = matcher.End(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
- DatetimeParseResultSpan parse_result;
- if (!ExtractDatetime(*matcher, reference_time_ms_utc, reference_timezone,
- locale_id, &(parse_result.data), &parse_result.span)) {
- return false;
- }
- if (!use_extractors_for_locating_) {
- parse_result.span = {start, end};
- }
+ DatetimeParseResultSpan parse_result;
+ if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
+ locale_id, &(parse_result.data), &parse_result.span)) {
+ return false;
+ }
+ if (!use_extractors_for_locating_) {
+ parse_result.span = {start, end};
+ }
+ if (parse_result.span.first != kInvalidIndex &&
+ parse_result.span.second != kInvalidIndex) {
parse_result.target_classification_score =
- pattern->target_classification_score();
- parse_result.priority_score = pattern->priority_score();
-
+ rule.pattern->target_classification_score();
+ parse_result.priority_score = rule.pattern->priority_score();
result->push_back(parse_result);
}
return true;
}
+bool DatetimeParser::ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const int locale_id, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* result) const {
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.compiled_regex->Matcher(input);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (anchor_start_end) {
+ if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, locale_id, result)) {
+ return false;
+ }
+ }
+ } else {
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, locale_id, result)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
constexpr char const* kDefaultLocale = "";
-std::vector<int> DatetimeParser::ParseLocales(
+std::vector<int> DatetimeParser::ParseAndExpandLocales(
const std::string& locales) const {
- std::vector<std::string> split_locales = strings::Split(locales, ',');
-
- // Add a default fallback locale to the end of the list.
- split_locales.push_back(kDefaultLocale);
+ std::vector<StringPiece> split_locales = strings::Split(locales, ',');
std::vector<int> result;
- for (const std::string& locale : split_locales) {
- auto locale_it = locale_string_to_id_.find(locale);
- if (locale_it == locale_string_to_id_.end()) {
- TC_LOG(INFO) << "Ignoring locale: " << locale;
+ for (const StringPiece& locale_str : split_locales) {
+ auto locale_it = locale_string_to_id_.find(locale_str.ToString());
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
continue;
}
- result.push_back(locale_it->second);
+
+ const std::string language = locale.Language();
+ const std::string script = locale.Script();
+ const std::string region = locale.Region();
+
+ // First, try adding language-script-* locale.
+ if (!script.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Second, try adding language-* locale.
+ if (!language.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+
+ // Second, try adding *-region locale.
+ if (!region.empty()) {
+ locale_it = locale_string_to_id_.find("*-" + region);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
}
+
+ // Add a default fallback locale to the end of the list.
+ auto locale_it = locale_string_to_id_.find(kDefaultLocale);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ } else {
+ TC_VLOG(1) << "Could not add default locale.";
+ }
+
return result;
}
@@ -250,13 +347,15 @@
} // namespace
-bool DatetimeParser::ExtractDatetime(const UniLib::RegexMatcher& matcher,
+bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
int locale_id, DatetimeParseResult* result,
CodepointSpan* result_span) const {
DateParseData parse;
- DatetimeExtractor extractor(matcher, locale_id, unilib_, extractor_rules_,
+ DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
+ extractor_rules_,
type_and_locale_to_extractor_rule_);
if (!extractor.Extract(&parse, result_span)) {
return false;
diff --git a/datetime/parser.h b/datetime/parser.h
index 9f31142..c9d2119 100644
--- a/datetime/parser.h
+++ b/datetime/parser.h
@@ -22,11 +22,13 @@
#include <unordered_map>
#include <vector>
+#include "datetime/extractor.h"
#include "model_generated.h"
#include "types.h"
#include "util/base/integral_types.h"
#include "util/calendar/calendar.h"
#include "util/utf8/unilib.h"
+#include "zlib-utils.h"
namespace libtextclassifier2 {
@@ -34,46 +36,58 @@
// time.
class DatetimeParser {
public:
- static std::unique_ptr<DatetimeParser> Instance(const DatetimeModel* model,
- const UniLib& unilib);
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor);
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
bool Parse(const std::string& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode,
+ ModeFlag mode, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
// Same as above but takes UnicodeText.
bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode,
+ ModeFlag mode, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
protected:
- DatetimeParser(const DatetimeModel* model, const UniLib& unilib);
+ DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor);
// Returns a list of locale ids for given locale spec string (comma-separated
// locale names).
- std::vector<int> ParseLocales(const std::string& locales) const;
- bool ParseWithRule(const UniLib::RegexPattern& regex,
- const DatetimeModelPattern* pattern,
- const UnicodeText& input, int64 reference_time_ms_utc,
+ std::vector<int> ParseAndExpandLocales(const std::string& locales) const;
+
+ bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc,
const std::string& reference_timezone, const int locale_id,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
// Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const UniLib::RegexMatcher& matcher,
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
const std::string& reference_timezone, int locale_id,
DatetimeParseResult* result,
CodepointSpan* result_span) const;
+ // Parse and extract information from current match in 'matcher'.
+ bool HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
private:
bool initialized_;
const UniLib& unilib_;
- std::vector<const DatetimeModelPattern*> rule_id_to_pattern_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> rules_;
+ std::vector<CompiledRule> rules_;
std::unordered_map<int, std::vector<int>> locale_to_rules_;
std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc
index 1df959f..36525e2 100644
--- a/datetime/parser_test.cc
+++ b/datetime/parser_test.cc
@@ -25,6 +25,7 @@
#include "datetime/parser.h"
#include "model_generated.h"
+#include "text-classifier.h"
#include "types-test-util.h"
using testing::ElementsAreArray;
@@ -54,15 +55,27 @@
public:
void SetUp() override {
model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
- const Model* model = GetModel(model_buffer_.data());
- ASSERT_TRUE(model != nullptr);
- ASSERT_TRUE(model->datetime_model() != nullptr);
- parser_ = DatetimeParser::Instance(model->datetime_model(), unilib_);
+ classifier_ = TextClassifier::FromUnownedBuffer(
+ model_buffer_.data(), model_buffer_.size(), &unilib_);
+ TC_CHECK(classifier_);
+ parser_ = classifier_->DatetimeParserForTests();
+ }
+
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich") {
+ std::vector<DatetimeParseResultSpan> results;
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ anchor_start_end, &results)) {
+ TC_LOG(ERROR) << text;
+ TC_CHECK(false);
+ }
+ return results.empty();
}
bool ParsesCorrectly(const std::string& marked_text,
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity,
+ bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich") {
auto expected_start_index = marked_text.find("{");
EXPECT_TRUE(expected_start_index != std::string::npos);
@@ -80,7 +93,7 @@
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- &results)) {
+ anchor_start_end, &results)) {
TC_LOG(ERROR) << text;
TC_CHECK(false);
}
@@ -115,7 +128,8 @@
protected:
std::string model_buffer_;
- std::unique_ptr<DatetimeParser> parser_;
+ std::unique_ptr<TextClassifier> classifier_;
+ const DatetimeParser* parser_;
UniLib unilib_;
};
@@ -129,7 +143,7 @@
TEST_F(ParserTest, Parse) {
EXPECT_TRUE(
ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1514847600000, GRANULARITY_DAY));
EXPECT_TRUE(
ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
@@ -205,6 +219,7 @@
EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
"America/Los_Angeles"));
EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_WEEK));
EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_DAY));
@@ -224,7 +239,103 @@
EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY));
}
-// TODO(zilka): Add a test that tests multiple locales.
+TEST_F(ParserTest, ParseWithAnchor) {
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
+ /*anchor_start_end=*/true));
+}
+
+class ParserLocaleTest : public testing::Test {
+ public:
+ void SetUp() override;
+ bool HasResult(const std::string& input, const std::string& locales);
+
+ protected:
+ UniLib unilib_;
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+void AddPattern(const std::string& regex, int locale,
+ std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
+ patterns->emplace_back(new DatetimeModelPatternT);
+ patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
+ patterns->back()->regexes.back()->pattern = regex;
+ patterns->back()->regexes.back()->groups.push_back(
+ DatetimeGroupType_GROUP_UNUSED);
+ patterns->back()->locales.push_back(locale);
+}
+
+void ParserLocaleTest::SetUp() {
+ DatetimeModelT model;
+ model.use_extractors_for_locating = false;
+ model.locales.clear();
+ model.locales.push_back("en-US");
+ model.locales.push_back("en-CH");
+ model.locales.push_back("zh-Hant");
+ model.locales.push_back("en-*");
+ model.locales.push_back("zh-Hant-*");
+ model.locales.push_back("*-CH");
+ model.locales.push_back("");
+
+ AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
+ AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
+ AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
+ AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
+ AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
+
+ builder_.Finish(DatetimeModel::Pack(builder_, &model));
+ const DatetimeModel* model_fb =
+ flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
+ ASSERT_TRUE(model_fb);
+
+ parser_ = DatetimeParser::Instance(model_fb, unilib_,
+ /*decompressor=*/nullptr);
+ ASSERT_TRUE(parser_);
+}
+
+bool ParserLocaleTest::HasResult(const std::string& input,
+ const std::string& locales) {
+ std::vector<DatetimeParseResultSpan> results;
+ EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales,
+ ModeFlag_ANNOTATION, false, &results));
+ return results.size() == 1;
+}
+
+TEST_F(ParserLocaleTest, English) {
+ EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+TEST_F(ParserLocaleTest, TraditionalChinese) {
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
+}
+
+TEST_F(ParserLocaleTest, SwissEnglish) {
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
+ EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
} // namespace
} // namespace libtextclassifier2
diff --git a/feature-processor.cc b/feature-processor.cc
index e4df94f..aa71740 100644
--- a/feature-processor.cc
+++ b/feature-processor.cc
@@ -166,6 +166,7 @@
std::string FeatureProcessor::GetDefaultCollection() const {
if (options_->default_collection() < 0 ||
+ options_->collections() == nullptr ||
options_->default_collection() >= options_->collections()->size()) {
TC_LOG(ERROR)
<< "Invalid or missing default collection. Returning empty string.";
diff --git a/model-executor.cc b/model-executor.cc
index c79056f..69931cb 100644
--- a/model-executor.cc
+++ b/model-executor.cc
@@ -38,44 +38,70 @@
return interpreter;
}
-TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
- const tflite::Model* model_spec, const int embedding_size,
- const int quantization_bits)
- : quantization_bits_(quantization_bits),
- output_embedding_size_(embedding_size) {
- internal::FromModelSpec(model_spec, &model_);
- tflite::InterpreterBuilder(*model_, builtins_)(&interpreter_);
- if (!interpreter_) {
+std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::Instance(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits) {
+ const tflite::Model* model_spec =
+ flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
+ flatbuffers::Verifier verifier(model_spec_buffer->data(),
+ model_spec_buffer->Length());
+ std::unique_ptr<const tflite::FlatBufferModel> model;
+ if (!model_spec->Verify(verifier) ||
+ !internal::FromModelSpec(model_spec, &model)) {
+ TC_LOG(ERROR) << "Could not load TFLite model.";
+ return nullptr;
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+ tflite::InterpreterBuilder(*model, builtins)(&interpreter);
+ if (!interpreter) {
TC_LOG(ERROR) << "Could not build TFLite interpreter for embeddings.";
- return;
+ return nullptr;
}
- if (interpreter_->tensors_size() != 2) {
- return;
+ if (interpreter->tensors_size() != 2) {
+ return nullptr;
}
- embeddings_ = interpreter_->tensor(0);
- if (embeddings_->dims->size != 2) {
- return;
+ const TfLiteTensor* embeddings = interpreter->tensor(0);
+ if (embeddings->dims->size != 2) {
+ return nullptr;
}
- num_buckets_ = embeddings_->dims->data[0];
- scales_ = interpreter_->tensor(1);
- if (scales_->dims->size != 2 || scales_->dims->data[0] != num_buckets_ ||
- scales_->dims->data[1] != 1) {
- return;
+ int num_buckets = embeddings->dims->data[0];
+ const TfLiteTensor* scales = interpreter->tensor(1);
+ if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets ||
+ scales->dims->data[1] != 1) {
+ return nullptr;
}
- bytes_per_embedding_ = embeddings_->dims->data[1];
- if (!CheckQuantizationParams(bytes_per_embedding_, quantization_bits_,
- output_embedding_size_)) {
+ int bytes_per_embedding = embeddings->dims->data[1];
+ if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits,
+ embedding_size)) {
TC_LOG(ERROR) << "Mismatch in quantization parameters.";
- return;
+ return nullptr;
}
- initialized_ = true;
+ return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor(
+ std::move(model), quantization_bits, num_buckets, bytes_per_embedding,
+ embedding_size, scales, embeddings, std::move(interpreter)));
}
+TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model, int quantization_bits,
+ int num_buckets, int bytes_per_embedding, int output_embedding_size,
+ const TfLiteTensor* scales, const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter)
+ : model_(std::move(model)),
+ quantization_bits_(quantization_bits),
+ num_buckets_(num_buckets),
+ bytes_per_embedding_(bytes_per_embedding),
+ output_embedding_size_(output_embedding_size),
+ scales_(scales),
+ embeddings_(embeddings),
+ interpreter_(std::move(interpreter)) {}
+
bool TFLiteEmbeddingExecutor::AddEmbedding(
const TensorView<int>& sparse_features, float* dest, int dest_size) const {
- if (!initialized_ || dest_size != output_embedding_size_) {
+ if (dest_size != output_embedding_size_) {
TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: "
<< dest_size << " " << output_embedding_size_;
return false;
diff --git a/model-executor.h b/model-executor.h
index 547d596..ef6d36f 100644
--- a/model-executor.h
+++ b/model-executor.h
@@ -45,8 +45,25 @@
// Executor for the text selection prediction and classification models.
class ModelExecutor {
public:
- explicit ModelExecutor(const tflite::Model* model_spec) {
- internal::FromModelSpec(model_spec, &model_);
+ static std::unique_ptr<const ModelExecutor> Instance(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
+ const tflite::Model* model =
+ flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
+ flatbuffers::Verifier verifier(model_spec_buffer->data(),
+ model_spec_buffer->Length());
+ if (!model->Verify(verifier)) {
+ return nullptr;
+ }
+ return Instance(model);
+ }
+
+ static std::unique_ptr<const ModelExecutor> Instance(
+ const tflite::Model* model_spec) {
+ std::unique_ptr<const tflite::FlatBufferModel> model;
+ if (!internal::FromModelSpec(model_spec, &model)) {
+ return nullptr;
+ }
+ return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
}
// Creates an Interpreter for the model that serves as a scratch-pad for the
@@ -60,10 +77,13 @@
}
protected:
+ explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
+ : model_(std::move(model)) {}
+
static const int kInputIndexFeatures = 0;
static const int kOutputIndexLogits = 0;
- std::unique_ptr<const tflite::FlatBufferModel> model_ = nullptr;
+ std::unique_ptr<const tflite::FlatBufferModel> model_;
tflite::ops::builtin::BuiltinOpResolver builtins_;
};
@@ -83,27 +103,33 @@
class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
public:
- explicit TFLiteEmbeddingExecutor(const tflite::Model* model_spec,
- int embedding_size, int quantization_bits);
+ static std::unique_ptr<TFLiteEmbeddingExecutor> Instance(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits);
+
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
int dest_size) const override;
- bool IsReady() const override { return initialized_; }
-
protected:
+ explicit TFLiteEmbeddingExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model,
+ int quantization_bits, int num_buckets, int bytes_per_embedding,
+ int output_embedding_size, const TfLiteTensor* scales,
+ const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter);
+
+ std::unique_ptr<const tflite::FlatBufferModel> model_;
+
int quantization_bits_;
- bool initialized_ = false;
int num_buckets_ = -1;
int bytes_per_embedding_ = -1;
int output_embedding_size_ = -1;
const TfLiteTensor* scales_ = nullptr;
const TfLiteTensor* embeddings_ = nullptr;
- std::unique_ptr<const tflite::FlatBufferModel> model_ = nullptr;
// NOTE: This interpreter is used in a read-only way (as a storage for the
// model params), thus is still thread-safe.
- std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
- tflite::ops::builtin::BuiltinOpResolver builtins_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
};
} // namespace libtextclassifier2
diff --git a/model.fbs b/model.fbs
index 590c815..23cf229 100755
--- a/model.fbs
+++ b/model.fbs
@@ -106,6 +106,35 @@
THOUSAND = 72,
}
+namespace libtextclassifier2;
+enum DatetimeGroupType : int {
+ GROUP_UNKNOWN = 0,
+ GROUP_UNUSED = 1,
+ GROUP_YEAR = 2,
+ GROUP_MONTH = 3,
+ GROUP_DAY = 4,
+ GROUP_HOUR = 5,
+ GROUP_MINUTE = 6,
+ GROUP_SECOND = 7,
+ GROUP_AMPM = 8,
+ GROUP_RELATIONDISTANCE = 9,
+ GROUP_RELATION = 10,
+ GROUP_RELATIONTYPE = 11,
+
+ // Dummy groups serve just as an inflator of the selection. E.g. we might want
+ // to select more text than was contained in an envelope of all extractor
+ // spans.
+ GROUP_DUMMY1 = 12,
+
+ GROUP_DUMMY2 = 13,
+}
+
+namespace libtextclassifier2;
+table CompressedBuffer {
+ buffer:[ubyte];
+ uncompressed_size:int;
+}
+
// Options for the model that predicts text selection.
namespace libtextclassifier2;
table SelectionModelOptions {
@@ -121,6 +150,9 @@
// Number of examples to bundle in one batch for inference.
batch_size:int = 1024;
+
+ // Whether to always classify a suggested selection or only on demand.
+ always_classify_suggested_selection:bool = 0;
}
// Options for the model that classifies a text selection.
@@ -130,6 +162,9 @@
phone_min_num_digits:int = 7;
phone_max_num_digits:int = 15;
+
+ // Limits for addresses.
+ address_min_num_tokens:int;
}
// List of regular expression matchers to check.
@@ -155,6 +190,8 @@
// using Find() instead of the true Match(). This approximate matching will
// use the first Find() result and then check that it spans the whole input.
use_approximate_matching:bool = 0;
+
+ compressed_pattern:libtextclassifier2.CompressedBuffer;
}
namespace libtextclassifier2;
@@ -162,10 +199,21 @@
patterns:[libtextclassifier2.RegexModel_.Pattern];
}
+// List of regex patterns.
+namespace libtextclassifier2.DatetimeModelPattern_;
+table Regex {
+ pattern:string;
+
+ // The ith entry specifies the type of the ith capturing group.
+ // This is used to decide how the matched content has to be parsed.
+ groups:[libtextclassifier2.DatetimeGroupType];
+
+ compressed_pattern:libtextclassifier2.CompressedBuffer;
+}
+
namespace libtextclassifier2;
table DatetimeModelPattern {
- // List of regex patterns.
- regexes:[string];
+ regexes:[libtextclassifier2.DatetimeModelPattern_.Regex];
// List of locale indices in DatetimeModel that represent the locales that
// these patterns should be used for. If empty, can be used for all locales.
@@ -186,6 +234,7 @@
extractor:libtextclassifier2.DatetimeExtractorType;
pattern:string;
locales:[int];
+ compressed_pattern:libtextclassifier2.CompressedBuffer;
}
namespace libtextclassifier2;
@@ -212,6 +261,20 @@
enabled_modes:libtextclassifier2.ModeFlag = ALL;
}
+// Options controlling the output of the classifier.
+namespace libtextclassifier2;
+table OutputOptions {
+ // Lists of collection names that will be filtered out at the output:
+ // - For annotation, the spans of given collection are simply dropped.
+ // - For classification, the result is mapped to the class "other".
+ // - For selection, the spans of given class are returned as
+ // single-selection.
+ filtered_collections_annotation:[string];
+
+ filtered_collections_classification:[string];
+ filtered_collections_selection:[string];
+}
+
namespace libtextclassifier2;
table Model {
// Comma-separated list of locales supported by the model as BCP 47 tags.
@@ -244,6 +307,15 @@
// Global switch that controls if SuggestSelection(), ClassifyText() and
// Annotate() will run. If a mode is disabled it returns empty/no-op results.
enabled_modes:libtextclassifier2.ModeFlag = ALL;
+
+ // If true, will snap the selections that consist only of whitespaces to the
+ // containing suggested span. Otherwise, no suggestion is proposed, since the
+ // selections are not part of any token.
+ snap_whitespace_selections:bool = 1;
+
+ // Global configuration for the output of SuggestSelection(), ClassifyText()
+ // and Annotate().
+ output_options:libtextclassifier2.OutputOptions;
}
// Role of the codepoints in the range.
@@ -409,7 +481,7 @@
// If true, the selection classifier output will contain only the selections
// that are feasible (e.g., those that are shorter than max_selection_span),
// if false, the output will be a complete cross-product of possible
- // selections to the left and posible selections to the right, including the
+ // selections to the left and possible selections to the right, including the
// infeasible ones.
// NOTE: Exists mainly for compatibility with older models that were trained
// with the non-reduced output space.
diff --git a/model_generated.h b/model_generated.h
index 21c1b85..ecf08fc 100755
--- a/model_generated.h
+++ b/model_generated.h
@@ -24,6 +24,9 @@
namespace libtextclassifier2 {
+struct CompressedBuffer;
+struct CompressedBufferT;
+
struct SelectionModelOptions;
struct SelectionModelOptionsT;
@@ -40,6 +43,13 @@
struct RegexModel;
struct RegexModelT;
+namespace DatetimeModelPattern_ {
+
+struct Regex;
+struct RegexT;
+
+} // namespace DatetimeModelPattern_
+
struct DatetimeModelPattern;
struct DatetimeModelPatternT;
@@ -52,6 +62,9 @@
struct ModelTriggeringOptions;
struct ModelTriggeringOptionsT;
+struct OutputOptions;
+struct OutputOptionsT;
+
struct Model;
struct ModelT;
@@ -363,6 +376,71 @@
return EnumNamesDatetimeExtractorType()[index];
}
+enum DatetimeGroupType {
+ DatetimeGroupType_GROUP_UNKNOWN = 0,
+ DatetimeGroupType_GROUP_UNUSED = 1,
+ DatetimeGroupType_GROUP_YEAR = 2,
+ DatetimeGroupType_GROUP_MONTH = 3,
+ DatetimeGroupType_GROUP_DAY = 4,
+ DatetimeGroupType_GROUP_HOUR = 5,
+ DatetimeGroupType_GROUP_MINUTE = 6,
+ DatetimeGroupType_GROUP_SECOND = 7,
+ DatetimeGroupType_GROUP_AMPM = 8,
+ DatetimeGroupType_GROUP_RELATIONDISTANCE = 9,
+ DatetimeGroupType_GROUP_RELATION = 10,
+ DatetimeGroupType_GROUP_RELATIONTYPE = 11,
+ DatetimeGroupType_GROUP_DUMMY1 = 12,
+ DatetimeGroupType_GROUP_DUMMY2 = 13,
+ DatetimeGroupType_MIN = DatetimeGroupType_GROUP_UNKNOWN,
+ DatetimeGroupType_MAX = DatetimeGroupType_GROUP_DUMMY2
+};
+
+inline DatetimeGroupType (&EnumValuesDatetimeGroupType())[14] {
+ static DatetimeGroupType values[] = {
+ DatetimeGroupType_GROUP_UNKNOWN,
+ DatetimeGroupType_GROUP_UNUSED,
+ DatetimeGroupType_GROUP_YEAR,
+ DatetimeGroupType_GROUP_MONTH,
+ DatetimeGroupType_GROUP_DAY,
+ DatetimeGroupType_GROUP_HOUR,
+ DatetimeGroupType_GROUP_MINUTE,
+ DatetimeGroupType_GROUP_SECOND,
+ DatetimeGroupType_GROUP_AMPM,
+ DatetimeGroupType_GROUP_RELATIONDISTANCE,
+ DatetimeGroupType_GROUP_RELATION,
+ DatetimeGroupType_GROUP_RELATIONTYPE,
+ DatetimeGroupType_GROUP_DUMMY1,
+ DatetimeGroupType_GROUP_DUMMY2
+ };
+ return values;
+}
+
+inline const char **EnumNamesDatetimeGroupType() {
+ static const char *names[] = {
+ "GROUP_UNKNOWN",
+ "GROUP_UNUSED",
+ "GROUP_YEAR",
+ "GROUP_MONTH",
+ "GROUP_DAY",
+ "GROUP_HOUR",
+ "GROUP_MINUTE",
+ "GROUP_SECOND",
+ "GROUP_AMPM",
+ "GROUP_RELATIONDISTANCE",
+ "GROUP_RELATION",
+ "GROUP_RELATIONTYPE",
+ "GROUP_DUMMY1",
+ "GROUP_DUMMY2",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDatetimeGroupType(DatetimeGroupType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesDatetimeGroupType()[index];
+}
+
namespace TokenizationCodepointRange_ {
enum Role {
@@ -481,15 +559,93 @@
} // namespace FeatureProcessorOptions_
+struct CompressedBufferT : public flatbuffers::NativeTable {
+ typedef CompressedBuffer TableType;
+ std::vector<uint8_t> buffer;
+ int32_t uncompressed_size;
+ CompressedBufferT()
+ : uncompressed_size(0) {
+ }
+};
+
+struct CompressedBuffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CompressedBufferT NativeTableType;
+ enum {
+ VT_BUFFER = 4,
+ VT_UNCOMPRESSED_SIZE = 6
+ };
+ const flatbuffers::Vector<uint8_t> *buffer() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_BUFFER);
+ }
+ int32_t uncompressed_size() const {
+ return GetField<int32_t>(VT_UNCOMPRESSED_SIZE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BUFFER) &&
+ verifier.Verify(buffer()) &&
+ VerifyField<int32_t>(verifier, VT_UNCOMPRESSED_SIZE) &&
+ verifier.EndTable();
+ }
+ CompressedBufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<CompressedBuffer> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CompressedBufferBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_buffer(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer) {
+ fbb_.AddOffset(CompressedBuffer::VT_BUFFER, buffer);
+ }
+ void add_uncompressed_size(int32_t uncompressed_size) {
+ fbb_.AddElement<int32_t>(CompressedBuffer::VT_UNCOMPRESSED_SIZE, uncompressed_size, 0);
+ }
+ explicit CompressedBufferBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CompressedBufferBuilder &operator=(const CompressedBufferBuilder &);
+ flatbuffers::Offset<CompressedBuffer> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CompressedBuffer>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer = 0,
+ int32_t uncompressed_size = 0) {
+ CompressedBufferBuilder builder_(_fbb);
+ builder_.add_uncompressed_size(uncompressed_size);
+ builder_.add_buffer(buffer);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBufferDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint8_t> *buffer = nullptr,
+ int32_t uncompressed_size = 0) {
+ return libtextclassifier2::CreateCompressedBuffer(
+ _fbb,
+ buffer ? _fbb.CreateVector<uint8_t>(*buffer) : 0,
+ uncompressed_size);
+}
+
+flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct SelectionModelOptionsT : public flatbuffers::NativeTable {
typedef SelectionModelOptions TableType;
bool strip_unpaired_brackets;
int32_t symmetry_context_size;
int32_t batch_size;
+ bool always_classify_suggested_selection;
SelectionModelOptionsT()
: strip_unpaired_brackets(true),
symmetry_context_size(0),
- batch_size(1024) {
+ batch_size(1024),
+ always_classify_suggested_selection(false) {
}
};
@@ -498,7 +654,8 @@
enum {
VT_STRIP_UNPAIRED_BRACKETS = 4,
VT_SYMMETRY_CONTEXT_SIZE = 6,
- VT_BATCH_SIZE = 8
+ VT_BATCH_SIZE = 8,
+ VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION = 10
};
bool strip_unpaired_brackets() const {
return GetField<uint8_t>(VT_STRIP_UNPAIRED_BRACKETS, 1) != 0;
@@ -509,11 +666,15 @@
int32_t batch_size() const {
return GetField<int32_t>(VT_BATCH_SIZE, 1024);
}
+ bool always_classify_suggested_selection() const {
+ return GetField<uint8_t>(VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_STRIP_UNPAIRED_BRACKETS) &&
VerifyField<int32_t>(verifier, VT_SYMMETRY_CONTEXT_SIZE) &&
VerifyField<int32_t>(verifier, VT_BATCH_SIZE) &&
+ VerifyField<uint8_t>(verifier, VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION) &&
verifier.EndTable();
}
SelectionModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -533,6 +694,9 @@
void add_batch_size(int32_t batch_size) {
fbb_.AddElement<int32_t>(SelectionModelOptions::VT_BATCH_SIZE, batch_size, 1024);
}
+ void add_always_classify_suggested_selection(bool always_classify_suggested_selection) {
+ fbb_.AddElement<uint8_t>(SelectionModelOptions::VT_ALWAYS_CLASSIFY_SUGGESTED_SELECTION, static_cast<uint8_t>(always_classify_suggested_selection), 0);
+ }
explicit SelectionModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -549,10 +713,12 @@
flatbuffers::FlatBufferBuilder &_fbb,
bool strip_unpaired_brackets = true,
int32_t symmetry_context_size = 0,
- int32_t batch_size = 1024) {
+ int32_t batch_size = 1024,
+ bool always_classify_suggested_selection = false) {
SelectionModelOptionsBuilder builder_(_fbb);
builder_.add_batch_size(batch_size);
builder_.add_symmetry_context_size(symmetry_context_size);
+ builder_.add_always_classify_suggested_selection(always_classify_suggested_selection);
builder_.add_strip_unpaired_brackets(strip_unpaired_brackets);
return builder_.Finish();
}
@@ -563,9 +729,11 @@
typedef ClassificationModelOptions TableType;
int32_t phone_min_num_digits;
int32_t phone_max_num_digits;
+ int32_t address_min_num_tokens;
ClassificationModelOptionsT()
: phone_min_num_digits(7),
- phone_max_num_digits(15) {
+ phone_max_num_digits(15),
+ address_min_num_tokens(0) {
}
};
@@ -573,7 +741,8 @@
typedef ClassificationModelOptionsT NativeTableType;
enum {
VT_PHONE_MIN_NUM_DIGITS = 4,
- VT_PHONE_MAX_NUM_DIGITS = 6
+ VT_PHONE_MAX_NUM_DIGITS = 6,
+ VT_ADDRESS_MIN_NUM_TOKENS = 8
};
int32_t phone_min_num_digits() const {
return GetField<int32_t>(VT_PHONE_MIN_NUM_DIGITS, 7);
@@ -581,10 +750,14 @@
int32_t phone_max_num_digits() const {
return GetField<int32_t>(VT_PHONE_MAX_NUM_DIGITS, 15);
}
+ int32_t address_min_num_tokens() const {
+ return GetField<int32_t>(VT_ADDRESS_MIN_NUM_TOKENS, 0);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_PHONE_MIN_NUM_DIGITS) &&
VerifyField<int32_t>(verifier, VT_PHONE_MAX_NUM_DIGITS) &&
+ VerifyField<int32_t>(verifier, VT_ADDRESS_MIN_NUM_TOKENS) &&
verifier.EndTable();
}
ClassificationModelOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -601,6 +774,9 @@
void add_phone_max_num_digits(int32_t phone_max_num_digits) {
fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_PHONE_MAX_NUM_DIGITS, phone_max_num_digits, 15);
}
+ void add_address_min_num_tokens(int32_t address_min_num_tokens) {
+ fbb_.AddElement<int32_t>(ClassificationModelOptions::VT_ADDRESS_MIN_NUM_TOKENS, address_min_num_tokens, 0);
+ }
explicit ClassificationModelOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -616,8 +792,10 @@
inline flatbuffers::Offset<ClassificationModelOptions> CreateClassificationModelOptions(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t phone_min_num_digits = 7,
- int32_t phone_max_num_digits = 15) {
+ int32_t phone_max_num_digits = 15,
+ int32_t address_min_num_tokens = 0) {
ClassificationModelOptionsBuilder builder_(_fbb);
+ builder_.add_address_min_num_tokens(address_min_num_tokens);
builder_.add_phone_max_num_digits(phone_max_num_digits);
builder_.add_phone_min_num_digits(phone_min_num_digits);
return builder_.Finish();
@@ -635,6 +813,7 @@
float target_classification_score;
float priority_score;
bool use_approximate_matching;
+ std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern;
PatternT()
: enabled_modes(libtextclassifier2::ModeFlag_ALL),
target_classification_score(1.0f),
@@ -651,7 +830,8 @@
VT_ENABLED_MODES = 8,
VT_TARGET_CLASSIFICATION_SCORE = 10,
VT_PRIORITY_SCORE = 12,
- VT_USE_APPROXIMATE_MATCHING = 14
+ VT_USE_APPROXIMATE_MATCHING = 14,
+ VT_COMPRESSED_PATTERN = 16
};
const flatbuffers::String *collection_name() const {
return GetPointer<const flatbuffers::String *>(VT_COLLECTION_NAME);
@@ -671,6 +851,9 @@
bool use_approximate_matching() const {
return GetField<uint8_t>(VT_USE_APPROXIMATE_MATCHING, 0) != 0;
}
+ const libtextclassifier2::CompressedBuffer *compressed_pattern() const {
+ return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_COLLECTION_NAME) &&
@@ -681,6 +864,8 @@
VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
VerifyField<float>(verifier, VT_PRIORITY_SCORE) &&
VerifyField<uint8_t>(verifier, VT_USE_APPROXIMATE_MATCHING) &&
+ VerifyOffset(verifier, VT_COMPRESSED_PATTERN) &&
+ verifier.VerifyTable(compressed_pattern()) &&
verifier.EndTable();
}
PatternT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -709,6 +894,9 @@
void add_use_approximate_matching(bool use_approximate_matching) {
fbb_.AddElement<uint8_t>(Pattern::VT_USE_APPROXIMATE_MATCHING, static_cast<uint8_t>(use_approximate_matching), 0);
}
+ void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) {
+ fbb_.AddOffset(Pattern::VT_COMPRESSED_PATTERN, compressed_pattern);
+ }
explicit PatternBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -728,8 +916,10 @@
libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL,
float target_classification_score = 1.0f,
float priority_score = 0.0f,
- bool use_approximate_matching = false) {
+ bool use_approximate_matching = false,
+ flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) {
PatternBuilder builder_(_fbb);
+ builder_.add_compressed_pattern(compressed_pattern);
builder_.add_priority_score(priority_score);
builder_.add_target_classification_score(target_classification_score);
builder_.add_enabled_modes(enabled_modes);
@@ -746,7 +936,8 @@
libtextclassifier2::ModeFlag enabled_modes = libtextclassifier2::ModeFlag_ALL,
float target_classification_score = 1.0f,
float priority_score = 0.0f,
- bool use_approximate_matching = false) {
+ bool use_approximate_matching = false,
+ flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) {
return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
collection_name ? _fbb.CreateString(collection_name) : 0,
@@ -754,7 +945,8 @@
enabled_modes,
target_classification_score,
priority_score,
- use_approximate_matching);
+ use_approximate_matching,
+ compressed_pattern);
}
flatbuffers::Offset<Pattern> CreatePattern(flatbuffers::FlatBufferBuilder &_fbb, const PatternT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -824,9 +1016,103 @@
flatbuffers::Offset<RegexModel> CreateRegexModel(flatbuffers::FlatBufferBuilder &_fbb, const RegexModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+namespace DatetimeModelPattern_ {
+
+struct RegexT : public flatbuffers::NativeTable {
+ typedef Regex TableType;
+ std::string pattern;
+ std::vector<libtextclassifier2::DatetimeGroupType> groups;
+ std::unique_ptr<libtextclassifier2::CompressedBufferT> compressed_pattern;
+ RegexT() {
+ }
+};
+
+struct Regex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RegexT NativeTableType;
+ enum {
+ VT_PATTERN = 4,
+ VT_GROUPS = 6,
+ VT_COMPRESSED_PATTERN = 8
+ };
+ const flatbuffers::String *pattern() const {
+ return GetPointer<const flatbuffers::String *>(VT_PATTERN);
+ }
+ const flatbuffers::Vector<int32_t> *groups() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_GROUPS);
+ }
+ const libtextclassifier2::CompressedBuffer *compressed_pattern() const {
+ return GetPointer<const libtextclassifier2::CompressedBuffer *>(VT_COMPRESSED_PATTERN);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PATTERN) &&
+ verifier.Verify(pattern()) &&
+ VerifyOffset(verifier, VT_GROUPS) &&
+ verifier.Verify(groups()) &&
+ VerifyOffset(verifier, VT_COMPRESSED_PATTERN) &&
+ verifier.VerifyTable(compressed_pattern()) &&
+ verifier.EndTable();
+ }
+ RegexT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<Regex> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct RegexBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_pattern(flatbuffers::Offset<flatbuffers::String> pattern) {
+ fbb_.AddOffset(Regex::VT_PATTERN, pattern);
+ }
+ void add_groups(flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups) {
+ fbb_.AddOffset(Regex::VT_GROUPS, groups);
+ }
+ void add_compressed_pattern(flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern) {
+ fbb_.AddOffset(Regex::VT_COMPRESSED_PATTERN, compressed_pattern);
+ }
+ explicit RegexBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RegexBuilder &operator=(const RegexBuilder &);
+ flatbuffers::Offset<Regex> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Regex>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Regex> CreateRegex(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> pattern = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> groups = 0,
+ flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) {
+ RegexBuilder builder_(_fbb);
+ builder_.add_compressed_pattern(compressed_pattern);
+ builder_.add_groups(groups);
+ builder_.add_pattern(pattern);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Regex> CreateRegexDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *pattern = nullptr,
+ const std::vector<int32_t> *groups = nullptr,
+ flatbuffers::Offset<libtextclassifier2::CompressedBuffer> compressed_pattern = 0) {
+ return libtextclassifier2::DatetimeModelPattern_::CreateRegex(
+ _fbb,
+ pattern ? _fbb.CreateString(pattern) : 0,
+ groups ? _fbb.CreateVector<int32_t>(*groups) : 0,
+ compressed_pattern);
+}
+
+flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+} // namespace DatetimeModelPattern_
+
struct DatetimeModelPatternT : public flatbuffers::NativeTable {
typedef DatetimeModelPattern TableType;
- std::vector<std::string> regexes;
+ std::vector<std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>> regexes;
std::vector<int32_t> locales;
float target_classification_score;
float priority_score;
@@ -847,8 +1133,8 @@
VT_PRIORITY_SCORE = 10,
VT_ENABLED_MODES = 12
};
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *regexes() const {
- return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_REGEXES);
+ const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *>(VT_REGEXES);
}
const flatbuffers::Vector<int32_t> *locales() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES);
@@ -866,7 +1152,7 @@
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_REGEXES) &&
verifier.Verify(regexes()) &&
- verifier.VerifyVectorOfStrings(regexes()) &&
+ verifier.VerifyVectorOfTables(regexes()) &&
VerifyOffset(verifier, VT_LOCALES) &&
verifier.Verify(locales()) &&
VerifyField<float>(verifier, VT_TARGET_CLASSIFICATION_SCORE) &&
@@ -882,7 +1168,7 @@
struct DatetimeModelPatternBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_regexes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexes) {
+ void add_regexes(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes) {
fbb_.AddOffset(DatetimeModelPattern::VT_REGEXES, regexes);
}
void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) {
@@ -911,7 +1197,7 @@
inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPattern(
flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> regexes = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>> regexes = 0,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0,
float target_classification_score = 1.0f,
float priority_score = 0.0f,
@@ -927,14 +1213,14 @@
inline flatbuffers::Offset<DatetimeModelPattern> CreateDatetimeModelPatternDirect(
flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<flatbuffers::Offset<flatbuffers::String>> *regexes = nullptr,
+ const std::vector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> *regexes = nullptr,
const std::vector<int32_t> *locales = nullptr,
float target_classification_score = 1.0f,
float priority_score = 0.0f,
ModeFlag enabled_modes = ModeFlag_ALL) {
return libtextclassifier2::CreateDatetimeModelPattern(
_fbb,
- regexes ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*regexes) : 0,
+ regexes ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>>(*regexes) : 0,
locales ? _fbb.CreateVector<int32_t>(*locales) : 0,
target_classification_score,
priority_score,
@@ -948,6 +1234,7 @@
DatetimeExtractorType extractor;
std::string pattern;
std::vector<int32_t> locales;
+ std::unique_ptr<CompressedBufferT> compressed_pattern;
DatetimeModelExtractorT()
: extractor(DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE) {
}
@@ -958,7 +1245,8 @@
enum {
VT_EXTRACTOR = 4,
VT_PATTERN = 6,
- VT_LOCALES = 8
+ VT_LOCALES = 8,
+ VT_COMPRESSED_PATTERN = 10
};
DatetimeExtractorType extractor() const {
return static_cast<DatetimeExtractorType>(GetField<int32_t>(VT_EXTRACTOR, 0));
@@ -969,6 +1257,9 @@
const flatbuffers::Vector<int32_t> *locales() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_LOCALES);
}
+ const CompressedBuffer *compressed_pattern() const {
+ return GetPointer<const CompressedBuffer *>(VT_COMPRESSED_PATTERN);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_EXTRACTOR) &&
@@ -976,6 +1267,8 @@
verifier.Verify(pattern()) &&
VerifyOffset(verifier, VT_LOCALES) &&
verifier.Verify(locales()) &&
+ VerifyOffset(verifier, VT_COMPRESSED_PATTERN) &&
+ verifier.VerifyTable(compressed_pattern()) &&
verifier.EndTable();
}
DatetimeModelExtractorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -995,6 +1288,9 @@
void add_locales(flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales) {
fbb_.AddOffset(DatetimeModelExtractor::VT_LOCALES, locales);
}
+ void add_compressed_pattern(flatbuffers::Offset<CompressedBuffer> compressed_pattern) {
+ fbb_.AddOffset(DatetimeModelExtractor::VT_COMPRESSED_PATTERN, compressed_pattern);
+ }
explicit DatetimeModelExtractorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1011,8 +1307,10 @@
flatbuffers::FlatBufferBuilder &_fbb,
DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
flatbuffers::Offset<flatbuffers::String> pattern = 0,
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0) {
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> locales = 0,
+ flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) {
DatetimeModelExtractorBuilder builder_(_fbb);
+ builder_.add_compressed_pattern(compressed_pattern);
builder_.add_locales(locales);
builder_.add_pattern(pattern);
builder_.add_extractor(extractor);
@@ -1023,12 +1321,14 @@
flatbuffers::FlatBufferBuilder &_fbb,
DatetimeExtractorType extractor = DatetimeExtractorType_UNKNOWN_DATETIME_EXTRACTOR_TYPE,
const char *pattern = nullptr,
- const std::vector<int32_t> *locales = nullptr) {
+ const std::vector<int32_t> *locales = nullptr,
+ flatbuffers::Offset<CompressedBuffer> compressed_pattern = 0) {
return libtextclassifier2::CreateDatetimeModelExtractor(
_fbb,
extractor,
pattern ? _fbb.CreateString(pattern) : 0,
- locales ? _fbb.CreateVector<int32_t>(*locales) : 0);
+ locales ? _fbb.CreateVector<int32_t>(*locales) : 0,
+ compressed_pattern);
}
flatbuffers::Offset<DatetimeModelExtractor> CreateDatetimeModelExtractor(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -1206,6 +1506,99 @@
flatbuffers::Offset<ModelTriggeringOptions> CreateModelTriggeringOptions(flatbuffers::FlatBufferBuilder &_fbb, const ModelTriggeringOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct OutputOptionsT : public flatbuffers::NativeTable {
+ typedef OutputOptions TableType;
+ std::vector<std::string> filtered_collections_annotation;
+ std::vector<std::string> filtered_collections_classification;
+ std::vector<std::string> filtered_collections_selection;
+ OutputOptionsT() {
+ }
+};
+
+struct OutputOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef OutputOptionsT NativeTableType;
+ enum {
+ VT_FILTERED_COLLECTIONS_ANNOTATION = 4,
+ VT_FILTERED_COLLECTIONS_CLASSIFICATION = 6,
+ VT_FILTERED_COLLECTIONS_SELECTION = 8
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_ANNOTATION);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_CLASSIFICATION);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_FILTERED_COLLECTIONS_SELECTION);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_ANNOTATION) &&
+ verifier.Verify(filtered_collections_annotation()) &&
+ verifier.VerifyVectorOfStrings(filtered_collections_annotation()) &&
+ VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_CLASSIFICATION) &&
+ verifier.Verify(filtered_collections_classification()) &&
+ verifier.VerifyVectorOfStrings(filtered_collections_classification()) &&
+ VerifyOffset(verifier, VT_FILTERED_COLLECTIONS_SELECTION) &&
+ verifier.Verify(filtered_collections_selection()) &&
+ verifier.VerifyVectorOfStrings(filtered_collections_selection()) &&
+ verifier.EndTable();
+ }
+ OutputOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<OutputOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct OutputOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_filtered_collections_annotation(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation) {
+ fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_ANNOTATION, filtered_collections_annotation);
+ }
+ void add_filtered_collections_classification(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification) {
+ fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_CLASSIFICATION, filtered_collections_classification);
+ }
+ void add_filtered_collections_selection(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection) {
+ fbb_.AddOffset(OutputOptions::VT_FILTERED_COLLECTIONS_SELECTION, filtered_collections_selection);
+ }
+ explicit OutputOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ OutputOptionsBuilder &operator=(const OutputOptionsBuilder &);
+ flatbuffers::Offset<OutputOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<OutputOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<OutputOptions> CreateOutputOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_annotation = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_classification = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> filtered_collections_selection = 0) {
+ OutputOptionsBuilder builder_(_fbb);
+ builder_.add_filtered_collections_selection(filtered_collections_selection);
+ builder_.add_filtered_collections_classification(filtered_collections_classification);
+ builder_.add_filtered_collections_annotation(filtered_collections_annotation);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<OutputOptions> CreateOutputOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_annotation = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_classification = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *filtered_collections_selection = nullptr) {
+ return libtextclassifier2::CreateOutputOptions(
+ _fbb,
+ filtered_collections_annotation ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_annotation) : 0,
+ filtered_collections_classification ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_classification) : 0,
+ filtered_collections_selection ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*filtered_collections_selection) : 0);
+}
+
+flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ModelT : public flatbuffers::NativeTable {
typedef Model TableType;
std::string locales;
@@ -1222,9 +1615,12 @@
std::unique_ptr<DatetimeModelT> datetime_model;
std::unique_ptr<ModelTriggeringOptionsT> triggering_options;
ModeFlag enabled_modes;
+ bool snap_whitespace_selections;
+ std::unique_ptr<OutputOptionsT> output_options;
ModelT()
: version(0),
- enabled_modes(ModeFlag_ALL) {
+ enabled_modes(ModeFlag_ALL),
+ snap_whitespace_selections(true) {
}
};
@@ -1244,7 +1640,9 @@
VT_REGEX_MODEL = 24,
VT_DATETIME_MODEL = 26,
VT_TRIGGERING_OPTIONS = 28,
- VT_ENABLED_MODES = 30
+ VT_ENABLED_MODES = 30,
+ VT_SNAP_WHITESPACE_SELECTIONS = 32,
+ VT_OUTPUT_OPTIONS = 34
};
const flatbuffers::String *locales() const {
return GetPointer<const flatbuffers::String *>(VT_LOCALES);
@@ -1288,6 +1686,12 @@
ModeFlag enabled_modes() const {
return static_cast<ModeFlag>(GetField<int32_t>(VT_ENABLED_MODES, 7));
}
+ bool snap_whitespace_selections() const {
+ return GetField<uint8_t>(VT_SNAP_WHITESPACE_SELECTIONS, 1) != 0;
+ }
+ const OutputOptions *output_options() const {
+ return GetPointer<const OutputOptions *>(VT_OUTPUT_OPTIONS);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_LOCALES) &&
@@ -1316,6 +1720,9 @@
VerifyOffset(verifier, VT_TRIGGERING_OPTIONS) &&
verifier.VerifyTable(triggering_options()) &&
VerifyField<int32_t>(verifier, VT_ENABLED_MODES) &&
+ VerifyField<uint8_t>(verifier, VT_SNAP_WHITESPACE_SELECTIONS) &&
+ VerifyOffset(verifier, VT_OUTPUT_OPTIONS) &&
+ verifier.VerifyTable(output_options()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1368,6 +1775,12 @@
void add_enabled_modes(ModeFlag enabled_modes) {
fbb_.AddElement<int32_t>(Model::VT_ENABLED_MODES, static_cast<int32_t>(enabled_modes), 7);
}
+ void add_snap_whitespace_selections(bool snap_whitespace_selections) {
+ fbb_.AddElement<uint8_t>(Model::VT_SNAP_WHITESPACE_SELECTIONS, static_cast<uint8_t>(snap_whitespace_selections), 1);
+ }
+ void add_output_options(flatbuffers::Offset<OutputOptions> output_options) {
+ fbb_.AddOffset(Model::VT_OUTPUT_OPTIONS, output_options);
+ }
explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1395,8 +1808,11 @@
flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<DatetimeModel> datetime_model = 0,
flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0,
- ModeFlag enabled_modes = ModeFlag_ALL) {
+ ModeFlag enabled_modes = ModeFlag_ALL,
+ bool snap_whitespace_selections = true,
+ flatbuffers::Offset<OutputOptions> output_options = 0) {
ModelBuilder builder_(_fbb);
+ builder_.add_output_options(output_options);
builder_.add_enabled_modes(enabled_modes);
builder_.add_triggering_options(triggering_options);
builder_.add_datetime_model(datetime_model);
@@ -1411,6 +1827,7 @@
builder_.add_name(name);
builder_.add_version(version);
builder_.add_locales(locales);
+ builder_.add_snap_whitespace_selections(snap_whitespace_selections);
return builder_.Finish();
}
@@ -1429,7 +1846,9 @@
flatbuffers::Offset<RegexModel> regex_model = 0,
flatbuffers::Offset<DatetimeModel> datetime_model = 0,
flatbuffers::Offset<ModelTriggeringOptions> triggering_options = 0,
- ModeFlag enabled_modes = ModeFlag_ALL) {
+ ModeFlag enabled_modes = ModeFlag_ALL,
+ bool snap_whitespace_selections = true,
+ flatbuffers::Offset<OutputOptions> output_options = 0) {
return libtextclassifier2::CreateModel(
_fbb,
locales ? _fbb.CreateString(locales) : 0,
@@ -1445,7 +1864,9 @@
regex_model,
datetime_model,
triggering_options,
- enabled_modes);
+ enabled_modes,
+ snap_whitespace_selections,
+ output_options);
}
flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -2312,6 +2733,35 @@
flatbuffers::Offset<FeatureProcessorOptions> CreateFeatureProcessorOptions(flatbuffers::FlatBufferBuilder &_fbb, const FeatureProcessorOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+inline CompressedBufferT *CompressedBuffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new CompressedBufferT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void CompressedBuffer::UnPackTo(CompressedBufferT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = buffer(); if (_e) { _o->buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffer[_i] = _e->Get(_i); } } };
+ { auto _e = uncompressed_size(); _o->uncompressed_size = _e; };
+}
+
+inline flatbuffers::Offset<CompressedBuffer> CompressedBuffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateCompressedBuffer(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CompressedBuffer> CreateCompressedBuffer(flatbuffers::FlatBufferBuilder &_fbb, const CompressedBufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CompressedBufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _buffer = _o->buffer.size() ? _fbb.CreateVector(_o->buffer) : 0;
+ auto _uncompressed_size = _o->uncompressed_size;
+ return libtextclassifier2::CreateCompressedBuffer(
+ _fbb,
+ _buffer,
+ _uncompressed_size);
+}
+
inline SelectionModelOptionsT *SelectionModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new SelectionModelOptionsT();
UnPackTo(_o, _resolver);
@@ -2324,6 +2774,7 @@
{ auto _e = strip_unpaired_brackets(); _o->strip_unpaired_brackets = _e; };
{ auto _e = symmetry_context_size(); _o->symmetry_context_size = _e; };
{ auto _e = batch_size(); _o->batch_size = _e; };
+ { auto _e = always_classify_suggested_selection(); _o->always_classify_suggested_selection = _e; };
}
inline flatbuffers::Offset<SelectionModelOptions> SelectionModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectionModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2337,11 +2788,13 @@
auto _strip_unpaired_brackets = _o->strip_unpaired_brackets;
auto _symmetry_context_size = _o->symmetry_context_size;
auto _batch_size = _o->batch_size;
+ auto _always_classify_suggested_selection = _o->always_classify_suggested_selection;
return libtextclassifier2::CreateSelectionModelOptions(
_fbb,
_strip_unpaired_brackets,
_symmetry_context_size,
- _batch_size);
+ _batch_size,
+ _always_classify_suggested_selection);
}
inline ClassificationModelOptionsT *ClassificationModelOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2355,6 +2808,7 @@
(void)_resolver;
{ auto _e = phone_min_num_digits(); _o->phone_min_num_digits = _e; };
{ auto _e = phone_max_num_digits(); _o->phone_max_num_digits = _e; };
+ { auto _e = address_min_num_tokens(); _o->address_min_num_tokens = _e; };
}
inline flatbuffers::Offset<ClassificationModelOptions> ClassificationModelOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ClassificationModelOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2367,10 +2821,12 @@
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ClassificationModelOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _phone_min_num_digits = _o->phone_min_num_digits;
auto _phone_max_num_digits = _o->phone_max_num_digits;
+ auto _address_min_num_tokens = _o->address_min_num_tokens;
return libtextclassifier2::CreateClassificationModelOptions(
_fbb,
_phone_min_num_digits,
- _phone_max_num_digits);
+ _phone_max_num_digits,
+ _address_min_num_tokens);
}
namespace RegexModel_ {
@@ -2390,6 +2846,7 @@
{ auto _e = target_classification_score(); _o->target_classification_score = _e; };
{ auto _e = priority_score(); _o->priority_score = _e; };
{ auto _e = use_approximate_matching(); _o->use_approximate_matching = _e; };
+ { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); };
}
inline flatbuffers::Offset<Pattern> Pattern::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PatternT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2406,6 +2863,7 @@
auto _target_classification_score = _o->target_classification_score;
auto _priority_score = _o->priority_score;
auto _use_approximate_matching = _o->use_approximate_matching;
+ auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0;
return libtextclassifier2::RegexModel_::CreatePattern(
_fbb,
_collection_name,
@@ -2413,7 +2871,8 @@
_enabled_modes,
_target_classification_score,
_priority_score,
- _use_approximate_matching);
+ _use_approximate_matching,
+ _compressed_pattern);
}
} // namespace RegexModel_
@@ -2444,6 +2903,42 @@
_patterns);
}
+namespace DatetimeModelPattern_ {
+
+inline RegexT *Regex::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new RegexT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void Regex::UnPackTo(RegexT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
+ { auto _e = groups(); if (_e) { _o->groups.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->groups[_i] = (DatetimeGroupType)_e->Get(_i); } } };
+ { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<libtextclassifier2::CompressedBufferT>(_e->UnPack(_resolver)); };
+}
+
+inline flatbuffers::Offset<Regex> Regex::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RegexT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateRegex(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<Regex> CreateRegex(flatbuffers::FlatBufferBuilder &_fbb, const RegexT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RegexT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
+ auto _groups = _o->groups.size() ? _fbb.CreateVector((const int32_t*)_o->groups.data(), _o->groups.size()) : 0;
+ auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0;
+ return libtextclassifier2::DatetimeModelPattern_::CreateRegex(
+ _fbb,
+ _pattern,
+ _groups,
+ _compressed_pattern);
+}
+
+} // namespace DatetimeModelPattern_
+
inline DatetimeModelPatternT *DatetimeModelPattern::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new DatetimeModelPatternT();
UnPackTo(_o, _resolver);
@@ -2453,7 +2948,7 @@
inline void DatetimeModelPattern::UnPackTo(DatetimeModelPatternT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
- { auto _e = regexes(); if (_e) { _o->regexes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexes[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = regexes(); if (_e) { _o->regexes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->regexes[_i] = std::unique_ptr<libtextclassifier2::DatetimeModelPattern_::RegexT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } };
{ auto _e = target_classification_score(); _o->target_classification_score = _e; };
{ auto _e = priority_score(); _o->priority_score = _e; };
@@ -2468,7 +2963,7 @@
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DatetimeModelPatternT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- auto _regexes = _o->regexes.size() ? _fbb.CreateVectorOfStrings(_o->regexes) : 0;
+ auto _regexes = _o->regexes.size() ? _fbb.CreateVector<flatbuffers::Offset<libtextclassifier2::DatetimeModelPattern_::Regex>> (_o->regexes.size(), [](size_t i, _VectorArgs *__va) { return CreateRegex(*__va->__fbb, __va->__o->regexes[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0;
auto _target_classification_score = _o->target_classification_score;
auto _priority_score = _o->priority_score;
@@ -2494,6 +2989,7 @@
{ auto _e = extractor(); _o->extractor = _e; };
{ auto _e = pattern(); if (_e) _o->pattern = _e->str(); };
{ auto _e = locales(); if (_e) { _o->locales.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->locales[_i] = _e->Get(_i); } } };
+ { auto _e = compressed_pattern(); if (_e) _o->compressed_pattern = std::unique_ptr<CompressedBufferT>(_e->UnPack(_resolver)); };
}
inline flatbuffers::Offset<DatetimeModelExtractor> DatetimeModelExtractor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DatetimeModelExtractorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2507,11 +3003,13 @@
auto _extractor = _o->extractor;
auto _pattern = _o->pattern.empty() ? 0 : _fbb.CreateString(_o->pattern);
auto _locales = _o->locales.size() ? _fbb.CreateVector(_o->locales) : 0;
+ auto _compressed_pattern = _o->compressed_pattern ? CreateCompressedBuffer(_fbb, _o->compressed_pattern.get(), _rehasher) : 0;
return libtextclassifier2::CreateDatetimeModelExtractor(
_fbb,
_extractor,
_pattern,
- _locales);
+ _locales,
+ _compressed_pattern);
}
inline DatetimeModelT *DatetimeModel::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -2578,6 +3076,38 @@
_enabled_modes);
}
+inline OutputOptionsT *OutputOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new OutputOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void OutputOptions::UnPackTo(OutputOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = filtered_collections_annotation(); if (_e) { _o->filtered_collections_annotation.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_annotation[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = filtered_collections_classification(); if (_e) { _o->filtered_collections_classification.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_classification[_i] = _e->Get(_i)->str(); } } };
+ { auto _e = filtered_collections_selection(); if (_e) { _o->filtered_collections_selection.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->filtered_collections_selection[_i] = _e->Get(_i)->str(); } } };
+}
+
+inline flatbuffers::Offset<OutputOptions> OutputOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateOutputOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<OutputOptions> CreateOutputOptions(flatbuffers::FlatBufferBuilder &_fbb, const OutputOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OutputOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _filtered_collections_annotation = _o->filtered_collections_annotation.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_annotation) : 0;
+ auto _filtered_collections_classification = _o->filtered_collections_classification.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_classification) : 0;
+ auto _filtered_collections_selection = _o->filtered_collections_selection.size() ? _fbb.CreateVectorOfStrings(_o->filtered_collections_selection) : 0;
+ return libtextclassifier2::CreateOutputOptions(
+ _fbb,
+ _filtered_collections_annotation,
+ _filtered_collections_classification,
+ _filtered_collections_selection);
+}
+
inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ModelT();
UnPackTo(_o, _resolver);
@@ -2601,6 +3131,8 @@
{ auto _e = datetime_model(); if (_e) _o->datetime_model = std::unique_ptr<DatetimeModelT>(_e->UnPack(_resolver)); };
{ auto _e = triggering_options(); if (_e) _o->triggering_options = std::unique_ptr<ModelTriggeringOptionsT>(_e->UnPack(_resolver)); };
{ auto _e = enabled_modes(); _o->enabled_modes = _e; };
+ { auto _e = snap_whitespace_selections(); _o->snap_whitespace_selections = _e; };
+ { auto _e = output_options(); if (_e) _o->output_options = std::unique_ptr<OutputOptionsT>(_e->UnPack(_resolver)); };
}
inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -2625,6 +3157,8 @@
auto _datetime_model = _o->datetime_model ? CreateDatetimeModel(_fbb, _o->datetime_model.get(), _rehasher) : 0;
auto _triggering_options = _o->triggering_options ? CreateModelTriggeringOptions(_fbb, _o->triggering_options.get(), _rehasher) : 0;
auto _enabled_modes = _o->enabled_modes;
+ auto _snap_whitespace_selections = _o->snap_whitespace_selections;
+ auto _output_options = _o->output_options ? CreateOutputOptions(_fbb, _o->output_options.get(), _rehasher) : 0;
return libtextclassifier2::CreateModel(
_fbb,
_locales,
@@ -2640,7 +3174,9 @@
_regex_model,
_datetime_model,
_triggering_options,
- _enabled_modes);
+ _enabled_modes,
+ _snap_whitespace_selections,
+ _output_options);
}
inline TokenizationCodepointRangeT *TokenizationCodepointRange::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model
new file mode 100644
index 0000000..39e7ea2
--- /dev/null
+++ b/models/textclassifier.ar.model
Binary files differ
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index 0452c1e..04f90b7 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model
new file mode 100644
index 0000000..bc79119
--- /dev/null
+++ b/models/textclassifier.es.model
Binary files differ
diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model
new file mode 100644
index 0000000..768968e
--- /dev/null
+++ b/models/textclassifier.fr.model
Binary files differ
diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model
new file mode 100644
index 0000000..823d02b
--- /dev/null
+++ b/models/textclassifier.it.model
Binary files differ
diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model
new file mode 100644
index 0000000..c65b9b0
--- /dev/null
+++ b/models/textclassifier.ja.model
Binary files differ
diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model
new file mode 100644
index 0000000..0c12ebe
--- /dev/null
+++ b/models/textclassifier.ko.model
Binary files differ
diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model
new file mode 100644
index 0000000..a4aedb5
--- /dev/null
+++ b/models/textclassifier.nl.model
Binary files differ
diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model
new file mode 100644
index 0000000..6797f93
--- /dev/null
+++ b/models/textclassifier.pl.model
Binary files differ
diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model
new file mode 100644
index 0000000..39fa301
--- /dev/null
+++ b/models/textclassifier.pt.model
Binary files differ
diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model
new file mode 100644
index 0000000..a824d4c
--- /dev/null
+++ b/models/textclassifier.ru.model
Binary files differ
diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model
new file mode 100644
index 0000000..5430511
--- /dev/null
+++ b/models/textclassifier.th.model
Binary files differ
diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model
new file mode 100644
index 0000000..2132f89
--- /dev/null
+++ b/models/textclassifier.tr.model
Binary files differ
diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model
new file mode 100644
index 0000000..96341ce
--- /dev/null
+++ b/models/textclassifier.zh-Hant.model
Binary files differ
diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model
new file mode 100644
index 0000000..adcab0f
--- /dev/null
+++ b/models/textclassifier.zh.model
Binary files differ
diff --git a/test_data/test_model.fb b/test_data/test_model.fb
index fc8353a..0f0161d 100644
--- a/test_data/test_model.fb
+++ b/test_data/test_model.fb
Binary files differ
diff --git a/test_data/test_model_cc.fb b/test_data/test_model_cc.fb
index b396943..2500551 100644
--- a/test_data/test_model_cc.fb
+++ b/test_data/test_model_cc.fb
Binary files differ
diff --git a/test_data/wrong_embeddings.fb b/test_data/wrong_embeddings.fb
index 000f739..9879e0b 100644
--- a/test_data/wrong_embeddings.fb
+++ b/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/text-classifier.cc b/text-classifier.cc
index 67346b9..417c84a 100644
--- a/text-classifier.cc
+++ b/text-classifier.cc
@@ -31,6 +31,8 @@
*[]() { return new std::string("other"); }();
const std::string& TextClassifier::kPhoneCollection =
*[]() { return new std::string("phone"); }();
+const std::string& TextClassifier::kAddressCollection =
+ *[]() { return new std::string("address"); }();
const std::string& TextClassifier::kDateCollection =
*[]() { return new std::string("date"); }();
@@ -163,9 +165,7 @@
TC_LOG(ERROR) << "No selection model.";
return;
}
- selection_executor_.reset(
- new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
- model_->selection_model()->data())));
+ selection_executor_ = ModelExecutor::Instance(model_->selection_model());
if (!selection_executor_) {
TC_LOG(ERROR) << "Could not initialize selection executor.";
return;
@@ -199,9 +199,8 @@
return;
}
- classification_executor_.reset(
- new ModelExecutor(flatbuffers::GetRoot<tflite::Model>(
- model_->classification_model()->data())));
+ classification_executor_ =
+ ModelExecutor::Instance(model_->classification_model());
if (!classification_executor_) {
TC_LOG(ERROR) << "Could not initialize classification executor.";
return;
@@ -232,54 +231,73 @@
return;
}
- embedding_executor_.reset(new TFLiteEmbeddingExecutor(
- flatbuffers::GetRoot<tflite::Model>(model_->embedding_model()->data()),
+ embedding_executor_ = TFLiteEmbeddingExecutor::Instance(
+ model_->embedding_model(),
model_->classification_feature_options()->embedding_size(),
model_->classification_feature_options()
- ->embedding_quantization_bits()));
- if (!embedding_executor_ || !embedding_executor_->IsReady()) {
+ ->embedding_quantization_bits());
+ if (!embedding_executor_) {
TC_LOG(ERROR) << "Could not initialize embedding executor.";
return;
}
}
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
if (model_->regex_model()) {
- if (!InitializeRegexModel()) {
+ if (!InitializeRegexModel(decompressor.get())) {
TC_LOG(ERROR) << "Could not initialize regex model.";
}
}
if (model_->datetime_model()) {
- datetime_parser_ =
- DatetimeParser::Instance(model_->datetime_model(), *unilib_);
+ datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(),
+ *unilib_, decompressor.get());
if (!datetime_parser_) {
TC_LOG(ERROR) << "Could not initialize datetime parser.";
return;
}
}
+ if (model_->output_options()) {
+ if (model_->output_options()->filtered_collections_annotation()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_annotation()) {
+ filtered_collections_annotation_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_classification()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_classification()) {
+ filtered_collections_classification_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_selection()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_selection()) {
+ filtered_collections_selection_.insert(collection->str());
+ }
+ }
+ }
+
initialized_ = true;
}
-bool TextClassifier::InitializeRegexModel() {
+bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) {
if (!model_->regex_model()->patterns()) {
initialized_ = false;
- TC_LOG(ERROR) << "No patterns in the regex config.";
return false;
}
// Initialize pattern recognizers.
int regex_pattern_id = 0;
for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
- std::unique_ptr<UniLib::RegexPattern> compiled_pattern(
- unilib_->CreateRegexPattern(UTF8ToUnicodeText(
- regex_pattern->pattern()->c_str(),
- regex_pattern->pattern()->Length(), /*do_copy=*/false)));
-
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
+ regex_pattern->compressed_pattern(),
+ decompressor);
if (!compiled_pattern) {
- TC_LOG(INFO) << "Failed to load pattern"
- << regex_pattern->pattern()->str();
- continue;
+ TC_LOG(INFO) << "Failed to load regex pattern";
+ return false;
}
if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
@@ -331,19 +349,86 @@
}
} // namespace
+namespace internal {
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on a left or right side
+// of this space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib) {
+ TC_CHECK(ValidNonEmptySpan(span));
+
+ UnicodeText::const_iterator it;
+
+ // Check that the current selection is all whitespaces.
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ for (int i = 0; i < (span.second - span.first); ++i, ++it) {
+ if (!unilib.IsWhitespace(*it)) {
+ return span;
+ }
+ }
+
+ CodepointSpan result;
+
+ // Try moving left.
+ result = span;
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
+ --result.first;
+ --it;
+ }
+ result.second = result.first + 1;
+ if (!unilib.IsWhitespace(*it)) {
+ return result;
+ }
+
+ // If moving left didn't find a non-whitespace character, just return the
+ // original span.
+ return span;
+}
+} // namespace internal
+
+bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_annotation_.find(
+ span.classification[0].collection) !=
+ filtered_collections_annotation_.end();
+}
+
+bool TextClassifier::FilteredForClassification(
+ const ClassificationResult& classification) const {
+ return filtered_collections_classification_.find(classification.collection) !=
+ filtered_collections_classification_.end();
+}
+
+bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_selection_.find(
+ span.classification[0].collection) !=
+ filtered_collections_selection_.end();
+}
+
CodepointSpan TextClassifier::SuggestSelection(
const std::string& context, CodepointSpan click_indices,
const SelectionOptions& options) const {
+ CodepointSpan original_click_indices = click_indices;
if (!initialized_) {
TC_LOG(ERROR) << "Not initialized";
- return click_indices;
+ return original_click_indices;
}
if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
- return click_indices;
+ return original_click_indices;
}
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+
+ if (!context_unicode.is_valid()) {
+ return original_click_indices;
+ }
+
const int context_codepoint_size = context_unicode.size_codepoints();
if (click_indices.first < 0 || click_indices.second < 0 ||
@@ -352,26 +437,44 @@
click_indices.first >= click_indices.second) {
TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
<< click_indices.first << " " << click_indices.second;
- return click_indices;
+ return original_click_indices;
+ }
+
+ if (model_->snap_whitespace_selections()) {
+ // We want to expand a purely white-space selection to a multi-selection it
+ // would've been part of. But with this feature disabled we would do a no-
+ // op, because no token is found. Therefore, we need to modify the
+ // 'click_indices' a bit to include a part of the token, so that the click-
+ // finding logic finds the clicked token correctly. This modification is
+ // done by the following function. Note, that it's enough to check the left
+ // side of the current selection, because if the white-space is a part of a
+ // multi-selection, neccessarily both tokens - on the left and the right
+ // sides need to be selected. Thus snapping only to the left is sufficient
+ // (there's a check at the bottom that makes sure that if we snap to the
+ // left token but the result does not contain the initial white-space,
+ // returns the original indices).
+ click_indices = internal::SnapLeftIfWhitespaceSelection(
+ click_indices, context_unicode, *unilib_);
}
std::vector<AnnotatedSpan> candidates;
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
+ std::vector<Token> tokens;
if (!ModelSuggestSelection(context_unicode, click_indices,
- &interpreter_manager, &candidates)) {
+ &interpreter_manager, &tokens, &candidates)) {
TC_LOG(ERROR) << "Model suggest selection failed.";
- return click_indices;
+ return original_click_indices;
}
if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
TC_LOG(ERROR) << "Regex suggest selection failed.";
- return click_indices;
+ return original_click_indices;
}
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
/*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
options.locales, ModeFlag_SELECTION, &candidates)) {
TC_LOG(ERROR) << "Datetime suggest selection failed.";
- return click_indices;
+ return original_click_indices;
}
// Sort candidates according to their position in the input, so that the next
@@ -383,19 +486,37 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, &interpreter_manager,
+ if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
&candidate_indices)) {
TC_LOG(ERROR) << "Couldn't resolve conflicts.";
- return click_indices;
+ return original_click_indices;
}
for (const int i : candidate_indices) {
- if (SpansOverlap(candidates[i].span, click_indices)) {
+ if (SpansOverlap(candidates[i].span, click_indices) &&
+ SpansOverlap(candidates[i].span, original_click_indices)) {
+ // Run model classification if not present but requested and there's a
+ // classification collection filter specified.
+ if (candidates[i].classification.empty() &&
+ model_->selection_options()->always_classify_suggested_selection() &&
+ !filtered_collections_selection_.empty()) {
+ if (!ModelClassifyText(
+ context, candidates[i].span, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &candidates[i].classification)) {
+ return original_click_indices;
+ }
+ }
+
+ // Ignore if span classification is filtered.
+ if (FilteredForSelection(candidates[i])) {
+ return original_click_indices;
+ }
+
return candidates[i].span;
}
}
- return click_indices;
+ return original_click_indices;
}
namespace {
@@ -422,6 +543,7 @@
bool TextClassifier::ResolveConflicts(
const std::vector<AnnotatedSpan>& candidates, const std::string& context,
+ const std::vector<Token>& cached_tokens,
InterpreterManager* interpreter_manager, std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
@@ -432,8 +554,9 @@
const bool conflict_found = first_non_overlapping != (i + 1);
if (conflict_found) {
std::vector<int> candidate_indices;
- if (!ResolveConflict(context, candidates, i, first_non_overlapping,
- interpreter_manager, &candidate_indices)) {
+ if (!ResolveConflict(context, cached_tokens, candidates, i,
+ first_non_overlapping, interpreter_manager,
+ &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -466,8 +589,9 @@
} // namespace
bool TextClassifier::ResolveConflict(
- const std::string& context, const std::vector<AnnotatedSpan>& candidates,
- int start_index, int end_index, InterpreterManager* interpreter_manager,
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates, int start_index,
+ int end_index, InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
std::unordered_map<int, float> scores;
@@ -484,7 +608,8 @@
// candidate conflicts and comes from the model, we need to run a
// classification to determine its priority:
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(context, candidates[i].span, interpreter_manager,
+ if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
+ interpreter_manager,
/*embedding_cache=*/nullptr, &classification)) {
return false;
}
@@ -522,7 +647,7 @@
bool TextClassifier::ModelSuggestSelection(
const UnicodeText& context_unicode, CodepointSpan click_indices,
- InterpreterManager* interpreter_manager,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
@@ -530,12 +655,11 @@
}
int click_pos;
- std::vector<Token> tokens =
- selection_feature_processor_->Tokenize(context_unicode);
+ *tokens = selection_feature_processor_->Tokenize(context_unicode);
selection_feature_processor_->RetokenizeAndFindClick(
context_unicode, click_indices,
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- &tokens, &click_pos);
+ tokens, &click_pos);
if (click_pos == kInvalidIndex) {
TC_VLOG(1) << "Could not calculate the click position.";
return false;
@@ -553,7 +677,7 @@
ExpandTokenSpan(SingleTokenSpan(click_pos),
/*num_tokens_left=*/symmetry_context_size,
/*num_tokens_right=*/symmetry_context_size),
- {0, tokens.size()});
+ {0, tokens->size()});
// Compute the extraction span based on the model type.
TokenSpan extraction_span;
@@ -579,11 +703,11 @@
/*num_tokens_left=*/context_size,
/*num_tokens_right=*/context_size);
}
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
- tokens, extraction_span,
+ *tokens, extraction_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
/*embedding_cache=*/nullptr,
@@ -596,7 +720,7 @@
// Produce selection model candidates.
std::vector<TokenSpan> chunks;
- if (!ModelChunk(tokens.size(), /*span_of_interest=*/symmetry_context_span,
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
interpreter_manager->SelectionInterpreter(), *cached_features,
&chunks)) {
TC_LOG(ERROR) << "Could not chunk.";
@@ -606,7 +730,7 @@
for (const TokenSpan& chunk : chunks) {
AnnotatedSpan candidate;
candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
- context_unicode, TokenSpanToCodepointSpan(tokens, chunk));
+ context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
if (model_->selection_options()->strip_unpaired_brackets()) {
candidate.span =
StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
@@ -801,6 +925,15 @@
}
}
+ // Address class sanity check.
+ if (!classification_results->empty() &&
+ classification_results->begin()->collection == kAddressCollection) {
+ if (TokenSpanSize(selection_token_span) <
+ model_->classification_options()->address_min_num_tokens()) {
+ *classification_results = {{kOtherCollection, 1.0}};
+ }
+ }
+
return true;
}
@@ -856,7 +989,8 @@
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
options.reference_timezone, options.locales,
- ModeFlag_CLASSIFICATION, &datetime_spans)) {
+ ModeFlag_CLASSIFICATION,
+ /*anchor_start_end=*/true, &datetime_spans)) {
TC_LOG(ERROR) << "Error during parsing datetime.";
return false;
}
@@ -887,6 +1021,10 @@
return {};
}
+ if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ return {};
+ }
+
if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
<< std::get<0>(selection_indices) << " "
@@ -897,14 +1035,22 @@
// Try the regular expression models.
ClassificationResult regex_result;
if (RegexClassifyText(context, selection_indices, ®ex_result)) {
- return {regex_result};
+ if (!FilteredForClassification(regex_result)) {
+ return {regex_result};
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
}
// Try the date model.
ClassificationResult datetime_result;
if (DatetimeClassifyText(context, selection_indices, options,
&datetime_result)) {
- return {datetime_result};
+ if (!FilteredForClassification(datetime_result)) {
+ return {datetime_result};
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
}
// Fallback to the model.
@@ -913,8 +1059,13 @@
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
if (ModelClassifyText(context, selection_indices, &interpreter_manager,
- /*embedding_cache=*/nullptr, &model_result)) {
- return model_result;
+ /*embedding_cache=*/nullptr, &model_result) &&
+ !model_result.empty()) {
+ if (!FilteredForClassification(model_result[0])) {
+ return model_result;
+ } else {
+ return {{kOtherCollection, 1.0}};
+ }
}
// No classifications.
@@ -923,6 +1074,7 @@
bool TextClassifier::ModelAnnotate(const std::string& context,
InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
@@ -948,18 +1100,17 @@
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
- std::vector<Token> tokens =
- selection_feature_processor_->Tokenize(line_str);
+ *tokens = selection_feature_processor_->Tokenize(line_str);
selection_feature_processor_->RetokenizeAndFindClick(
line_str, {0, std::distance(line.first, line.second)},
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- &tokens,
+ tokens,
/*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0, tokens.size()};
+ const TokenSpan full_line_span = {0, tokens->size()};
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
- tokens, full_line_span,
+ *tokens, full_line_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
/*embedding_cache=*/nullptr,
@@ -971,7 +1122,7 @@
}
std::vector<TokenSpan> local_chunks;
- if (!ModelChunk(tokens.size(), /*span_of_interest=*/full_line_span,
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
interpreter_manager->SelectionInterpreter(),
*cached_features, &local_chunks)) {
TC_LOG(ERROR) << "Could not chunk.";
@@ -982,12 +1133,12 @@
for (const TokenSpan& chunk : local_chunks) {
const CodepointSpan codepoint_span =
selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(tokens, chunk));
+ line_str, TokenSpanToCodepointSpan(*tokens, chunk));
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, tokens, codepoint_span,
+ if (!ModelClassifyText(line_str, *tokens, codepoint_span,
interpreter_manager, &embedding_cache,
&classification)) {
TC_LOG(ERROR) << "Could not classify text: "
@@ -1016,6 +1167,10 @@
return *selection_feature_processor_;
}
+const DatetimeParser* TextClassifier::DatetimeParserForTests() const {
+ return datetime_parser_.get();
+}
+
std::vector<AnnotatedSpan> TextClassifier::Annotate(
const std::string& context, const AnnotationOptions& options) const {
std::vector<AnnotatedSpan> candidates;
@@ -1024,10 +1179,15 @@
return {};
}
+ if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ return {};
+ }
+
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
// Annotate with the selection model.
- if (!ModelAnnotate(context, &interpreter_manager, &candidates)) {
+ std::vector<Token> tokens;
+ if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
return {};
}
@@ -1056,7 +1216,7 @@
});
std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, &interpreter_manager,
+ if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
&candidate_indices)) {
TC_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
@@ -1066,7 +1226,8 @@
result.reserve(candidate_indices.size());
for (const int i : candidate_indices) {
if (!candidates[i].classification.empty() &&
- !ClassifiedAsOther(candidates[i].classification)) {
+ !ClassifiedAsOther(candidates[i].classification) &&
+ !FilteredForAnnotation(candidates[i])) {
result.push_back(std::move(candidates[i]));
}
}
@@ -1351,10 +1512,14 @@
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const {
+ if (!datetime_parser_) {
+ return false;
+ }
+
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
reference_timezone, locales, mode,
- &datetime_spans)) {
+ /*anchor_start_end=*/false, &datetime_spans)) {
return false;
}
for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
diff --git a/text-classifier.h b/text-classifier.h
index 0c79429..ad94dc4 100644
--- a/text-classifier.h
+++ b/text-classifier.h
@@ -32,6 +32,7 @@
#include "types.h"
#include "util/memory/mmap.h"
#include "util/utf8/unilib.h"
+#include "zlib-utils.h"
namespace libtextclassifier2 {
@@ -150,9 +151,13 @@
// Exposes the selection feature processor for tests and evaluations.
const FeatureProcessor& SelectionFeatureProcessorForTests() const;
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const;
+
// String collection names for various classes.
static const std::string& kOtherCollection;
static const std::string& kPhoneCollection;
+ static const std::string& kAddressCollection;
static const std::string& kDateCollection;
protected:
@@ -186,7 +191,7 @@
void ValidateAndInitialize();
// Initializes regular expressions for the regex model.
- bool InitializeRegexModel();
+ bool InitializeRegexModel(ZlibDecompressor* decompressor);
// Resolves conflicts in the list of candidates by removing some overlapping
// ones. Returns indices of the surviving ones.
@@ -194,6 +199,7 @@
// the span.
bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
const std::string& context,
+ const std::vector<Token>& cached_tokens,
InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
@@ -201,15 +207,19 @@
// (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
// indices to 'chosen_indices'. Returns false if a problem arises.
bool ResolveConflict(const std::string& context,
+ const std::vector<Token>& cached_tokens,
const std::vector<AnnotatedSpan>& candidates,
int start_index, int end_index,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
// Gets selection candidates from the ML model.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
bool ModelSuggestSelection(const UnicodeText& context_unicode,
CodepointSpan click_indices,
InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
// Classifies the selected text given the context string with the
@@ -249,8 +259,11 @@
// with the classification model.
// The annotations are sorted by their position in the context string and
// exclude spans classified as 'other'.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
bool ModelAnnotate(const std::string& context,
InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
// Groups the tokens into chunks. A chunk is a token span that should be the
@@ -297,6 +310,12 @@
const std::string& locales, ModeFlag mode,
std::vector<AnnotatedSpan>* result) const;
+ // Returns whether a classification should be filtered.
+ bool FilteredForAnnotation(const AnnotatedSpan& span) const;
+ bool FilteredForClassification(
+ const ClassificationResult& classification) const;
+ bool FilteredForSelection(const AnnotatedSpan& span) const;
+
const Model* model_;
std::unique_ptr<const ModelExecutor> selection_executor_;
@@ -321,6 +340,9 @@
bool enabled_for_annotation_ = false;
bool enabled_for_classification_ = false;
bool enabled_for_selection_ = false;
+ std::unordered_set<std::string> filtered_collections_annotation_;
+ std::unordered_set<std::string> filtered_collections_classification_;
+ std::unordered_set<std::string> filtered_collections_selection_;
std::vector<CompiledRegexPattern> regex_patterns_;
std::unordered_set<int> regex_approximate_match_pattern_ids_;
@@ -334,6 +356,14 @@
};
namespace internal {
+
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on the left side
+// of this block of white-space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib);
+
// Copies tokens from 'cached_tokens' that are
// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
// from the tokens that correspond to 'selection_indices'.
diff --git a/text-classifier_test.cc b/text-classifier_test.cc
index 74534e2..440cedf 100644
--- a/text-classifier_test.cc
+++ b/text-classifier_test.cc
@@ -104,6 +104,9 @@
FirstResult(classifier->ClassifyText("", {0, 0})));
EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
+ // Test invalid utf8 input.
+ EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
+ "\xf0\x9f\x98\x8b\x8b", {0, 0})));
}
TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
@@ -150,6 +153,41 @@
IsEmpty());
}
+TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone classification
+ unpacked_model->output_options->filtered_collections_classification.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
+ "Call me at (800) 123-456 today", {11, 24})));
+
+ // Check that the address classification still passes.
+ EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
+ "350 Third Street, Cambridge", {0, 27})));
+}
+
std::unique_ptr<RegexModel_::PatternT> MakePattern(
const std::string& collection_name, const std::string& pattern,
const bool enabled_for_classification, const bool enabled_for_selection,
@@ -225,7 +263,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -260,7 +297,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -294,7 +330,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
@@ -328,7 +363,6 @@
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
- unpacked_model->regex_model.reset(new RegexModelT);
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
@@ -469,6 +503,45 @@
IsEmpty());
}
+TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 23));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone selection
+ unpacked_model->output_options->filtered_collections_selection.push_back(
+ "phone");
+ // We need to force this for filtering.
+ unpacked_model->selection_options->always_classify_suggested_selection = true;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
+ std::make_pair(11, 14));
+
+ // Address selection should still work.
+ EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
+ std::make_pair(0, 27));
+}
+
TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
@@ -548,6 +621,91 @@
std::make_pair(-10, -1));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
std::make_pair(100, 17));
+
+ // Try passing invalid utf8.
+ EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
+ std::make_pair(-1, -1));
+}
+
+TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) {
+ CREATE_UNILIB_FOR_TESTING;
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
+ std::make_pair(10, 11));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
+ std::make_pair(23, 24));
+ EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
+ {14, 17}),
+ std::make_pair(11, 25));
+ EXPECT_EQ(
+ classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
+ std::make_pair(11, 23));
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
+ std::make_pair(14, 40));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
+ std::make_pair(4, 5));
+ EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
+ std::make_pair(7, 8));
+
+ // With a punctuation around the selected whitespace.
+ EXPECT_EQ(
+ classifier->SuggestSelection(
+ "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
+ std::make_pair(14, 41));
+
+ // When all's whitespace, should return the original indices.
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
+ std::make_pair(0, 1));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
+ std::make_pair(0, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
+ std::make_pair(2, 3));
+ EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
+ std::make_pair(5, 6));
+}
+
+TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) {
+ CREATE_UNILIB_FOR_TESTING;
+ UnicodeText text;
+
+ text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(3, 4));
+ text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(3, 4));
+
+ // Nothing on the left.
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
+ std::make_pair(0, 1));
+
+ // Whitespace only.
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib),
+ std::make_pair(2, 3));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
+ std::make_pair(4, 5));
+ text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
+ EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
+ std::make_pair(0, 1));
}
TEST_P(TextClassifierTest, Annotate) {
@@ -572,6 +730,11 @@
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+
+ // Try passing invalid utf8.
+ EXPECT_TRUE(
+ classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
+ .empty());
}
TEST_P(TextClassifierTest, AnnotateSmallBatches) {
@@ -686,6 +849,104 @@
EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
}
+TEST_P(TextClassifierTest, AnnotateFilteredCollections) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // Disable phone annotation
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "phone");
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
+ &unilib);
+ ASSERT_TRUE(classifier);
+
+ const std::string test_string =
+ "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
+ "number is 853 225 3556";
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+ IsAnnotatedSpan(19, 24, "date"),
+#endif
+ IsAnnotatedSpan(28, 55, "address"),
+ IsAnnotatedSpan(79, 91, "phone"),
+ }));
+
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+ unpacked_model->output_options.reset(new OutputOptionsT);
+
+ // We add a custom annotator that wins against the phone classification
+ // below and that we subsequently suppress.
+ unpacked_model->output_options->filtered_collections_annotation.push_back(
+ "suppress");
+
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "suppress", "(\\d{3} ?\\d{4})",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_THAT(classifier->Annotate(test_string),
+ ElementsAreArray({
+ IsAnnotatedSpan(19, 24, "date"),
+ IsAnnotatedSpan(28, 55, "address"),
+ }));
+}
+#endif
+
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
std::unique_ptr<TextClassifier> classifier =
@@ -734,9 +995,43 @@
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
- result.clear();
}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+TEST_P(TextClassifierTest, ClassifyTextDatePriorities) {
+ std::unique_ptr<TextClassifier> classifier =
+ TextClassifier::FromPath(GetModelPath() + GetParam());
+ EXPECT_TRUE(classifier);
+
+ std::vector<ClassificationResult> result;
+ ClassificationOptions options;
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-US";
+ result = classifier->ClassifyText("03/05", {0, 5}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+
+ result.clear();
+ options.reference_timezone = "Europe/Zurich";
+ options.locales = "en-GB,en-US";
+ result = classifier->ClassifyText("03/05", {0, 5}, options);
+
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_DAY);
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
+#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -789,7 +1084,7 @@
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
@@ -807,7 +1102,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
@@ -823,7 +1118,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
@@ -839,7 +1134,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
@@ -857,7 +1152,7 @@
}};
std::vector<int> chosen;
- classifier.ResolveConflicts(candidates, /*context=*/"",
+ classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
@@ -916,5 +1211,43 @@
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+TEST_P(TextClassifierTest, MinAddressTokenLength) {
+ CREATE_UNILIB_FOR_TESTING;
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ std::unique_ptr<TextClassifier> classifier;
+
+ // With unrestricted number of address tokens should behave normally.
+ unpacked_model->classification_options->address_min_num_tokens = 0;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "address");
+
+ // Raise number of address tokens to suppress the address classification.
+ unpacked_model->classification_options->address_min_num_tokens = 5;
+
+ flatbuffers::FlatBufferBuilder builder2;
+ builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
+ classifier = TextClassifier::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder2.GetBufferPointer()),
+ builder2.GetSize(), &unilib);
+ ASSERT_TRUE(classifier);
+
+ EXPECT_EQ(FirstResult(classifier->ClassifyText(
+ "I live at 350 Third Street, Cambridge.", {10, 37})),
+ "other");
+}
+#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+
} // namespace
} // namespace libtextclassifier2
diff --git a/util/i18n/locale.cc b/util/i18n/locale.cc
new file mode 100644
index 0000000..c587d2d
--- /dev/null
+++ b/util/i18n/locale.cc
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/i18n/locale.h"
+
+#include "util/strings/split.h"
+
+namespace libtextclassifier2 {
+
+namespace {
+
+bool CheckLanguage(StringPiece language) {
+ if (language.size() != 2 && language.size() != 3) {
+ return false;
+ }
+
+ // Needs to be all lowercase.
+ for (int i = 0; i < language.size(); ++i) {
+ if (!std::islower(language[i])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool CheckScript(StringPiece script) {
+ if (script.size() != 4) {
+ return false;
+ }
+
+ if (!std::isupper(script[0])) {
+ return false;
+ }
+
+ // Needs to be all lowercase.
+ for (int i = 1; i < script.size(); ++i) {
+ if (!std::islower(script[i])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool CheckRegion(StringPiece region) {
+ if (region.size() == 2) {
+ return std::isupper(region[0]) && std::isupper(region[1]);
+ } else if (region.size() == 3) {
+ return std::isdigit(region[0]) && std::isdigit(region[1]) &&
+ std::isdigit(region[2]);
+ } else {
+ return false;
+ }
+}
+
+} // namespace
+
+Locale Locale::FromBCP47(const std::string& locale_tag) {
+ std::vector<StringPiece> parts = strings::Split(locale_tag, '-');
+ if (parts.empty()) {
+ return Locale::Invalid();
+ }
+
+ auto parts_it = parts.begin();
+ StringPiece language = *parts_it;
+ if (!CheckLanguage(language)) {
+ return Locale::Invalid();
+ }
+ ++parts_it;
+
+ StringPiece script;
+ if (parts_it != parts.end()) {
+ script = *parts_it;
+ if (!CheckScript(script)) {
+ script = "";
+ } else {
+ ++parts_it;
+ }
+ }
+
+ StringPiece region;
+ if (parts_it != parts.end()) {
+ region = *parts_it;
+ if (!CheckRegion(region)) {
+ region = "";
+ } else {
+ ++parts_it;
+ }
+ }
+
+ // NOTE: We don't parse the rest of the BCP47 tag here even if specified.
+
+ return Locale(language.ToString(), script.ToString(), region.ToString());
+}
+
+} // namespace libtextclassifier2
diff --git a/util/i18n/locale.h b/util/i18n/locale.h
new file mode 100644
index 0000000..16f10dc
--- /dev/null
+++ b/util/i18n/locale.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_
+#define LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_
+
+#include <string>
+
+#include "util/base/integral_types.h"
+
+namespace libtextclassifier2 {
+
+class Locale {
+ public:
+ // Constructs the object from a valid BCP47 tag. If the tag is invalid,
+ // an object is created that gives false when IsInvalid() is called.
+ static Locale FromBCP47(const std::string& locale_tag);
+
+ // Creates a prototypical invalid locale object.
+ static Locale Invalid() {
+ Locale locale(/*language=*/"", /*script=*/"", /*region=*/"");
+ locale.is_valid_ = false;
+ return locale;
+ }
+
+ std::string Language() const { return language_; }
+
+ std::string Script() const { return script_; }
+
+ std::string Region() const { return region_; }
+
+ bool IsValid() const { return is_valid_; }
+
+ private:
+ Locale(const std::string& language, const std::string& script,
+ const std::string& region)
+ : language_(language),
+ script_(script),
+ region_(region),
+ is_valid_(true) {}
+
+ std::string language_;
+ std::string script_;
+ std::string region_;
+ bool is_valid_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_UTIL_I18N_LOCALE_H_
diff --git a/util/i18n/locale_test.cc b/util/i18n/locale_test.cc
new file mode 100644
index 0000000..72ece98
--- /dev/null
+++ b/util/i18n/locale_test.cc
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/i18n/locale.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+namespace {
+
+TEST(LocaleTest, ParseUnknown) {
+ Locale locale = Locale::Invalid();
+ EXPECT_FALSE(locale.IsValid());
+}
+
+TEST(LocaleTest, ParseSwissEnglish) {
+ Locale locale = Locale::FromBCP47("en-CH");
+ EXPECT_TRUE(locale.IsValid());
+ EXPECT_EQ(locale.Language(), "en");
+ EXPECT_EQ(locale.Script(), "");
+ EXPECT_EQ(locale.Region(), "CH");
+}
+
+TEST(LocaleTest, ParseChineseChina) {
+ Locale locale = Locale::FromBCP47("zh-CN");
+ EXPECT_TRUE(locale.IsValid());
+ EXPECT_EQ(locale.Language(), "zh");
+ EXPECT_EQ(locale.Script(), "");
+ EXPECT_EQ(locale.Region(), "CN");
+}
+
+TEST(LocaleTest, ParseChineseTaiwan) {
+ Locale locale = Locale::FromBCP47("zh-Hant-TW");
+ EXPECT_TRUE(locale.IsValid());
+ EXPECT_EQ(locale.Language(), "zh");
+ EXPECT_EQ(locale.Script(), "Hant");
+ EXPECT_EQ(locale.Region(), "TW");
+}
+
+TEST(LocaleTest, ParseEnglish) {
+ Locale locale = Locale::FromBCP47("en");
+ EXPECT_TRUE(locale.IsValid());
+ EXPECT_EQ(locale.Language(), "en");
+ EXPECT_EQ(locale.Script(), "");
+ EXPECT_EQ(locale.Region(), "");
+}
+
+TEST(LocaleTest, ParseCineseTraditional) {
+ Locale locale = Locale::FromBCP47("zh-Hant");
+ EXPECT_TRUE(locale.IsValid());
+ EXPECT_EQ(locale.Language(), "zh");
+ EXPECT_EQ(locale.Script(), "Hant");
+ EXPECT_EQ(locale.Region(), "");
+}
+
+} // namespace
+} // namespace libtextclassifier2
diff --git a/util/strings/split.cc b/util/strings/split.cc
index e61e3ba..2c610ba 100644
--- a/util/strings/split.cc
+++ b/util/strings/split.cc
@@ -19,14 +19,14 @@
namespace libtextclassifier2 {
namespace strings {
-std::vector<std::string> Split(const std::string &text, char delim) {
- std::vector<std::string> result;
+std::vector<StringPiece> Split(const StringPiece &text, char delim) {
+ std::vector<StringPiece> result;
int token_start = 0;
if (!text.empty()) {
for (size_t i = 0; i < text.size() + 1; i++) {
if ((i == text.size()) || (text[i] == delim)) {
result.push_back(
- std::string(text.data() + token_start, i - token_start));
+ StringPiece(text.data() + token_start, i - token_start));
token_start = i + 1;
}
}
diff --git a/util/strings/split.h b/util/strings/split.h
index abd453b..96f73fe 100644
--- a/util/strings/split.h
+++ b/util/strings/split.h
@@ -20,10 +20,12 @@
#include <string>
#include <vector>
+#include "util/strings/stringpiece.h"
+
namespace libtextclassifier2 {
namespace strings {
-std::vector<std::string> Split(const std::string &text, char delim);
+std::vector<StringPiece> Split(const StringPiece &text, char delim);
} // namespace strings
} // namespace libtextclassifier2
diff --git a/util/strings/stringpiece.h b/util/strings/stringpiece.h
index bd62274..cd07848 100644
--- a/util/strings/stringpiece.h
+++ b/util/strings/stringpiece.h
@@ -51,6 +51,8 @@
size_t size() const { return size_; }
size_t length() const { return size_; }
+ bool empty() const { return size_ == 0; }
+
// Returns a std::string containing a copy of the underlying data.
std::string ToString() const {
return std::string(data(), size());
diff --git a/util/strings/utf8.cc b/util/strings/utf8.cc
new file mode 100644
index 0000000..39dcb4e
--- /dev/null
+++ b/util/strings/utf8.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/strings/utf8.h"
+
+namespace libtextclassifier2 {
+bool IsValidUTF8(const char *src, int size) {
+ for (int i = 0; i < size;) {
+ // Unexpected trail byte.
+ if (IsTrailByte(src[i])) {
+ return false;
+ }
+
+ const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[i]);
+ if (num_codepoint_bytes <= 0 || i + num_codepoint_bytes > size) {
+ return false;
+ }
+
+ // Check that remaining bytes in the codepoint are trailing bytes.
+ i++;
+ for (int k = 1; k < num_codepoint_bytes; k++, i++) {
+ if (!IsTrailByte(src[i])) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+} // namespace libtextclassifier2
diff --git a/util/strings/utf8.h b/util/strings/utf8.h
index e54c18a..1e75da2 100644
--- a/util/strings/utf8.h
+++ b/util/strings/utf8.h
@@ -44,6 +44,9 @@
return static_cast<signed char>(x) < -0x40;
}
+// Returns true iff src points to a well-formed UTF-8 string.
+bool IsValidUTF8(const char *src, int size);
+
} // namespace libtextclassifier2
#endif // LIBTEXTCLASSIFIER_UTIL_STRINGS_UTF8_H_
diff --git a/util/utf8/unicodetext.cc b/util/utf8/unicodetext.cc
index 90a581f..70fecd4 100644
--- a/util/utf8/unicodetext.cc
+++ b/util/utf8/unicodetext.cc
@@ -191,6 +191,10 @@
bool UnicodeText::empty() const { return size_bytes() == 0; }
+bool UnicodeText::is_valid() const {
+ return IsValidUTF8(repr_.data_, repr_.size_);
+}
+
bool UnicodeText::operator==(const UnicodeText& other) const {
if (repr_.size_ != other.repr_.size_) {
return false;
diff --git a/util/utf8/unicodetext.h b/util/utf8/unicodetext.h
index 8e13496..7300111 100644
--- a/util/utf8/unicodetext.h
+++ b/util/utf8/unicodetext.h
@@ -154,6 +154,9 @@
bool empty() const;
+ // Checks whether the underlying data is valid utf8 data.
+ bool is_valid() const;
+
bool operator==(const UnicodeText& other) const;
// x.PointToUTF8(buf,len) changes x so that it points to buf
diff --git a/util/utf8/unicodetext_test.cc b/util/utf8/unicodetext_test.cc
index 8aef952..9ec7621 100644
--- a/util/utf8/unicodetext_test.cc
+++ b/util/utf8/unicodetext_test.cc
@@ -65,6 +65,28 @@
EXPECT_NE(t.data(), alias.data());
}
+TEST(UnicodeTextTest, Validation) {
+ EXPECT_TRUE(UTF8ToUnicodeText("1234πhello", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("\u304A\u00B0\u106B", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("this is a testπππ", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x8b", /*do_copy=*/false).is_valid());
+ // Too short (string is too short).
+ EXPECT_FALSE(UTF8ToUnicodeText("\xf0\x9f", /*do_copy=*/false).is_valid());
+ // Too long (too many trailing bytes).
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x8b\x8b", /*do_copy=*/false).is_valid());
+ // Too short (too few trailing bytes).
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x61\x61", /*do_copy=*/false).is_valid());
+ // Invalid with context.
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("hello \xf0\x9f\x98\x61\x61 world1", /*do_copy=*/false)
+ .is_valid());
+}
+
class IteratorTest : public UnicodeTextTest {};
TEST_F(IteratorTest, Iterates) {
diff --git a/util/utf8/unilib-icu.cc b/util/utf8/unilib-icu.cc
index b1eac2c..9e9ce19 100644
--- a/util/utf8/unilib-icu.cc
+++ b/util/utf8/unilib-icu.cc
@@ -66,7 +66,10 @@
UniLib::RegexMatcher::RegexMatcher(icu::RegexPattern* pattern,
icu::UnicodeString text)
- : pattern_(pattern), text_(std::move(text)) {
+ : text_(std::move(text)),
+ last_find_offset_(0),
+ last_find_offset_codepoints_(0),
+ last_find_offset_dirty_(true) {
UErrorCode status = U_ZERO_ERROR;
matcher_.reset(pattern->matcher(text_, status));
if (U_FAILURE(status)) {
@@ -125,6 +128,25 @@
return true;
}
+bool UniLib::RegexMatcher::UpdateLastFindOffset() const {
+ if (!last_find_offset_dirty_) {
+ return true;
+ }
+
+ // Update the position of the match.
+ UErrorCode icu_status = U_ZERO_ERROR;
+ const int find_offset = matcher_->start(0, icu_status);
+ if (U_FAILURE(icu_status)) {
+ return false;
+ }
+ last_find_offset_codepoints_ +=
+ text_.countChar32(last_find_offset_, find_offset - last_find_offset_);
+ last_find_offset_ = find_offset;
+ last_find_offset_dirty_ = false;
+
+ return true;
+}
+
bool UniLib::RegexMatcher::Find(int* status) {
if (!matcher_) {
*status = kError;
@@ -136,6 +158,8 @@
*status = kError;
return false;
}
+
+ last_find_offset_dirty_ = true;
*status = kNoError;
return result;
}
@@ -145,10 +169,11 @@
}
int UniLib::RegexMatcher::Start(int group_idx, int* status) const {
- if (!matcher_) {
+ if (!matcher_ || !UpdateLastFindOffset()) {
*status = kError;
return kError;
}
+
UErrorCode icu_status = U_ZERO_ERROR;
const int result = matcher_->start(group_idx, icu_status);
if (U_FAILURE(icu_status)) {
@@ -156,20 +181,16 @@
return kError;
}
*status = kNoError;
- return text_.countChar32(/*start=*/0, /*length=*/result);
-}
-int UniLib::RegexMatcher::Start(StringPiece group_name, int* status) const {
- UErrorCode icu_status = U_ZERO_ERROR;
- const int group_idx = pattern_->groupNumberFromName(
- icu::UnicodeString::fromUTF8(
- icu::StringPiece(group_name.data(), group_name.size())),
- icu_status);
- if (U_FAILURE(icu_status)) {
- *status = kError;
- return kError;
+ // If the group didn't participate in the match the result is -1 and is
+ // incompatible with the caching logic bellow.
+ if (result == -1) {
+ return -1;
}
- return Start(group_idx, status);
+
+ return last_find_offset_codepoints_ +
+ text_.countChar32(/*start=*/last_find_offset_,
+ /*length=*/result - last_find_offset_);
}
int UniLib::RegexMatcher::End(int* status) const {
@@ -177,7 +198,7 @@
}
int UniLib::RegexMatcher::End(int group_idx, int* status) const {
- if (!matcher_) {
+ if (!matcher_ || !UpdateLastFindOffset()) {
*status = kError;
return kError;
}
@@ -188,20 +209,16 @@
return kError;
}
*status = kNoError;
- return text_.countChar32(/*start=*/0, /*length=*/result);
-}
-int UniLib::RegexMatcher::End(StringPiece group_name, int* status) const {
- UErrorCode icu_status = U_ZERO_ERROR;
- const int group_idx = pattern_->groupNumberFromName(
- icu::UnicodeString::fromUTF8(
- icu::StringPiece(group_name.data(), group_name.size())),
- icu_status);
- if (U_FAILURE(icu_status)) {
- *status = kError;
- return kError;
+ // If the group didn't participate in the match the result is -1 and is
+ // incompatible with the caching logic bellow.
+ if (result == -1) {
+ return -1;
}
- return End(group_idx, status);
+
+ return last_find_offset_codepoints_ +
+ text_.countChar32(/*start=*/last_find_offset_,
+ /*length=*/result - last_find_offset_);
}
UnicodeText UniLib::RegexMatcher::Group(int* status) const {
@@ -225,20 +242,6 @@
return UTF8ToUnicodeText(result, /*do_copy=*/true);
}
-UnicodeText UniLib::RegexMatcher::Group(StringPiece group_name,
- int* status) const {
- UErrorCode icu_status = U_ZERO_ERROR;
- const int group_idx = pattern_->groupNumberFromName(
- icu::UnicodeString::fromUTF8(
- icu::StringPiece(group_name.data(), group_name.size())),
- icu_status);
- if (U_FAILURE(icu_status)) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
- return Group(group_idx, status);
-}
-
constexpr int UniLib::BreakIterator::kDone;
UniLib::BreakIterator::BreakIterator(const UnicodeText& text)
diff --git a/util/utf8/unilib-icu.h b/util/utf8/unilib-icu.h
index 9488a00..8983756 100644
--- a/util/utf8/unilib-icu.h
+++ b/util/utf8/unilib-icu.h
@@ -23,7 +23,6 @@
#include <memory>
#include "util/base/integral_types.h"
-#include "util/strings/stringpiece.h"
#include "util/utf8/unicodetext.h"
#include "unicode/brkiter.h"
#include "unicode/errorcode.h"
@@ -81,9 +80,6 @@
// was not called previously.
int Start(int group_idx, int* status) const;
- // Same as above but uses the group name instead of the index.
- int Start(StringPiece group_name, int* status) const;
-
// Gets the end offset of the last match (from 'Find').
// Sets status to 'kError' if 'Find'
// was not called previously.
@@ -95,9 +91,6 @@
// was not called previously.
int End(int group_idx, int* status) const;
- // Same as above but uses the group name instead of the index.
- int End(StringPiece group_name, int* status) const;
-
// Gets the text of the last match (from 'Find').
// Sets status to 'kError' if 'Find' was not called previously.
UnicodeText Group(int* status) const;
@@ -107,19 +100,18 @@
// was not called previously.
UnicodeText Group(int group_idx, int* status) const;
- // Gets the text of the specified group of the last match (from 'Find').
- // Sets status to 'kError' if an invalid group was specified or if 'Find'
- // was not called previously.
- UnicodeText Group(StringPiece group_name, int* status) const;
-
protected:
friend class RegexPattern;
explicit RegexMatcher(icu::RegexPattern* pattern, icu::UnicodeString text);
private:
+ bool UpdateLastFindOffset() const;
+
std::unique_ptr<icu::RegexMatcher> matcher_;
- icu::RegexPattern* pattern_;
icu::UnicodeString text_;
+ mutable int last_find_offset_;
+ mutable int last_find_offset_codepoints_;
+ mutable bool last_find_offset_dirty_;
};
class RegexPattern {
diff --git a/util/utf8/unilib_test.cc b/util/utf8/unilib_test.cc
index 665bfec..13b1347 100644
--- a/util/utf8/unilib_test.cc
+++ b/util/utf8/unilib_test.cc
@@ -82,9 +82,7 @@
TC_LOG(INFO) << matcher->Matches(&status);
TC_LOG(INFO) << matcher->Find(&status);
TC_LOG(INFO) << matcher->Start(0, &status);
- TC_LOG(INFO) << matcher->Start("group_name", &status);
TC_LOG(INFO) << matcher->End(0, &status);
- TC_LOG(INFO) << matcher->End("group_name", &status);
TC_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
}
@@ -151,22 +149,14 @@
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->Start(1, &status), 8);
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start("group1", &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->Start(2, &status), 9);
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start("group2", &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->End(0, &status), 13);
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->End(1, &status), 9);
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End("group1", &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->End(2, &status), 12);
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End("group2", &status), 12);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123π");
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0");
diff --git a/zlib-utils.cc b/zlib-utils.cc
new file mode 100644
index 0000000..8650c9c
--- /dev/null
+++ b/zlib-utils.cc
@@ -0,0 +1,198 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "zlib-utils.h"
+
+#include <memory>
+
+#include "util/base/logging.h"
+
+namespace libtextclassifier2 {
+
+std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() {
+ std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor());
+ if (!result->initialized_) {
+ result.reset();
+ }
+ return result;
+}
+
+ZlibDecompressor::ZlibDecompressor() {
+ memset(&stream_, 0, sizeof(stream_));
+ stream_.zalloc = Z_NULL;
+ stream_.zfree = Z_NULL;
+ initialized_ = (inflateInit(&stream_) == Z_OK);
+}
+
+ZlibDecompressor::~ZlibDecompressor() {
+ if (initialized_) {
+ inflateEnd(&stream_);
+ }
+}
+
+bool ZlibDecompressor::Decompress(const CompressedBuffer* compressed_buffer,
+ std::string* out) {
+ out->resize(compressed_buffer->uncompressed_size());
+ stream_.next_in =
+ reinterpret_cast<const Bytef*>(compressed_buffer->buffer()->Data());
+ stream_.avail_in = compressed_buffer->buffer()->Length();
+ stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str()));
+ stream_.avail_out = compressed_buffer->uncompressed_size();
+ return (inflate(&stream_, Z_SYNC_FLUSH) == Z_OK);
+}
+
+std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() {
+ std::unique_ptr<ZlibCompressor> result(new ZlibCompressor());
+ if (!result->initialized_) {
+ result.reset();
+ }
+ return result;
+}
+
+ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) {
+ memset(&stream_, 0, sizeof(stream_));
+ stream_.zalloc = Z_NULL;
+ stream_.zfree = Z_NULL;
+ buffer_size_ = tmp_buffer_size;
+ buffer_.reset(new Bytef[buffer_size_]);
+ initialized_ = (deflateInit(&stream_, level) == Z_OK);
+}
+
+ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); }
+
+void ZlibCompressor::Compress(const std::string& uncompressed_content,
+ CompressedBufferT* out) {
+ out->uncompressed_size = uncompressed_content.size();
+ out->buffer.clear();
+ stream_.next_in =
+ reinterpret_cast<const Bytef*>(uncompressed_content.c_str());
+ stream_.avail_in = uncompressed_content.size();
+ stream_.next_out = buffer_.get();
+ stream_.avail_out = buffer_size_;
+ unsigned char* buffer_deflate_start_position =
+ reinterpret_cast<unsigned char*>(buffer_.get());
+ int status;
+ do {
+ // Deflate chunk-wise.
+ // Z_SYNC_FLUSH causes all pending output to be flushed, but doesn't
+ // reset the compression state.
+ // As we do not know how big the compressed buffer will be, we compress
+ // chunk wise and append the flushed content to the output string buffer.
+ // As we store the uncompressed size, we do not have to do this during
+ // decompression.
+ status = deflate(&stream_, Z_SYNC_FLUSH);
+ unsigned char* buffer_deflate_end_position =
+ reinterpret_cast<unsigned char*>(stream_.next_out);
+ if (buffer_deflate_end_position != buffer_deflate_start_position) {
+ out->buffer.insert(out->buffer.end(), buffer_deflate_start_position,
+ buffer_deflate_end_position);
+ stream_.next_out = buffer_deflate_start_position;
+ stream_.avail_out = buffer_size_;
+ } else {
+ break;
+ }
+ } while (status == Z_OK);
+}
+
+// Compress rule fields in the model.
+bool CompressModel(ModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ pattern->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(pattern->pattern,
+ pattern->compressed_pattern.get());
+ pattern->pattern.clear();
+ }
+ }
+
+ // Compress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ regex->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(regex->pattern,
+ regex->compressed_pattern.get());
+ regex->pattern.clear();
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ extractor->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(extractor->pattern,
+ extractor->compressed_pattern.get());
+ extractor->pattern.clear();
+ }
+ }
+ return true;
+}
+
+std::string CompressSerializedModel(const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC_CHECK(unpacked_model != nullptr);
+ TC_CHECK(CompressModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
+ const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
+ const CompressedBuffer* compressed_pattern,
+ ZlibDecompressor* decompressor) {
+ UnicodeText unicode_regex_pattern;
+ std::string decompressed_pattern;
+ if (compressed_pattern != nullptr &&
+ compressed_pattern->buffer() != nullptr) {
+ if (decompressor == nullptr ||
+ !decompressor->Decompress(compressed_pattern, &decompressed_pattern)) {
+ TC_LOG(ERROR) << "Cannot decompress pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(decompressed_pattern.data(),
+ decompressed_pattern.size(), /*do_copy=*/false);
+ } else {
+ if (uncompressed_pattern == nullptr) {
+ TC_LOG(ERROR) << "Cannot load uncompressed pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(uncompressed_pattern->c_str(),
+ uncompressed_pattern->Length(), /*do_copy=*/false);
+ }
+
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib.CreateRegexPattern(unicode_regex_pattern);
+ if (!regex_pattern) {
+ TC_LOG(ERROR) << "Could not create pattern: "
+ << unicode_regex_pattern.ToUTF8String();
+ }
+ return regex_pattern;
+}
+
+} // namespace libtextclassifier2
diff --git a/zlib-utils.h b/zlib-utils.h
new file mode 100644
index 0000000..d79f76e
--- /dev/null
+++ b/zlib-utils.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Functions to compress and decompress low entropy entries in the model.
+
+#ifndef LIBTEXTCLASSIFIER_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_ZLIB_UTILS_H_
+
+#include <memory>
+
+#include "model_generated.h"
+#include "util/utf8/unilib.h"
+#include "zlib.h"
+
+namespace libtextclassifier2 {
+
+class ZlibDecompressor {
+ public:
+ static std::unique_ptr<ZlibDecompressor> Instance();
+ ~ZlibDecompressor();
+
+ bool Decompress(const CompressedBuffer* compressed_buffer, std::string* out);
+
+ private:
+ ZlibDecompressor();
+ z_stream stream_;
+ bool initialized_;
+};
+
+class ZlibCompressor {
+ public:
+ static std::unique_ptr<ZlibCompressor> Instance();
+ ~ZlibCompressor();
+
+ void Compress(const std::string& uncompressed_content,
+ CompressedBufferT* out);
+
+ private:
+ explicit ZlibCompressor(int level = Z_BEST_COMPRESSION,
+ // Tmp. buffer size was set based on the current set
+ // of patterns to be compressed.
+ int tmp_buffer_size = 64 * 1024);
+ z_stream stream_;
+ std::unique_ptr<Bytef[]> buffer_;
+ unsigned int buffer_size_;
+ bool initialized_;
+};
+
+// Compresses regex and datetime rules in the model in place.
+bool CompressModel(ModelT* model);
+
+// Compresses regex and datetime rules in the model.
+std::string CompressSerializedModel(const std::string& model);
+
+// Create and compile a regex pattern from optionally compressed pattern.
+std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
+ const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
+ const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor);
+
+} // namespace libtextclassifier2
+
+#endif // LIBTEXTCLASSIFIER_ZLIB_UTILS_H_
diff --git a/zlib-utils_test.cc b/zlib-utils_test.cc
new file mode 100644
index 0000000..d3b5a19
--- /dev/null
+++ b/zlib-utils_test.cc
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "zlib-utils.h"
+
+#include <memory>
+
+#include "model_generated.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier2 {
+
+TEST(ZlibUtilsTest, CompressModel) {
+ ModelT model;
+ model.regex_model.reset(new RegexModelT);
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a test pattern";
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a second test pattern";
+
+ model.datetime_model.reset(new DatetimeModelT);
+ model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT);
+ model.datetime_model->patterns.back()->regexes.emplace_back(
+ new DatetimeModelPattern_::RegexT);
+ model.datetime_model->patterns.back()->regexes.back()->pattern =
+ "an example datetime pattern";
+ model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT);
+ model.datetime_model->extractors.back()->pattern =
+ "an example datetime extractor";
+
+ // Compress the model.
+ EXPECT_TRUE(CompressModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
+ EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
+
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, &model));
+ const Model* compressed_model =
+ GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->Decompress(
+ compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a test pattern");
+ EXPECT_TRUE(decompressor->Decompress(
+ compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a second test pattern");
+ EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model()
+ ->patterns()
+ ->Get(0)
+ ->regexes()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime pattern");
+ EXPECT_TRUE(decompressor->Decompress(compressed_model->datetime_model()
+ ->extractors()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
+}
+
+} // namespace libtextclassifier2