Import libtextclassifier

Test: atest atest framework/base/core/tests/coretests/src/android/view/textclassifier/

Change-Id: I4255dcb44bdef06448d436c4166483eba46cf264
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index bd5f06f..ad492df 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -21,12 +21,14 @@
 #include <cmath>
 #include <iterator>
 #include <numeric>
+#include "annotator/model_generated.h"
 #include "annotator/types.h"
 
 #include "annotator/collections.h"
 #include "utils/base/logging.h"
 #include "utils/checksum.h"
 #include "utils/math/softmax.h"
+#include "utils/regex-match.h"
 #include "utils/utf8/unicodetext.h"
 
 namespace libtextclassifier3 {
@@ -333,6 +335,21 @@
     }
   }
 
+  if (model_->entity_data_schema()) {
+    entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+        model_->entity_data_schema()->Data(),
+        model_->entity_data_schema()->size());
+    if (entity_data_schema_ == nullptr) {
+      TC3_LOG(ERROR) << "Could not load entity data schema data.";
+      return;
+    }
+
+    entity_data_builder_.reset(
+        new ReflectiveFlatbufferBuilder(entity_data_schema_));
+  } else {
+    entity_data_builder_ = nullptr;
+  }
+
   initialized_ = true;
 }
 
@@ -395,6 +412,18 @@
   return true;
 }
 
+bool Annotator::InitializeInstalledAppEngine(
+    const std::string& serialized_config) {
+  std::unique_ptr<InstalledAppEngine> installed_app_engine(
+      new InstalledAppEngine(selection_feature_processor_.get()));
+  if (!installed_app_engine->Initialize(serialized_config)) {
+    TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
+    return false;
+  }
+  installed_app_engine_ = std::move(installed_app_engine);
+  return true;
+}
+
 namespace {
 
 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -410,17 +439,6 @@
   return count;
 }
 
-std::string ExtractSelection(const std::string& context,
-                             CodepointSpan selection_indices) {
-  const UnicodeText context_unicode =
-      UTF8ToUnicodeText(context, /*do_copy=*/false);
-  auto selection_begin = context_unicode.begin();
-  std::advance(selection_begin, selection_indices.first);
-  auto selection_end = context_unicode.begin();
-  std::advance(selection_end, selection_indices.second);
-  return UnicodeText::UTF8Substring(selection_begin, selection_end);
-}
-
 bool VerifyCandidate(const VerificationOptions* verification_options,
                      const std::string& match) {
   if (!verification_options) {
@@ -558,7 +576,8 @@
   }
   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
-                     options.locales, ModeFlag_SELECTION, &candidates)) {
+                     options.locales, ModeFlag_SELECTION,
+                     options.annotation_usecase, &candidates)) {
     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
     return original_click_indices;
   }
@@ -571,6 +590,11 @@
     TC3_LOG(ERROR) << "Contact suggest selection failed.";
     return original_click_indices;
   }
+  if (installed_app_engine_ &&
+      !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+    TC3_LOG(ERROR) << "Installed app suggest selection failed.";
+    return original_click_indices;
+  }
 
   // Sort candidates according to their position in the input, so that the next
   // code can assume that any connected component of overlapping spans forms a
@@ -1057,7 +1081,8 @@
     const std::string& context, CodepointSpan selection_indices,
     ClassificationResult* classification_result) const {
   const std::string selection_text =
-      ExtractSelection(context, selection_indices);
+      UTF8ToUnicodeText(context, /*do_copy=*/false)
+          .UTF8Substring(selection_indices.first, selection_indices.second);
   const UnicodeText selection_text_unicode(
       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
 
@@ -1082,6 +1107,13 @@
           regex_pattern.config->collection_name()->str(),
           regex_pattern.config->target_classification_score(),
           regex_pattern.config->priority_score()};
+
+      if (!SerializedEntityDataFromRegexMatch(
+              regex_pattern.config, matcher.get(),
+              &classification_result->serialized_entity_data)) {
+        TC3_LOG(ERROR) << "Could not get entity data.";
+        return false;
+      }
       return true;
     }
     if (status != UniLib::RegexMatcher::kNoError) {
@@ -1115,12 +1147,14 @@
   }
 
   const std::string selection_text =
-      ExtractSelection(context, selection_indices);
+      UTF8ToUnicodeText(context, /*do_copy=*/false)
+          .UTF8Substring(selection_indices.first, selection_indices.second);
 
   std::vector<DatetimeParseResultSpan> datetime_spans;
   if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
                                options.reference_timezone, options.locales,
                                ModeFlag_CLASSIFICATION,
+                               options.annotation_usecase,
                                /*anchor_start_end=*/true, &datetime_spans)) {
     TC3_LOG(ERROR) << "Error during parsing datetime.";
     return false;
@@ -1188,6 +1222,18 @@
     }
   }
 
