Import libtextclassifier
Test: atest atest framework/base/core/tests/coretests/src/android/view/textclassifier/
Change-Id: I4255dcb44bdef06448d436c4166483eba46cf264
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index bd5f06f..ad492df 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -21,12 +21,14 @@
#include <cmath>
#include <iterator>
#include <numeric>
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -333,6 +335,21 @@
}
}
+ if (model_->entity_data_schema()) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->entity_data_schema()->Data(),
+ model_->entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_builder_ = nullptr;
+ }
+
initialized_ = true;
}
@@ -395,6 +412,18 @@
return true;
}
+bool Annotator::InitializeInstalledAppEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<InstalledAppEngine> installed_app_engine(
+ new InstalledAppEngine(selection_feature_processor_.get()));
+ if (!installed_app_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
+ return false;
+ }
+ installed_app_engine_ = std::move(installed_app_engine);
+ return true;
+}
+
namespace {
int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -410,17 +439,6 @@
return count;
}
-std::string ExtractSelection(const std::string& context,
- CodepointSpan selection_indices) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- auto selection_begin = context_unicode.begin();
- std::advance(selection_begin, selection_indices.first);
- auto selection_end = context_unicode.begin();
- std::advance(selection_end, selection_indices.second);
- return UnicodeText::UTF8Substring(selection_begin, selection_end);
-}
-
bool VerifyCandidate(const VerificationOptions* verification_options,
const std::string& match) {
if (!verification_options) {
@@ -558,7 +576,8 @@
}
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
/*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, &candidates)) {
+ options.locales, ModeFlag_SELECTION,
+ options.annotation_usecase, &candidates)) {
TC3_LOG(ERROR) << "Datetime suggest selection failed.";
return original_click_indices;
}
@@ -571,6 +590,11 @@
TC3_LOG(ERROR) << "Contact suggest selection failed.";
return original_click_indices;
}
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Installed app suggest selection failed.";
+ return original_click_indices;
+ }
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
@@ -1057,7 +1081,8 @@
const std::string& context, CodepointSpan selection_indices,
ClassificationResult* classification_result) const {
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
const UnicodeText selection_text_unicode(
UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
@@ -1082,6 +1107,13 @@
regex_pattern.config->collection_name()->str(),
regex_pattern.config->target_classification_score(),
regex_pattern.config->priority_score()};
+
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(),
+ &classification_result->serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
return true;
}
if (status != UniLib::RegexMatcher::kNoError) {
@@ -1115,12 +1147,14 @@
}
const std::string selection_text =
- ExtractSelection(context, selection_indices);
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
options.reference_timezone, options.locales,
ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
/*anchor_start_end=*/true, &datetime_spans)) {
TC3_LOG(ERROR) << "Error during parsing datetime.";
return false;
@@ -1188,6 +1222,18 @@
}
}
+ // Try the installed app engine.
+ ClassificationResult installed_app_result;
+ if (installed_app_engine_ &&
+ installed_app_engine_->ClassifyText(context, selection_indices,
+ &installed_app_result)) {
+ if (!FilteredForClassification(installed_app_result)) {
+ return {installed_app_result};
+ } else {
+ return {{Collections::Other(), 1.0}};
+ }
+ }
+
// Try the regular expression models.
ClassificationResult regex_result;
if (RegexClassifyText(context, selection_indices, ®ex_result)) {
@@ -1377,7 +1423,8 @@
// Annotate with the datetime model.
if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
options.reference_time_ms_utc, options.reference_timezone,
- options.locales, ModeFlag_ANNOTATION, &candidates)) {
+ options.locales, ModeFlag_ANNOTATION,
+ options.annotation_usecase, &candidates)) {
TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
return {};
}
@@ -1395,6 +1442,13 @@
return {};
}
+ // Annotate with the installed app engine.
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
+ return {};
+ }
+
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
@@ -1464,6 +1518,66 @@
return result;
}
+bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
+ if (pattern->serialized_entity_data() != nullptr) {
+ return true;
+ }
+ if (pattern->capturing_group() != nullptr) {
+ for (const RegexModel_::Pattern_::CapturingGroup* group :
+ *pattern->capturing_group()) {
+ if (group->entity_field_path() != nullptr) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool Annotator::SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const {
+ if (!HasEntityData(pattern)) {
+ serialized_entity_data->clear();
+ return true;
+ }
+ TC3_CHECK(entity_data_builder_ != nullptr);
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+
+ TC3_CHECK(entity_data != nullptr);
+
+ // Set static entity data.
+ if (pattern->serialized_entity_data() != nullptr) {
+ TC3_CHECK(entity_data != nullptr);
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(pattern->serialized_entity_data()->c_str(),
+ pattern->serialized_entity_data()->size()));
+ }
+
+ // Add entity data from rule capturing groups.
+ if (pattern->capturing_group() != nullptr) {
+ const int num_groups = pattern->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ const FlatbufferFieldPath* field_path =
+ pattern->capturing_group()->Get(i)->entity_field_path();
+ if (field_path == nullptr) {
+ continue;
+ }
+ TC3_CHECK(entity_data != nullptr);
+ if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+ }
+
+ *serialized_entity_data = entity_data->Serialize();
+ return true;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
std::vector<AnnotatedSpan>* result) const {
@@ -1484,6 +1598,14 @@
continue;
}
}
+
+ std::string serialized_entity_data;
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(), &serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
+
result->emplace_back();
// Selection/annotation regular expressions need to specify a capturing
@@ -1495,6 +1617,9 @@
{regex_pattern.config->collection_name()->str(),
regex_pattern.config->target_classification_score(),
regex_pattern.config->priority_score()}};
+
+ result->back().classification[0].serialized_entity_data =
+ serialized_entity_data;
}
}
return true;
@@ -1747,6 +1872,7 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
std::vector<AnnotatedSpan>* result) const {
if (!datetime_parser_) {
return true;
@@ -1755,6 +1881,7 @@
std::vector<DatetimeParseResultSpan> datetime_spans;
if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
reference_timezone, locales, mode,
+ annotation_usecase,
/*anchor_start_end=*/false, &datetime_spans)) {
return false;
}
@@ -1775,6 +1902,9 @@
}
const Model* Annotator::ViewModel() const { return model_; }
+const reflection::Schema* Annotator::entity_data_schema() const {
+ return entity_data_schema_;
+}
const Model* ViewModel(const void* buffer, int size) {
if (!buffer) {
@@ -1784,4 +1914,10 @@
return LoadAndVerifyModel(buffer, size);
}
+bool Annotator::LookUpKnowledgeEntity(
+ const std::string& id, std::string* serialized_knowledge_result) const {
+ return knowledge_engine_ &&
+ knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+}
+
} // namespace libtextclassifier3
diff --git a/annotator/annotator.h b/annotator/annotator.h
index fa3e3d2..68dc05e 100644
--- a/annotator/annotator.h
+++ b/annotator/annotator.h
@@ -27,24 +27,38 @@
#include "annotator/contact/contact-engine.h"
#include "annotator/datetime/parser.h"
#include "annotator/feature-processor.h"
+#include "annotator/installed_app/installed-app-engine.h"
#include "annotator/knowledge/knowledge-engine.h"
#include "annotator/model-executor.h"
#include "annotator/model_generated.h"
#include "annotator/strip-unpaired-brackets.h"
#include "annotator/types.h"
#include "annotator/zlib-utils.h"
+#include "utils/flatbuffers.h"
#include "utils/memory/mmap.h"
#include "utils/utf8/unilib.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
+// Aliases for long enum values.
+const AnnotationUsecase ANNOTATION_USECASE_SMART =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART;
+const AnnotationUsecase ANNOTATION_USECASE_RAW =
+ AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
struct SelectionOptions {
// Comma-separated list of locale specification for the input text (BCP 47
// tags).
std::string locales;
- static SelectionOptions Default() { return SelectionOptions(); }
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const SelectionOptions& other) const {
+ return this->locales == other.locales &&
+ this->annotation_usecase == other.annotation_usecase;
+ }
};
struct ClassificationOptions {
@@ -60,7 +74,15 @@
// tags).
std::string locales;
- static ClassificationOptions Default() { return ClassificationOptions(); }
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const ClassificationOptions& other) const {
+ return this->reference_time_ms_utc == other.reference_time_ms_utc &&
+ this->reference_timezone == other.reference_timezone &&
+ this->locales == other.locales &&
+ this->annotation_usecase == other.annotation_usecase;
+ }
};
struct AnnotationOptions {
@@ -76,7 +98,15 @@
// tags).
std::string locales;
- static AnnotationOptions Default() { return AnnotationOptions(); }
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ bool operator==(const AnnotationOptions& other) const {
+ return this->reference_time_ms_utc == other.reference_time_ms_utc &&
+ this->reference_timezone == other.reference_timezone &&
+ this->locales == other.locales &&
+ this->annotation_usecase == other.annotation_usecase;
+ }
};
// Holds TFLite interpreters for selection and classification models.
@@ -137,6 +167,9 @@
// Initializes the contact engine with the given config.
bool InitializeContactEngine(const std::string& serialized_config);
+ // Initializes the installed app engine with the given config.
+ bool InitializeInstalledAppEngine(const std::string& serialized_config);
+
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
@@ -147,22 +180,27 @@
// Requires that the model is a smart selection model.
CodepointSpan SuggestSelection(
const std::string& context, CodepointSpan click_indices,
- const SelectionOptions& options = SelectionOptions::Default()) const;
+ const SelectionOptions& options = SelectionOptions()) const;
// Classifies the selected text given the context string.
// Returns an empty result if an error occurs.
std::vector<ClassificationResult> ClassifyText(
const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options =
- ClassificationOptions::Default()) const;
+ const ClassificationOptions& options = ClassificationOptions()) const;
// Annotates given input text. The annotations are sorted by their position
// in the context string and exclude spans classified as 'other'.
std::vector<AnnotatedSpan> Annotate(
const std::string& context,
- const AnnotationOptions& options = AnnotationOptions::Default()) const;
+ const AnnotationOptions& options = AnnotationOptions()) const;
+
+ // Looks up a knowledge entity by its id. If successful, populates the
+ // serialized knowledge result and returns true.
+ bool LookUpKnowledgeEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const;
const Model* ViewModel() const;
+ const reflection::Schema* entity_data_schema() const;
// Exposes the feature processor for tests and evaluations.
const FeatureProcessor* SelectionFeatureProcessorForTests() const;
@@ -315,6 +353,7 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
std::vector<AnnotatedSpan>* result) const;
// Returns whether a classification should be filtered.
@@ -328,6 +367,14 @@
const UniLib::RegexMatcher* match,
const RegexModel_::Pattern* config) const;
+ // Returns whether a regex pattern provides entity data from a match.
+ bool HasEntityData(const RegexModel_::Pattern* pattern) const;
+
+ // Constructs and serializes entity data from regex matches.
+ bool SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const;
+
const Model* model_;
std::unique_ptr<const ModelExecutor> selection_executor_;
@@ -367,6 +414,11 @@
std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
std::unique_ptr<const ContactEngine> contact_engine_;
+ std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
+
+ // Builder for creating extra data.
+ const reflection::Schema* entity_data_schema_;
+ std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
};
namespace internal {
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
index 955cf52..e760c5c 100644
--- a/annotator/annotator_jni.cc
+++ b/annotator/annotator_jni.cc
@@ -91,7 +91,8 @@
result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
+ "Ljava/lang/String;[L" TC3_PACKAGE_PATH
TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
const jmethodID datetime_parse_class_constructor =
env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
@@ -153,6 +154,23 @@
classification_result[i].contact_phone_number.c_str());
}
+ jstring contact_id = nullptr;
+ if (!classification_result[i].contact_id.empty()) {
+ contact_id =
+ env->NewStringUTF(classification_result[i].contact_id.c_str());
+ }
+
+ jstring app_name = nullptr;
+ if (!classification_result[i].app_name.empty()) {
+ app_name = env->NewStringUTF(classification_result[i].app_name.c_str());
+ }
+
+ jstring app_package_name = nullptr;
+ if (!classification_result[i].app_package_name.empty()) {
+ app_package_name =
+ env->NewStringUTF(classification_result[i].app_package_name.c_str());
+ }
+
jobject remote_action_templates_result = nullptr;
// Only generate RemoteActionTemplate for the top classification result
// as classifyText does not need RemoteAction from other results anyway.
@@ -172,7 +190,7 @@
static_cast<jfloat>(classification_result[i].score), row_datetime_parse,
serialized_knowledge_result, contact_name, contact_given_name,
contact_nickname, contact_email_address, contact_phone_number,
- remote_action_templates_result);
+ contact_id, app_name, app_package_name, remote_action_templates_result);
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
@@ -380,6 +398,25 @@
return model->InitializeContactEngine(serialized_config_string);
}
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+
+ std::string serialized_config_string;
+ const int length = env->GetArrayLength(serialized_config);
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeInstalledAppEngine(serialized_config_string);
+}
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options) {
@@ -427,7 +464,8 @@
intent_generator =
libtextclassifier3::IntentGenerator::CreateIntentGenerator(
ff_model->ViewModel()->intent_options(),
- ff_model->ViewModel()->resources(), jni_cache, app_context);
+ ff_model->ViewModel()->resources(), jni_cache, app_context,
+ ff_model->entity_data_schema());
std::unique_ptr<libtextclassifier3::RemoteActionTemplatesHandler>
remote_actions_templates_handler =
libtextclassifier3::RemoteActionTemplatesHandler::Create(env,
@@ -483,6 +521,25 @@
return results;
}
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id) {
+ if (!ptr) {
+ return nullptr;
+ }
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+ const std::string id_utf8 = ToStlString(env, id);
+ std::string serialized_knowledge_result;
+ if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
+ return nullptr;
+ }
+ jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
+ env->SetByteArrayRegion(
+ result, 0, serialized_knowledge_result.size(),
+ reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
+ return result;
+}
+
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr) {
Annotator* model = reinterpret_cast<Annotator*>(ptr);
diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h
index 59e02a9..9f8b55b 100644
--- a/annotator/annotator_jni.h
+++ b/annotator/annotator_jni.h
@@ -46,6 +46,10 @@
nativeInitializeContactEngine)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options);
@@ -58,6 +62,10 @@
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id);
+
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
(JNIEnv* env, jobject thiz, jlong ptr);
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
index b5198d4..d807ad8 100644
--- a/annotator/annotator_test.cc
+++ b/annotator/annotator_test.cc
@@ -55,6 +55,56 @@
return TC3_TEST_DATA_DIR;
}
+// Create fake entity data schema meta data.
+void AddTestEntitySchemaData(ModelT* unpacked_model) {
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("first_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("is_alive"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Bool),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("last_name"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8),
+ };
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ unpacked_model->entity_data_schema.assign(
+ schema_builder.GetBufferPointer(),
+ schema_builder.GetBufferPointer() + schema_builder.GetSize());
+}
+
class AnnotatorTest : public ::testing::TestWithParam<const char*> {
protected:
AnnotatorTest()
@@ -70,8 +120,6 @@
EXPECT_FALSE(classifier);
}
-INSTANTIATE_TEST_SUITE_P(ClickContext, AnnotatorTest,
- Values("test_model_cc.fb"));
INSTANTIATE_TEST_SUITE_P(BoundsSensitive, AnnotatorTest,
Values("test_model.fb"));
@@ -266,6 +314,73 @@
"www.google.com every today!|Call me at (800) 123-456 today.",
{51, 65})));
}
+
+TEST_P(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add fake entity schema metadata.
+ AddTestEntitySchemaData(unpacked_model.get());
+
+ // Add test regex models.
+ unpacked_model->regex_model->patterns.push_back(MakePattern(
+ "person", "(Barack) (Obama)", /*enabled_for_classification=*/true,
+ /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
+
+ // Use meta data to generate custom serialized entity data.
+ ReflectiveFlatbufferBuilder entity_data_builder(
+ flatbuffers::GetRoot<reflection::Schema>(
+ unpacked_model->entity_data_schema.data()));
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder.NewRoot();
+ entity_data->Set("is_alive", true);
+
+ RegexModel_::PatternT* pattern =
+ unpacked_model->regex_model->patterns.back().get();
+ pattern->serialized_entity_data = entity_data->Serialize();
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ // Group 0 is the full match, capturing groups starting at 1.
+ pattern->capturing_group[1]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[1]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[1]->entity_field_path->field.back()->field_name =
+ "first_name";
+ pattern->capturing_group[2]->entity_field_path.reset(
+ new FlatbufferFieldPathT);
+ pattern->capturing_group[2]->entity_field_path->field.emplace_back(
+ new FlatbufferFieldT);
+ pattern->capturing_group[2]->entity_field_path->field.back()->field_name =
+ "last_name";
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ auto classifications = classifier->ClassifyText(
+ "this afternoon Barack Obama gave a speech at", {15, 27});
+ EXPECT_EQ(1, classifications.size());
+ EXPECT_EQ("person", classifications[0].collection);
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ classifications[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "Barack");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Obama");
+ EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
+}
#endif // TC3_UNILIB_ICU
#ifdef TC3_UNILIB_ICU
@@ -626,7 +741,7 @@
SelectionOptions options;
EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
- std::make_pair(0, 7));
+ std::make_pair(0, 12));
}
TEST_P(AnnotatorTest, SuggestSelectionWithPunctuation) {
@@ -774,8 +889,8 @@
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
-
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
// Try passing invalid utf8.
EXPECT_TRUE(
classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
@@ -809,7 +924,8 @@
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
- EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
+ EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
+ ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
}
#ifdef TC3_UNILIB_ICU
diff --git a/annotator/collections.h b/annotator/collections.h
index 170aafc..089d7cb 100644
--- a/annotator/collections.h
+++ b/annotator/collections.h
@@ -29,6 +29,11 @@
*[]() { return new std::string("address"); }();
return value;
}
+ static const std::string& App() {
+ static const std::string& value =
+ *[]() { return new std::string("app"); }();
+ return value;
+ }
static const std::string& Contact() {
static const std::string& value =
*[]() { return new std::string("contact"); }();
diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc
index 0328dcf..b9d0c30 100644
--- a/annotator/datetime/extractor.cc
+++ b/annotator/datetime/extractor.cc
@@ -430,6 +430,9 @@
{DatetimeExtractorType_SATURDAY,
DateParseData::RelationType::SATURDAY},
{DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
+ {DatetimeExtractorType_SECONDS, DateParseData::RelationType::SECOND},
+ {DatetimeExtractorType_MINUTES, DateParseData::RelationType::MINUTE},
+ {DatetimeExtractorType_HOURS, DateParseData::RelationType::HOUR},
{DatetimeExtractorType_DAY, DateParseData::RelationType::DAY},
{DatetimeExtractorType_WEEK, DateParseData::RelationType::WEEK},
{DatetimeExtractorType_MONTH, DateParseData::RelationType::MONTH},
@@ -438,40 +441,4 @@
parsed_relation_type);
}
-bool DatetimeExtractor::ParseTimeUnit(
- const UnicodeText& input, DateParseData::TimeUnit* parsed_time_unit) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_DAYS, DateParseData::TimeUnit::DAYS},
- {DatetimeExtractorType_WEEKS, DateParseData::TimeUnit::WEEKS},
- {DatetimeExtractorType_MONTHS, DateParseData::TimeUnit::MONTHS},
- {DatetimeExtractorType_HOURS, DateParseData::TimeUnit::HOURS},
- {DatetimeExtractorType_MINUTES, DateParseData::TimeUnit::MINUTES},
- {DatetimeExtractorType_SECONDS, DateParseData::TimeUnit::SECONDS},
- {DatetimeExtractorType_YEARS, DateParseData::TimeUnit::YEARS},
- },
- parsed_time_unit);
-}
-
-bool DatetimeExtractor::ParseWeekday(
- const UnicodeText& input,
- DateParseData::RelationType* parsed_weekday) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY,
- DateParseData::RelationType::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY,
- DateParseData::RelationType::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
- {DatetimeExtractorType_SATURDAY,
- DateParseData::RelationType::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
- },
- parsed_weekday);
-}
-
} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc
index e2a2266..e2ff978 100644
--- a/annotator/datetime/parser.cc
+++ b/annotator/datetime/parser.cc
@@ -112,17 +112,18 @@
bool DatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locales, mode,
- anchor_start_end, results);
+ annotation_usecase, anchor_start_end, results);
}
bool DatetimeParser::FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
+ const std::string& reference_locale,
std::unordered_set<int>* executed_rules,
std::vector<DatetimeParseResultSpan>* found_spans) const {
for (const int locale_id : locale_ids) {
@@ -137,6 +138,11 @@
continue;
}
+ if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
+ annotation_usecase) == 0) {
+ continue;
+ }
+
if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
continue;
}
@@ -156,7 +162,7 @@
bool DatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const {
std::vector<DatetimeParseResultSpan> found_spans;
std::unordered_set<int> executed_rules;
@@ -164,8 +170,9 @@
const std::vector<int> requested_locales =
ParseAndExpandLocales(locales, &reference_locale);
if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, anchor_start_end,
- reference_locale, &executed_rules, &found_spans)) {
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, reference_locale,
+ &executed_rules, &found_spans)) {
return false;
}
@@ -337,97 +344,47 @@
return result;
}
-namespace {
+void DatetimeParser::FillInterpretations(
+ const DateParseData& parse,
+ std::vector<DateParseData>* interpretations) const {
+ DatetimeGranularity granularity = calendarlib_.GetGranularity(parse);
-DatetimeGranularity GetGranularity(const DateParseData& data) {
- DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
- if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::YEAR))) {
- granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ DateParseData modified_parse(parse);
+ // If the relation field is not set, but relation_type field *is*, assume
+ // the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
+ // "monday 3pm" (otherwise only "this monday 3pm" would work).
+ if (!(modified_parse.field_set_mask &
+ DateParseData::Fields::RELATION_FIELD) &&
+ (modified_parse.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD)) {
+ modified_parse.relation = DateParseData::Relation::NEXT_OR_SAME;
+ modified_parse.field_set_mask |= DateParseData::Fields::RELATION_FIELD;
}
- if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONTH))) {
- granularity = DatetimeGranularity::GRANULARITY_MONTH;
- }
- if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::WEEK)) {
- granularity = DatetimeGranularity::GRANULARITY_WEEK;
- }
- if (data.field_set_mask & DateParseData::DAY_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_FIELD &&
- (data.relation == DateParseData::Relation::NOW ||
- data.relation == DateParseData::Relation::TOMORROW ||
- data.relation == DateParseData::Relation::YESTERDAY)) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONDAY ||
- data.relation_type == DateParseData::RelationType::TUESDAY ||
- data.relation_type == DateParseData::RelationType::WEDNESDAY ||
- data.relation_type == DateParseData::RelationType::THURSDAY ||
- data.relation_type == DateParseData::RelationType::FRIDAY ||
- data.relation_type == DateParseData::RelationType::SATURDAY ||
- data.relation_type == DateParseData::RelationType::SUNDAY ||
- data.relation_type == DateParseData::RelationType::DAY))) {
- granularity = DatetimeGranularity::GRANULARITY_DAY;
- }
- if (data.field_set_mask & DateParseData::HOUR_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_HOUR;
- }
- if (data.field_set_mask & DateParseData::MINUTE_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_MINUTE;
- }
- if (data.field_set_mask & DateParseData::SECOND_FIELD) {
- granularity = DatetimeGranularity::GRANULARITY_SECOND;
- }
- return granularity;
-}
-
-void FillInterpretations(const DateParseData& parse,
- std::vector<DateParseData>* interpretations) {
- DatetimeGranularity granularity = GetGranularity(parse);
// Multiple interpretations of ambiguous datetime expressions are generated
// here.
- if (granularity > DatetimeGranularity::GRANULARITY_DAY && parse.hour <= 12 &&
- (parse.field_set_mask & DateParseData::Fields::AMPM_FIELD) == 0) {
+ if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
+ (modified_parse.field_set_mask & DateParseData::Fields::HOUR_FIELD) &&
+ modified_parse.hour <= 12 &&
+ !(modified_parse.field_set_mask & DateParseData::Fields::AMPM_FIELD)) {
// If it's not clear if the time is AM or PM, generate all variants.
- interpretations->push_back(parse);
+ interpretations->push_back(modified_parse);
interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
interpretations->back().ampm = DateParseData::AMPM::AM;
- interpretations->push_back(parse);
+ interpretations->push_back(modified_parse);
interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
interpretations->back().ampm = DateParseData::AMPM::PM;
- } else if (((parse.field_set_mask &
- DateParseData::Fields::RELATION_TYPE_FIELD) != 0) &&
- ((parse.field_set_mask & DateParseData::Fields::RELATION_FIELD) ==
- 0)) {
- // If it's not clear if it's this monday next monday or last monday,
- // generate
- // all variants.
- interpretations->push_back(parse);
- interpretations->back().field_set_mask |=
- DateParseData::Fields::RELATION_FIELD;
- interpretations->back().relation = DateParseData::Relation::LAST;
-
- interpretations->push_back(parse);
- interpretations->back().field_set_mask |=
- DateParseData::Fields::RELATION_FIELD;
- interpretations->back().relation = DateParseData::Relation::NEXT;
-
- interpretations->push_back(parse);
- interpretations->back().field_set_mask |=
- DateParseData::Fields::RELATION_FIELD;
- interpretations->back().relation = DateParseData::Relation::NEXT_OR_SAME;
} else {
// Otherwise just generate 1 variant.
- interpretations->push_back(parse);
+ interpretations->push_back(modified_parse);
}
+ // TODO(zilka): Add support for generating alternatives for "monday" -> "this
+ // monday", "next monday", "last monday". The previous implementation did not
+ // work as expected, because didn't work correctly for this/previous day of
+ // week, and resulted sometimes results in the same date being proposed.
}
-} // namespace
-
bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
const int64 reference_time_ms_utc,
@@ -454,10 +411,9 @@
results->reserve(results->size() + interpretations.size());
for (const DateParseData& interpretation : interpretations) {
DatetimeParseResult result;
- result.granularity = GetGranularity(interpretation);
if (!calendarlib_.InterpretParseData(
interpretation, reference_time_ms_utc, reference_timezone,
- reference_locale, result.granularity, &(result.time_ms_utc))) {
+ reference_locale, &(result.time_ms_utc), &(result.granularity))) {
return false;
}
results->push_back(result);
diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h
index 133d674..3f0c143 100644
--- a/annotator/datetime/parser.h
+++ b/annotator/datetime/parser.h
@@ -47,13 +47,15 @@
// beginning of 'input' and end at the end of it.
bool Parse(const std::string& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
// Same as above but takes UnicodeText.
bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, bool anchor_start_end,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
#ifdef TC3_TEST_ONLY
@@ -77,7 +79,8 @@
bool FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, bool anchor_start_end, const std::string& reference_locale,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
std::unordered_set<int>* executed_rules,
std::vector<DatetimeParseResultSpan>* found_spans) const;
@@ -88,6 +91,9 @@
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* result) const;
+ void FillInterpretations(const DateParseData& parse,
+ std::vector<DateParseData>* interpretations) const;
+
// Converts the current match in 'matcher' into DatetimeParseResult.
bool ExtractDatetime(const CompiledRule& rule,
const UniLib::RegexMatcher& matcher,
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
index 997d780..ad5c462 100644
--- a/annotator/datetime/parser_test.cc
+++ b/annotator/datetime/parser_test.cc
@@ -63,10 +63,12 @@
}
bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich") {
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- anchor_start_end, &results)) {
+ annotation_usecase, anchor_start_end, &results)) {
TC3_LOG(ERROR) << text;
TC3_CHECK(false);
}
@@ -78,7 +80,9 @@
DatetimeGranularity expected_granularity,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US") {
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
const UnicodeText marked_text_unicode =
UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
auto brace_open_it =
@@ -98,7 +102,7 @@
std::vector<DatetimeParseResultSpan> results;
if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- anchor_start_end, &results)) {
+ annotation_usecase, anchor_start_end, &results)) {
TC3_LOG(ERROR) << text;
TC3_CHECK(false);
}
@@ -149,10 +153,12 @@
DatetimeGranularity expected_granularity,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US") {
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
return ParsesCorrectly(marked_text, std::vector<int64>{expected_ms_utc},
expected_granularity, anchor_start_end, timezone,
- locales);
+ locales, annotation_usecase);
}
bool ParsesCorrectlyGerman(const std::string& marked_text,
@@ -258,6 +264,10 @@
ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("set an alarm for {7am tomorrow}", 108000000,
+ GRANULARITY_HOUR));
+ EXPECT_TRUE(
+ ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR));
}
TEST_F(ParserTest, ParseWithAnchor) {
@@ -271,6 +281,43 @@
/*anchor_start_end=*/true));
}
+TEST_F(ParserTest, ParseWithRawUsecase) {
+ // Annotated for RAW usecase.
+ EXPECT_TRUE(ParsesCorrectly(
+ "{tomorrow}", 82800000, GRANULARITY_DAY, /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {in two hours}", 7200000, GRANULARITY_HOUR,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me {next month}", 2674800000, GRANULARITY_MONTH,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+ EXPECT_TRUE(ParsesCorrectly(
+ "what's the time {now}", -3600000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "call me on {Saturday}", 169200000, GRANULARITY_DAY,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
+
+ // Not annotated for Smart usecase.
+ EXPECT_TRUE(HasNoResult(
+ "{tomorrow}", /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich",
+ /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
+}
+
TEST_F(ParserTest, ParseGerman) {
EXPECT_TRUE(
ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
@@ -366,7 +413,7 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
-TEST_F(ParserTest, WhenEnabled_GeneratesAlternatives) {
+TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
true;
@@ -375,9 +422,12 @@
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
{1514777400000, 1514820600000},
GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{monday 3pm}", 396000000, GRANULARITY_HOUR));
+ EXPECT_TRUE(ParsesCorrectly("{monday 3:00}", {352800000, 396000000},
+ GRANULARITY_MINUTE));
}
-TEST_F(ParserTest, WhenDisabled_DoesNotGenerateAlternatives) {
+TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
false;
@@ -443,9 +493,10 @@
bool ParserLocaleTest::HasResult(const std::string& input,
const std::string& locales) {
std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales,
- ModeFlag_ANNOTATION, false, &results));
+ EXPECT_TRUE(parser_->Parse(
+ input, /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
+ AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
return results.size() == 1;
}
diff --git a/annotator/entity-data.fbs b/annotator/entity-data.fbs
new file mode 100755
index 0000000..779b047
--- /dev/null
+++ b/annotator/entity-data.fbs
@@ -0,0 +1,22 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Entity information and data associated with text classifications.
+namespace libtextclassifier3;
+table EntityData {
+}
+
+root_type libtextclassifier3.EntityData;
diff --git a/annotator/installed_app/installed-app-engine-dummy.h b/annotator/installed_app/installed-app-engine-dummy.h
new file mode 100644
index 0000000..a45f5d0
--- /dev/null
+++ b/annotator/installed_app/installed-app-engine-dummy.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the installed app engine.
+class InstalledAppEngine {
+ public:
+ explicit InstalledAppEngine(const FeatureProcessor* feature_processor) {}
+
+ bool Initialize(const std::string& serialized_config) {
+ TC3_LOG(ERROR) << "No installed app engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_DUMMY_H_
diff --git a/annotator/installed_app/installed-app-engine.h b/annotator/installed_app/installed-app-engine.h
new file mode 100644
index 0000000..d05d357
--- /dev/null
+++ b/annotator/installed_app/installed-app-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
+
+#include "annotator/installed_app/installed-app-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_INSTALLED_APP_INSTALLED_APP_ENGINE_H_
diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h
index a6285dc..96d77c5 100644
--- a/annotator/knowledge/knowledge-engine-dummy.h
+++ b/annotator/knowledge/knowledge-engine-dummy.h
@@ -40,6 +40,11 @@
std::vector<AnnotatedSpan>* result) const {
return true;
}
+
+ bool LookUpEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const {
+ return false;
+ }
};
} // namespace libtextclassifier3
diff --git a/annotator/model.fbs b/annotator/model.fbs
index be27cc0..303f1a0 100755
--- a/annotator/model.fbs
+++ b/annotator/model.fbs
@@ -14,6 +14,7 @@
// limitations under the License.
//
+include "utils/flatbuffers.fbs";
include "utils/intents/intent-config.fbs";
include "utils/resources.fbs";
include "utils/zlib/buffer.fbs";
@@ -33,6 +34,18 @@
ALL = 7,
}
+// Enum for specifying the annotation usecase. The values are meant to be used
+// as flags, and thus are not mutually exclusive.
+namespace libtextclassifier3;
+enum AnnotationUsecase : int {
+ ANNOTATION_USECASE_NONE = 0,
+ ANNOTATION_USECASE_SMART = 1,
+ ANNOTATION_USECASE_RAW = 2,
+
+ // 32-bit int maximum value (all bits set); since in flatbuffers we use int
+ ANNOTATION_USECASE_ALL = 2147483647,
+}
+
namespace libtextclassifier3;
enum DatetimeExtractorType : int {
UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
@@ -72,7 +85,10 @@
DAYS = 34,
WEEKS = 35,
MONTHS = 36,
+
+ // TODO(zilka): Make the following 3 values singular for consistency.
HOURS = 37,
+
MINUTES = 38,
SECONDS = 39,
YEARS = 40,
@@ -180,6 +196,10 @@
// If true, the span of the capturing group will be used to
// extend the selection.
extend_selection:bool = true;
+
+ // If set, the text of the capturing group will be used to set a field in
+ // the classfication result entity data.
+ entity_field_path:FlatbufferFieldPath;
}
// List of regular expression matchers to check.
@@ -211,6 +231,9 @@
verification_options:VerificationOptions;
capturing_group:[Pattern_.CapturingGroup];
+
+ // Serialized entity data to set for a match.
+ serialized_entity_data:string;
}
namespace libtextclassifier3;
@@ -246,6 +269,9 @@
// The modes for which to apply the patterns.
enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to apply the patterns.
+ enabled_annotation_usecases:AnnotationUsecase = ANNOTATION_USECASE_ALL;
}
namespace libtextclassifier3;
@@ -364,6 +390,9 @@
// Model resources.
resources:ResourcePool;
+
+ // Schema data for handling entity data.
+ entity_data_schema:[ubyte];
}
// Role of the codepoints in the range.
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
index bca3c2b..2f8aeb6 100644
--- a/annotator/test_data/test_model.fb
+++ b/annotator/test_data/test_model.fb
Binary files differ
diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb
deleted file mode 100644
index 9ca37ce..0000000
--- a/annotator/test_data/test_model_cc.fb
+++ /dev/null
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
index 3d18ce5..ce9cd83 100644
--- a/annotator/test_data/wrong_embeddings.fb
+++ b/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/annotator/types-test-util.h b/annotator/types-test-util.h
index fbbdd63..9d40d40 100644
--- a/annotator/types-test-util.h
+++ b/annotator/types-test-util.h
@@ -44,6 +44,13 @@
return stream << tmp_stream.message;
}
+inline std::ostream& operator<<(std::ostream& stream,
+ const DateParseData& value) {
+ logging::LoggingStringStream tmp_stream;
+ tmp_stream << value;
+ return stream << tmp_stream.message;
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
diff --git a/annotator/types.cc b/annotator/types.cc
index 78d72df..d92e4d1 100644
--- a/annotator/types.cc
+++ b/annotator/types.cc
@@ -80,4 +80,26 @@
<< ", " << best_class << ", " << best_score << ")";
}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DateParseData& data) {
+ // TODO(zilka): Add human-readable form of field_set_mask and the enum fields.
+ stream = stream << "DateParseData {\n";
+ stream = stream << " field_set_mask: " << data.field_set_mask << "\n";
+ stream = stream << " year: " << data.year << "\n";
+ stream = stream << " month: " << data.month << "\n";
+ stream = stream << " day_of_month: " << data.day_of_month << "\n";
+ stream = stream << " hour: " << data.hour << "\n";
+ stream = stream << " minute: " << data.minute << "\n";
+ stream = stream << " second: " << data.second << "\n";
+ stream = stream << " ampm: " << static_cast<int>(data.ampm) << "\n";
+ stream = stream << " zone_offset: " << data.zone_offset << "\n";
+ stream = stream << " dst_offset: " << data.dst_offset << "\n";
+ stream = stream << " relation: " << static_cast<int>(data.relation) << "\n";
+ stream = stream << " relation_type: " << static_cast<int>(data.relation_type)
+ << "\n";
+ stream = stream << " relation_distance: " << data.relation_distance << "\n";
+ stream = stream << "}";
+ return stream;
+}
+
} // namespace libtextclassifier3
diff --git a/annotator/types.h b/annotator/types.h
index 71acaf4..03ebef2 100644
--- a/annotator/types.h
+++ b/annotator/types.h
@@ -27,8 +27,10 @@
#include <utility>
#include <vector>
+#include "annotator/entity-data_generated.h"
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
#include "utils/variant.h"
namespace libtextclassifier3 {
@@ -164,9 +166,7 @@
};
struct DatetimeParseResult {
- // The absolute time in milliseconds since the epoch in UTC. This is derived
- // from the reference time and the fields specified in the text - so it may
- // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
+ // The absolute time in milliseconds since the epoch in UTC.
int64 time_ms_utc;
// The precision of the estimate then in to calculating the milliseconds
@@ -212,13 +212,19 @@
DatetimeParseResult datetime_parse_result;
std::string serialized_knowledge_result;
std::string contact_name, contact_given_name, contact_nickname,
- contact_email_address, contact_phone_number;
+ contact_email_address, contact_phone_number, contact_id;
+ std::string app_name, app_package_name;
// Internal score used for conflict resolution.
float priority_score;
- // Extra information.
- std::map<std::string, Variant> extra;
+
+ // Entity data information.
+ std::string serialized_entity_data;
+ const EntityData* entity_data() {
+ return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
+ serialized_entity_data.size());
+ }
explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
@@ -306,7 +312,10 @@
DAY = 8,
WEEK = 9,
MONTH = 10,
- YEAR = 11
+ YEAR = 11,
+ HOUR = 12,
+ MINUTE = 13,
+ SECOND = 14,
};
enum Fields {
@@ -390,6 +399,10 @@
}
};
+// Pretty-printing function for DateParseData.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DateParseData& data);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_