Import libtextclassifier
Test: atest atest framework/base/core/tests/coretests/src/android/view/textclassifier/
Change-Id: I4255dcb44bdef06448d436c4166483eba46cf264
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index bd5f06f..ad492df 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -21,12 +21,14 @@
#include <cmath>
#include <iterator>
#include <numeric>
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -333,6 +335,21 @@
}
}
+ if (model_->entity_data_schema()) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->entity_data_schema()->Data(),
+ model_->entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_builder_ = nullptr;
+ }
+
initialized_ = true;
}
@@ -395,6 +412,18 @@
return true;
}
+bool Annotator::InitializeInstalledAppEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<InstalledAppEngine> installed_app_engine(
+ new InstalledAppEngine(selection_feature_processor_.get()));
+ if (!installed_app_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
+ return false;
+ }
+ installed_app_engine_ = std::move(installed_app_engine);
+ return true;
+}
+
namespace {
int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -410,17 +439,6 @@
return count;
}
-std::string ExtractSelection(const std::string& context,
- CodepointSpan selection_indices) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- auto selection_begin = context_unicode.begin();
- std::advance(selection_begin, selection_indices.first);
- auto selection_end = context_unicode.begin();
- std::advance(selection_end, selection_indices.second);
- return UnicodeText::UTF8Substring(selection_begin, selection_end);
-}
-
bool VerifyCandidate(const VerificationOptions* verification_options,
const std::string& match) {
if (!verification_options) {
@@ -558,7 +576,8 @@
}
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
/*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, &candidates)) {
+ options.locales, ModeFlag_SELECTION,
+ options.annotation_usecase, &candidates)) {
TC3_LOG(ERROR) << "Datetime suggest selection failed.";
return original_click_indices;
}
@@ -571,6 +590,11 @@
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Installed app suggest selection failed.";
+ return original_click_indices;
+ }
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
@@ -1057,7 +1081,8 @@
const std::string& context, CodepointSpan selection_indices,
ClassificationResult* classification_result) const {
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
const UnicodeText selection_text_unicode(
UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
@@ -1082,6 +1107,13 @@
regex_pattern.config->collection_name()->str(),
regex_pattern.config->target_classification_score(),
regex_pattern.config->priority_score()};
+
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(),
+ &classification_result->serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
return true;
}
if (status != UniLib::RegexMatcher::kNoError) {
@@ -1115,12 +1147,14 @@
}
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
options.reference_timezone, options.locales,
ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
/*anchor_start_end=*/true, &datetime_spans)) {
TC3_LOG(ERROR) << "Error during parsing datetime.";
return false;
@@ -1188,6 +1222,18 @@
}
}
+ // Try the installed app engine.
+ ClassificationResult installed_app_result;
+ if (installed_app_engine_ &&
+ installed_app_engine_->ClassifyText(context, selection_indices,
+ &installed_app_result)) {
+ if (!FilteredForClassification(installed_app_result)) {
+ return {installed_app_result};
+ } else {
+ return {{Collections::Other(), 1.0}};
+ }
+ }
+
// Try the regular expression models.
ClassificationResult regex_result;
if (RegexClassifyText(context, selection_indices, ®ex_result)) {
@@ -1377,7 +1423,8 @@
// Annotate with the datetime model.
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
options.reference_time_ms_utc, options.reference_timezone,
- options.locales, ModeFlag_ANNOTATION, &candidates)) {
+ options.locales, ModeFlag_ANNOTATION,
+ options.annotation_usecase, &candidates)) {
TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
return {};
}
@@ -1395,6 +1442,13 @@
return {};
}
+ // Annotate with the installed app engine.
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
+ return {};
+ }
+
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
@@ -1464,6 +1518,66 @@
return result;
}
+bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
+ if (pattern->serialized_entity_data() != nullptr) {
+ return true;
+ }
+ if (pattern->capturing_group() != nullptr) {
+ for (const RegexModel_::Pattern_::CapturingGroup* group :
+ *pattern->capturing_group()) {
+ if (group->entity_field_path() != nullptr) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool Annotator::SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const {
+ if (!HasEntityData(pattern)) {
+ serialized_entity_data->clear();
+ return true;
+ }
+ TC3_CHECK(entity_data_builder_ != nullptr);
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+
+ TC3_CHECK(entity_data != nullptr);
+
+ // Set static entity data.
+ if (pattern->serialized_entity_data() != nullptr) {
+ TC3_CHECK(entity_data != nullptr);
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(pattern->serialized_entity_data()->c_str(),
+ pattern->serialized_entity_data()->size()));
+ }
+
+ // Add entity data from rule capturing groups.
+ if (pattern->capturing_group() != nullptr) {
+ const int num_groups = pattern->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ const FlatbufferFieldPath* field_path =
+ pattern->capturing_group()->Get(i)->entity_field_path();
+ if (field_path == nullptr) {
+ continue;
+ }
+ TC3_CHECK(entity_data != nullptr);
+ if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+ }
+
+ *serialized_entity_data = entity_data->Serialize();
+ return true;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
std::vector<AnnotatedSpan>* result) const {
@@ -1484,6 +1598,14 @@
continue;
}
}
+
+ std::string serialized_entity_data;
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(), &serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
+
result->emplace_back();
// Selection/annotation regular expressions need to specify a capturing
@@ -1495,6 +1617,9 @@
{regex_pattern.config->collection_name()->str(),
regex_pattern.config->target_classification_score(),
regex_pattern.config->priority_score()}};
+
+ result->back().classification[0].serialized_entity_data =
+ serialized_entity_data;
}
}
return true;
@@ -1747,6 +1872,7 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
std::vector<AnnotatedSpan>* result) const {
if (!datetime_parser_) {
return true;
@@ -1755,6 +1881,7 @@
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
reference_timezone, locales, mode,
+ annotation_usecase,
/*anchor_start_end=*/false, &datetime_spans)) {
return false;
}
@@ -1775,6 +1902,9 @@
}
const Model* Annotator::ViewModel() const { return model_; }
+const reflection::Schema* Annotator::entity_data_schema() const {
+ return entity_data_schema_;
+}
const Model* ViewModel(const void* buffer, int size) {
if (!buffer) {
@@ -1784,4 +1914,10 @@
return LoadAndVerifyModel(buffer, size);
}
+bool Annotator::LookUpKnowledgeEntity(
+ const std::string& id, std::string* serialized_knowledge_result) const {
+ return knowledge_engine_ &&
+ knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+}
+
} // namespace libtextclassifier3