Import libtextclassifier
Major change: Support intent config for action generation.
BUG: 123745079
Test: atest TextClassifierTest
Change-Id: Ia6dded6065ff37e6736abc1749a4078593fff87d
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 330cf0b..bd5f06f 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -385,7 +385,8 @@
}
bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
- std::unique_ptr<ContactEngine> contact_engine(new ContactEngine());
+ std::unique_ptr<ContactEngine> contact_engine(
+ new ContactEngine(selection_feature_processor_.get()));
if (!contact_engine->Initialize(serialized_config)) {
TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
return false;
@@ -670,7 +671,7 @@
inline bool ClassifiedAsOther(
const std::vector<ClassificationResult>& classification) {
return !classification.empty() &&
- classification[0].collection == Collections::kOther;
+ classification[0].collection == Collections::Other();
}
float GetPriorityScore(
@@ -936,7 +937,7 @@
if (model_->classification_options()->max_num_tokens() > 0 &&
model_->classification_options()->max_num_tokens() <
selection_num_tokens) {
- *classification_results = {{Collections::kOther, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
return true;
}
@@ -976,7 +977,7 @@
if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
tokens, extraction_span)) {
- *classification_results = {{Collections::kOther, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
return true;
}
@@ -1030,22 +1031,22 @@
// Phone class sanity check.
if (!classification_results->empty() &&
- classification_results->begin()->collection == Collections::kPhone) {
+ classification_results->begin()->collection == Collections::Phone()) {
const int digit_count = CountDigits(context, selection_indices);
if (digit_count <
model_->classification_options()->phone_min_num_digits() ||
digit_count >
model_->classification_options()->phone_max_num_digits()) {
- *classification_results = {{Collections::kOther, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
}
}
// Address class sanity check.
if (!classification_results->empty() &&
- classification_results->begin()->collection == Collections::kAddress) {
+ classification_results->begin()->collection == Collections::Address()) {
if (selection_num_tokens <
model_->classification_options()->address_min_num_tokens()) {
- *classification_results = {{Collections::kOther, 1.0}};
+ *classification_results = {{Collections::Other(), 1.0}};
}
}
@@ -1091,6 +1092,20 @@
return false;
}
+namespace {
+std::string PickCollectionForDatetime(
+ const DatetimeParseResult& datetime_parse_result) {
+ switch (datetime_parse_result.granularity) {
+ case GRANULARITY_HOUR:
+ case GRANULARITY_MINUTE:
+ case GRANULARITY_SECOND:
+ return Collections::DateTime();
+ default:
+ return Collections::Date();
+ }
+}
+} // namespace
+
bool Annotator::DatetimeClassifyText(
const std::string& context, CodepointSpan selection_indices,
const ClassificationOptions& options,
@@ -1118,7 +1133,8 @@
selection_indices) {
for (const DatetimeParseResult& parse_result : datetime_span.data) {
classification_results->emplace_back(
- Collections::kDate, datetime_span.target_classification_score);
+ PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score);
classification_results->back().datetime_parse_result = parse_result;
}
return true;
@@ -1157,7 +1173,7 @@
if (!FilteredForClassification(knowledge_result)) {
return {knowledge_result};
} else {
- return {{Collections::kOther, 1.0}};
+ return {{Collections::Other(), 1.0}};
}
}
@@ -1168,7 +1184,7 @@
if (!FilteredForClassification(contact_result)) {
return {contact_result};
} else {
- return {{Collections::kOther, 1.0}};
+ return {{Collections::Other(), 1.0}};
}
}
@@ -1178,7 +1194,7 @@
if (!FilteredForClassification(regex_result)) {
return {regex_result};
} else {
- return {{Collections::kOther, 1.0}};
+ return {{Collections::Other(), 1.0}};
}
}
@@ -1196,7 +1212,7 @@
if (!datetime_results.empty()) {
return datetime_results;
} else {
- return {{Collections::kOther, 1.0}};
+ return {{Collections::Other(), 1.0}};
}
}
@@ -1211,7 +1227,7 @@
if (!FilteredForClassification(model_result[0])) {
return model_result;
} else {
- return {{Collections::kOther, 1.0}};
+ return {{Collections::Other(), 1.0}};
}
}
@@ -1747,7 +1763,8 @@
AnnotatedSpan annotated_span;
annotated_span.span = datetime_span.span;
annotated_span.classification = {
- {Collections::kDate, datetime_span.target_classification_score,
+ {PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score,
datetime_span.priority_score}};
annotated_span.classification[0].datetime_parse_result = parse_result;