+  // Try the installed app engine.
+  ClassificationResult installed_app_result;
+  if (installed_app_engine_ &&
+      installed_app_engine_->ClassifyText(context, selection_indices,
+                                          &installed_app_result)) {
+    if (!FilteredForClassification(installed_app_result)) {
+      return {installed_app_result};
+    } else {
+      return {{Collections::Other(), 1.0}};
+    }
+  }
+
   // Try the regular expression models.
   ClassificationResult regex_result;
   if (RegexClassifyText(context, selection_indices, &regex_result)) {
@@ -1377,7 +1423,8 @@
   // Annotate with the datetime model.
   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
                      options.reference_time_ms_utc, options.reference_timezone,
-                     options.locales, ModeFlag_ANNOTATION, &candidates)) {
+                     options.locales, ModeFlag_ANNOTATION,
+                     options.annotation_usecase, &candidates)) {
     TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
     return {};
   }
@@ -1395,6 +1442,13 @@
     return {};
   }
 
+  // Annotate with the installed app engine.
+  if (installed_app_engine_ &&
+      !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+    TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
+    return {};
+  }
+
   // Sort candidates according to their position in the input, so that the next
   // code can assume that any connected component of overlapping spans forms a
   // contiguous block.
@@ -1464,6 +1518,66 @@
   return result;
 }
 
+bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
+  if (pattern->serialized_entity_data() != nullptr) {
+    return true;
+  }
+  if (pattern->capturing_group() != nullptr) {
+    for (const RegexModel_::Pattern_::CapturingGroup* group :
+         *pattern->capturing_group()) {
+      if (group->entity_field_path() != nullptr) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool Annotator::SerializedEntityDataFromRegexMatch(
+    const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+    std::string* serialized_entity_data) const {
+  if (!HasEntityData(pattern)) {
+    serialized_entity_data->clear();
+    return true;
+  }
+  TC3_CHECK(entity_data_builder_ != nullptr);
+
+  std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+      entity_data_builder_->NewRoot();
+
+  TC3_CHECK(entity_data != nullptr);
+
+  // Set static entity data.
+  if (pattern->serialized_entity_data() != nullptr) {
+    TC3_CHECK(entity_data != nullptr);
+    entity_data->MergeFromSerializedFlatbuffer(
+        StringPiece(pattern->serialized_entity_data()->c_str(),
+                    pattern->serialized_entity_data()->size()));
+  }
+
+  // Add entity data from rule capturing groups.
+  if (pattern->capturing_group() != nullptr) {
+    const int num_groups = pattern->capturing_group()->size();
+    for (int i = 0; i < num_groups; i++) {
+      const FlatbufferFieldPath* field_path =
+          pattern->capturing_group()->Get(i)->entity_field_path();
+      if (field_path == nullptr) {
+        continue;
+      }
+      TC3_CHECK(entity_data != nullptr);
+      if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
+                                      entity_data.get())) {
+        TC3_LOG(ERROR)
+            << "Could not set entity data from rule capturing group.";
+        return false;
+      }
+    }
+  }
+
+  *serialized_entity_data = entity_data->Serialize();
+  return true;
+}
+
 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
                            const std::vector<int>& rules,
                            std::vector<AnnotatedSpan>* result) const {
@@ -1484,6 +1598,14 @@
           continue;
         }
       }
+
+      std::string serialized_entity_data;
+      if (!SerializedEntityDataFromRegexMatch(
+              regex_pattern.config, matcher.get(), &serialized_entity_data)) {
+        TC3_LOG(ERROR) << "Could not get entity data.";
+        return false;
+      }
+
       result->emplace_back();
 
       // Selection/annotation regular expressions need to specify a capturing
@@ -1495,6 +1617,9 @@
           {regex_pattern.config->collection_name()->str(),
            regex_pattern.config->target_classification_score(),
            regex_pattern.config->priority_score()}};
+
+      result->back().classification[0].serialized_entity_data =
+          serialized_entity_data;
     }
   }
   return true;
@@ -1747,6 +1872,7 @@
                               int64 reference_time_ms_utc,
                               const std::string& reference_timezone,
                               const std::string& locales, ModeFlag mode,
+                              AnnotationUsecase annotation_usecase,
                               std::vector<AnnotatedSpan>* result) const {
   if (!datetime_parser_) {
     return true;
@@ -1755,6 +1881,7 @@
   std::vector<DatetimeParseResultSpan> datetime_spans;
   if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
                                reference_timezone, locales, mode,
+                               annotation_usecase,
                                /*anchor_start_end=*/false, &datetime_spans)) {
     return false;
   }
@@ -1775,6 +1902,9 @@
 }
 
 const Model* Annotator::ViewModel() const { return model_; }
+const reflection::Schema* Annotator::entity_data_schema() const {
+  return entity_data_schema_;
+}
 
 const Model* ViewModel(const void* buffer, int size) {
   if (!buffer) {
@@ -1784,4 +1914,10 @@
   return LoadAndVerifyModel(buffer, size);
 }
 
+bool Annotator::LookUpKnowledgeEntity(
+    const std::string& id, std::string* serialized_knowledge_result) const {
+  return knowledge_engine_ &&
+         knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+}
+
 }  // namespace libtextclassifier3