Fixes utf8 handling, datetime model and flight number model, makes
models smaller by compressing the regex rules, and adds i18n models.
(sync from google3)
Test: bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Test: bit CtsViewTestCases:android.view.textclassifier.cts.TextClassificationManagerTest
Bug: 64929062
Bug: 77223425
Change-Id: I04472e6b247e824bf2b745077c50fcde4269aefc
diff --git a/datetime/extractor.cc b/datetime/extractor.cc
index 8c6c3ff..f4ab8f4 100644
--- a/datetime/extractor.cc
+++ b/datetime/extractor.cc
@@ -20,156 +20,119 @@
namespace libtextclassifier2 {
-constexpr char const* kGroupYear = "YEAR";
-constexpr char const* kGroupMonth = "MONTH";
-constexpr char const* kGroupDay = "DAY";
-constexpr char const* kGroupHour = "HOUR";
-constexpr char const* kGroupMinute = "MINUTE";
-constexpr char const* kGroupSecond = "SECOND";
-constexpr char const* kGroupAmpm = "AMPM";
-constexpr char const* kGroupRelationDistance = "RELATIONDISTANCE";
-constexpr char const* kGroupRelation = "RELATION";
-constexpr char const* kGroupRelationType = "RELATIONTYPE";
-// Dummy groups serve just as an inflator of the selection. E.g. we might want
-// to select more text than was contained in an envelope of all extractor spans.
-constexpr char const* kGroupDummy1 = "DUMMY1";
-constexpr char const* kGroupDummy2 = "DUMMY2";
-
bool DatetimeExtractor::Extract(DateParseData* result,
CodepointSpan* result_span) const {
result->field_set_mask = 0;
*result_span = {kInvalidIndex, kInvalidIndex};
- UnicodeText group_text;
- if (GroupNotEmpty(kGroupYear, &group_text)) {
- result->field_set_mask |= DateParseData::YEAR_FIELD;
- if (!ParseYear(group_text, &(result->year))) {
- TC_LOG(ERROR) << "Couldn't extract YEAR.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupYear, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
+ if (rule_.regex->groups() == nullptr) {
+ return false;
}
- if (GroupNotEmpty(kGroupMonth, &group_text)) {
- result->field_set_mask |= DateParseData::MONTH_FIELD;
- if (!ParseMonth(group_text, &(result->month))) {
- TC_LOG(ERROR) << "Couldn't extract MONTH.";
+ for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
+ UnicodeText group_text;
+ const int group_type = rule_.regex->groups()->Get(group_id);
+ if (group_type == DatetimeGroupType_GROUP_UNUSED) {
+ continue;
+ }
+ if (!GroupTextFromMatch(group_id, &group_text)) {
+ TC_LOG(ERROR) << "Couldn't retrieve group.";
return false;
}
- if (!UpdateMatchSpan(kGroupMonth, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
+ // The pattern can have a group defined in a part that was not matched,
+ // e.g. an optional part. In this case we'll get an empty content here.
+ if (group_text.empty()) {
+ continue;
}
- }
-
- if (GroupNotEmpty(kGroupDay, &group_text)) {
- result->field_set_mask |= DateParseData::DAY_FIELD;
- if (!ParseDigits(group_text, &(result->day_of_month))) {
- TC_LOG(ERROR) << "Couldn't extract DAY.";
- return false;
+ switch (group_type) {
+ case DatetimeGroupType_GROUP_YEAR: {
+ if (!ParseYear(group_text, &(result->year))) {
+ TC_LOG(ERROR) << "Couldn't extract YEAR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::YEAR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MONTH: {
+ if (!ParseMonth(group_text, &(result->month))) {
+ TC_LOG(ERROR) << "Couldn't extract MONTH.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MONTH_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DAY: {
+ if (!ParseDigits(group_text, &(result->day_of_month))) {
+ TC_LOG(ERROR) << "Couldn't extract DAY.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::DAY_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_HOUR: {
+ if (!ParseDigits(group_text, &(result->hour))) {
+ TC_LOG(ERROR) << "Couldn't extract HOUR.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::HOUR_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_MINUTE: {
+ if (!ParseDigits(group_text, &(result->minute))) {
+ TC_LOG(ERROR) << "Couldn't extract MINUTE.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::MINUTE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_SECOND: {
+ if (!ParseDigits(group_text, &(result->second))) {
+ TC_LOG(ERROR) << "Couldn't extract SECOND.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::SECOND_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_AMPM: {
+ if (!ParseAMPM(group_text, &(result->ampm))) {
+ TC_LOG(ERROR) << "Couldn't extract AMPM.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::AMPM_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
+ if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATION: {
+ if (!ParseRelation(group_text, &(result->relation))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONTYPE: {
+ if (!ParseRelationType(group_text, &(result->relation_type))) {
+ TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
+ break;
+ }
+ case DatetimeGroupType_GROUP_DUMMY1:
+ case DatetimeGroupType_GROUP_DUMMY2:
+ break;
+ default:
+ TC_LOG(INFO) << "Unknown group type.";
+ continue;
}
- if (!UpdateMatchSpan(kGroupDay, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupHour, &group_text)) {
- result->field_set_mask |= DateParseData::HOUR_FIELD;
- if (!ParseDigits(group_text, &(result->hour))) {
- TC_LOG(ERROR) << "Couldn't extract HOUR.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupHour, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupMinute, &group_text)) {
- result->field_set_mask |= DateParseData::MINUTE_FIELD;
- if (!ParseDigits(group_text, &(result->minute))) {
- TC_LOG(ERROR) << "Couldn't extract MINUTE.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupMinute, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupSecond, &group_text)) {
- result->field_set_mask |= DateParseData::SECOND_FIELD;
- if (!ParseDigits(group_text, &(result->second))) {
- TC_LOG(ERROR) << "Couldn't extract SECOND.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupSecond, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupAmpm, &group_text)) {
- result->field_set_mask |= DateParseData::AMPM_FIELD;
- if (!ParseAMPM(group_text, &(result->ampm))) {
- TC_LOG(ERROR) << "Couldn't extract AMPM.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupAmpm, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelationDistance, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
- if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelationDistance, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelation, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_FIELD;
- if (!ParseRelation(group_text, &(result->relation))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelation, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupRelationType, &group_text)) {
- result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
- if (!ParseRelationType(group_text, &(result->relation_type))) {
- TC_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
- return false;
- }
- if (!UpdateMatchSpan(kGroupRelationType, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupDummy1, &group_text)) {
- if (!UpdateMatchSpan(kGroupDummy1, result_span)) {
- TC_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (GroupNotEmpty(kGroupDummy2, &group_text)) {
- if (!UpdateMatchSpan(kGroupDummy2, result_span)) {
+ if (!UpdateMatchSpan(group_id, result_span)) {
TC_LOG(ERROR) << "Couldn't update span.";
return false;
}
@@ -226,24 +189,24 @@
return true;
}
-bool DatetimeExtractor::GroupNotEmpty(StringPiece name,
- UnicodeText* result) const {
+bool DatetimeExtractor::GroupTextFromMatch(int group_id,
+ UnicodeText* result) const {
int status;
- *result = matcher_.Group(name, &status);
+ *result = matcher_.Group(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- return !result->empty();
+ return true;
}
-bool DatetimeExtractor::UpdateMatchSpan(StringPiece name,
+bool DatetimeExtractor::UpdateMatchSpan(int group_id,
CodepointSpan* span) const {
int status;
- const int match_start = matcher_.Start(name, &status);
+ const int match_start = matcher_.Start(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- const int match_end = matcher_.End(name, &status);
+ const int match_end = matcher_.End(group_id, &status);
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
diff --git a/datetime/extractor.h b/datetime/extractor.h
index ceeb9cf..5c36ec4 100644
--- a/datetime/extractor.h
+++ b/datetime/extractor.h
@@ -29,18 +29,31 @@
namespace libtextclassifier2 {
+struct CompiledRule {
+ // The compiled regular expression.
+ std::unique_ptr<const UniLib::RegexPattern> compiled_regex;
+
+ // The uncompiled pattern and information about the pattern groups.
+ const DatetimeModelPattern_::Regex* regex;
+
+ // DatetimeModelPattern which 'regex' is part of and comes from.
+ const DatetimeModelPattern* pattern;
+};
+
// A helper class for DatetimeParser that extracts structured data
// (DateParseDate) from the current match of the passed RegexMatcher.
class DatetimeExtractor {
public:
DatetimeExtractor(
- const UniLib::RegexMatcher& matcher, int locale_id, const UniLib& unilib,
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int locale_id, const UniLib& unilib,
const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
extractor_rules,
const std::unordered_map<DatetimeExtractorType,
std::unordered_map<int, int>>&
type_and_locale_to_extractor_rule)
- : matcher_(matcher),
+ : rule_(rule),
+ matcher_(matcher),
locale_id_(locale_id),
unilib_(unilib),
rules_(extractor_rules),
@@ -57,10 +70,10 @@
DatetimeExtractorType extractor_type,
UnicodeText* match_result = nullptr) const;
- bool GroupNotEmpty(StringPiece name, UnicodeText* result) const;
+ bool GroupTextFromMatch(int group_id, UnicodeText* result) const;
// Updates the span to include the current match for the given group.
- bool UpdateMatchSpan(StringPiece group_name, CodepointSpan* span) const;
+ bool UpdateMatchSpan(int group_id, CodepointSpan* span) const;
// Returns true if any of the extractors from 'mapping' matched. If it did,
// will fill 'result' with the associated value from 'mapping'.
@@ -84,6 +97,7 @@
DateParseData::RelationType* parsed_relation_type) const;
bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+ const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
int locale_id_;
const UniLib& unilib_;
diff --git a/datetime/parser.cc b/datetime/parser.cc
index 8ad3d33..e9a6eb1 100644
--- a/datetime/parser.cc
+++ b/datetime/parser.cc
@@ -21,61 +21,78 @@
#include "datetime/extractor.h"
#include "util/calendar/calendar.h"
+#include "util/i18n/locale.h"
#include "util/strings/split.h"
namespace libtextclassifier2 {
std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
- const DatetimeModel* model, const UniLib& unilib) {
- std::unique_ptr<DatetimeParser> result(new DatetimeParser(model, unilib));
+ const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor) {
+ std::unique_ptr<DatetimeParser> result(
+ new DatetimeParser(model, unilib, decompressor));
if (!result->initialized_) {
result.reset();
}
return result;
}
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib)
+DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor)
: unilib_(unilib) {
initialized_ = false;
- for (int i = 0; i < model->patterns()->Length(); ++i) {
- const DatetimeModelPattern* pattern = model->patterns()->Get(i);
- for (int j = 0; j < pattern->regexes()->Length(); ++j) {
+
+ if (model == nullptr) {
+ return;
+ }
+
+ if (model->patterns() != nullptr) {
+ for (const DatetimeModelPattern* pattern : *model->patterns()) {
+ if (pattern->regexes()) {
+ for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(unilib, regex->pattern(),
+ regex->compressed_pattern(),
+ decompressor);
+ if (!regex_pattern) {
+ TC_LOG(ERROR) << "Couldn't create rule pattern.";
+ return;
+ }
+ rules_.push_back({std::move(regex_pattern), regex, pattern});
+ if (pattern->locales()) {
+ for (int locale : *pattern->locales()) {
+ locale_to_rules_[locale].push_back(rules_.size() - 1);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (model->extractors() != nullptr) {
+ for (const DatetimeModelExtractor* extractor : *model->extractors()) {
std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib.CreateRegexPattern(UTF8ToUnicodeText(
- pattern->regexes()->Get(j)->str(), /*do_copy=*/false));
+ UncompressMakeRegexPattern(unilib, extractor->pattern(),
+ extractor->compressed_pattern(),
+ decompressor);
if (!regex_pattern) {
- TC_LOG(ERROR) << "Couldn't create pattern: "
- << pattern->regexes()->Get(j)->str();
+ TC_LOG(ERROR) << "Couldn't create extractor pattern";
return;
}
- rules_.push_back(std::move(regex_pattern));
- rule_id_to_pattern_.push_back(pattern);
- for (int k = 0; k < pattern->locales()->Length(); ++k) {
- locale_to_rules_[pattern->locales()->Get(k)].push_back(rules_.size() -
- 1);
+ extractor_rules_.push_back(std::move(regex_pattern));
+
+ if (extractor->locales()) {
+ for (int locale : *extractor->locales()) {
+ type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
+ extractor_rules_.size() - 1;
+ }
}
}
}
- for (int i = 0; i < model->extractors()->Length(); ++i) {
- const DatetimeModelExtractor* extractor = model->extractors()->Get(i);
- std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- unilib.CreateRegexPattern(
- UTF8ToUnicodeText(extractor->pattern()->str(), /*do_copy=*/false));
- if (!regex_pattern) {
- TC_LOG(ERROR) << "Couldn't create pattern: "
- << extractor->pattern()->str();
- return;
+ if (model->locales() != nullptr) {
+ for (int i = 0; i < model->locales()->Length(); ++i) {
+ locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
- extractor_rules_.push_back(std::move(regex_pattern));
-
- for (int j = 0; j < extractor->locales()->Length(); ++j) {
- type_and_locale_to_extractor_rule_[extractor->extractor()]
- [extractor->locales()->Get(j)] = i;
- }
- }
-
- for (int i = 0; i < model->locales()->Length(); ++i) {
- locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
use_extractors_for_locating_ = model->use_extractors_for_locating();
@@ -86,19 +103,21 @@
bool DatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locales, mode,
- results);
+ anchor_start_end, results);
}
bool DatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, std::vector<DatetimeParseResultSpan>* results) const {
+ ModeFlag mode, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
std::vector<DatetimeParseResultSpan> found_spans;
std::unordered_set<int> executed_rules;
- for (const int locale_id : ParseLocales(locales)) {
+ for (const int locale_id : ParseAndExpandLocales(locales)) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
continue;
@@ -110,26 +129,45 @@
continue;
}
- if (!(rule_id_to_pattern_[rule_id]->enabled_modes() & mode)) {
+ if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
continue;
}
executed_rules.insert(rule_id);
- if (!ParseWithRule(*rules_[rule_id], rule_id_to_pattern_[rule_id], input,
- reference_time_ms_utc, reference_timezone, locale_id,
+ if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, locale_id, anchor_start_end,
&found_spans)) {
return false;
}
}
}
- // Resolve conflicts by always picking the longer span.
- std::sort(
- found_spans.begin(), found_spans.end(),
- [](const DatetimeParseResultSpan& a, const DatetimeParseResultSpan& b) {
- return (a.span.second - a.span.first) > (b.span.second - b.span.first);
- });
+ std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
+ int counter = 0;
+ for (const auto& found_span : found_spans) {
+ indexed_found_spans.push_back({found_span, counter});
+ counter++;
+ }
+
+ // Resolve conflicts by always picking the longer span and breaking ties by
+ // selecting the earlier entry in the list for a given locale.
+ std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
+
+ found_spans.clear();
+ for (auto& span_index_pair : indexed_found_spans) {
+ found_spans.push_back(span_index_pair.first);
+ }
std::set<int, std::function<bool(int, int)>> chosen_indices_set(
[&found_spans](int a, int b) {
@@ -145,60 +183,119 @@
return true;
}
-bool DatetimeParser::ParseWithRule(
- const UniLib::RegexPattern& regex, const DatetimeModelPattern* pattern,
- const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
- std::unique_ptr<UniLib::RegexMatcher> matcher = regex.Matcher(input);
-
+bool DatetimeParser::HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ int locale_id, std::vector<DatetimeParseResultSpan>* result) const {
int status = UniLib::RegexMatcher::kNoError;
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- const int start = matcher->Start(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
+ const int start = matcher.Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
- const int end = matcher->End(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
+ const int end = matcher.End(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
- DatetimeParseResultSpan parse_result;
- if (!ExtractDatetime(*matcher, reference_time_ms_utc, reference_timezone,
- locale_id, &(parse_result.data), &parse_result.span)) {
- return false;
- }
- if (!use_extractors_for_locating_) {
- parse_result.span = {start, end};
- }
+ DatetimeParseResultSpan parse_result;
+ if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
+ locale_id, &(parse_result.data), &parse_result.span)) {
+ return false;
+ }
+ if (!use_extractors_for_locating_) {
+ parse_result.span = {start, end};
+ }
+ if (parse_result.span.first != kInvalidIndex &&
+ parse_result.span.second != kInvalidIndex) {
parse_result.target_classification_score =
- pattern->target_classification_score();
- parse_result.priority_score = pattern->priority_score();
-
+ rule.pattern->target_classification_score();
+ parse_result.priority_score = rule.pattern->priority_score();
result->push_back(parse_result);
}
return true;
}
+bool DatetimeParser::ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const int locale_id, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* result) const {
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.compiled_regex->Matcher(input);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (anchor_start_end) {
+ if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, locale_id, result)) {
+ return false;
+ }
+ }
+ } else {
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, locale_id, result)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
constexpr char const* kDefaultLocale = "";
-std::vector<int> DatetimeParser::ParseLocales(
+std::vector<int> DatetimeParser::ParseAndExpandLocales(
const std::string& locales) const {
- std::vector<std::string> split_locales = strings::Split(locales, ',');
-
- // Add a default fallback locale to the end of the list.
- split_locales.push_back(kDefaultLocale);
+ std::vector<StringPiece> split_locales = strings::Split(locales, ',');
std::vector<int> result;
- for (const std::string& locale : split_locales) {
- auto locale_it = locale_string_to_id_.find(locale);
- if (locale_it == locale_string_to_id_.end()) {
- TC_LOG(INFO) << "Ignoring locale: " << locale;
+ for (const StringPiece& locale_str : split_locales) {
+ auto locale_it = locale_string_to_id_.find(locale_str.ToString());
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
continue;
}
- result.push_back(locale_it->second);
+
+ const std::string language = locale.Language();
+ const std::string script = locale.Script();
+ const std::string region = locale.Region();
+
+ // First, try adding language-script-* locale.
+ if (!script.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Second, try adding language-* locale.
+ if (!language.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+
+ // Second, try adding *-region locale.
+ if (!region.empty()) {
+ locale_it = locale_string_to_id_.find("*-" + region);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
}
+
+ // Add a default fallback locale to the end of the list.
+ auto locale_it = locale_string_to_id_.find(kDefaultLocale);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ } else {
+ TC_VLOG(1) << "Could not add default locale.";
+ }
+
return result;
}
@@ -250,13 +347,15 @@
} // namespace
-bool DatetimeParser::ExtractDatetime(const UniLib::RegexMatcher& matcher,
+bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
int locale_id, DatetimeParseResult* result,
CodepointSpan* result_span) const {
DateParseData parse;
- DatetimeExtractor extractor(matcher, locale_id, unilib_, extractor_rules_,
+ DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
+ extractor_rules_,
type_and_locale_to_extractor_rule_);
if (!extractor.Extract(&parse, result_span)) {
return false;
diff --git a/datetime/parser.h b/datetime/parser.h
index 9f31142..c9d2119 100644
--- a/datetime/parser.h
+++ b/datetime/parser.h
@@ -22,11 +22,13 @@
#include <unordered_map>
#include <vector>
+#include "datetime/extractor.h"
#include "model_generated.h"
#include "types.h"
#include "util/base/integral_types.h"
#include "util/calendar/calendar.h"
#include "util/utf8/unilib.h"
+#include "zlib-utils.h"
namespace libtextclassifier2 {
@@ -34,46 +36,58 @@
// time.
class DatetimeParser {
public:
- static std::unique_ptr<DatetimeParser> Instance(const DatetimeModel* model,
- const UniLib& unilib);
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor);
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
bool Parse(const std::string& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode,
+ ModeFlag mode, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
// Same as above but takes UnicodeText.
bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode,
+ ModeFlag mode, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
protected:
- DatetimeParser(const DatetimeModel* model, const UniLib& unilib);
+ DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
+ ZlibDecompressor* decompressor);
// Returns a list of locale ids for given locale spec string (comma-separated
// locale names).
- std::vector<int> ParseLocales(const std::string& locales) const;
- bool ParseWithRule(const UniLib::RegexPattern& regex,
- const DatetimeModelPattern* pattern,
- const UnicodeText& input, int64 reference_time_ms_utc,
+ std::vector<int> ParseAndExpandLocales(const std::string& locales) const;
+
+ bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc,
const std::string& reference_timezone, const int locale_id,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
// Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const UniLib::RegexMatcher& matcher,
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
int64 reference_time_ms_utc,
const std::string& reference_timezone, int locale_id,
DatetimeParseResult* result,
CodepointSpan* result_span) const;
+ // Parse and extract information from current match in 'matcher'.
+ bool HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
private:
bool initialized_;
const UniLib& unilib_;
- std::vector<const DatetimeModelPattern*> rule_id_to_pattern_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> rules_;
+ std::vector<CompiledRule> rules_;
std::unordered_map<int, std::vector<int>> locale_to_rules_;
std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
diff --git a/datetime/parser_test.cc b/datetime/parser_test.cc
index 1df959f..36525e2 100644
--- a/datetime/parser_test.cc
+++ b/datetime/parser_test.cc
@@ -25,6 +25,7 @@
#include "datetime/parser.h"
#include "model_generated.h"
+#include "text-classifier.h"
#include "types-test-util.h"
using testing::ElementsAreArray;
@@ -54,15 +55,27 @@
public:
void SetUp() override {
model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
- const Model* model = GetModel(model_buffer_.data());
- ASSERT_TRUE(model != nullptr);
- ASSERT_TRUE(model->datetime_model() != nullptr);
- parser_ = DatetimeParser::Instance(model->datetime_model(), unilib_);
+ classifier_ = TextClassifier::FromUnownedBuffer(
+ model_buffer_.data(), model_buffer_.size(), &unilib_);
+ TC_CHECK(classifier_);
+ parser_ = classifier_->DatetimeParserForTests();
+ }
+
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich") {
+ std::vector<DatetimeParseResultSpan> results;
+ if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
+ anchor_start_end, &results)) {
+ TC_LOG(ERROR) << text;
+ TC_CHECK(false);
+ }
+ return results.empty();
}
bool ParsesCorrectly(const std::string& marked_text,
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity,
+ bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich") {
auto expected_start_index = marked_text.find("{");
EXPECT_TRUE(expected_start_index != std::string::npos);
@@ -80,7 +93,7 @@
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- &results)) {
+ anchor_start_end, &results)) {
TC_LOG(ERROR) << text;
TC_CHECK(false);
}
@@ -115,7 +128,8 @@
protected:
std::string model_buffer_;
- std::unique_ptr<DatetimeParser> parser_;
+ std::unique_ptr<TextClassifier> classifier_;
+ const DatetimeParser* parser_;
UniLib unilib_;
};
@@ -129,7 +143,7 @@
TEST_F(ParserTest, Parse) {
EXPECT_TRUE(
ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
+ EXPECT_TRUE(ParsesCorrectly("{1 2 2018}", 1514847600000, GRANULARITY_DAY));
EXPECT_TRUE(
ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
@@ -205,6 +219,7 @@
EXPECT_TRUE(ParsesCorrectly("{today}", -3600000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectly("{today}", -57600000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
"America/Los_Angeles"));
EXPECT_TRUE(ParsesCorrectly("{next week}", 255600000, GRANULARITY_WEEK));
EXPECT_TRUE(ParsesCorrectly("{next day}", 82800000, GRANULARITY_DAY));
@@ -224,7 +239,103 @@
EXPECT_TRUE(ParsesCorrectly("{three days ago}", -262800000, GRANULARITY_DAY));
}
-// TODO(zilka): Add a test that tests multiple locales.
+TEST_F(ParserTest, ParseWithAnchor) {
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
+ GRANULARITY_DAY, /*anchor_start_end=*/true));
+ EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
+ GRANULARITY_DAY, /*anchor_start_end=*/false));
+ EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
+ /*anchor_start_end=*/true));
+}
+
+class ParserLocaleTest : public testing::Test {
+ public:
+ void SetUp() override;
+ bool HasResult(const std::string& input, const std::string& locales);
+
+ protected:
+ UniLib unilib_;
+ flatbuffers::FlatBufferBuilder builder_;
+ std::unique_ptr<DatetimeParser> parser_;
+};
+
+void AddPattern(const std::string& regex, int locale,
+ std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
+ patterns->emplace_back(new DatetimeModelPatternT);
+ patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
+ patterns->back()->regexes.back()->pattern = regex;
+ patterns->back()->regexes.back()->groups.push_back(
+ DatetimeGroupType_GROUP_UNUSED);
+ patterns->back()->locales.push_back(locale);
+}
+
+void ParserLocaleTest::SetUp() {
+ DatetimeModelT model;
+ model.use_extractors_for_locating = false;
+ model.locales.clear();
+ model.locales.push_back("en-US");
+ model.locales.push_back("en-CH");
+ model.locales.push_back("zh-Hant");
+ model.locales.push_back("en-*");
+ model.locales.push_back("zh-Hant-*");
+ model.locales.push_back("*-CH");
+ model.locales.push_back("");
+
+ AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
+ AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
+ AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
+ AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
+ AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
+ AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
+
+ builder_.Finish(DatetimeModel::Pack(builder_, &model));
+ const DatetimeModel* model_fb =
+ flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
+ ASSERT_TRUE(model_fb);
+
+ parser_ = DatetimeParser::Instance(model_fb, unilib_,
+ /*decompressor=*/nullptr);
+ ASSERT_TRUE(parser_);
+}
+
+bool ParserLocaleTest::HasResult(const std::string& input,
+ const std::string& locales) {
+ std::vector<DatetimeParseResultSpan> results;
+ EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales,
+ ModeFlag_ANNOTATION, false, &results));
+ return results.size() == 1;
+}
+
+TEST_F(ParserLocaleTest, English) {
+ EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
+ EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
+
+TEST_F(ParserLocaleTest, TraditionalChinese) {
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
+ EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
+ EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
+}
+
+TEST_F(ParserLocaleTest, SwissEnglish) {
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
+ EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
+ EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
+ EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
+}
} // namespace
} // namespace libtextclassifier2