[automerger skipped] Add dependencies on jni_headers am: 8fd625660f am: a25f7792dc -s ours
am skip reason: Change-Id I54806e51f2d1715ef1c18b0103a8ec5f6b7a540d with SHA-1 9671164f19 is in history
Change-Id: I065faab214a328ebc7849582a72a43f3340db20f
diff --git a/Android.bp b/Android.bp
deleted file mode 100644
index 7683387..0000000
--- a/Android.bp
+++ /dev/null
@@ -1,347 +0,0 @@
-// Copyright (C) 2017 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-cc_library_headers {
- name: "libtextclassifier_hash_headers",
- vendor_available: true,
- export_include_dirs: ["."],
- apex_available: [
- "//apex_available:platform",
- "com.android.neuralnetworks",
- "test_com.android.neuralnetworks",
- ],
-}
-
-cc_defaults {
- name: "libtextclassifier_hash_defaults",
- srcs: [
- "utils/hash/farmhash.cc",
- "util/hash/hash.cc",
- ],
- cflags: [
- "-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash",
- "-Wall",
- "-Werror",
- "-Wno-unused-function",
- ],
-}
-
-cc_library_shared {
- name: "libtextclassifier_hash",
- defaults: ["libtextclassifier_hash_defaults"],
- vendor_available: true,
- double_loadable: true,
-}
-
-cc_library_static {
- name: "libtextclassifier_hash_static",
- defaults: ["libtextclassifier_hash_defaults"],
- vendor_available: true,
- sdk_version: "current",
- stl: "libc++_static",
- apex_available: [
- "//apex_available:platform",
- "com.android.neuralnetworks",
- "test_com.android.neuralnetworks",
- ],
-}
-
-java_library_static {
- name: "libtextclassifier-java",
- sdk_version: "core_current",
- srcs: ["java/**/*.java"],
-}
-
-cc_defaults {
- name: "libtextclassifier_defaults",
-
- // For debug / treemap purposes.
- //strip: {
- // keep_symbols: true,
- //},
-
- cflags: [
- "-Wall",
- "-Werror",
- "-Wno-deprecated-declarations",
- "-Wno-ignored-qualifiers",
- "-Wno-missing-field-initializers",
- "-Wno-sign-compare",
- "-Wno-tautological-constant-out-of-range-compare",
- "-Wno-undefined-var-template",
- "-Wno-unused-function",
- "-Wno-unused-parameter",
- "-Wno-extern-c-compat",
-
- "-funsigned-char",
- "-fvisibility=hidden",
- "-DLIBTEXTCLASSIFIER_UNILIB_ICU",
- "-DZLIB_CONST",
- "-DSAFTM_COMPACT_LOGGING",
- "-DTC3_WITH_ACTIONS_OPS",
- "-DTC3_UNILIB_JAVAICU",
- "-DTC3_CALENDAR_JAVAICU",
- "-DTC3_AOSP"
- ],
-
- product_variables: {
- debuggable: {
- // Only enable debug logging in userdebug/eng builds.
- cflags: ["-DTC_DEBUG_LOGGING=1"],
- },
- },
-
- generated_headers: [
- "libtextclassifier_fbgen_flatbuffers",
- "libtextclassifier_fbgen_tokenizer",
- "libtextclassifier_fbgen_codepoint_range",
- "libtextclassifier_fbgen_entity-data",
- "libtextclassifier_fbgen_zlib_buffer",
- "libtextclassifier_fbgen_resources_extra",
- "libtextclassifier_fbgen_intent_config",
- "libtextclassifier_fbgen_annotator_model",
- "libtextclassifier_fbgen_actions_model",
- "libtextclassifier_fbgen_tflite_text_encoder_config",
- "libtextclassifier_fbgen_lang_id_embedded_network",
- "libtextclassifier_fbgen_lang_id_model",
- "libtextclassifier_fbgen_actions-entity-data",
- ],
-
- header_libs: [
- "jni_headers",
- "tensorflow_headers",
- "flatbuffer_headers",
- ],
-
- shared_libs: [
- "liblog",
- "libtflite",
- "libz",
- ],
-
- static_libs: [
- "liblua",
- "libutf",
- ],
-}
-
-// -----------------
-// Generate headers with FlatBuffer schema compiler.
-// -----------------
-genrule_defaults {
- name: "fbgen",
- tools: ["flatc"],
- // "depfile" is used here in conjunction with flatc's -M to gather the deps
- cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -M $(in) >$(depfile) && " +
- "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier -o $$(dirname $(out)) $(in)",
- depfile: true,
-}
-
-genrule {
- name: "libtextclassifier_fbgen_flatbuffers",
- srcs: ["utils/flatbuffers.fbs"],
- out: ["utils/flatbuffers_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_tokenizer",
- srcs: ["utils/tokenizer.fbs"],
- out: ["utils/tokenizer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_codepoint_range",
- srcs: ["utils/codepoint-range.fbs"],
- out: ["utils/codepoint-range_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_resources_extra",
- srcs: ["utils/resources.fbs"],
- out: ["utils/resources_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_entity-data",
- srcs: ["annotator/entity-data.fbs"],
- out: ["annotator/entity-data_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_zlib_buffer",
- srcs: ["utils/zlib/buffer.fbs"],
- out: ["utils/zlib/buffer_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_intent_config",
- srcs: ["utils/intents/intent-config.fbs"],
- out: ["utils/intents/intent-config_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_annotator_model",
- srcs: ["annotator/model.fbs"],
- out: ["annotator/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions_model",
- srcs: ["actions/actions_model.fbs"],
- out: ["actions/actions_model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_tflite_text_encoder_config",
- srcs: ["utils/tflite/text_encoder_config.fbs"],
- out: ["utils/tflite/text_encoder_config_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_embedded_network",
- srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
- out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_lang_id_model",
- srcs: ["lang_id/common/flatbuffers/model.fbs"],
- out: ["lang_id/common/flatbuffers/model_generated.h"],
- defaults: ["fbgen"],
-}
-
-genrule {
- name: "libtextclassifier_fbgen_actions-entity-data",
- srcs: ["actions/actions-entity-data.fbs"],
- out: ["actions/actions-entity-data_generated.h"],
- defaults: ["fbgen"],
-}
-
-// -----------------
-// libtextclassifier
-// -----------------
-cc_library_shared {
- name: "libtextclassifier",
- defaults: ["libtextclassifier_defaults"],
-
- srcs: ["**/*.cc"],
- exclude_srcs: [
- "**/*_test.cc",
- "**/*-test-lib.cc",
- "utils/testing/*.cc",
- "test-util.*",
- "utils/calendar/*_test-include.*",
- "utils/utf8/*_test-include.*"
- ],
-
- required: [
- "libtextclassifier_annotator_en_model",
- "libtextclassifier_annotator_universal_model",
- "libtextclassifier_actions_suggestions_universal_model",
- "libtextclassifier_lang_id_model",
- ],
-
- version_script: "jni.lds",
-}
-
-// -----------------------
-// libtextclassifier_tests
-// -----------------------
-cc_test {
- name: "libtextclassifier_tests",
- defaults: ["libtextclassifier_defaults"],
-
- test_suites: ["device-tests"],
-
- data: [
- "annotator/test_data/**/*",
- "actions/test_data/**/*",
- ],
-
- srcs: ["**/*.cc"],
- // TODO: Do not filter out tflite test once the dependency issue is resolved.
- exclude_srcs: [
- "utils/tflite/*_test.cc",
- "utils/flatbuffers_test.cc",
- "utils/calendar/*_test-include.*",
- "utils/utf8/*_test-include.*"
- ],
-
- static_libs: ["libgmock"],
- header_libs: ["jni_headers"],
-
- multilib: {
- lib32: {
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest/libtextclassifier_tests/test_data/\""],
- },
- lib64: {
- cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest64/libtextclassifier_tests/test_data/\""],
- },
- },
-}
-
-// ----------------
-// Annotator models
-// ----------------
-
-prebuilt_etc {
- name: "libtextclassifier_annotator_en_model",
- filename: "textclassifier.en.model",
- owner: "google",
- src: "models/textclassifier.en.model",
- sub_dir: "textclassifier",
-}
-
-prebuilt_etc {
- name: "libtextclassifier_annotator_universal_model",
- filename: "textclassifier.universal.model",
- owner: "google",
- src: "models/textclassifier.universal.model",
- sub_dir: "textclassifier",
-}
-
-// ---------------------------
-// Actions Suggestions models
-// ---------------------------
-
-prebuilt_etc {
- name: "libtextclassifier_actions_suggestions_universal_model",
- filename: "actions_suggestions.universal.model",
- owner: "google",
- src: "models/actions_suggestions.universal.model",
- sub_dir: "textclassifier",
-}
-
-// ------------
-// LangId model
-// ------------
-
-prebuilt_etc {
- name: "libtextclassifier_lang_id_model",
- filename: "lang_id.model",
- owner: "google",
- src: "models/lang_id.model",
- sub_dir: "textclassifier",
-}
diff --git a/TEST_MAPPING b/TEST_MAPPING
new file mode 100644
index 0000000..3c8e10b
--- /dev/null
+++ b/TEST_MAPPING
@@ -0,0 +1,15 @@
+{
+ "presubmit": [
+ {
+ "name": "TextClassifierServiceTest",
+ "options": [
+ {
+ "exclude-annotation": "androidx.test.filters.FlakyTest"
+ }
+ ]
+ },
+ {
+ "name": "libtextclassifier_tests"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/actions/actions-entity-data.fbs b/actions/actions-entity-data.fbs
deleted file mode 100755
index 4ed68bb..0000000
--- a/actions/actions-entity-data.fbs
+++ /dev/null
@@ -1,24 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-// Extra information and data associated with actions.
-namespace libtextclassifier3;
-table ActionsEntityData {
- // Extracted text.
- text:string;
-}
-
-root_type libtextclassifier3.ActionsEntityData;
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
deleted file mode 100644
index 29a4424..0000000
--- a/actions/actions-suggestions.cc
+++ /dev/null
@@ -1,1450 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/actions-suggestions.h"
-
-#include <memory>
-
-#include "actions/lua-actions.h"
-#include "actions/types.h"
-#include "actions/zlib-utils.h"
-#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
-#include "utils/lua-utils.h"
-#include "utils/regex-match.h"
-#include "utils/strings/split.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/zlib/zlib_regex.h"
-#include "tensorflow/lite/string_util.h"
-
-namespace libtextclassifier3 {
-
-const std::string& ActionsSuggestions::kViewCalendarType =
- *[]() { return new std::string("view_calendar"); }();
-const std::string& ActionsSuggestions::kViewMapType =
- *[]() { return new std::string("view_map"); }();
-const std::string& ActionsSuggestions::kTrackFlightType =
- *[]() { return new std::string("track_flight"); }();
-const std::string& ActionsSuggestions::kOpenUrlType =
- *[]() { return new std::string("open_url"); }();
-const std::string& ActionsSuggestions::kSendSmsType =
- *[]() { return new std::string("send_sms"); }();
-const std::string& ActionsSuggestions::kCallPhoneType =
- *[]() { return new std::string("call_phone"); }();
-const std::string& ActionsSuggestions::kSendEmailType =
- *[]() { return new std::string("send_email"); }();
-const std::string& ActionsSuggestions::kShareLocation =
- *[]() { return new std::string("share_location"); }();
-
-namespace {
-
-const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
- flatbuffers::Verifier verifier(addr, size);
- if (VerifyActionsModelBuffer(verifier)) {
- return GetActionsModel(addr);
- } else {
- return nullptr;
- }
-}
-
-template <typename T>
-T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
- const T default_value) {
- if (values == nullptr) {
- return default_value;
- }
- return values->GetField<T>(field_offset, default_value);
-}
-
-// Returns number of (tail) messages of a conversation to consider.
-int NumMessagesToConsider(const Conversation& conversation,
- const int max_conversation_history_length) {
- return ((max_conversation_history_length < 0 ||
- conversation.messages.size() < max_conversation_history_length)
- ? conversation.messages.size()
- : max_conversation_history_length);
-}
-
-} // namespace
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
- const uint8_t* buffer, const int size, const UniLib* unilib,
- const std::string& triggering_preconditions_overlay) {
- auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
- const ActionsModel* model = LoadAndVerifyModel(buffer, size);
- if (model == nullptr) {
- return nullptr;
- }
- actions->model_ = model;
- actions->SetOrCreateUnilib(unilib);
- actions->triggering_preconditions_overlay_buffer_ =
- triggering_preconditions_overlay;
- if (!actions->ValidateAndInitialize()) {
- return nullptr;
- }
- return actions;
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
- const std::string& triggering_preconditions_overlay) {
- if (!mmap->handle().ok()) {
- TC3_VLOG(1) << "Mmap failed.";
- return nullptr;
- }
- const ActionsModel* model = LoadAndVerifyModel(
- reinterpret_cast<const uint8_t*>(mmap->handle().start()),
- mmap->handle().num_bytes());
- if (!model) {
- TC3_LOG(ERROR) << "Model verification failed.";
- return nullptr;
- }
- auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
- actions->model_ = model;
- actions->mmap_ = std::move(mmap);
- actions->SetOrCreateUnilib(unilib);
- actions->triggering_preconditions_overlay_buffer_ =
- triggering_preconditions_overlay;
- if (!actions->ValidateAndInitialize()) {
- return nullptr;
- }
- return actions;
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
- std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay) {
- if (!mmap->handle().ok()) {
- TC3_VLOG(1) << "Mmap failed.";
- return nullptr;
- }
- const ActionsModel* model = LoadAndVerifyModel(
- reinterpret_cast<const uint8_t*>(mmap->handle().start()),
- mmap->handle().num_bytes());
- if (!model) {
- TC3_LOG(ERROR) << "Model verification failed.";
- return nullptr;
- }
- auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
- actions->model_ = model;
- actions->mmap_ = std::move(mmap);
- actions->owned_unilib_ = std::move(unilib);
- actions->unilib_ = actions->owned_unilib_.get();
- actions->triggering_preconditions_overlay_buffer_ =
- triggering_preconditions_overlay;
- if (!actions->ValidateAndInitialize()) {
- return nullptr;
- }
- return actions;
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
- const int fd, const int offset, const int size, const UniLib* unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
- if (offset >= 0 && size >= 0) {
- mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
- } else {
- mmap.reset(new libtextclassifier3::ScopedMmap(fd));
- }
- return FromScopedMmap(std::move(mmap), unilib,
- triggering_preconditions_overlay);
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
- const int fd, const int offset, const int size,
- std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
- if (offset >= 0 && size >= 0) {
- mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
- } else {
- mmap.reset(new libtextclassifier3::ScopedMmap(fd));
- }
- return FromScopedMmap(std::move(mmap), std::move(unilib),
- triggering_preconditions_overlay);
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
- const int fd, const UniLib* unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return FromScopedMmap(std::move(mmap), unilib,
- triggering_preconditions_overlay);
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
- const int fd, std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return FromScopedMmap(std::move(mmap), std::move(unilib),
- triggering_preconditions_overlay);
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
- const std::string& path, const UniLib* unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(path));
- return FromScopedMmap(std::move(mmap), unilib,
- triggering_preconditions_overlay);
-}
-
-std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
- const std::string& path, std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay) {
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(path));
- return FromScopedMmap(std::move(mmap), std::move(unilib),
- triggering_preconditions_overlay);
-}
-
-void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
- if (unilib != nullptr) {
- unilib_ = unilib;
- } else {
- owned_unilib_.reset(new UniLib);
- unilib_ = owned_unilib_.get();
- }
-}
-
-bool ActionsSuggestions::ValidateAndInitialize() {
- if (model_ == nullptr) {
- TC3_LOG(ERROR) << "No model specified.";
- return false;
- }
-
- if (model_->smart_reply_action_type() == nullptr) {
- TC3_LOG(ERROR) << "No smart reply action type specified.";
- return false;
- }
-
- if (!InitializeTriggeringPreconditions()) {
- TC3_LOG(ERROR) << "Could not initialize preconditions.";
- return false;
- }
-
- if (model_->locales() &&
- !ParseLocales(model_->locales()->c_str(), &locales_)) {
- TC3_LOG(ERROR) << "Could not parse model supported locales.";
- return false;
- }
-
- if (model_->tflite_model_spec() != nullptr) {
- model_executor_ = TfLiteModelExecutor::FromBuffer(
- model_->tflite_model_spec()->tflite_model());
- if (!model_executor_) {
- TC3_LOG(ERROR) << "Could not initialize model executor.";
- return false;
- }
- }
-
- if (model_->annotation_actions_spec() != nullptr &&
- model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
- for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
- *model_->annotation_actions_spec()->annotation_mapping()) {
- annotation_entity_types_.insert(mapping->annotation_collection()->str());
- }
- }
-
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- if (!InitializeRules(decompressor.get())) {
- TC3_LOG(ERROR) << "Could not initialize rules.";
- return false;
- }
-
- if (model_->actions_entity_data_schema() != nullptr) {
- entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
- model_->actions_entity_data_schema()->Data(),
- model_->actions_entity_data_schema()->size());
- if (entity_data_schema_ == nullptr) {
- TC3_LOG(ERROR) << "Could not load entity data schema data.";
- return false;
- }
-
- entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
- } else {
- entity_data_schema_ = nullptr;
- }
-
- std::string actions_script;
- if (GetUncompressedString(model_->lua_actions_script(),
- model_->compressed_lua_actions_script(),
- decompressor.get(), &actions_script) &&
- !actions_script.empty()) {
- if (!Compile(actions_script, &lua_bytecode_)) {
- TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
- return false;
- }
- }
-
- if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
- model_->ranking_options(), decompressor.get(),
- model_->smart_reply_action_type()->str()))) {
- TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
- return false;
- }
-
- // Create feature processor if specified.
- const ActionsTokenFeatureProcessorOptions* options =
- model_->feature_processor_options();
- if (options != nullptr) {
- if (options->tokenizer_options() == nullptr) {
- TC3_LOG(ERROR) << "No tokenizer options specified.";
- return false;
- }
-
- feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
- embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
- options->embedding_model(), options->embedding_size(),
- options->embedding_quantization_bits());
-
- if (embedding_executor_ == nullptr) {
- TC3_LOG(ERROR) << "Could not initialize embedding executor.";
- return false;
- }
-
- // Cache embedding of padding, start and end token.
- if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
- !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
- !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
- TC3_LOG(ERROR) << "Could not precompute token embeddings.";
- return false;
- }
- token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
- }
-
- // Create low confidence model if specified.
- if (model_->low_confidence_ngram_model() != nullptr) {
- ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(),
- feature_processor_ == nullptr
- ? nullptr
- : feature_processor_->tokenizer(),
- unilib_);
- if (ngram_model_ == nullptr) {
- TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
- return false;
- }
- }
-
- return true;
-}
-
-bool ActionsSuggestions::InitializeTriggeringPreconditions() {
- triggering_preconditions_overlay_ =
- LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
- triggering_preconditions_overlay_buffer_);
-
- if (triggering_preconditions_overlay_ == nullptr &&
- !triggering_preconditions_overlay_buffer_.empty()) {
- TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
- return false;
- }
- const flatbuffers::Table* overlay =
- reinterpret_cast<const flatbuffers::Table*>(
- triggering_preconditions_overlay_);
- const TriggeringPreconditions* defaults = model_->preconditions();
- if (defaults == nullptr) {
- TC3_LOG(ERROR) << "No triggering conditions specified.";
- return false;
- }
-
- preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
- defaults->min_smart_reply_triggering_score());
- preconditions_.max_sensitive_topic_score = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
- defaults->max_sensitive_topic_score());
- preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
- defaults->suppress_on_sensitive_topic());
- preconditions_.min_input_length =
- ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
- defaults->min_input_length());
- preconditions_.max_input_length =
- ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
- defaults->max_input_length());
- preconditions_.min_locale_match_fraction = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
- defaults->min_locale_match_fraction());
- preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
- defaults->handle_missing_locale_as_supported());
- preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
- defaults->handle_unknown_locale_as_supported());
- preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
- defaults->suppress_on_low_confidence_input());
- preconditions_.diversification_distance_threshold = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD,
- defaults->diversification_distance_threshold());
- preconditions_.confidence_threshold =
- ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD,
- defaults->confidence_threshold());
- preconditions_.empirical_probability_factor = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR,
- defaults->empirical_probability_factor());
- preconditions_.min_reply_score_threshold = ValueOrDefault(
- overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
- defaults->min_reply_score_threshold());
-
- return true;
-}
-
-bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
- std::vector<float>* embedding) const {
- return feature_processor_->AppendFeatures(
- {token_id},
- /*dense_features=*/{}, embedding_executor_.get(), embedding);
-}
-
-bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
- if (model_->rules() != nullptr) {
- if (!InitializeRules(decompressor, model_->rules(), &rules_)) {
- TC3_LOG(ERROR) << "Could not initialize action rules.";
- return false;
- }
- }
-
- if (model_->low_confidence_rules() != nullptr) {
- if (!InitializeRules(decompressor, model_->low_confidence_rules(),
- &low_confidence_rules_)) {
- TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
- return false;
- }
- }
-
- // Extend by rules provided by the overwrite.
- // NOTE: The rules from the original models are *not* cleared.
- if (triggering_preconditions_overlay_ != nullptr &&
- triggering_preconditions_overlay_->low_confidence_rules() != nullptr) {
- // These rules are optionally compressed, but separately.
- std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
- ZlibDecompressor::Instance();
- if (overwrite_decompressor == nullptr) {
- TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
- return false;
- }
- if (!InitializeRules(
- overwrite_decompressor.get(),
- triggering_preconditions_overlay_->low_confidence_rules(),
- &low_confidence_rules_)) {
- TC3_LOG(ERROR)
- << "Could not initialize low confidence rules from overwrite.";
- return false;
- }
- }
-
- return true;
-}
-
-bool ActionsSuggestions::InitializeRules(
- ZlibDecompressor* decompressor, const RulesModel* rules,
- std::vector<CompiledRule>* compiled_rules) const {
- for (const RulesModel_::Rule* rule : *rules->rule()) {
- std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
- UncompressMakeRegexPattern(
- *unilib_, rule->pattern(), rule->compressed_pattern(),
- rules->lazy_regex_compilation(), decompressor);
- if (compiled_pattern == nullptr) {
- TC3_LOG(ERROR) << "Failed to load rule pattern.";
- return false;
- }
-
- // Check whether there is a check on the output.
- std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
- if (rule->output_pattern() != nullptr ||
- rule->compressed_output_pattern() != nullptr) {
- compiled_output_pattern = UncompressMakeRegexPattern(
- *unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
- rules->lazy_regex_compilation(), decompressor);
- if (compiled_output_pattern == nullptr) {
- TC3_LOG(ERROR) << "Failed to load rule output pattern.";
- return false;
- }
- }
-
- compiled_rules->emplace_back(rule, std::move(compiled_pattern),
- std::move(compiled_output_pattern));
- }
-
- return true;
-}
-
-bool ActionsSuggestions::IsLowConfidenceInput(
- const Conversation& conversation, const int num_messages,
- std::vector<int>* post_check_rules) const {
- for (int i = 1; i <= num_messages; i++) {
- const std::string& message =
- conversation.messages[conversation.messages.size() - i].text;
- const UnicodeText message_unicode(
- UTF8ToUnicodeText(message, /*do_copy=*/false));
-
- // Run ngram linear regression model.
- if (ngram_model_ != nullptr) {
- if (ngram_model_->Eval(message_unicode)) {
- return true;
- }
- }
-
- // Run the regex based rules.
- for (int low_confidence_rule = 0;
- low_confidence_rule < low_confidence_rules_.size();
- low_confidence_rule++) {
- const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- rule.pattern->Matcher(message_unicode);
- int status = UniLib::RegexMatcher::kNoError;
- if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- // Rule only applies to input-output pairs, so defer the check.
- if (rule.output_pattern != nullptr) {
- post_check_rules->push_back(low_confidence_rule);
- continue;
- }
- return true;
- }
- }
- }
- return false;
-}
-
-bool ActionsSuggestions::FilterConfidenceOutput(
- const std::vector<int>& post_check_rules,
- std::vector<ActionSuggestion>* actions) const {
- if (post_check_rules.empty() || actions->empty()) {
- return true;
- }
- std::vector<ActionSuggestion> filtered_text_replies;
- for (const ActionSuggestion& action : *actions) {
- if (action.response_text.empty()) {
- filtered_text_replies.push_back(action);
- continue;
- }
- bool passes_post_check = true;
- const UnicodeText text_reply_unicode(
- UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
- for (const int rule_id : post_check_rules) {
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- low_confidence_rules_[rule_id].output_pattern->Matcher(
- text_reply_unicode);
- if (matcher == nullptr) {
- TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
- return false;
- }
- int status = UniLib::RegexMatcher::kNoError;
- if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
- passes_post_check = false;
- break;
- }
- }
- if (passes_post_check) {
- filtered_text_replies.push_back(action);
- }
- }
- *actions = std::move(filtered_text_replies);
- return true;
-}
-
-ActionSuggestion ActionsSuggestions::SuggestionFromSpec(
- const ActionSuggestionSpec* action, const std::string& default_type,
- const std::string& default_response_text,
- const std::string& default_serialized_entity_data,
- const float default_score, const float default_priority_score) const {
- ActionSuggestion suggestion;
- suggestion.score = action != nullptr ? action->score() : default_score;
- suggestion.priority_score =
- action != nullptr ? action->priority_score() : default_priority_score;
- suggestion.type = action != nullptr && action->type() != nullptr
- ? action->type()->str()
- : default_type;
- suggestion.response_text =
- action != nullptr && action->response_text() != nullptr
- ? action->response_text()->str()
- : default_response_text;
- suggestion.serialized_entity_data =
- action != nullptr && action->serialized_entity_data() != nullptr
- ? action->serialized_entity_data()->str()
- : default_serialized_entity_data;
- return suggestion;
-}
-
-std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
- const std::vector<std::string>& context) const {
- std::vector<std::vector<Token>> tokens;
- tokens.reserve(context.size());
- for (const std::string& message : context) {
- tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
- }
- return tokens;
-}
-
-bool ActionsSuggestions::EmbedTokensPerMessage(
- const std::vector<std::vector<Token>>& tokens,
- std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
- const int num_messages = tokens.size();
- *max_num_tokens_per_message = 0;
- for (int i = 0; i < num_messages; i++) {
- const int num_message_tokens = tokens[i].size();
- if (num_message_tokens > *max_num_tokens_per_message) {
- *max_num_tokens_per_message = num_message_tokens;
- }
- }
-
- if (model_->feature_processor_options()->min_num_tokens_per_message() >
- *max_num_tokens_per_message) {
- *max_num_tokens_per_message =
- model_->feature_processor_options()->min_num_tokens_per_message();
- }
- if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
- *max_num_tokens_per_message >
- model_->feature_processor_options()->max_num_tokens_per_message()) {
- *max_num_tokens_per_message =
- model_->feature_processor_options()->max_num_tokens_per_message();
- }
-
- // Embed all tokens and add paddings to pad tokens of each message to the
- // maximum number of tokens in a message of the conversation.
- // If a number of tokens is specified in the model config, tokens at the
- // beginning of a message are dropped if they don't fit in the limit.
- for (int i = 0; i < num_messages; i++) {
- const int start =
- std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
- for (int pos = start; pos < tokens[i].size(); pos++) {
- if (!feature_processor_->AppendTokenFeatures(
- tokens[i][pos], embedding_executor_.get(), embeddings)) {
- TC3_LOG(ERROR) << "Could not run token feature extractor.";
- return false;
- }
- }
- // Add padding.
- for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
- embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
- embedded_padding_token_.end());
- }
- }
-
- return true;
-}
-
-bool ActionsSuggestions::EmbedAndFlattenTokens(
- const std::vector<std::vector<Token>> tokens,
- std::vector<float>* embeddings, int* total_token_count) const {
- const int num_messages = tokens.size();
- int start_message = 0;
- int message_token_offset = 0;
-
- // If a maximum model input length is specified, we need to check how
- // much we need to trim at the start.
- const int max_num_total_tokens =
- model_->feature_processor_options()->max_num_total_tokens();
- if (max_num_total_tokens > 0) {
- int total_tokens = 0;
- start_message = num_messages - 1;
- for (; start_message >= 0; start_message--) {
- // Tokens of the message + start and end token.
- const int num_message_tokens = tokens[start_message].size() + 2;
- total_tokens += num_message_tokens;
-
- // Check whether we exhausted the budget.
- if (total_tokens >= max_num_total_tokens) {
- message_token_offset = total_tokens - max_num_total_tokens;
- break;
- }
- }
- }
-
- // Add embeddings.
- *total_token_count = 0;
- for (int i = start_message; i < num_messages; i++) {
- if (message_token_offset == 0) {
- ++(*total_token_count);
- // Add `start message` token.
- embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
- embedded_start_token_.end());
- }
-
- for (int pos = std::max(0, message_token_offset - 1);
- pos < tokens[i].size(); pos++) {
- ++(*total_token_count);
- if (!feature_processor_->AppendTokenFeatures(
- tokens[i][pos], embedding_executor_.get(), embeddings)) {
- TC3_LOG(ERROR) << "Could not run token feature extractor.";
- return false;
- }
- }
-
- // Add `end message` token.
- ++(*total_token_count);
- embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
- embedded_end_token_.end());
-
- // Reset for the subsequent messages.
- message_token_offset = 0;
- }
-
- // Add optional padding.
- const int min_num_total_tokens =
- model_->feature_processor_options()->min_num_total_tokens();
- for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
- embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
- embedded_padding_token_.end());
- }
-
- return true;
-}
-
-bool ActionsSuggestions::AllocateInput(const int conversation_length,
- const int max_tokens,
- const int total_token_count,
- tflite::Interpreter* interpreter) const {
- if (model_->tflite_model_spec()->resize_inputs()) {
- if (model_->tflite_model_spec()->input_context() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter->inputs()[model_->tflite_model_spec()->input_context()],
- {1, conversation_length});
- }
- if (model_->tflite_model_spec()->input_user_id() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
- {1, conversation_length});
- }
- if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter
- ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
- {1, conversation_length});
- }
- if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter
- ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
- {conversation_length, 1});
- }
- if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter
- ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
- {conversation_length, max_tokens, token_embedding_size_});
- }
- if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
- interpreter->ResizeInputTensor(
- interpreter->inputs()[model_->tflite_model_spec()
- ->input_flattened_token_embeddings()],
- {1, total_token_count});
- }
- }
-
- return interpreter->AllocateTensors() == kTfLiteOk;
-}
-
-bool ActionsSuggestions::SetupModelInput(
- const std::vector<std::string>& context, const std::vector<int>& user_ids,
- const std::vector<float>& time_diffs, const int num_suggestions,
- const float confidence_threshold, const float diversification_distance,
- const float empirical_probability_factor,
- tflite::Interpreter* interpreter) const {
- // Compute token embeddings.
- std::vector<std::vector<Token>> tokens;
- std::vector<float> token_embeddings;
- std::vector<float> flattened_token_embeddings;
- int max_tokens = 0;
- int total_token_count = 0;
- if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
- model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
- model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
- if (feature_processor_ == nullptr) {
- TC3_LOG(ERROR) << "No feature processor specified.";
- return false;
- }
-
- // Tokenize the messages in the conversation.
- tokens = Tokenize(context);
- if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
- if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
- TC3_LOG(ERROR) << "Could not extract token features.";
- return false;
- }
- }
- if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
- if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
- &total_token_count)) {
- TC3_LOG(ERROR) << "Could not extract token features.";
- return false;
- }
- }
- }
-
- if (!AllocateInput(context.size(), max_tokens, total_token_count,
- interpreter)) {
- TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
- return false;
- }
- if (model_->tflite_model_spec()->input_context() >= 0) {
- model_executor_->SetInput<std::string>(
- model_->tflite_model_spec()->input_context(), context, interpreter);
- }
- if (model_->tflite_model_spec()->input_context_length() >= 0) {
- model_executor_->SetInput<int>(
- model_->tflite_model_spec()->input_context_length(), context.size(),
- interpreter);
- }
- if (model_->tflite_model_spec()->input_user_id() >= 0) {
- model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
- user_ids, interpreter);
- }
- if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
- model_executor_->SetInput<int>(
- model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
- interpreter);
- }
- if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_time_diffs(), time_diffs,
- interpreter);
- }
- if (model_->tflite_model_spec()->input_diversification_distance() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_diversification_distance(),
- diversification_distance, interpreter);
- }
- if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_confidence_threshold(),
- confidence_threshold, interpreter);
- }
- if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_empirical_probability_factor(),
- confidence_threshold, interpreter);
- }
- if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
- std::vector<int> num_tokens_per_message(tokens.size());
- for (int i = 0; i < tokens.size(); i++) {
- num_tokens_per_message[i] = tokens[i].size();
- }
- model_executor_->SetInput<int>(
- model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
- interpreter);
- }
- if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
- interpreter);
- }
- if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
- model_executor_->SetInput<float>(
- model_->tflite_model_spec()->input_flattened_token_embeddings(),
- flattened_token_embeddings, interpreter);
- }
- return true;
-}
-
-bool ActionsSuggestions::ReadModelOutput(
- tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response) const {
- // Read sensitivity and triggering score predictions.
- if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
- const TensorView<float>& triggering_score =
- model_executor_->OutputView<float>(
- model_->tflite_model_spec()->output_triggering_score(),
- interpreter);
- if (!triggering_score.is_valid() || triggering_score.size() == 0) {
- TC3_LOG(ERROR) << "Could not compute triggering score.";
- return false;
- }
- response->triggering_score = triggering_score.data()[0];
- response->output_filtered_min_triggering_score =
- (response->triggering_score <
- preconditions_.min_smart_reply_triggering_score);
- }
- if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
- const TensorView<float>& sensitive_topic_score =
- model_executor_->OutputView<float>(
- model_->tflite_model_spec()->output_sensitive_topic_score(),
- interpreter);
- if (!sensitive_topic_score.is_valid() ||
- sensitive_topic_score.dim(0) != 1) {
- TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
- return false;
- }
- response->sensitivity_score = sensitive_topic_score.data()[0];
- response->output_filtered_sensitivity =
- (response->sensitivity_score >
- preconditions_.max_sensitive_topic_score);
- }
-
- // Suppress model outputs.
- if (response->output_filtered_sensitivity) {
- return true;
- }
-
- // Read smart reply predictions.
- std::vector<ActionSuggestion> text_replies;
- if (!response->output_filtered_min_triggering_score &&
- model_->tflite_model_spec()->output_replies() >= 0) {
- const std::vector<tflite::StringRef> replies =
- model_executor_->Output<tflite::StringRef>(
- model_->tflite_model_spec()->output_replies(), interpreter);
- TensorView<float> scores = model_executor_->OutputView<float>(
- model_->tflite_model_spec()->output_replies_scores(), interpreter);
- for (int i = 0; i < replies.size(); i++) {
- if (replies[i].len == 0) continue;
- const float score = scores.data()[i];
- if (score < preconditions_.min_reply_score_threshold) {
- continue;
- }
- response->actions.push_back({std::string(replies[i].str, replies[i].len),
- model_->smart_reply_action_type()->str(),
- score});
- }
- }
-
- // Read actions suggestions.
- if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
- const TensorView<float> actions_scores = model_executor_->OutputView<float>(
- model_->tflite_model_spec()->output_actions_scores(), interpreter);
- for (int i = 0; i < model_->action_type()->Length(); i++) {
- const ActionTypeOptions* action_type = model_->action_type()->Get(i);
- // Skip disabled action classes, such as the default other category.
- if (!action_type->enabled()) {
- continue;
- }
- const float score = actions_scores.data()[i];
- if (score < action_type->min_triggering_score()) {
- continue;
- }
- ActionSuggestion suggestion =
- SuggestionFromSpec(action_type->action(),
- /*default_type=*/action_type->name()->str());
- suggestion.score = score;
- response->actions.push_back(suggestion);
- }
- }
-
- return true;
-}
-
-bool ActionsSuggestions::SuggestActionsFromModel(
- const Conversation& conversation, const int num_messages,
- const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response,
- std::unique_ptr<tflite::Interpreter>* interpreter) const {
- TC3_CHECK_LE(num_messages, conversation.messages.size());
-
- if (!model_executor_) {
- return true;
- }
- *interpreter = model_executor_->CreateInterpreter();
-
- if (!*interpreter) {
- TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
- "actions suggestions model.";
- return false;
- }
-
- std::vector<std::string> context;
- std::vector<int> user_ids;
- std::vector<float> time_diffs;
- context.reserve(num_messages);
- user_ids.reserve(num_messages);
- time_diffs.reserve(num_messages);
-
- // Gather last `num_messages` messages from the conversation.
- int64 last_message_reference_time_ms_utc = 0;
- const float second_in_ms = 1000;
- for (int i = conversation.messages.size() - num_messages;
- i < conversation.messages.size(); i++) {
- const ConversationMessage& message = conversation.messages[i];
- context.push_back(message.text);
- user_ids.push_back(message.user_id);
-
- float time_diff_secs = 0;
- if (message.reference_time_ms_utc != 0 &&
- last_message_reference_time_ms_utc != 0) {
- time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
- last_message_reference_time_ms_utc) /
- second_in_ms);
- }
- if (message.reference_time_ms_utc != 0) {
- last_message_reference_time_ms_utc = message.reference_time_ms_utc;
- }
- time_diffs.push_back(time_diff_secs);
- }
-
- if (!SetupModelInput(context, user_ids, time_diffs,
- /*num_suggestions=*/model_->num_smart_replies(),
- preconditions_.confidence_threshold,
- preconditions_.diversification_distance_threshold,
- preconditions_.empirical_probability_factor,
- interpreter->get())) {
- TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
- return false;
- }
-
- if ((*interpreter)->Invoke() != kTfLiteOk) {
- TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
- return false;
- }
-
- return ReadModelOutput(interpreter->get(), options, response);
-}
-
-AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
- const ConversationMessage& message) const {
- AnnotationOptions options;
- options.detected_text_language_tags = message.detected_text_language_tags;
- options.reference_time_ms_utc = message.reference_time_ms_utc;
- options.reference_timezone = message.reference_timezone;
- options.annotation_usecase =
- model_->annotation_actions_spec()->annotation_usecase();
- options.is_serialized_entity_data_enabled =
- model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
- options.entity_types = annotation_entity_types_;
- return options;
-}
-
-void ActionsSuggestions::SuggestActionsFromAnnotations(
- const Conversation& conversation, const ActionSuggestionOptions& options,
- const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
- if (model_->annotation_actions_spec() == nullptr ||
- model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
- model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
- return;
- }
-
- // Create actions based on the annotations.
- const int max_from_any_person =
- model_->annotation_actions_spec()->max_history_from_any_person();
- const int max_from_last_person =
- model_->annotation_actions_spec()->max_history_from_last_person();
- const int last_person = conversation.messages.back().user_id;
-
- int num_messages_last_person = 0;
- int num_messages_any_person = 0;
- bool all_from_last_person = true;
- for (int message_index = conversation.messages.size() - 1; message_index >= 0;
- message_index--) {
- const ConversationMessage& message = conversation.messages[message_index];
- std::vector<AnnotatedSpan> annotations = message.annotations;
-
- // Update how many messages we have processed from the last person in the
- // conversation and from any person in the conversation.
- num_messages_any_person++;
- if (all_from_last_person && message.user_id == last_person) {
- num_messages_last_person++;
- } else {
- all_from_last_person = false;
- }
-
- if (num_messages_any_person > max_from_any_person &&
- (!all_from_last_person ||
- num_messages_last_person > max_from_last_person)) {
- break;
- }
-
- if (message.user_id == kLocalUserId) {
- if (model_->annotation_actions_spec()->only_until_last_sent()) {
- break;
- }
- if (!model_->annotation_actions_spec()->include_local_user_messages()) {
- continue;
- }
- }
-
- if (annotations.empty() && annotator != nullptr) {
- annotations = annotator->Annotate(message.text,
- AnnotationOptionsForMessage(message));
- }
- std::vector<ActionSuggestionAnnotation> action_annotations;
- action_annotations.reserve(annotations.size());
- for (const AnnotatedSpan& annotation : annotations) {
- if (annotation.classification.empty()) {
- continue;
- }
-
- const ClassificationResult& classification_result =
- annotation.classification[0];
-
- ActionSuggestionAnnotation action_annotation;
- action_annotation.span = {
- message_index, annotation.span,
- UTF8ToUnicodeText(message.text, /*do_copy=*/false)
- .UTF8Substring(annotation.span.first, annotation.span.second)};
- action_annotation.entity = classification_result;
- action_annotation.name = classification_result.collection;
- action_annotations.push_back(action_annotation);
- }
-
- if (model_->annotation_actions_spec()->deduplicate_annotations()) {
- // Create actions only for deduplicated annotations.
- for (const int annotation_id :
- DeduplicateAnnotations(action_annotations)) {
- SuggestActionsFromAnnotation(
- message_index, action_annotations[annotation_id], actions);
- }
- } else {
- // Create actions for all annotations.
- for (const ActionSuggestionAnnotation& annotation : action_annotations) {
- SuggestActionsFromAnnotation(message_index, annotation, actions);
- }
- }
- }
-}
-
-void ActionsSuggestions::SuggestActionsFromAnnotation(
- const int message_index, const ActionSuggestionAnnotation& annotation,
- std::vector<ActionSuggestion>* actions) const {
- for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
- *model_->annotation_actions_spec()->annotation_mapping()) {
- if (annotation.entity.collection ==
- mapping->annotation_collection()->str()) {
- if (annotation.entity.score < mapping->min_annotation_score()) {
- continue;
- }
- ActionSuggestion suggestion = SuggestionFromSpec(mapping->action());
- if (mapping->use_annotation_score()) {
- suggestion.score = annotation.entity.score;
- }
-
- // Set annotation text as (additional) entity data field.
- if (mapping->entity_field() != nullptr) {
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
- TC3_CHECK(entity_data != nullptr);
-
- // Merge existing static entity data.
- if (!suggestion.serialized_entity_data.empty()) {
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(suggestion.serialized_entity_data.c_str(),
- suggestion.serialized_entity_data.size()));
- }
-
- entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
- suggestion.serialized_entity_data = entity_data->Serialize();
- }
-
- suggestion.annotations = {annotation};
- actions->push_back(suggestion);
- }
- }
-}
-
-std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
- const std::vector<ActionSuggestionAnnotation>& annotations) const {
- std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
-
- for (int i = 0; i < annotations.size(); i++) {
- const std::pair<std::string, std::string> key = {annotations[i].name,
- annotations[i].span.text};
- auto entry = deduplicated_annotations.find(key);
- if (entry != deduplicated_annotations.end()) {
- // Kepp the annotation with the higher score.
- if (annotations[entry->second].entity.score <
- annotations[i].entity.score) {
- entry->second = i;
- }
- continue;
- }
- deduplicated_annotations.insert(entry, {key, i});
- }
-
- std::vector<int> result;
- result.reserve(deduplicated_annotations.size());
- for (const auto& key_and_annotation : deduplicated_annotations) {
- result.push_back(key_and_annotation.second);
- }
- return result;
-}
-
-bool ActionsSuggestions::FillAnnotationFromMatchGroup(
- const UniLib::RegexMatcher* matcher,
- const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
- const int message_index, ActionSuggestionAnnotation* annotation) const {
- if (group->annotation_name() != nullptr ||
- group->annotation_type() != nullptr) {
- int status = UniLib::RegexMatcher::kNoError;
- const CodepointSpan span = {matcher->Start(group->group_id(), &status),
- matcher->End(group->group_id(), &status)};
- std::string text =
- matcher->Group(group->group_id(), &status).ToUTF8String();
- if (status != UniLib::RegexMatcher::kNoError) {
- TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
- return false;
- }
-
- // The capturing group was not part of the match.
- if (span.first == kInvalidIndex || span.second == kInvalidIndex) {
- return false;
- }
- annotation->span.span = span;
- annotation->span.message_index = message_index;
- annotation->span.text = text;
- if (group->annotation_name() != nullptr) {
- annotation->name = group->annotation_name()->str();
- }
- if (group->annotation_type() != nullptr) {
- annotation->entity.collection = group->annotation_type()->str();
- }
- }
- return true;
-}
-
-bool ActionsSuggestions::SuggestActionsFromRules(
- const Conversation& conversation,
- std::vector<ActionSuggestion>* actions) const {
- // Create actions based on rules checking the last message.
- const int message_index = conversation.messages.size() - 1;
- const std::string& message = conversation.messages.back().text;
- const UnicodeText message_unicode(
- UTF8ToUnicodeText(message, /*do_copy=*/false));
- for (const CompiledRule& rule : rules_) {
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- rule.pattern->Matcher(message_unicode);
- int status = UniLib::RegexMatcher::kNoError;
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- for (const RulesModel_::Rule_::RuleActionSpec* rule_action :
- *rule.rule->actions()) {
- const ActionSuggestionSpec* action = rule_action->action();
- std::vector<ActionSuggestionAnnotation> annotations;
-
- bool sets_entity_data = false;
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
- : nullptr;
-
- // Set static entity data.
- if (action != nullptr && action->serialized_entity_data() != nullptr) {
- TC3_CHECK(entity_data != nullptr);
- sets_entity_data = true;
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(action->serialized_entity_data()->c_str(),
- action->serialized_entity_data()->size()));
- }
-
- // Add entity data from rule capturing groups.
- if (rule_action->capturing_group() != nullptr) {
- for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
- group : *rule_action->capturing_group()) {
- if (group->entity_field() != nullptr) {
- TC3_CHECK(entity_data != nullptr);
- sets_entity_data = true;
- if (!SetFieldFromCapturingGroup(
- group->group_id(), group->entity_field(), matcher.get(),
- entity_data.get())) {
- TC3_LOG(ERROR)
- << "Could not set entity data from rule capturing group.";
- return false;
- }
- }
-
- // Create a text annotation for the group span.
- ActionSuggestionAnnotation annotation;
- if (FillAnnotationFromMatchGroup(matcher.get(), group,
- message_index, &annotation)) {
- annotations.push_back(annotation);
- }
-
- // Create text reply.
- if (group->text_reply() != nullptr) {
- int status = UniLib::RegexMatcher::kNoError;
- const std::string group_text =
- matcher->Group(group->group_id(), &status).ToUTF8String();
- if (status != UniLib::RegexMatcher::kNoError) {
- TC3_LOG(ERROR) << "Could get text from capturing group.";
- return false;
- }
- if (group_text.empty()) {
- // The group was not part of the match, ignore and continue.
- continue;
- }
- actions->push_back(SuggestionFromSpec(
- group->text_reply(),
- /*default_type=*/model_->smart_reply_action_type()->str(),
- /*default_response_text=*/group_text));
- }
- }
- }
-
- if (action != nullptr) {
- ActionSuggestion suggestion = SuggestionFromSpec(action);
- suggestion.annotations = annotations;
- if (sets_entity_data) {
- suggestion.serialized_entity_data = entity_data->Serialize();
- }
- actions->push_back(suggestion);
- }
- }
- }
- }
- return true;
-}
-
-bool ActionsSuggestions::SuggestActionsFromLua(
- const Conversation& conversation, const TfLiteModelExecutor* model_executor,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* annotation_entity_data_schema,
- std::vector<ActionSuggestion>* actions) const {
- if (lua_bytecode_.empty()) {
- return true;
- }
-
- auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
- lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
- interpreter, entity_data_schema_, annotation_entity_data_schema);
- if (lua_actions == nullptr) {
- TC3_LOG(ERROR) << "Could not create lua actions.";
- return false;
- }
- return lua_actions->SuggestActions(actions);
-}
-
-bool ActionsSuggestions::GatherActionsSuggestions(
- const Conversation& conversation, const Annotator* annotator,
- const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response) const {
- if (conversation.messages.empty()) {
- return true;
- }
-
- const int num_messages = NumMessagesToConsider(
- conversation, model_->max_conversation_history_length());
-
- if (num_messages <= 0) {
- TC3_LOG(INFO) << "No messages provided for actions suggestions.";
- return false;
- }
-
- SuggestActionsFromAnnotations(conversation, options, annotator,
- &response->actions);
-
- int input_text_length = 0;
- int num_matching_locales = 0;
- for (int i = conversation.messages.size() - num_messages;
- i < conversation.messages.size(); i++) {
- input_text_length += conversation.messages[i].text.length();
- std::vector<Locale> message_languages;
- if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
- &message_languages)) {
- continue;
- }
- if (Locale::IsAnyLocaleSupported(
- message_languages, locales_,
- preconditions_.handle_unknown_locale_as_supported)) {
- ++num_matching_locales;
- }
- }
-
- // Bail out if we are provided with too few or too much input.
- if (input_text_length < preconditions_.min_input_length ||
- (preconditions_.max_input_length >= 0 &&
- input_text_length > preconditions_.max_input_length)) {
- TC3_LOG(INFO) << "Too much or not enough input for inference.";
- return response;
- }
-
- // Bail out if the text does not look like it can be handled by the model.
- const float matching_fraction =
- static_cast<float>(num_matching_locales) / num_messages;
- if (matching_fraction < preconditions_.min_locale_match_fraction) {
- TC3_LOG(INFO) << "Not enough locale matches.";
- response->output_filtered_locale_mismatch = true;
- return true;
- }
-
- std::vector<int> post_check_rules;
- if (preconditions_.suppress_on_low_confidence_input &&
- IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) {
- response->output_filtered_low_confidence = true;
- return true;
- }
-
- std::unique_ptr<tflite::Interpreter> interpreter;
- if (!SuggestActionsFromModel(conversation, num_messages, options, response,
- &interpreter)) {
- TC3_LOG(ERROR) << "Could not run model.";
- return false;
- }
-
- // Suppress all predictions if the conversation was deemed sensitive.
- if (preconditions_.suppress_on_sensitive_topic &&
- response->output_filtered_sensitivity) {
- return true;
- }
-
- if (!SuggestActionsFromLua(
- conversation, model_executor_.get(), interpreter.get(),
- annotator != nullptr ? annotator->entity_data_schema() : nullptr,
- &response->actions)) {
- TC3_LOG(ERROR) << "Could not suggest actions from script.";
- return false;
- }
-
- if (!SuggestActionsFromRules(conversation, &response->actions)) {
- TC3_LOG(ERROR) << "Could not suggest actions from rules.";
- return false;
- }
-
- if (preconditions_.suppress_on_low_confidence_input &&
- !FilterConfidenceOutput(post_check_rules, &response->actions)) {
- TC3_LOG(ERROR) << "Could not post-check actions.";
- return false;
- }
-
- return true;
-}
-
-ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
- const Conversation& conversation, const Annotator* annotator,
- const ActionSuggestionOptions& options) const {
- ActionsSuggestionsResponse response;
- if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
- TC3_LOG(ERROR) << "Could not gather actions suggestions.";
- response.actions.clear();
- } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
- annotator != nullptr
- ? annotator->entity_data_schema()
- : nullptr)) {
- TC3_LOG(ERROR) << "Could not rank actions.";
- response.actions.clear();
- }
- return response;
-}
-
-ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
- const Conversation& conversation,
- const ActionSuggestionOptions& options) const {
- return SuggestActions(conversation, /*annotator=*/nullptr, options);
-}
-
-const ActionsModel* ActionsSuggestions::model() const { return model_; }
-const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
- return entity_data_schema_;
-}
-
-const ActionsModel* ViewActionsModel(const void* buffer, int size) {
- if (buffer == nullptr) {
- return nullptr;
- }
- return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
deleted file mode 100644
index 2dde133..0000000
--- a/actions/actions-suggestions.h
+++ /dev/null
@@ -1,319 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
-
-#include <map>
-#include <memory>
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "actions/actions_model_generated.h"
-#include "actions/feature-processor.h"
-#include "actions/ngram-model.h"
-#include "actions/ranker.h"
-#include "actions/types.h"
-#include "annotator/annotator.h"
-#include "annotator/model-executor.h"
-#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/i18n/locale.h"
-#include "utils/memory/mmap.h"
-#include "utils/tflite-model-executor.h"
-#include "utils/utf8/unilib.h"
-#include "utils/variant.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-
-// Options for suggesting actions.
-struct ActionSuggestionOptions {
- static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
-};
-
-// Class for predicting actions following a conversation.
-class ActionsSuggestions {
- public:
- // Creates ActionsSuggestions from given data buffer with model.
- static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
- const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
- const std::string& triggering_preconditions_overlay = "");
-
- // Creates ActionsSuggestions from model in the ScopedMmap object and takes
- // ownership of it.
- static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
- const UniLib* unilib = nullptr,
- const std::string& triggering_preconditions_overlay = "");
- // Same as above, but also takes ownership of the unilib.
- static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
- std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay);
-
- // Creates ActionsSuggestions from model given as a file descriptor, offset
- // and size in it. If offset and size are less than 0, will ignore them and
- // will just use the fd.
- static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
- const int fd, const int offset, const int size,
- const UniLib* unilib = nullptr,
- const std::string& triggering_preconditions_overlay = "");
- // Same as above, but also takes ownership of the unilib.
- static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
- const int fd, const int offset, const int size,
- std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay = "");
-
- // Creates ActionsSuggestions from model given as a file descriptor.
- static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
- const int fd, const UniLib* unilib = nullptr,
- const std::string& triggering_preconditions_overlay = "");
- // Same as above, but also takes ownership of the unilib.
- static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
- const int fd, std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay);
-
- // Creates ActionsSuggestions from model given as a POSIX path.
- static std::unique_ptr<ActionsSuggestions> FromPath(
- const std::string& path, const UniLib* unilib = nullptr,
- const std::string& triggering_preconditions_overlay = "");
- // Same as above, but also takes ownership of unilib.
- static std::unique_ptr<ActionsSuggestions> FromPath(
- const std::string& path, std::unique_ptr<UniLib> unilib,
- const std::string& triggering_preconditions_overlay);
-
- ActionsSuggestionsResponse SuggestActions(
- const Conversation& conversation,
- const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
-
- ActionsSuggestionsResponse SuggestActions(
- const Conversation& conversation, const Annotator* annotator,
- const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
-
- const ActionsModel* model() const;
- const reflection::Schema* entity_data_schema() const;
-
- static const int kLocalUserId = 0;
-
- // Should be in sync with those defined in Android.
- // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
- static const std::string& kViewCalendarType;
- static const std::string& kViewMapType;
- static const std::string& kTrackFlightType;
- static const std::string& kOpenUrlType;
- static const std::string& kSendSmsType;
- static const std::string& kCallPhoneType;
- static const std::string& kSendEmailType;
- static const std::string& kShareLocation;
-
- protected:
- // Exposed for testing.
- bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
-
- // Embeds the tokens per message separately. Each message is padded to the
- // maximum length with the padding token.
- bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
- std::vector<float>* embeddings,
- int* max_num_tokens_per_message) const;
-
- // Concatenates the embedded message tokens - separated by start and end
- // token between messages.
- // If the total token count is greater than the maximum length, tokens at the
- // start are dropped to fit into the limit.
- // If the total token count is smaller than the minimum length, padding tokens
- // are added to the end.
- // Messages are assumed to be ordered by recency - most recent is last.
- bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,
- std::vector<float>* embeddings,
- int* total_token_count) const;
-
- const ActionsModel* model_;
-
- // Feature extractor and options.
- std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
- std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
- std::vector<float> embedded_padding_token_;
- std::vector<float> embedded_start_token_;
- std::vector<float> embedded_end_token_;
- int token_embedding_size_;
-
- private:
- struct CompiledRule {
- const RulesModel_::Rule* rule;
- std::unique_ptr<UniLib::RegexPattern> pattern;
- std::unique_ptr<UniLib::RegexPattern> output_pattern;
- CompiledRule(const RulesModel_::Rule* rule,
- std::unique_ptr<UniLib::RegexPattern> pattern,
- std::unique_ptr<UniLib::RegexPattern> output_pattern)
- : rule(rule),
- pattern(std::move(pattern)),
- output_pattern(std::move(output_pattern)) {}
- };
-
- // Checks that model contains all required fields, and initializes internal
- // datastructures.
- bool ValidateAndInitialize();
-
- void SetOrCreateUnilib(const UniLib* unilib);
-
- // Initializes regular expression rules.
- bool InitializeRules(ZlibDecompressor* decompressor);
- bool InitializeRules(ZlibDecompressor* decompressor, const RulesModel* rules,
- std::vector<CompiledRule>* compiled_rules) const;
-
- // Prepare preconditions.
- // Takes values from flag provided data, but falls back to model provided
- // values for parameters that are not explicitly provided.
- bool InitializeTriggeringPreconditions();
-
- // Tokenizes a conversation and produces the tokens per message.
- std::vector<std::vector<Token>> Tokenize(
- const std::vector<std::string>& context) const;
-
- bool AllocateInput(const int conversation_length, const int max_tokens,
- const int total_token_count,
- tflite::Interpreter* interpreter) const;
-
- bool SetupModelInput(const std::vector<std::string>& context,
- const std::vector<int>& user_ids,
- const std::vector<float>& time_diffs,
- const int num_suggestions,
- const float confidence_threshold,
- const float diversification_distance,
- const float empirical_probability_factor,
- tflite::Interpreter* interpreter) const;
- bool ReadModelOutput(tflite::Interpreter* interpreter,
- const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response) const;
-
- bool SuggestActionsFromModel(
- const Conversation& conversation, const int num_messages,
- const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response,
- std::unique_ptr<tflite::Interpreter>* interpreter) const;
-
- // Creates options for annotation of a message.
- AnnotationOptions AnnotationOptionsForMessage(
- const ConversationMessage& message) const;
-
- void SuggestActionsFromAnnotations(
- const Conversation& conversation, const ActionSuggestionOptions& options,
- const Annotator* annotator, std::vector<ActionSuggestion>* actions) const;
-
- void SuggestActionsFromAnnotation(
- const int message_index, const ActionSuggestionAnnotation& annotation,
- std::vector<ActionSuggestion>* actions) const;
-
- // Deduplicates equivalent annotations - annotations that have the same type
- // and same span text.
- // Returns the indices of the deduplicated annotations.
- std::vector<int> DeduplicateAnnotations(
- const std::vector<ActionSuggestionAnnotation>& annotations) const;
-
- bool SuggestActionsFromRules(const Conversation& conversation,
- std::vector<ActionSuggestion>* actions) const;
-
- bool SuggestActionsFromLua(
- const Conversation& conversation,
- const TfLiteModelExecutor* model_executor,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* annotation_entity_data_schema,
- std::vector<ActionSuggestion>* actions) const;
-
- bool GatherActionsSuggestions(const Conversation& conversation,
- const Annotator* annotator,
- const ActionSuggestionOptions& options,
- ActionsSuggestionsResponse* response) const;
-
- // Checks whether the input triggers the low confidence checks.
- bool IsLowConfidenceInput(const Conversation& conversation,
- const int num_messages,
- std::vector<int>* post_check_rules) const;
- // Checks and filters suggestions triggering the low confidence post checks.
- bool FilterConfidenceOutput(const std::vector<int>& post_check_rules,
- std::vector<ActionSuggestion>* actions) const;
-
- ActionSuggestion SuggestionFromSpec(
- const ActionSuggestionSpec* action, const std::string& default_type = "",
- const std::string& default_response_text = "",
- const std::string& default_serialized_entity_data = "",
- const float default_score = 0.0f,
- const float default_priority_score = 0.0f) const;
-
- bool FillAnnotationFromMatchGroup(
- const UniLib::RegexMatcher* matcher,
- const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
- const int message_index, ActionSuggestionAnnotation* annotation) const;
-
- std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
-
- // Tensorflow Lite models.
- std::unique_ptr<const TfLiteModelExecutor> model_executor_;
-
- // Rules.
- std::vector<CompiledRule> rules_, low_confidence_rules_;
-
- std::unique_ptr<UniLib> owned_unilib_;
- const UniLib* unilib_;
-
- // Locales supported by the model.
- std::vector<Locale> locales_;
-
- // Annotation entities used by the model.
- std::unordered_set<std::string> annotation_entity_types_;
-
- // Builder for creating extra data.
- const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
- std::unique_ptr<ActionsSuggestionsRanker> ranker_;
-
- std::string lua_bytecode_;
-
- // Triggering preconditions. These parameters can be backed by the model and
- // (partially) be provided by flags.
- TriggeringPreconditionsT preconditions_;
- std::string triggering_preconditions_overlay_buffer_;
- const TriggeringPreconditions* triggering_preconditions_overlay_;
-
- // Low confidence input ngram classifier.
- std::unique_ptr<const NGramModel> ngram_model_;
-};
-
-// Interprets the buffer as a Model flatbuffer and returns it for reading.
-const ActionsModel* ViewActionsModel(const void* buffer, int size);
-
-// Opens model from given path and runs a function, passing the loaded Model
-// flatbuffer as an argument.
-//
-// This is mainly useful if we don't want to pay the cost for the model
-// initialization because we'll be only reading some flatbuffer values from the
-// file.
-template <typename ReturnType, typename Func>
-ReturnType VisitActionsModel(const std::string& path, Func function) {
- ScopedMmap mmap(path);
- if (!mmap.handle().ok()) {
- function(/*model=*/nullptr);
- }
- const ActionsModel* model =
- ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
- return function(model);
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
deleted file mode 100644
index e0cfbaa..0000000
--- a/actions/actions-suggestions_test.cc
+++ /dev/null
@@ -1,1332 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/actions-suggestions.h"
-
-#include <fstream>
-#include <iterator>
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "actions/test_utils.h"
-#include "actions/zlib-utils.h"
-#include "annotator/collections.h"
-#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/flatbuffers_generated.h"
-#include "utils/hash/farmhash.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-
-namespace libtextclassifier3 {
-namespace {
-using testing::_;
-
-constexpr char kModelFileName[] = "actions_suggestions_test.model";
-constexpr char kHashGramModelFileName[] =
- "actions_suggestions_test.hashgram.model";
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-std::string GetModelPath() {
- return "";
-}
-
-class ActionsSuggestionsTest : public testing::Test {
- protected:
- ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- std::unique_ptr<ActionsSuggestions> LoadTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName,
- &unilib_);
- }
- std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
- &unilib_);
- }
- UniLib unilib_;
-};
-
-TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
- EXPECT_THAT(LoadTestModel(), testing::NotNull());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"zz"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions.front().type, "view_map");
- EXPECT_EQ(response.actions.front().score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- SetTestEntityDataSchema(actions_model.get());
-
- // Set custom actions from annotations config.
- actions_model->annotation_actions_spec->annotation_mapping.clear();
- actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
- new AnnotationActionsSpec_::AnnotationMappingT);
- AnnotationActionsSpec_::AnnotationMappingT* mapping =
- actions_model->annotation_actions_spec->annotation_mapping.back().get();
- mapping->annotation_collection = "address";
- mapping->action.reset(new ActionSuggestionSpecT);
- mapping->action->type = "save_location";
- mapping->action->score = 1.0;
- mapping->action->priority_score = 2.0;
- mapping->entity_field.reset(new FlatbufferFieldPathT);
- mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
- mapping->entity_field->field.back()->field_name = "location";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions.front().type, "save_location");
- EXPECT_EQ(response.actions.front().score, 1.0);
-
- // Check that the `location` entity field holds the text from the address
- // annotation.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions.front().serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "home");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {11, 15};
- flight_annotation.classification = {ClassificationResult("flight", 2.5)};
- AnnotatedSpan flight_annotation2;
- flight_annotation2.span = {35, 39};
- flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {55, 68};
- email_annotation.classification = {ClassificationResult("email", 2.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "call me at LX38 or send message to LX38 or test@test.com.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation, flight_annotation2, email_annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 3.0);
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[1].score, 2.0);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- // Disable deduplication.
- actions_model->annotation_actions_spec->deduplicate_annotations = false;
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {11, 15};
- flight_annotation.classification = {ClassificationResult("flight", 2.5)};
- AnnotatedSpan flight_annotation2;
- flight_annotation2.span = {35, 39};
- flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {55, 68};
- email_annotation.classification = {ClassificationResult("email", 2.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "call me at LX38 or send message to LX38 or test@test.com.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation, flight_annotation2, email_annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 3.0);
- EXPECT_EQ(response.actions[1].type, "track_flight");
- EXPECT_EQ(response.actions[1].score, 2.5);
- EXPECT_EQ(response.actions[2].type, "send_email");
- EXPECT_EQ(response.actions[2].score, 2.0);
-}
-
-ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
- const std::function<void(ActionsModelT*)>& set_config_fn,
- const UniLib* unilib = nullptr) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Set custom config.
- set_config_fn(actions_model.get());
-
- // Disable smart reply for easier testing.
- actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), unilib);
-
- AnnotatedSpan flight_annotation;
- flight_annotation.span = {15, 19};
- flight_annotation.classification = {ClassificationResult("flight", 2.0)};
- AnnotatedSpan email_annotation;
- email_annotation.span = {0, 16};
- email_annotation.classification = {ClassificationResult("email", 1.0)};
-
- return actions_suggestions->SuggestActions(
- {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
- "hehe@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/2,
- "yoyo@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/1,
- "test@android.com",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {email_annotation},
- /*locales=*/"en"},
- {/*user_id=*/1,
- "I am on flight LX38.",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/
- {flight_annotation},
- /*locales=*/"en"}}});
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 1;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "track_flight");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 1;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 3;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 2;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessages) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 3;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- false;
- actions_model->annotation_actions_spec->only_until_last_sent = true;
- actions_model->annotation_actions_spec->max_history_from_any_person = 5;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 3);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
- const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
- [](ActionsModelT* actions_model) {
- actions_model->annotation_actions_spec->include_local_user_messages =
- true;
- actions_model->annotation_actions_spec->only_until_last_sent = false;
- actions_model->annotation_actions_spec->max_history_from_any_person = 5;
- actions_model->annotation_actions_spec->max_history_from_last_person =
- 1;
- },
- &unilib_);
- EXPECT_EQ(response.actions.size(), 4);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[1].type, "send_email");
- EXPECT_EQ(response.actions[2].type, "send_email");
- EXPECT_EQ(response.actions[3].type, "send_email");
-}
-
-void TestSuggestActionsWithThreshold(
- const std::function<void(ActionsModelT*)>& set_value_fn,
- const UniLib* unilib = nullptr, const int expected_size = 0,
- const std::string& preconditions_overwrite = "") {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- set_value_fn(actions_model.get());
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), unilib, preconditions_overwrite);
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I have the low-ground. Where are you?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_LE(response.actions.size(), expected_size);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
- },
- &unilib_,
- /*expected_size=*/1 /*no smart reply, only actions*/
- );
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_reply_score_threshold = 1.0;
- },
- &unilib_,
- /*expected_size=*/1 /*no smart reply, only actions*/
- );
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->max_sensitive_topic_score = 0.0;
- },
- &unilib_,
- /*expected_size=*/4 /* no sensitive prediction in test model*/);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->max_input_length = 0;
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->min_input_length = 100;
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) {
- TriggeringPreconditionsT preconditions_overwrite;
- preconditions_overwrite.max_input_length = 0;
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
- TestSuggestActionsWithThreshold(
- // Keep model untouched.
- [](ActionsModelT* actions_model) {}, &unilib_,
- /*expected_size=*/0,
- std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize()));
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) {
- TestSuggestActionsWithThreshold(
- [](ActionsModelT* actions_model) {
- actions_model->preconditions->suppress_on_low_confidence_input = true;
- actions_model->low_confidence_rules.reset(new RulesModelT);
- actions_model->low_confidence_rules->rule.emplace_back(
- new RulesModel_::RuleT);
- actions_model->low_confidence_rules->rule.back()->pattern =
- "low-ground";
- },
- &unilib_);
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- // Add custom triggering rule.
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- {
- std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
- new RulesModel_::Rule_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Desaster!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
- {
- std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
- new RulesModel_::Rule_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Kenobi!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
-
- // Add input-output low confidence rule.
- actions_model->preconditions->suppress_on_low_confidence_input = true;
- actions_model->low_confidence_rules.reset(new RulesModelT);
- actions_model->low_confidence_rules->rule.emplace_back(
- new RulesModel_::RuleT);
- actions_model->low_confidence_rules->rule.back()->pattern = "hello";
- actions_model->low_confidence_rules->rule.back()->output_pattern =
- "(?i:desaster)";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-}
-
-TEST_F(ActionsSuggestionsTest,
- SuggestActionsLowConfidenceInputOutputOverwrite) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- actions_model->low_confidence_rules.reset();
-
- // Add custom triggering rule.
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- {
- std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
- new RulesModel_::Rule_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Desaster!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
- {
- std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action(
- new RulesModel_::Rule_::RuleActionSpecT);
- rule_action->action.reset(new ActionSuggestionSpecT);
- rule_action->action->type = "text_reply";
- rule_action->action->response_text = "General Kenobi!";
- rule_action->action->score = 1.0f;
- rule_action->action->priority_score = 1.0f;
- rule->actions.push_back(std::move(rule_action));
- }
-
- // Add custom triggering rule via overwrite.
- actions_model->preconditions->low_confidence_rules.reset();
- TriggeringPreconditionsT preconditions;
- preconditions.suppress_on_low_confidence_input = true;
- preconditions.low_confidence_rules.reset(new RulesModelT);
- preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT);
- preconditions.low_confidence_rules->rule.back()->pattern = "hello";
- preconditions.low_confidence_rules->rule.back()->output_pattern =
- "(?i:desaster)";
- flatbuffers::FlatBufferBuilder preconditions_builder;
- preconditions_builder.Finish(
- TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
- std::string serialize_preconditions = std::string(
- reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
- preconditions_builder.GetSize());
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_, serialize_preconditions);
-
- ASSERT_TRUE(actions_suggestions);
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-}
-#endif
-
-TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Don't test if no sensitivity score is produced
- if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
- return;
- }
-
- actions_model->preconditions->max_sensitive_topic_score = 0.0;
- actions_model->preconditions->suppress_on_sensitive_topic = true;
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {
- ClassificationResult(Collections::Address(), 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
-
- // Allow a larger conversation context.
- actions_model->max_conversation_history_length = 10;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {
- ClassificationResult(Collections::Address(), 1.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
- /*reference_time_ms_utc=*/10000,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"},
- {/*user_id=*/1, "good! are you at home?",
- /*reference_time_ms_utc=*/15000,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- ASSERT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "view_map");
- EXPECT_EQ(response.actions[0].score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {8, 12};
- annotation.classification = {
- ClassificationResult(Collections::Flight(), 1.0)};
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
-
- ASSERT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "track_flight");
- EXPECT_EQ(response.actions[0].score, 1.0);
- EXPECT_EQ(response.actions[0].annotations.size(), 1);
- EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
- EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
-}
-
-#ifdef TC3_UNILIB_ICU
-TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
- rule->pattern = "^(?i:hello\\s(there))$";
- rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
- rule->actions.back()->action.reset(new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action = rule->actions.back()->action.get();
- action->type = "text_reply";
- action->response_text = "General Kenobi!";
- action->score = 1.0f;
- action->priority_score = 1.0f;
-
- // Set capturing groups for entity data.
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
- rule->actions.back()->capturing_group.back().get();
- greeting_group->group_id = 0;
- greeting_group->entity_field.reset(new FlatbufferFieldPathT);
- greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
- greeting_group->entity_field->field.back()->field_name = "greeting";
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
- rule->actions.back()->capturing_group.back().get();
- location_group->group_id = 1;
- location_group->entity_field.reset(new FlatbufferFieldPathT);
- location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
- location_group->entity_field->field.back()->field_name = "location";
-
- // Set test entity data schema.
- SetTestEntityDataSchema(actions_model.get());
-
- // Use meta data to generate custom serialized entity data.
- ReflectiveFlatbufferBuilder entity_data_builder(
- flatbuffers::GetRoot<reflection::Schema>(
- actions_model->actions_entity_data_schema.data()));
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder.NewRoot();
- entity_data->Set("person", "Kenobi");
- action->serialized_entity_data = entity_data->Serialize();
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
-
- // Check entity data.
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- response.actions[0].serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "hello there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
- rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
- rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
-
- // Set capturing groups for entity data.
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
- rule->actions.back()->capturing_group.back().get();
- code_group->group_id = 1;
- code_group->text_reply.reset(new ActionSuggestionSpecT);
- code_group->text_reply->score = 1.0f;
- code_group->text_reply->priority_score = 1.0f;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1,
- "visit test.com or reply STOP to cancel your subscription",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "STOP");
-}
-
-TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
-
- // Check that the location sharing model triggered.
- bool has_location_sharing_action = false;
- for (const ActionSuggestion action : response.actions) {
- if (action.type == ActionsSuggestions::kShareLocation) {
- has_location_sharing_action = true;
- break;
- }
- }
- EXPECT_TRUE(has_location_sharing_action);
- const int num_actions = response.actions.size();
-
- // Add custom rule for location sharing.
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$";
- actions_model->rules->rule.back()->actions.emplace_back(
- new RulesModel_::Rule_::RuleActionSpecT);
- actions_model->rules->rule.back()->actions.back()->action.reset(
- new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action =
- actions_model->rules->rule.back()->actions.back()->action.get();
- action->score = 1.0f;
- action->type = ActionsSuggestions::kShareLocation;
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{}, /*locales=*/"en"}}});
- EXPECT_EQ(response.actions.size(), num_actions);
-}
-
-TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- AnnotatedSpan annotation;
- annotation.span = {7, 11};
- annotation.classification = {
- ClassificationResult(Collections::Flight(), 1.0)};
- ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
-
- // Check that the phone actions are present.
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "track_flight");
-
- // Add custom rule.
- const std::string actions_model_string =
- ReadFile(GetModelPath() + kModelFileName);
- std::unique_ptr<ActionsModelT> actions_model =
- UnPackActionsModel(actions_model_string.c_str());
- ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
-
- actions_model->rules.reset(new RulesModelT());
- actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
- RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
- rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
- rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
- rule->actions.back()->action.reset(new ActionSuggestionSpecT);
- ActionSuggestionSpecT* action = rule->actions.back()->action.get();
- action->score = 1.0f;
- action->priority_score = 2.0f;
- action->type = "test_code";
- rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
- rule->actions.back()->capturing_group.back().get();
- code_group->group_id = 1;
- code_group->annotation_name = "code";
- code_group->annotation_type = "code";
-
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, actions_model.get()));
- actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
- reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
- builder.GetSize(), &unilib_);
-
- response = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "I'm on LX38",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].type, "test_code");
-}
-#endif
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- std::vector<AnnotatedSpan> annotations(2);
- annotations[0].span = {11, 15};
- annotations[0].classification = {ClassificationResult("address", 1.0)};
- annotations[1].span = {19, 23};
- annotations[1].classification = {ClassificationResult("address", 2.0)};
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "are you at home or work?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/annotations,
- /*locales=*/"en"}}});
- EXPECT_GE(response.actions.size(), 2);
- EXPECT_EQ(response.actions[0].type, "view_map");
- EXPECT_EQ(response.actions[0].score, 2.0);
- EXPECT_EQ(response.actions[1].type, "view_map");
- EXPECT_EQ(response.actions[1].score, 1.0);
-}
-
-TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
- EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
- [](const ActionsModel* model) {
- if (model == nullptr) {
- return false;
- }
- return true;
- }));
- EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
- [](const ActionsModel* model) {
- if (model == nullptr) {
- return false;
- }
- return true;
- }));
-}
-
-TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) {
- std::unique_ptr<ActionsSuggestions> actions_suggestions =
- LoadHashGramTestModel();
- ASSERT_TRUE(actions_suggestions != nullptr);
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "hello",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(response.actions, testing::IsEmpty());
- }
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "where are you",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(
- response.actions,
- ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
- }
- {
- const ActionsSuggestionsResponse& response =
- actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "do you know johns number",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{},
- /*locales=*/"en"}}});
- EXPECT_THAT(
- response.actions,
- ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
- }
-}
-
-// Test class to expose token embedding methods for testing.
-class TestingMessageEmbedder : private ActionsSuggestions {
- public:
- explicit TestingMessageEmbedder(const ActionsModel* model);
-
- using ActionsSuggestions::EmbedAndFlattenTokens;
- using ActionsSuggestions::EmbedTokensPerMessage;
-
- protected:
- // EmbeddingExecutor that always returns features based on
- // the id of the sparse features.
- class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- const int dest_size) const override {
- TC3_CHECK_GE(dest_size, 1);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- return true;
- }
- };
-};
-
-TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) {
- model_ = model;
- const ActionsTokenFeatureProcessorOptions* options =
- model->feature_processor_options();
- feature_processor_.reset(
- new ActionsFeatureProcessor(options, /*unilib=*/nullptr));
- embedding_executor_.reset(new FakeEmbeddingExecutor());
- EXPECT_TRUE(
- EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
- EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
- EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
- token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
- EXPECT_EQ(token_embedding_size_, 1);
-}
-
-class EmbeddingTest : public testing::Test {
- protected:
- EmbeddingTest() {
- model_.feature_processor_options.reset(
- new ActionsTokenFeatureProcessorOptionsT);
- options_ = model_.feature_processor_options.get();
- options_->chargram_orders = {1};
- options_->num_buckets = 1000;
- options_->embedding_size = 1;
- options_->start_token_id = 0;
- options_->end_token_id = 1;
- options_->padding_token_id = 2;
- options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
- }
-
- TestingMessageEmbedder CreateTestingMessageEmbedder() {
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
- buffer_ = builder.ReleaseBufferPointer();
- return TestingMessageEmbedder(
- flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
- }
-
- flatbuffers::DetachedBuffer buffer_;
- ActionsModelT model_;
- ActionsTokenFeatureProcessorOptionsT* options_;
-};
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 3);
- EXPECT_EQ(embeddings.size(), 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
- options_->min_num_tokens_per_message = 5;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 5);
- EXPECT_EQ(embeddings.size(), 5);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
- options_->max_num_tokens_per_message = 2;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 2);
- EXPECT_EQ(embeddings.size(), 2);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
-}
-
-TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3)}};
- std::vector<float> embeddings;
- int max_num_tokens_per_message = 0;
-
- EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
- &max_num_tokens_per_message));
-
- EXPECT_EQ(max_num_tokens_per_message, 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 5);
- EXPECT_EQ(embeddings.size(), 5);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
- options_->min_num_total_tokens = 7;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 7);
- EXPECT_EQ(embeddings.size(), 7);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id));
- EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
- options_->max_num_total_tokens = 3;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 3);
- EXPECT_EQ(embeddings.size(), 3);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 9);
- EXPECT_EQ(embeddings.size(), 9);
- EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[1],
- testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[2],
- testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[6],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[7],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id));
-}
-
-TEST_F(EmbeddingTest,
- EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
- options_->max_num_total_tokens = 7;
- const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
- std::vector<std::vector<Token>> tokens = {
- {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
- {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
- std::vector<float> embeddings;
- int total_token_count = 0;
-
- EXPECT_TRUE(
- embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
-
- EXPECT_EQ(total_token_count, 7);
- EXPECT_EQ(embeddings.size(), 7);
- EXPECT_THAT(embeddings[0],
- testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id));
- EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id));
- EXPECT_THAT(embeddings[3],
- testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[4],
- testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[5],
- testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) %
- options_->num_buckets));
- EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
deleted file mode 100644
index 20891fa..0000000
--- a/actions/actions_jni.cc
+++ /dev/null
@@ -1,408 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// JNI wrapper for actions.
-
-#include "actions/actions_jni.h"
-
-#include <jni.h>
-#include <map>
-#include <type_traits>
-#include <vector>
-
-#include "actions/actions-suggestions.h"
-#include "annotator/annotator.h"
-#include "annotator/annotator_jni_common.h"
-#include "utils/base/integral_types.h"
-#include "utils/intents/intent-generator.h"
-#include "utils/intents/jni.h"
-#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/java/string_utils.h"
-#include "utils/memory/mmap.h"
-
-using libtextclassifier3::ActionsSuggestions;
-using libtextclassifier3::ActionsSuggestionsResponse;
-using libtextclassifier3::ActionSuggestion;
-using libtextclassifier3::ActionSuggestionOptions;
-using libtextclassifier3::Annotator;
-using libtextclassifier3::Conversation;
-using libtextclassifier3::IntentGenerator;
-using libtextclassifier3::ScopedLocalRef;
-using libtextclassifier3::ToStlString;
-
-// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
-// pointer from JNI. When using a standard ICU the pointer is not needed and the
-// objects are instantiated implicitly.
-#ifdef TC3_UNILIB_JAVAICU
-using libtextclassifier3::UniLib;
-#endif
-
-namespace libtextclassifier3 {
-
-namespace {
-
-// Cached state for model inference.
-// Keeps a jni cache, intent generator and model instance so that they don't
-// have to be recreated for each call.
-class ActionsSuggestionsJniContext {
- public:
- static ActionsSuggestionsJniContext* Create(
- const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
- std::unique_ptr<ActionsSuggestions> model) {
- if (jni_cache == nullptr || model == nullptr) {
- return nullptr;
- }
- std::unique_ptr<IntentGenerator> intent_generator =
- IntentGenerator::Create(model->model()->android_intent_options(),
- model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
-
- if (intent_generator == nullptr || template_handler == nullptr) {
- return nullptr;
- }
-
- return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
- std::move(intent_generator),
- std::move(template_handler));
- }
-
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
- return jni_cache_;
- }
-
- ActionsSuggestions* model() const { return model_.get(); }
-
- IntentGenerator* intent_generator() const { return intent_generator_.get(); }
-
- RemoteActionTemplatesHandler* template_handler() const {
- return template_handler_.get();
- }
-
- private:
- ActionsSuggestionsJniContext(
- const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
- std::unique_ptr<ActionsSuggestions> model,
- std::unique_ptr<IntentGenerator> intent_generator,
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
- : jni_cache_(jni_cache),
- model_(std::move(model)),
- intent_generator_(std::move(intent_generator)),
- template_handler_(std::move(template_handler)) {}
-
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
- std::unique_ptr<ActionsSuggestions> model_;
- std::unique_ptr<IntentGenerator> intent_generator_;
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
-};
-
-ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
- jobject joptions) {
- ActionSuggestionOptions options = ActionSuggestionOptions::Default();
- return options;
-}
-
-jobjectArray ActionSuggestionsToJObjectArray(
- JNIEnv* env, const ActionsSuggestionsJniContext* context,
- jobject app_context,
- const reflection::Schema* annotations_entity_data_schema,
- const std::vector<ActionSuggestion>& action_result,
- const Conversation& conversation, const jstring device_locales,
- const bool generate_intents) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$ActionSuggestion"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
- return nullptr;
- }
-
- const jmethodID result_class_constructor = env->GetMethodID(
- result_class.get(), "<init>",
- "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
- TC3_NAMED_VARIANT_CLASS_NAME_STR
- ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
- const jobjectArray results =
- env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
- for (int i = 0; i < action_result.size(); i++) {
- jobject extras = nullptr;
-
- const reflection::Schema* actions_entity_data_schema =
- context->model()->entity_data_schema();
- if (actions_entity_data_schema != nullptr &&
- !action_result[i].serialized_entity_data.empty()) {
- extras = context->template_handler()->EntityDataAsNamedVariantArray(
- actions_entity_data_schema, action_result[i].serialized_entity_data);
- }
-
- jbyteArray serialized_entity_data = nullptr;
- if (!action_result[i].serialized_entity_data.empty()) {
- serialized_entity_data =
- env->NewByteArray(action_result[i].serialized_entity_data.size());
- env->SetByteArrayRegion(
- serialized_entity_data, 0,
- action_result[i].serialized_entity_data.size(),
- reinterpret_cast<const jbyte*>(
- action_result[i].serialized_entity_data.data()));
- }
-
- jobject remote_action_templates_result = nullptr;
- if (generate_intents) {
- std::vector<RemoteActionTemplate> remote_action_templates;
- if (context->intent_generator()->GenerateIntents(
- device_locales, action_result[i], conversation, app_context,
- actions_entity_data_schema, annotations_entity_data_schema,
- &remote_action_templates)) {
- remote_action_templates_result =
- context->template_handler()->RemoteActionTemplatesToJObjectArray(
- remote_action_templates);
- }
- }
-
- ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
- action_result[i].response_text);
-
- ScopedLocalRef<jobject> result(env->NewObject(
- result_class.get(), result_class_constructor, reply.get(),
- env->NewStringUTF(action_result[i].type.c_str()),
- static_cast<jfloat>(action_result[i].score), extras,
- serialized_entity_data, remote_action_templates_result));
- env->SetObjectArrayElement(results, i, result.get());
- }
- return results;
-}
-
-ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
- if (!jmessage) {
- return {};
- }
-
- const ScopedLocalRef<jclass> message_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$ConversationMessage"),
- env);
- const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
- env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
- "Ljava/lang/String;");
- const std::pair<bool, int32> status_or_user_id =
- CallJniMethod0<int32>(env, jmessage, message_class.get(),
- &JNIEnv::CallIntMethod, "getUserId", "I");
- const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
- env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
- "getReferenceTimeMsUtc", "J");
- const std::pair<bool, jobject> status_or_reference_timezone =
- CallJniMethod0<jobject>(env, jmessage, message_class.get(),
- &JNIEnv::CallObjectMethod, "getReferenceTimezone",
- "Ljava/lang/String;");
- const std::pair<bool, jobject> status_or_detected_text_language_tags =
- CallJniMethod0<jobject>(
- env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
- "getDetectedTextLanguageTags", "Ljava/lang/String;");
- if (!status_or_text.first || !status_or_user_id.first ||
- !status_or_detected_text_language_tags.first ||
- !status_or_reference_time.first || !status_or_reference_timezone.first) {
- return {};
- }
-
- ConversationMessage message;
- message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
- message.user_id = status_or_user_id.second;
- message.reference_time_ms_utc = status_or_reference_time.second;
- message.reference_timezone = ToStlString(
- env, static_cast<jstring>(status_or_reference_timezone.second));
- message.detected_text_language_tags = ToStlString(
- env, static_cast<jstring>(status_or_detected_text_language_tags.second));
- return message;
-}
-
-Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
- if (!jconversation) {
- return {};
- }
-
- const ScopedLocalRef<jclass> conversation_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$Conversation"),
- env);
-
- const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
- env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
- "getConversationMessages",
- "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
-
- if (!status_or_messages.first) {
- return {};
- }
-
- const jobjectArray jmessages =
- reinterpret_cast<jobjectArray>(status_or_messages.second);
-
- const int size = env->GetArrayLength(jmessages);
-
- std::vector<ConversationMessage> messages;
- for (int i = 0; i < size; i++) {
- jobject jmessage = env->GetObjectArrayElement(jmessages, i);
- ConversationMessage message = FromJavaConversationMessage(env, jmessage);
- messages.push_back(message);
- }
- Conversation conversation;
- conversation.messages = messages;
- return conversation;
-}
-
-jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
- }
- const ActionsModel* model = libtextclassifier3::ViewActionsModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model || !model->locales()) {
- return env->NewStringUTF("");
- }
- return env->NewStringUTF(model->locales()->c_str());
-}
-
-jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return 0;
- }
- const ActionsModel* model = libtextclassifier3::ViewActionsModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model) {
- return 0;
- }
- return model->version();
-}
-
-jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
- }
- const ActionsModel* model = libtextclassifier3::ViewActionsModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model || !model->name()) {
- return env->NewStringUTF("");
- }
- return env->NewStringUTF(model->name()->c_str());
-}
-} // namespace
-} // namespace libtextclassifier3
-
-using libtextclassifier3::ActionsSuggestionsJniContext;
-using libtextclassifier3::ActionSuggestionsToJObjectArray;
-using libtextclassifier3::FromJavaActionSuggestionOptions;
-using libtextclassifier3::FromJavaConversation;
-
-TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
- libtextclassifier3::JniCache::Create(env);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
- }
-#ifdef TC3_UNILIB_JAVAICU
- return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache,
- ActionsSuggestions::FromFileDescriptor(
- fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
-#else
- return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
- preconditions)));
-#endif // TC3_UNILIB_JAVAICU
-}
-
-TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
- libtextclassifier3::JniCache::Create(env);
- const std::string path_str = ToStlString(env, path);
- std::string preconditions;
- if (serialized_preconditions != nullptr &&
- !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
- &preconditions)) {
- TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
- return 0;
- }
-#ifdef TC3_UNILIB_JAVAICU
- return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromPath(
- path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- preconditions)));
-#else
- return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
- jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
- preconditions)));
-#endif // TC3_UNILIB_JAVAICU
-}
-
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
-(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
- jlong annotatorPtr, jobject app_context, jstring device_locales,
- jboolean generate_intents) {
- if (!ptr) {
- return nullptr;
- }
- const Conversation conversation = FromJavaConversation(env, jconversation);
- const ActionSuggestionOptions options =
- FromJavaActionSuggestionOptions(env, joptions);
- const ActionsSuggestionsJniContext* context =
- reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
- const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
-
- const ActionsSuggestionsResponse response =
- context->model()->SuggestActions(conversation, annotator, options);
-
- const reflection::Schema* anntotations_entity_data_schema =
- annotator ? annotator->entity_data_schema() : nullptr;
- return ActionSuggestionsToJObjectArray(
- env, context, app_context, anntotations_entity_data_schema,
- response.actions, conversation, device_locales, generate_intents);
-}
-
-TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
-(JNIEnv* env, jobject clazz, jlong model_ptr) {
- const ActionsSuggestionsJniContext* context =
- reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
- delete context;
-}
-
-TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return libtextclassifier3::GetNameFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
-}
diff --git a/actions/actions_jni.h b/actions/actions_jni.h
deleted file mode 100644
index fe2b998..0000000
--- a/actions/actions_jni.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
-
-#include <jni.h>
-#include <string>
-#include "utils/java/jni-base.h"
-
-#ifndef TC3_ACTIONS_CLASS_NAME
-#define TC3_ACTIONS_CLASS_NAME ActionsSuggestionsModel
-#endif
-
-#define TC3_ACTIONS_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ACTIONS_CLASS_NAME)
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
-(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions);
-
-TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
-(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
-
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
-(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
- jlong annotatorPtr, jobject app_context, jstring device_locales,
- jboolean generate_intents);
-
-TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
-(JNIEnv* env, jobject thiz, jlong ptr);
-
-TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jint fd);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
deleted file mode 100755
index 42c7d88..0000000
--- a/actions/actions_model.fbs
+++ /dev/null
@@ -1,480 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "annotator/model.fbs";
-include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
-include "utils/intents/intent-config.fbs";
-include "utils/resources.fbs";
-include "utils/tokenizer.fbs";
-include "utils/zlib/buffer.fbs";
-
-file_identifier "TC3A";
-
-// TensorFlow Lite model for suggesting actions.
-namespace libtextclassifier3;
-table TensorflowLiteModelSpec {
- // TensorFlow Lite model for suggesting actions.
- tflite_model:[ubyte] (force_align: 16);
-
- // Input specification.
- // (num messages,) int32 tensor, the user id per message.
- input_user_id:int = 0;
-
- // (num messages,) string tensor, each message of the conversation.
- input_context:int = 1;
-
- // int, the number of messages in the conversation.
- input_context_length:int = 2;
-
- // (num messages,) float tensor, the time difference in seconds of the
- // messages in the conversation.
- input_time_diffs:int = 3;
-
- // int, the number of smart replies to produce.
- input_num_suggestions:int = 4;
-
- // float, the output diversification distance parameter.
- input_diversification_distance:int = -1;
-
- // float, the empirical probability factor parameter.
- input_empirical_probability_factor:int = -1;
-
- // float, the confidence threshold.
- input_confidence_threshold:int = -1;
-
- // Input port for hashed and embedded tokens, a (num messages, max tokens,
- // embedding size) float tensor specifying the embeddings of each token of
- // each message in the conversation.
- input_token_embeddings:int = -1;
-
- // Input port for the number of tokens per message.
- // (num messages) int32 tensor specifying the number of tokens in each message
- // in the conversation.
- input_num_tokens:int = -1;
-
- // Output specification.
- output_replies:int = 0;
-
- output_replies_scores:int = 1;
- output_sensitive_topic_score:int = 3;
- output_triggering_score:int = 4;
- output_actions_scores:int = 5;
-
- // Model setup.
- // When true, the inputs are resized to the concrete input sizes before
- // inference otherwise, it's assumed that the model has the correct input
- // shapes set.
- resize_inputs:bool = false;
-
- // Input port for the hashed, embedded and flattened/concatenated tokens.
- // A (max tokens, embedding_size) float tensor specifying the embeddings of
- // each token.
- input_flattened_token_embeddings:int = -1;
-}
-
-// Configuration for the tokenizer.
-namespace libtextclassifier3;
-table ActionsTokenizerOptions {
- type:TokenizationType = INTERNAL_TOKENIZER;
-
- // If true, white space tokens will be kept when using the icu tokenizer.
- icu_preserve_whitespace_tokens:bool = false;
-
- // Codepoint ranges that determine what role the different codepoints play
- // during tokenized. The ranges must not overlap.
- tokenization_codepoint_config:[TokenizationCodepointRange];
-
- // A set of codepoint ranges to use in the mixed tokenization mode to identify
- // stretches of tokens to re-tokenize using the internal tokenizer.
- internal_tokenizer_codepoint_ranges:[CodepointRange];
-
- // If true, tokens will be also split when the codepoint's script_id changes
- // as defined in TokenizationCodepointRange.
- tokenize_on_script_change:bool = false;
-}
-
-// Configuration for the feature processor.
-namespace libtextclassifier3;
-table ActionsTokenFeatureProcessorOptions {
- // Tokenizer options.
- tokenizer_options:ActionsTokenizerOptions;
-
- // Serialized TensorFlow Lite model with weights for the token embeddings.
- embedding_model:[ubyte] (force_align: 16);
-
- // Size of the embedding.
- embedding_size:int = -1;
-
- // Number of bits for quantization for embeddings.
- embedding_quantization_bits:int = 8;
-
- // Number of buckets used for hashing charactergrams.
- num_buckets:int = -1;
-
- // Orders of charactergrams to extract, e.g. 2 means character bigrams, 3
- // character trigrams etc.
- chargram_orders:[int];
-
- // Whether to extract the token case feature.
- extract_case_feature:bool;
-
- // If true, will use the unicode-aware functionality for extracting features.
- unicode_aware_features:bool;
-
- // Regexp features to extract.
- regexp_features:[string];
-
- // Whether to remap digits to a single number.
- remap_digits:bool;
-
- // Whether to lowercase all tokens.
- lowercase_tokens:bool;
-
- // Maximum length of a word.
- max_token_length:int = 20;
-
- // The `max_num_tokens_per_message` and `min_num_tokens_per_message` are
- // applied when tokens are embedded per message.
- // If set and the number of tokens of a message is bigger than this limit,
- // tokens at the beginning of the message are dropped to fit the limit.
- max_num_tokens_per_message:int = -1;
-
- // If set, the tokens of each message will be padded to this fixed number of
- // tokens.
- min_num_tokens_per_message:int = -1;
-
- // If set and the total number of concatenated tokens is bigger than this
- // limit, tokens at the start of the conversation are dropped.
- max_num_total_tokens:int = -1;
-
- // If set and the total number of concatenaed tokens is smaller than this
- // limit, the conversation is padded with padding tokens.
- min_num_total_tokens:int = -1;
-
- // Id that is used as encoding of the padding token.
- padding_token_id:int = 0;
-
- // Id that is used as encoding of the start of message token.
- start_token_id:int = 1;
-
- // Id that is used as encoding of the end of message token.
- end_token_id:int = 2;
-}
-
-// N-Gram based linear regression model.
-namespace libtextclassifier3;
-table NGramLinearRegressionModel {
- // A flat list of all the hashed n-grams concatenated back to back. Elements
- // should only ever be accessed via the offset table below.
- hashed_ngram_tokens:[uint];
-
- // Offsets to the start of the n-grams in hashed_ngram_tokens. The last
- // element in this array is the length of hashed_ngrams to make it easier to
- // compute n-gram lengths.
- ngram_start_offsets:[ushort];
-
- // Weights of the n-grams.
- ngram_weights:[float];
-
- // The default weight assigned to n-grams that weren't matched.
- default_token_weight:float;
-
- // Maximum n-gram length to consider when calculating the denominatior.
- // This should usually be the same as max_ngram_length but can diverge
- // if additional (longer) n-grams are added to a model as part of a minor
- // update.
- max_denom_ngram_length:int;
-
- // If non-zero, the order of the skip-gram to match.
- max_skips:int;
-
- // The threshold above which the model output is considered positive.
- threshold:float;
-
- // Model specific tokenizer options.
- // If not specified, will reuse the feature processor tokenizer.
- tokenizer_options:ActionsTokenizerOptions;
-}
-
-namespace libtextclassifier3;
-table TriggeringPreconditions {
- // Lower bound thresholds for the smart reply model prediction output.
- min_smart_reply_triggering_score:float;
-
- // Maximum sensitive score for which actions and smart replies are shown.
- max_sensitive_topic_score:float = 1;
-
- // Whether to suppress all model output when a conversation is classified as
- // sensitive.
- suppress_on_sensitive_topic:bool = true;
-
- // Thresholds on the model prediction input.
- // The minimal length of input to consider for prediction.
- min_input_length:int = 0;
-
- // The maximal length of input to consider for prediciton, -1 if unbounded.
- max_input_length:int = -1;
-
- // Minimal fraction of messages in the input conversation that need to match
- // a locale that the model can handle.
- min_locale_match_fraction:float = 0.75;
-
- handle_missing_locale_as_supported:bool = false;
- handle_unknown_locale_as_supported:bool = false;
-
- // Filter input with low-confidence triggers.
- suppress_on_low_confidence_input:bool = true;
-
- // Same as low_confidence_rules in ActionsModel.
- // NOTE: Only fill this when the TriggeringPreconditions are pushed separately
- // as a flag value (i.e. as overlay).
- low_confidence_rules:RulesModel;
-
- // Smart reply thresholds.
- diversification_distance_threshold:float = 0;
-
- confidence_threshold:float = 0;
- empirical_probability_factor:float = 0;
- min_reply_score_threshold:float = 0;
-}
-
-namespace libtextclassifier3;
-table ActionSuggestionSpec {
- // Type of the action suggestion.
- type:string;
-
- // Text of a smart reply action.
- response_text:string;
-
- // Score.
- score:float;
-
- // Serialized entity information.
- serialized_entity_data:string;
-
- // Priority score used for internal conflict resolution.
- priority_score:float = 0;
-}
-
-// Options to specify triggering behaviour per action class.
-namespace libtextclassifier3;
-table ActionTypeOptions {
- // The name of the predicted action.
- name:string;
-
- // Triggering behaviour.
- // Whether the action class is considered in the model output or not.
- enabled:bool = true;
-
- // Minimal output score threshold.
- min_triggering_score:float = 0;
-
- // The action to trigger.
- action:ActionSuggestionSpec;
-}
-
-namespace libtextclassifier3.AnnotationActionsSpec_;
-table AnnotationMapping {
- // The annotation collection.
- annotation_collection:string;
-
- // The action name to use.
- action:ActionSuggestionSpec;
-
- // Whether to use the score of the annotation as the action score.
- use_annotation_score:bool = true;
-
- // Minimum threshold for the annotation score for filtering.
- min_annotation_score:float;
-
- // If set, the text of the annotation will be used to set a field in the
- // action entity data.
- entity_field:FlatbufferFieldPath;
-}
-
-// Configuration for actions based on annotatations.
-namespace libtextclassifier3;
-table AnnotationActionsSpec {
- annotation_mapping:[AnnotationActionsSpec_.AnnotationMapping];
-
- // Whether to deduplicate annotations by type and text prior to generating
- // actions.
- deduplicate_annotations:bool = true;
-
- // Annotation usecase to specify for text annotation.
- annotation_usecase:AnnotationUsecase = ANNOTATION_USECASE_SMART;
-
- // Maximum number of recent messages to consider from any person.
- // We consider at most `max_history_from_any_person` many recent messages if
- // they were received from different users or at most the maximum of this and
- // `max_history_from_last_person` if they are all from the same user.
- max_history_from_any_person:int = 1;
-
- // Maximum number of recent messages to consider from the last person.
- max_history_from_last_person:int = 1;
-
- // Whether to include messages from the local user.
- include_local_user_messages:bool = false;
-
- // Whether to only consider messages up to the last one sent by the local
- // user.
- only_until_last_sent:bool = true;
-
- // If true, annotator would populare serialized_entity_data in the results.
- is_serialized_entity_data_enabled:bool = true;
-}
-
-// Ranking options.
-namespace libtextclassifier3;
-table RankingOptions {
- // When true, actions suggestions are deduplicated by `type`, `response_text`
- // and associated annotations, keeping the higher scoring actions.
- deduplicate_suggestions:bool = true;
-
- // When true, actions are deduplicated by the span they are referring to.
- deduplicate_suggestions_by_span:bool = true;
-
- // Optional script to run for ranking and filtering the action suggestions.
- // The following global variables are available to the script:
- // * input: (optionally deduplicated) action suggestions, via the `actions`
- // global
- // * output: indices of the actions to keep in the provided order.
- lua_ranking_script:string;
-
- compressed_lua_ranking_script:CompressedBuffer;
-
- // If true, suppresses smart replies if other smart actions are suggested.
- suppress_smart_replies_with_actions:bool = false;
-
- // If true, keep actions from the same entities together for ranking.
- group_by_annotations:bool = true;
-}
-
-// Entity data to set from capturing groups.
-namespace libtextclassifier3.RulesModel_.Rule_.RuleActionSpec_;
-table RuleCapturingGroup {
- // The id of group.
- group_id:int;
-
- // If set, the text of the capturing group will be used to set a field
- // in the action entity data.
- entity_field:FlatbufferFieldPath;
-
- // If set, the capturing group will be used to create a text annotation
- // with the given name and type.
- annotation_type:string;
-
- annotation_name:string;
-
- // If set, the capturing group text will be used to create a text
- // reply.
- text_reply:ActionSuggestionSpec;
-}
-
-// The actions to produce upon triggering.
-namespace libtextclassifier3.RulesModel_.Rule_;
-table RuleActionSpec {
- // The action.
- action:ActionSuggestionSpec;
-
- capturing_group:[RuleActionSpec_.RuleCapturingGroup];
-}
-
-// List of regular expression matchers.
-namespace libtextclassifier3.RulesModel_;
-table Rule {
- // The regular expression pattern.
- pattern:string;
-
- compressed_pattern:CompressedBuffer;
- actions:[Rule_.RuleActionSpec];
-
- // Patterns for post-checking the outputs.
- output_pattern:string;
-
- compressed_output_pattern:CompressedBuffer;
-}
-
-// Rule based actions.
-namespace libtextclassifier3;
-table RulesModel {
- rule:[RulesModel_.Rule];
-
- // If true, will compile the regexes only on first use.
- lazy_regex_compilation:bool = true;
-}
-
-namespace libtextclassifier3;
-table ActionsModel {
- // Comma-separated list of locales supported by the model as BCP 47 tags.
- locales:string;
-
- // Version of the actions model.
- version:int;
-
- // A name for the model that can be used e.g. for logging.
- name:string;
-
- tflite_model_spec:TensorflowLiteModelSpec;
-
- // Output classes.
- smart_reply_action_type:string;
-
- action_type:[ActionTypeOptions];
-
- // Triggering conditions of the model.
- preconditions:TriggeringPreconditions;
-
- // Default number of smart reply predictions.
- num_smart_replies:int = 3;
-
- // Length of message history to consider, -1 if unbounded.
- max_conversation_history_length:int = 1;
-
- // Configuration for mapping annotations to action suggestions.
- annotation_actions_spec:AnnotationActionsSpec;
-
- // Configuration for rules.
- rules:RulesModel;
-
- // Configuration for intent generation on Android.
- android_intent_options:IntentFactoryModel;
-
- // Model resources.
- resources:ResourcePool;
-
- // Schema data for handling entity data.
- actions_entity_data_schema:[ubyte];
-
- // Action ranking options.
- ranking_options:RankingOptions;
-
- // Lua based actions.
- lua_actions_script:string;
-
- compressed_lua_actions_script:CompressedBuffer;
-
- // Low confidence classifiers.
- low_confidence_rules:RulesModel;
-
- low_confidence_ngram_model:NGramLinearRegressionModel;
-
- // Feature processor options.
- feature_processor_options:ActionsTokenFeatureProcessorOptions;
-}
-
-root_type libtextclassifier3.ActionsModel;
diff --git a/actions/feature-processor.cc b/actions/feature-processor.cc
deleted file mode 100644
index d0b2072..0000000
--- a/actions/feature-processor.cc
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/feature-processor.h"
-
-namespace libtextclassifier3 {
-namespace {
-TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
- const ActionsTokenFeatureProcessorOptions* const options) {
- TokenFeatureExtractorOptions extractor_options;
- extractor_options.num_buckets = options->num_buckets();
- if (options->chargram_orders() != nullptr) {
- for (int order : *options->chargram_orders()) {
- extractor_options.chargram_orders.push_back(order);
- }
- }
- extractor_options.max_word_length = options->max_token_length();
- extractor_options.extract_case_feature = options->extract_case_feature();
- extractor_options.unicode_aware_features = options->unicode_aware_features();
- extractor_options.extract_selection_mask_feature = false;
- if (options->regexp_features() != nullptr) {
- for (const auto& regexp_feauture : *options->regexp_features()) {
- extractor_options.regexp_features.push_back(regexp_feauture->str());
- }
- }
- extractor_options.remap_digits = options->remap_digits();
- extractor_options.lowercase_tokens = options->lowercase_tokens();
- return extractor_options;
-}
-} // namespace
-
-std::unique_ptr<Tokenizer> CreateTokenizer(
- const ActionsTokenizerOptions* options, const UniLib* unilib) {
- std::vector<const TokenizationCodepointRange*> codepoint_config;
- if (options->tokenization_codepoint_config() != nullptr) {
- codepoint_config.insert(codepoint_config.end(),
- options->tokenization_codepoint_config()->begin(),
- options->tokenization_codepoint_config()->end());
- }
- std::vector<const CodepointRange*> internal_codepoint_config;
- if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
- internal_codepoint_config.insert(
- internal_codepoint_config.end(),
- options->internal_tokenizer_codepoint_ranges()->begin(),
- options->internal_tokenizer_codepoint_ranges()->end());
- }
- const bool tokenize_on_script_change =
- options->tokenization_codepoint_config() != nullptr &&
- options->tokenize_on_script_change();
- return std::unique_ptr<Tokenizer>(new Tokenizer(
- options->type(), unilib, codepoint_config, internal_codepoint_config,
- tokenize_on_script_change, options->icu_preserve_whitespace_tokens()));
-}
-
-ActionsFeatureProcessor::ActionsFeatureProcessor(
- const ActionsTokenFeatureProcessorOptions* options, const UniLib* unilib)
- : options_(options),
- tokenizer_(CreateTokenizer(options->tokenizer_options(), unilib)),
- token_feature_extractor_(BuildTokenFeatureExtractorOptions(options),
- *unilib) {}
-
-int ActionsFeatureProcessor::GetTokenEmbeddingSize() const {
- return options_->embedding_size() +
- token_feature_extractor_.DenseFeaturesCount();
-}
-
-bool ActionsFeatureProcessor::AppendFeatures(
- const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const {
- // Embed the sparse features, appending them directly to the output.
- const int embedding_size = options_->embedding_size();
- output_features->resize(output_features->size() + embedding_size);
- float* output_features_end =
- output_features->data() + output_features->size();
- if (!embedding_executor->AddEmbedding(
- TensorView<int>(sparse_features.data(),
- {static_cast<int>(sparse_features.size())}),
- /*dest=*/output_features_end - embedding_size,
- /*dest_size=*/embedding_size)) {
- TC3_LOG(ERROR) << "Could not embed token's sparse features.";
- return false;
- }
-
- // Append the dense features to the output.
- output_features->insert(output_features->end(), dense_features.begin(),
- dense_features.end());
- return true;
-}
-
-bool ActionsFeatureProcessor::AppendTokenFeatures(
- const Token& token, const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const {
- // Extract the sparse and dense features.
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- if (!token_feature_extractor_.Extract(token, /*(unused) is_in_span=*/false,
- &sparse_features, &dense_features)) {
- TC3_LOG(ERROR) << "Could not extract token's features.";
- return false;
- }
- return AppendFeatures(sparse_features, dense_features, embedding_executor,
- output_features);
-}
-
-bool ActionsFeatureProcessor::AppendTokenFeatures(
- const std::vector<Token>& tokens,
- const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const {
- for (const Token& token : tokens) {
- if (!AppendTokenFeatures(token, embedding_executor, output_features)) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/feature-processor.h b/actions/feature-processor.h
deleted file mode 100644
index e34ccff..0000000
--- a/actions/feature-processor.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
-
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "annotator/model-executor.h"
-#include "annotator/types.h"
-#include "utils/token-feature-extractor.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-// Create tokenizer from options.
-std::unique_ptr<Tokenizer> CreateTokenizer(
- const ActionsTokenizerOptions* options, const UniLib* unilib);
-
-// Feature processor for the actions suggestions model.
-class ActionsFeatureProcessor {
- public:
- ActionsFeatureProcessor(const ActionsTokenFeatureProcessorOptions* options,
- const UniLib* unilib);
-
- // Embeds and appends features to the output vector.
- bool AppendFeatures(const std::vector<int>& sparse_features,
- const std::vector<float>& dense_features,
- const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const;
-
- // Extracts the features of a token and appends them to the output vector.
- bool AppendTokenFeatures(const Token& token,
- const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const;
-
- // Extracts the features of a vector of tokens and appends each to the output
- // vector.
- bool AppendTokenFeatures(const std::vector<Token>& tokens,
- const EmbeddingExecutor* embedding_executor,
- std::vector<float>* output_features) const;
-
- int GetTokenEmbeddingSize() const;
-
- const Tokenizer* tokenizer() const { return tokenizer_.get(); }
-
- private:
- const ActionsTokenFeatureProcessorOptions* options_;
- const std::unique_ptr<Tokenizer> tokenizer_;
- const TokenFeatureExtractor token_feature_extractor_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
diff --git a/actions/feature-processor_test.cc b/actions/feature-processor_test.cc
deleted file mode 100644
index 0a1e3ac..0000000
--- a/actions/feature-processor_test.cc
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/feature-processor.h"
-
-#include "actions/actions_model_generated.h"
-#include "annotator/model-executor.h"
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::FloatEq;
-
-// EmbeddingExecutor that always returns features based on
-// the id of the sparse features.
-class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- const int dest_size) const override {
- TC3_CHECK_GE(dest_size, 4);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- dest[1] = sparse_features.data()[0];
- dest[2] = -sparse_features.data()[0];
- dest[3] = -sparse_features.data()[0];
- return true;
- }
-
- private:
- std::vector<float> storage_;
-};
-
-class FeatureProcessorTest : public ::testing::Test {
- protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
-
- flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
- ActionsTokenFeatureProcessorOptionsT* options) const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
- return builder.Release();
- }
-
- FakeEmbeddingExecutor embedding_executor_;
- UniLib unilib_;
-};
-
-TEST_F(FeatureProcessorTest, TokenEmbeddings) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- Token token("aaa", 0, 3);
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
- &token_features));
- EXPECT_EQ(token_features.size(), 4);
-}
-
-TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.extract_case_feature = true;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- Token token("Aaa", 0, 3);
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
- &token_features));
- EXPECT_EQ(token_features.size(), 5);
- EXPECT_THAT(token_features[4], FloatEq(1.0));
-}
-
-TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
- ActionsTokenFeatureProcessorOptionsT options;
- options.embedding_size = 4;
- options.extract_case_feature = true;
- options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
-
- flatbuffers::DetachedBuffer options_fb =
- PackFeatureProcessorOptions(&options);
- ActionsFeatureProcessor feature_processor(
- flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
- options_fb.data()),
- &unilib_);
-
- const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
- Token("Cccc", 8, 12)};
- std::vector<float> token_features;
- EXPECT_TRUE(feature_processor.AppendTokenFeatures(
- tokens, &embedding_executor_, &token_features));
- EXPECT_EQ(token_features.size(), 15);
- EXPECT_THAT(token_features[4], FloatEq(1.0));
- EXPECT_THAT(token_features[9], FloatEq(-1.0));
- EXPECT_THAT(token_features[14], FloatEq(1.0));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/actions/lua-actions.cc b/actions/lua-actions.cc
deleted file mode 100644
index 5bbba98..0000000
--- a/actions/lua-actions.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-actions.h"
-#include "utils/base/logging.h"
-#include "utils/lua-utils.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lualib.h"
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-namespace {
-TensorView<float> GetTensorViewForOutput(
- const TfLiteModelExecutor* model_executor,
- const tflite::Interpreter* interpreter, int output) {
- if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
- return TensorView<float>::Invalid();
- }
- return model_executor->OutputView<float>(output, interpreter);
-}
-} // namespace
-
-int LuaActionsSuggestions::TensorViewIterator::Item(
- const TensorView<float>* tensor, const int64 index,
- lua_State* state) const {
- lua_pushnumber(state, tensor->data()[index]);
- return 1;
-}
-
-std::unique_ptr<LuaActionsSuggestions>
-LuaActionsSuggestions::CreateLuaActionsSuggestions(
- const std::string& snippet, const Conversation& conversation,
- const TfLiteModelExecutor* model_executor,
- const TensorflowLiteModelSpec* model_spec,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema) {
- auto lua_actions =
- std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
- snippet, conversation, model_executor, model_spec, interpreter,
- actions_entity_data_schema, annotations_entity_data_schema));
- if (!lua_actions->Initialize()) {
- TC3_LOG(ERROR)
- << "Could not initialize lua environment for actions suggestions.";
- return nullptr;
- }
- return lua_actions;
-}
-
-LuaActionsSuggestions::LuaActionsSuggestions(
- const std::string& snippet, const Conversation& conversation,
- const TfLiteModelExecutor* model_executor,
- const TensorflowLiteModelSpec* model_spec,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema)
- : snippet_(snippet),
- conversation_(conversation),
- conversation_iterator_(annotations_entity_data_schema, this),
- actions_scores_(
- model_spec == nullptr
- ? TensorView<float>::Invalid()
- : GetTensorViewForOutput(model_executor, interpreter,
- model_spec->output_actions_scores())),
- smart_reply_scores_(
- model_spec == nullptr
- ? TensorView<float>::Invalid()
- : GetTensorViewForOutput(model_executor, interpreter,
- model_spec->output_replies_scores())),
- sensitivity_score_(model_spec == nullptr
- ? TensorView<float>::Invalid()
- : GetTensorViewForOutput(
- model_executor, interpreter,
- model_spec->output_sensitive_topic_score())),
- triggering_score_(
- model_spec == nullptr
- ? TensorView<float>::Invalid()
- : GetTensorViewForOutput(model_executor, interpreter,
- model_spec->output_triggering_score())),
- actions_entity_data_schema_(actions_entity_data_schema),
- annotations_entity_data_schema_(annotations_entity_data_schema) {}
-
-bool LuaActionsSuggestions::Initialize() {
- return RunProtected([this] {
- LoadDefaultLibraries();
-
- // Expose conversation message stream.
- conversation_iterator_.NewIterator("messages",
- &conversation_.messages, state_);
- lua_setglobal(state_, "messages");
-
- // Expose ML model output.
- lua_newtable(state_);
- {
- tensor_iterator_.NewIterator("actions_scores", &actions_scores_,
- state_);
- lua_setfield(state_, /*idx=*/-2, "actions_scores");
- }
- {
- tensor_iterator_.NewIterator("reply_scores", &smart_reply_scores_,
- state_);
- lua_setfield(state_, /*idx=*/-2, "reply_scores");
- }
- {
- tensor_iterator_.NewIterator("sensitivity", &sensitivity_score_,
- state_);
- lua_setfield(state_, /*idx=*/-2, "sensitivity");
- }
- {
- tensor_iterator_.NewIterator("triggering_score",
- &triggering_score_, state_);
- lua_setfield(state_, /*idx=*/-2, "triggering_score");
- }
- lua_setglobal(state_, "model");
-
- return LUA_OK;
- }) == LUA_OK;
-}
-
-bool LuaActionsSuggestions::SuggestActions(
- std::vector<ActionSuggestion>* actions) {
- if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
- /*name=*/nullptr) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
- return false;
- }
-
- if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
- return false;
- }
-
- if (RunProtected(
- [this, actions] {
- return ReadActions(actions_entity_data_schema_,
- annotations_entity_data_schema_, this, actions);
- },
- /*num_args=*/1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not read lua result.";
- return false;
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/lua-actions.h b/actions/lua-actions.h
deleted file mode 100644
index 2f82653..0000000
--- a/actions/lua-actions.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
-
-#include "actions/actions_model_generated.h"
-#include "actions/lua-utils.h"
-#include "actions/types.h"
-#include "utils/lua-utils.h"
-#include "utils/tensor-view.h"
-#include "utils/tflite-model-executor.h"
-
-namespace libtextclassifier3 {
-
-// Lua backed actions suggestions.
-class LuaActionsSuggestions : public LuaEnvironment {
- public:
- static std::unique_ptr<LuaActionsSuggestions> CreateLuaActionsSuggestions(
- const std::string& snippet, const Conversation& conversation,
- const TfLiteModelExecutor* model_executor,
- const TensorflowLiteModelSpec* model_spec,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema);
-
- bool SuggestActions(std::vector<ActionSuggestion>* actions);
-
- private:
- // Model tensor lua iterator.
- class TensorViewIterator
- : public LuaEnvironment::ItemIterator<TensorView<float>> {
- public:
- explicit TensorViewIterator() {}
- int Item(const TensorView<float>* tensor, const int64 index,
- lua_State* state) const override;
- };
-
- LuaActionsSuggestions(
- const std::string& snippet, const Conversation& conversation,
- const TfLiteModelExecutor* model_executor,
- const TensorflowLiteModelSpec* model_spec,
- const tflite::Interpreter* interpreter,
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema);
-
- bool Initialize();
-
- const std::string& snippet_;
- const Conversation& conversation_;
- ConversationIterator conversation_iterator_;
- TensorViewIterator tensor_iterator_;
- TensorView<float> actions_scores_;
- TensorView<float> smart_reply_scores_;
- TensorView<float> sensitivity_score_;
- TensorView<float> triggering_score_;
- const reflection::Schema* actions_entity_data_schema_;
- const reflection::Schema* annotations_entity_data_schema_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
diff --git a/actions/lua-actions_test.cc b/actions/lua-actions_test.cc
deleted file mode 100644
index f7b9cd5..0000000
--- a/actions/lua-actions_test.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-actions.h"
-
-#include <map>
-#include <string>
-
-#include "actions/test_utils.h"
-#include "actions/types.h"
-#include "utils/tflite-model-executor.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-MATCHER_P2(IsAction, type, response_text, "") {
- return testing::Value(arg.type, type) &&
- testing::Value(arg.response_text, response_text);
-}
-
-MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
-
-TEST(LuaActions, SimpleAction) {
- Conversation conversation;
- const std::string test_snippet = R"(
- return {{ type = "test_action" }}
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions,
- testing::ElementsAreArray({IsActionType("test_action")}));
-}
-
-TEST(LuaActions, ConversationActions) {
- Conversation conversation;
- conversation.messages.push_back({/*user_id=*/0, "hello there!"});
- conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
- const std::string test_snippet = R"(
- local actions = {}
- for i, message in pairs(messages) do
- if i < #messages then
- if message.text == "hello there!" and
- messages[i+1].text == "general kenobi!" then
- table.insert(actions, {
- type = "text_reply",
- response_text = "you are a bold one!"
- })
- end
- if message.text == "i am the senate!" and
- messages[i+1].text == "not yet!" then
- table.insert(actions, {
- type = "text_reply",
- response_text = "it's treason then"
- })
- end
- end
- end
- return actions;
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, testing::ElementsAreArray(
- {IsAction("text_reply", "you are a bold one!")}));
-}
-
-TEST(LuaActions, SimpleModelAction) {
- Conversation conversation;
- const std::string test_snippet = R"(
- if #model.actions_scores == 0 then
- return {{ type = "test_action" }}
- end
- return {}
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions,
- testing::ElementsAreArray({IsActionType("test_action")}));
-}
-
-TEST(LuaActions, AnnotationActions) {
- AnnotatedSpan annotation;
- annotation.span = {11, 15};
- annotation.classification = {ClassificationResult("address", 1.0)};
- Conversation conversation = {{{/*user_id=*/1, "are you at home?",
- /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"Europe/Zurich",
- /*annotations=*/{annotation},
- /*locales=*/"en"}}};
- const std::string test_snippet = R"(
- local actions = {}
- local last_message = messages[#messages]
- for i, annotation in pairs(last_message.annotation) do
- if #annotation.classification > 0 then
- if annotation.classification[1].collection == "address" then
- local text = string.sub(last_message.text,
- annotation.span["begin"] + 1,
- annotation.span["end"])
- table.insert(actions, {
- type = "text_reply",
- response_text = "i am at " .. text,
- annotation = {{
- name = "location",
- span = {
- text = text
- },
- entity = annotation.classification[1]
- }},
- })
- end
- end
- end
- return actions;
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/nullptr,
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, testing::ElementsAreArray(
- {IsAction("text_reply", "i am at home")}));
- EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
-}
-
-TEST(LuaActions, EntityData) {
- std::string test_schema = TestEntityDataSchema();
- Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
- const std::string test_snippet = R"(
- return {{
- type = "test",
- entity = {
- greeting = "hello",
- location = "there",
- person = "Kenobi",
- },
- }};
- )";
- std::vector<ActionSuggestion> actions;
- EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
- test_snippet, conversation,
- /*model_executor=*/nullptr,
- /*model_spec=*/nullptr,
- /*interpreter=*/nullptr,
- /*actions_entity_data_schema=*/
- flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
- /*annotations_entity_data_schema=*/nullptr)
- ->SuggestActions(&actions));
- EXPECT_THAT(actions, testing::SizeIs(1));
- EXPECT_EQ("test", actions.front().type);
- const flatbuffers::Table* entity =
- flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
- actions.front().serialized_entity_data.data()));
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
- "hello");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
- "there");
- EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
- "Kenobi");
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/actions/lua-ranker.cc b/actions/lua-ranker.cc
deleted file mode 100644
index a185b07..0000000
--- a/actions/lua-ranker.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-ranker.h"
-#include "utils/base/logging.h"
-#include "utils/lua-utils.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lualib.h"
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-
-std::unique_ptr<ActionsSuggestionsLuaRanker>
-ActionsSuggestionsLuaRanker::Create(
- const Conversation& conversation, const std::string& ranker_code,
- const reflection::Schema* entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- ActionsSuggestionsResponse* response) {
- auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
- new ActionsSuggestionsLuaRanker(
- conversation, ranker_code, entity_data_schema,
- annotations_entity_data_schema, response));
- if (!ranker->Initialize()) {
- TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
- return nullptr;
- }
- return ranker;
-}
-
-bool ActionsSuggestionsLuaRanker::Initialize() {
- return RunProtected([this] {
- LoadDefaultLibraries();
-
- // Expose generated actions.
- actions_iterator_.NewIterator("actions", &response_->actions,
- state_);
- lua_setglobal(state_, "actions");
-
- // Expose conversation message stream.
- conversation_iterator_.NewIterator("messages",
- &conversation_.messages, state_);
- lua_setglobal(state_, "messages");
- return LUA_OK;
- }) == LUA_OK;
-}
-
-int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected actions table, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_pop(state_, 1);
- lua_error(state_);
- return LUA_ERRRUN;
- }
- std::vector<ActionSuggestion> ranked_actions;
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- const int action_id =
- static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1;
- lua_pop(state_, 1);
- if (action_id < 0 || action_id >= response_->actions.size()) {
- TC3_LOG(ERROR) << "Invalid action index: " << action_id;
- lua_error(state_);
- return LUA_ERRRUN;
- }
- ranked_actions.push_back(response_->actions[action_id]);
- }
- lua_pop(state_, 1);
- response_->actions = ranked_actions;
- return LUA_OK;
-}
-
-bool ActionsSuggestionsLuaRanker::RankActions() {
- if (response_->actions.empty()) {
- // Nothing to do.
- return true;
- }
-
- if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
- /*name=*/nullptr) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
- return false;
- }
-
- if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not run ranking snippet.";
- return false;
- }
-
- if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
- LUA_OK) {
- TC3_LOG(ERROR) << "Could not read lua result.";
- return false;
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/lua-ranker.h b/actions/lua-ranker.h
deleted file mode 100644
index 687f412..0000000
--- a/actions/lua-ranker.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
-
-#include <memory>
-#include <string>
-
-#include "actions/lua-utils.h"
-#include "actions/types.h"
-#include "utils/lua-utils.h"
-
-namespace libtextclassifier3 {
-
-// Lua backed action suggestion ranking.
-class ActionsSuggestionsLuaRanker : public LuaEnvironment {
- public:
- static std::unique_ptr<ActionsSuggestionsLuaRanker> Create(
- const Conversation& conversation, const std::string& ranker_code,
- const reflection::Schema* entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- ActionsSuggestionsResponse* response);
-
- bool RankActions();
-
- private:
- explicit ActionsSuggestionsLuaRanker(
- const Conversation& conversation, const std::string& ranker_code,
- const reflection::Schema* entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- ActionsSuggestionsResponse* response)
- : conversation_(conversation),
- ranker_code_(ranker_code),
- response_(response),
- actions_iterator_(entity_data_schema, annotations_entity_data_schema,
- this),
- conversation_iterator_(annotations_entity_data_schema, this) {}
-
- bool Initialize();
-
- // Reads ranking results from the lua stack.
- int ReadActionsRanking();
-
- const Conversation& conversation_;
- const std::string& ranker_code_;
- ActionsSuggestionsResponse* response_;
- const ActionsIterator actions_iterator_;
- const ConversationIterator conversation_iterator_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
diff --git a/actions/lua-utils.cc b/actions/lua-utils.cc
deleted file mode 100644
index edeadf9..0000000
--- a/actions/lua-utils.cc
+++ /dev/null
@@ -1,354 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/lua-utils.h"
-
-namespace libtextclassifier3 {
-namespace {
-static constexpr const char* kTextKey = "text";
-static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
-static constexpr const char* kGranularityKey = "granularity";
-static constexpr const char* kCollectionKey = "collection";
-static constexpr const char* kNameKey = "name";
-static constexpr const char* kScoreKey = "score";
-static constexpr const char* kPriorityScoreKey = "priority_score";
-static constexpr const char* kTypeKey = "type";
-static constexpr const char* kResponseTextKey = "response_text";
-static constexpr const char* kAnnotationKey = "annotation";
-static constexpr const char* kSpanKey = "span";
-static constexpr const char* kMessageKey = "message";
-static constexpr const char* kBeginKey = "begin";
-static constexpr const char* kEndKey = "end";
-static constexpr const char* kClassificationKey = "classification";
-static constexpr const char* kSerializedEntity = "serialized_entity";
-static constexpr const char* kEntityKey = "entity";
-} // namespace
-
-template <>
-int AnnotationIterator<ClassificationResult>::Item(
- const std::vector<ClassificationResult>* annotations, StringPiece key,
- lua_State* state) const {
- // Lookup annotation by collection.
- for (const ClassificationResult& annotation : *annotations) {
- if (key.Equals(annotation.collection)) {
- PushAnnotation(annotation, entity_data_schema_, env_);
- return 1;
- }
- }
- TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
- << " found.";
- lua_error(state);
- return 0;
-}
-
-template <>
-int AnnotationIterator<ActionSuggestionAnnotation>::Item(
- const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
- lua_State* state) const {
- // Lookup annotation by name.
- for (const ActionSuggestionAnnotation& annotation : *annotations) {
- if (key.Equals(annotation.name)) {
- PushAnnotation(annotation, entity_data_schema_, env_);
- return 1;
- }
- }
- TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
- lua_error(state);
- return 0;
-}
-
-void PushAnnotation(const ClassificationResult& classification,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env) {
- if (entity_data_schema == nullptr ||
- classification.serialized_entity_data.empty()) {
- // Empty table.
- lua_newtable(env->state());
- } else {
- env->PushFlatbuffer(entity_data_schema,
- flatbuffers::GetRoot<flatbuffers::Table>(
- classification.serialized_entity_data.data()));
- }
- lua_pushinteger(env->state(),
- classification.datetime_parse_result.time_ms_utc);
- lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
- lua_pushinteger(env->state(),
- classification.datetime_parse_result.granularity);
- lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
- env->PushString(classification.collection);
- lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
- lua_pushnumber(env->state(), classification.score);
- lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
- env->PushString(classification.serialized_entity_data);
- lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
-}
-
-void PushAnnotation(const ClassificationResult& classification,
- StringPiece text,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env) {
- PushAnnotation(classification, entity_data_schema, env);
- env->PushString(text);
- lua_setfield(env->state(), /*idx=*/-2, kTextKey);
-}
-
-void PushAnnotatedSpan(
- const AnnotatedSpan& annotated_span,
- const AnnotationIterator<ClassificationResult>& annotation_iterator,
- LuaEnvironment* env) {
- lua_newtable(env->state());
- {
- lua_newtable(env->state());
- lua_pushinteger(env->state(), annotated_span.span.first);
- lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
- lua_pushinteger(env->state(), annotated_span.span.second);
- lua_setfield(env->state(), /*idx=*/-2, kEndKey);
- }
- lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
- annotation_iterator.NewIterator(kClassificationKey,
- &annotated_span.classification, env->state());
- lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
-}
-
-MessageTextSpan ReadSpan(LuaEnvironment* env) {
- MessageTextSpan span;
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- const StringPiece key = env->ReadString(/*index=*/-2);
- if (key.Equals(kMessageKey)) {
- span.message_index =
- static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kBeginKey)) {
- span.span.first =
- static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kEndKey)) {
- span.span.second =
- static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kTextKey)) {
- span.text = env->ReadString(/*index=*/-1).ToString();
- } else {
- TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
- }
- lua_pop(env->state(), 1);
- }
- return span;
-}
-
-int ReadAnnotations(const reflection::Schema* entity_data_schema,
- LuaEnvironment* env,
- std::vector<ActionSuggestionAnnotation>* annotations) {
- if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected annotations table, got: "
- << lua_type(env->state(), /*idx=*/-1);
- lua_pop(env->state(), 1);
- lua_error(env->state());
- return LUA_ERRRUN;
- }
-
- // Read actions.
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected annotation table, got: "
- << lua_type(env->state(), /*idx=*/-1);
- lua_pop(env->state(), 1);
- continue;
- }
- annotations->push_back(ReadAnnotation(entity_data_schema, env));
- lua_pop(env->state(), 1);
- }
- return LUA_OK;
-}
-
-ActionSuggestionAnnotation ReadAnnotation(
- const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
- ActionSuggestionAnnotation annotation;
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- const StringPiece key = env->ReadString(/*index=*/-2);
- if (key.Equals(kNameKey)) {
- annotation.name = env->ReadString(/*index=*/-1).ToString();
- } else if (key.Equals(kSpanKey)) {
- annotation.span = ReadSpan(env);
- } else if (key.Equals(kEntityKey)) {
- annotation.entity = ReadClassificationResult(entity_data_schema, env);
- } else {
- TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
- }
- lua_pop(env->state(), 1);
- }
- return annotation;
-}
-
-ClassificationResult ReadClassificationResult(
- const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
- ClassificationResult classification;
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- const StringPiece key = env->ReadString(/*index=*/-2);
- if (key.Equals(kCollectionKey)) {
- classification.collection = env->ReadString(/*index=*/-1).ToString();
- } else if (key.Equals(kScoreKey)) {
- classification.score =
- static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kTimeUsecKey)) {
- classification.datetime_parse_result.time_ms_utc =
- static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kGranularityKey)) {
- classification.datetime_parse_result.granularity =
- static_cast<DatetimeGranularity>(
- lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kSerializedEntity)) {
- classification.serialized_entity_data =
- env->ReadString(/*index=*/-1).ToString();
- } else if (key.Equals(kEntityKey)) {
- auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
- env->ReadFlatbuffer(buffer.get());
- classification.serialized_entity_data = buffer->Serialize();
- } else {
- TC3_LOG(INFO) << "Unknown classification result field: "
- << key.ToString();
- }
- lua_pop(env->state(), 1);
- }
- return classification;
-}
-
-void PushAnnotation(const ActionSuggestionAnnotation& annotation,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env) {
- PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
- env);
- env->PushString(annotation.name);
- lua_setfield(env->state(), /*idx=*/-2, kNameKey);
- {
- lua_newtable(env->state());
- lua_pushinteger(env->state(), annotation.span.message_index);
- lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
- lua_pushinteger(env->state(), annotation.span.span.first);
- lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
- lua_pushinteger(env->state(), annotation.span.span.second);
- lua_setfield(env->state(), /*idx=*/-2, kEndKey);
- }
- lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
-}
-
-void PushAction(
- const ActionSuggestion& action,
- const reflection::Schema* entity_data_schema,
- const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
- LuaEnvironment* env) {
- if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
- // Empty table.
- lua_newtable(env->state());
- } else {
- env->PushFlatbuffer(entity_data_schema,
- flatbuffers::GetRoot<flatbuffers::Table>(
- action.serialized_entity_data.data()));
- }
- env->PushString(action.type);
- lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
- env->PushString(action.response_text);
- lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
- lua_pushnumber(env->state(), action.score);
- lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
- lua_pushnumber(env->state(), action.priority_score);
- lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
- annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
- env->state());
- lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
-}
-
-ActionSuggestion ReadAction(
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- LuaEnvironment* env) {
- ActionSuggestion action;
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- const StringPiece key = env->ReadString(/*index=*/-2);
- if (key.Equals(kResponseTextKey)) {
- action.response_text = env->ReadString(/*index=*/-1).ToString();
- } else if (key.Equals(kTypeKey)) {
- action.type = env->ReadString(/*index=*/-1).ToString();
- } else if (key.Equals(kScoreKey)) {
- action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kPriorityScoreKey)) {
- action.priority_score =
- static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
- } else if (key.Equals(kAnnotationKey)) {
- ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
- } else if (key.Equals(kEntityKey)) {
- auto buffer =
- ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
- env->ReadFlatbuffer(buffer.get());
- action.serialized_entity_data = buffer->Serialize();
- } else {
- TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
- }
- lua_pop(env->state(), 1);
- }
- return action;
-}
-
-int ReadActions(const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
- if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected actions table, got: "
- << lua_type(env->state(), /*idx=*/-1);
- lua_pop(env->state(), 1);
- lua_error(env->state());
- return LUA_ERRRUN;
- }
-
- // Read actions.
- lua_pushnil(env->state());
- while (lua_next(env->state(), /*idx=*/-2)) {
- if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected action table, got: "
- << lua_type(env->state(), /*idx=*/-1);
- lua_pop(env->state(), 1);
- continue;
- }
- actions->push_back(ReadAction(actions_entity_data_schema,
- annotations_entity_data_schema, env));
- lua_pop(env->state(), /*n=1*/ 1);
- }
- lua_pop(env->state(), /*n=*/1);
-
- return LUA_OK;
-}
-
-int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
- const int64 pos, lua_State* state) const {
- const ConversationMessage& message = (*messages)[pos];
- lua_newtable(state);
- lua_pushinteger(state, message.user_id);
- lua_setfield(state, /*idx=*/-2, "user_id");
- env_->PushString(message.text);
- lua_setfield(state, /*idx=*/-2, "text");
- lua_pushinteger(state, message.reference_time_ms_utc);
- lua_setfield(state, /*idx=*/-2, "time_ms_utc");
- env_->PushString(message.reference_timezone);
- lua_setfield(state, /*idx=*/-2, "timezone");
- annotated_span_iterator_.NewIterator("annotation", &message.annotations,
- state);
- lua_setfield(state, /*idx=*/-2, "annotation");
- return 1;
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/lua-utils.h b/actions/lua-utils.h
deleted file mode 100644
index 4f06674..0000000
--- a/actions/lua-utils.h
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
-
-#include "actions/types.h"
-#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-#include "utils/lua-utils.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lua.h"
-#include "lualib.h"
-#ifdef __cplusplus
-}
-#endif
-
-// Action specific shared lua utilities.
-namespace libtextclassifier3 {
-
-// Provides an annotation to lua.
-void PushAnnotation(const ClassificationResult& classification,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env);
-void PushAnnotation(const ClassificationResult& classification,
- StringPiece text,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env);
-void PushAnnotation(const ActionSuggestionAnnotation& annotation,
- const reflection::Schema* entity_data_schema,
- LuaEnvironment* env);
-
-// A lua iterator to enumerate annotation.
-template <typename Annotation>
-class AnnotationIterator
- : public LuaEnvironment::ItemIterator<std::vector<Annotation>> {
- public:
- AnnotationIterator(const reflection::Schema* entity_data_schema,
- LuaEnvironment* env)
- : env_(env), entity_data_schema_(entity_data_schema) {}
- int Item(const std::vector<Annotation>* annotations, const int64 pos,
- lua_State* state) const override {
- PushAnnotation((*annotations)[pos], entity_data_schema_, env_);
- return 1;
- }
- int Item(const std::vector<Annotation>* annotations, StringPiece key,
- lua_State* state) const override;
-
- private:
- LuaEnvironment* env_;
- const reflection::Schema* entity_data_schema_;
-};
-
-template <>
-int AnnotationIterator<ClassificationResult>::Item(
- const std::vector<ClassificationResult>* annotations, StringPiece key,
- lua_State* state) const;
-
-template <>
-int AnnotationIterator<ActionSuggestionAnnotation>::Item(
- const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
- lua_State* state) const;
-
-void PushAnnotatedSpan(
- const AnnotatedSpan& annotated_span,
- const AnnotationIterator<ClassificationResult>& annotation_iterator,
- LuaEnvironment* env);
-
-MessageTextSpan ReadSpan(LuaEnvironment* env);
-ActionSuggestionAnnotation ReadAnnotation(
- const reflection::Schema* entity_data_schema, LuaEnvironment* env);
-int ReadAnnotations(const reflection::Schema* entity_data_schema,
- LuaEnvironment* env,
- std::vector<ActionSuggestionAnnotation>* annotations);
-ClassificationResult ReadClassificationResult(
- const reflection::Schema* entity_data_schema, LuaEnvironment* env);
-
-// A lua iterator to enumerate annotated spans.
-class AnnotatedSpanIterator
- : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> {
- public:
- AnnotatedSpanIterator(
- const AnnotationIterator<ClassificationResult>& annotation_iterator,
- LuaEnvironment* env)
- : env_(env), annotation_iterator_(annotation_iterator) {}
- AnnotatedSpanIterator(const reflection::Schema* entity_data_schema,
- LuaEnvironment* env)
- : env_(env), annotation_iterator_(entity_data_schema, env) {}
-
- int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos,
- lua_State* state) const override {
- PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_);
- return /*num results=*/1;
- }
-
- private:
- LuaEnvironment* env_;
- AnnotationIterator<ClassificationResult> annotation_iterator_;
-};
-
-// Provides an action to lua.
-void PushAction(
- const ActionSuggestion& action,
- const reflection::Schema* entity_data_schema,
- const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
- LuaEnvironment* env);
-
-ActionSuggestion ReadAction(
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- LuaEnvironment* env);
-int ReadActions(const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- LuaEnvironment* env, std::vector<ActionSuggestion>* actions);
-
-// A lua iterator to enumerate actions suggestions.
-class ActionsIterator
- : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> {
- public:
- ActionsIterator(const reflection::Schema* entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema,
- LuaEnvironment* env)
- : env_(env),
- entity_data_schema_(entity_data_schema),
- annotation_iterator_(annotations_entity_data_schema, env) {}
- int Item(const std::vector<ActionSuggestion>* actions, const int64 pos,
- lua_State* state) const override {
- PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_,
- env_);
- return /*num results=*/1;
- }
-
- private:
- LuaEnvironment* env_;
- const reflection::Schema* entity_data_schema_;
- AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
-};
-
-// Conversation message lua iterator.
-class ConversationIterator
- : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> {
- public:
- ConversationIterator(
- const AnnotationIterator<ClassificationResult>& annotation_iterator,
- LuaEnvironment* env)
- : env_(env),
- annotated_span_iterator_(
- AnnotatedSpanIterator(annotation_iterator, env)) {}
- ConversationIterator(const reflection::Schema* entity_data_schema,
- LuaEnvironment* env)
- : env_(env),
- annotated_span_iterator_(
- AnnotatedSpanIterator(entity_data_schema, env)) {}
-
- int Item(const std::vector<ConversationMessage>* messages, const int64 pos,
- lua_State* state) const override;
-
- private:
- LuaEnvironment* env_;
- AnnotatedSpanIterator annotated_span_iterator_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
diff --git a/actions/ngram-model.cc b/actions/ngram-model.cc
deleted file mode 100644
index 2263617..0000000
--- a/actions/ngram-model.cc
+++ /dev/null
@@ -1,209 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/ngram-model.h"
-
-#include <algorithm>
-
-#include "actions/feature-processor.h"
-#include "utils/hash/farmhash.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-// An iterator to iterate over the initial tokens of the n-grams of a model.
-class FirstTokenIterator
- : public std::iterator<std::random_access_iterator_tag,
- /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
- /*pointer=*/const uint32*,
- /*reference=*/uint32&> {
- public:
- explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
- int index)
- : model_(model), index_(index) {}
-
- FirstTokenIterator& operator++() {
- index_++;
- return *this;
- }
- FirstTokenIterator& operator+=(ptrdiff_t dist) {
- index_ += dist;
- return *this;
- }
- ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
- return index_ - other_it.index_;
- }
- uint32 operator*() const {
- const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
- return (*model_->hashed_ngram_tokens())[token_offset];
- }
- int index() const { return index_; }
-
- private:
- const NGramLinearRegressionModel* model_;
- int index_;
-};
-
-} // anonymous namespace
-
-std::unique_ptr<NGramModel> NGramModel::Create(
- const NGramLinearRegressionModel* model, const Tokenizer* tokenizer,
- const UniLib* unilib) {
- if (model == nullptr) {
- return nullptr;
- }
- if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
- TC3_LOG(ERROR) << "No tokenizer options specified.";
- return nullptr;
- }
- return std::unique_ptr<NGramModel>(new NGramModel(model, tokenizer, unilib));
-}
-
-NGramModel::NGramModel(const NGramLinearRegressionModel* model,
- const Tokenizer* tokenizer, const UniLib* unilib)
- : model_(model) {
- // Create new tokenizer if options are specified, reuse feature processor
- // tokenizer otherwise.
- if (model->tokenizer_options() != nullptr) {
- owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
- tokenizer_ = owned_tokenizer_.get();
- } else {
- tokenizer_ = tokenizer;
- }
-}
-
-// Returns whether a given n-gram matches the token stream.
-bool NGramModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
- const uint32* ngram_tokens,
- size_t num_ngram_tokens, int max_skips) const {
- int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
- for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
- if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
- // Token matches. Advance both and reset the skip budget.
- ++token_idx;
- ++ngram_token_idx;
- skip_remain = max_skips;
- } else if (skip_remain > 0) {
- // No match, but we have skips left, so just advance over the token.
- ++token_idx;
- skip_remain--;
- } else {
- // No match and we're out of skips. Reject.
- return false;
- }
- }
- return ngram_token_idx == num_ngram_tokens;
-}
-
-// Calculates the total number of skip-grams that can be created for a stream
-// with the given number of tokens.
-uint64 NGramModel::GetNumSkipGrams(int num_tokens, int max_ngram_length,
- int max_skips) {
- // Start with unigrams.
- uint64 total = num_tokens;
- for (int ngram_len = 2;
- ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
- // We can easily compute the expected length of the n-gram (with skips),
- // but it doesn't account for the fact that they may be longer than the
- // input and should be pruned.
- // Instead, we iterate over the distribution of effective n-gram lengths
- // and add each length individually.
- const int num_gaps = ngram_len - 1;
- const int len_min = ngram_len;
- const int len_max = ngram_len + num_gaps * max_skips;
- const int len_mid = (len_max + len_min) / 2;
- for (int len_i = len_min; len_i <= len_max; ++len_i) {
- if (len_i > num_tokens) continue;
- const int num_configs_of_len_i =
- len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
- const int num_start_offsets = num_tokens - len_i + 1;
- total += num_configs_of_len_i * num_start_offsets;
- }
- }
- return total;
-}
-
-std::pair<int, int> NGramModel::GetFirstTokenMatches(uint32 token_hash) const {
- const int num_ngrams = model_->ngram_weights()->size();
- const auto start_it = FirstTokenIterator(model_, 0);
- const auto end_it = FirstTokenIterator(model_, num_ngrams);
- const int start = std::lower_bound(start_it, end_it, token_hash).index();
- const int end = std::upper_bound(start_it, end_it, token_hash).index();
- return std::make_pair(start, end);
-}
-
-bool NGramModel::Eval(const UnicodeText& text, float* score) const {
- const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
-
- // If we have no tokens, then just bail early.
- if (raw_tokens.empty()) {
- if (score != nullptr) {
- *score = model_->default_token_weight();
- }
- return false;
- }
-
- // Hash the tokens.
- std::vector<uint32> tokens;
- tokens.reserve(raw_tokens.size());
- for (const Token& raw_token : raw_tokens) {
- tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
- raw_token.value.length()));
- }
-
- // Calculate the total number of skip-grams that can be generated for the
- // input text.
- const uint64 num_candidates = GetNumSkipGrams(
- tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
-
- // For each token, see whether it denotes the start of an n-gram in the model.
- int num_matches = 0;
- float weight_matches = 0.f;
- for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
- const std::pair<int, int> ngram_range =
- GetFirstTokenMatches(tokens[start_i]);
- for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
- ++ngram_idx) {
- const uint16 ngram_tokens_begin =
- (*model_->ngram_start_offsets())[ngram_idx];
- const uint16 ngram_tokens_end =
- (*model_->ngram_start_offsets())[ngram_idx + 1];
- if (IsNGramMatch(
- /*tokens=*/tokens.data() + start_i,
- /*num_tokens=*/tokens.size() - start_i,
- /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
- ngram_tokens_begin,
- /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
- /*max_skips=*/model_->max_skips())) {
- ++num_matches;
- weight_matches += (*model_->ngram_weights())[ngram_idx];
- }
- }
- }
-
- // Calculate the score.
- const int num_misses = num_candidates - num_matches;
- const float internal_score =
- (weight_matches + (model_->default_token_weight() * num_misses)) /
- num_candidates;
- if (score != nullptr) {
- *score = internal_score;
- }
- return internal_score > model_->threshold();
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/ngram-model.h b/actions/ngram-model.h
deleted file mode 100644
index ec0b606..0000000
--- a/actions/ngram-model.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
-
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-class NGramModel {
- public:
- static std::unique_ptr<NGramModel> Create(
- const NGramLinearRegressionModel* model, const Tokenizer* tokenizer,
- const UniLib* unilib);
-
- // Evaluates an n-gram linear regression model, and tests against the
- // threshold. Returns true in case of a positive classification. The caller
- // may also optionally query the score.
- bool Eval(const UnicodeText& text, float* score = nullptr) const;
-
- // Exposed for testing only.
- static uint64 GetNumSkipGrams(int num_tokens, int max_ngram_length,
- int max_skips);
-
- private:
- NGramModel(const NGramLinearRegressionModel* model,
- const Tokenizer* tokenizer, const UniLib* unilib);
-
- // Returns the (begin,end] range of n-grams where the first hashed token
- // matches the given value.
- std::pair<int, int> GetFirstTokenMatches(uint32 token_hash) const;
-
- // Returns whether a given n-gram matches the token stream.
- bool IsNGramMatch(const uint32* tokens, size_t num_tokens,
- const uint32* ngram_tokens, size_t num_ngram_tokens,
- int max_skips) const;
-
- const NGramLinearRegressionModel* model_;
- const Tokenizer* tokenizer_;
- std::unique_ptr<Tokenizer> owned_tokenizer_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
diff --git a/actions/test_data/actions_suggestions_test.default.model b/actions/test_data/actions_suggestions_test.default.model
deleted file mode 100644
index 60f10e6..0000000
--- a/actions/test_data/actions_suggestions_test.default.model
+++ /dev/null
Binary files differ
diff --git a/actions/test_data/actions_suggestions_test.hashgram.model b/actions/test_data/actions_suggestions_test.hashgram.model
deleted file mode 100644
index cdc6bdc..0000000
--- a/actions/test_data/actions_suggestions_test.hashgram.model
+++ /dev/null
Binary files differ
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
deleted file mode 100644
index 6cec2b7..0000000
--- a/actions/test_data/actions_suggestions_test.model
+++ /dev/null
Binary files differ
diff --git a/actions/test_utils.cc b/actions/test_utils.cc
deleted file mode 100644
index 187aa67..0000000
--- a/actions/test_utils.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/test_utils.h"
-
-namespace libtextclassifier3 {
-
-std::string TestEntityDataSchema() {
- // Create fake entity data schema meta data.
- // Cannot use object oriented API here as that is not available for the
- // reflection schema.
- flatbuffers::FlatBufferBuilder schema_builder;
- std::vector<flatbuffers::Offset<reflection::Field>> fields = {
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("greeting"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/0,
- /*offset=*/4),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("location"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/1,
- /*offset=*/6),
- reflection::CreateField(
- schema_builder,
- /*name=*/schema_builder.CreateString("person"),
- /*type=*/
- reflection::CreateType(schema_builder,
- /*base_type=*/reflection::String),
- /*id=*/2,
- /*offset=*/8)};
- std::vector<flatbuffers::Offset<reflection::Enum>> enums;
- std::vector<flatbuffers::Offset<reflection::Object>> objects = {
- reflection::CreateObject(
- schema_builder,
- /*name=*/schema_builder.CreateString("EntityData"),
- /*fields=*/
- schema_builder.CreateVectorOfSortedTables(&fields))};
- schema_builder.Finish(reflection::CreateSchema(
- schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
- schema_builder.CreateVectorOfSortedTables(&enums),
- /*(unused) file_ident=*/0,
- /*(unused) file_ext=*/0,
- /*root_table*/ objects[0]));
-
- return std::string(
- reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
- schema_builder.GetSize());
-}
-
-void SetTestEntityDataSchema(ActionsModelT* test_model) {
- const std::string serialized_schema = TestEntityDataSchema();
-
- test_model->actions_entity_data_schema.assign(
- serialized_schema.data(),
- serialized_schema.data() + serialized_schema.size());
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/test_utils.h b/actions/test_utils.h
deleted file mode 100644
index 618523c..0000000
--- a/actions/test_utils.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
-
-#include <string>
-#include "actions/actions_model_generated.h"
-#include "utils/flatbuffers.h"
-
-namespace libtextclassifier3 {
-
-// Create test entity data schema.
-std::string TestEntityDataSchema();
-void SetTestEntityDataSchema(ActionsModelT* test_model);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
diff --git a/actions/types.h b/actions/types.h
deleted file mode 100644
index 212cfda..0000000
--- a/actions/types.h
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
-#define LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
-
-#include <map>
-#include <string>
-#include <vector>
-
-#include "actions/actions-entity-data_generated.h"
-#include "annotator/types.h"
-#include "utils/flatbuffers.h"
-
-namespace libtextclassifier3 {
-
-// A text span in the conversation.
-struct MessageTextSpan {
- // The referenced message.
- // -1 if not referencing a particular message in the provided input.
- int message_index;
-
- // The span within the reference message.
- // (-1, -1) if not referencing a particular location.
- CodepointSpan span;
-
- // The span text.
- std::string text;
-
- explicit MessageTextSpan()
- : message_index(kInvalidIndex), span({kInvalidIndex, kInvalidIndex}) {}
- MessageTextSpan(const int message_index, const CodepointSpan span,
- const std::string& text)
- : message_index(message_index), span(span), text(text) {}
-};
-
-// An entity associated with an action.
-struct ActionSuggestionAnnotation {
- MessageTextSpan span;
- ClassificationResult entity;
-
- // Optional annotation name.
- std::string name;
-};
-
-// Action suggestion that contains a response text and the type of the response.
-struct ActionSuggestion {
- // Text of the action suggestion.
- std::string response_text;
-
- // Type of the action suggestion.
- std::string type;
-
- // Score.
- float score;
-
- // Priority score for internal conflict resolution.
- float priority_score;
-
- // The associated annotations.
- std::vector<ActionSuggestionAnnotation> annotations;
-
- // Extras information.
- std::string serialized_entity_data;
-
- const ActionsEntityData* entity_data() {
- return LoadAndVerifyFlatbuffer<ActionsEntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- }
-};
-
-// Actions suggestions result containing meta - information and the suggested
-// actions.
-struct ActionsSuggestionsResponse {
- ActionsSuggestionsResponse()
- : sensitivity_score(-1),
- triggering_score(-1),
- output_filtered_sensitivity(false),
- output_filtered_min_triggering_score(false),
- output_filtered_low_confidence(false),
- output_filtered_locale_mismatch(false) {}
-
- // The sensitivity assessment.
- float sensitivity_score;
- float triggering_score;
-
- // Whether the output was suppressed by the sensitivity threshold.
- bool output_filtered_sensitivity;
-
- // Whether the output was suppressed by the triggering score threshold.
- bool output_filtered_min_triggering_score;
-
- // Whether the output was suppressed by the low confidence patterns.
- bool output_filtered_low_confidence;
-
- // Whether the output was suppressed due to locale mismatch.
- bool output_filtered_locale_mismatch;
-
- // The suggested actions.
- std::vector<ActionSuggestion> actions;
-};
-
-// Represents a single message in the conversation.
-struct ConversationMessage {
- // User ID distinguishing the user from other users in the conversation.
- int user_id;
-
- // Text of the message.
- std::string text;
-
- // Reference time of this message.
- int64 reference_time_ms_utc;
-
- // Timezone in which the input text was written (format as accepted by ICU).
- std::string reference_timezone;
-
- // Annotations on the text.
- std::vector<AnnotatedSpan> annotations;
-
- // Comma-separated list of BCP 47 language tags of the message.
- std::string detected_text_language_tags;
-};
-
-// Conversation between multiple users.
-struct Conversation {
- // Sequence of messages that were exchanged in the conversation.
- std::vector<ConversationMessage> messages;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc
deleted file mode 100644
index b1d997d..0000000
--- a/actions/zlib-utils.cc
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/zlib-utils.h"
-
-#include <memory>
-
-#include "utils/base/logging.h"
-#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
-
-namespace libtextclassifier3 {
-
-// Compress rule fields in the model.
-bool CompressActionsModel(ActionsModelT* model) {
- std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
- if (!zlib_compressor) {
- TC3_LOG(ERROR) << "Cannot compress model.";
- return false;
- }
-
- // Compress regex rules.
- if (model->rules != nullptr) {
- for (int i = 0; i < model->rules->rule.size(); i++) {
- RulesModel_::RuleT* rule = model->rules->rule[i].get();
- rule->compressed_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
- rule->pattern.clear();
- }
- }
-
- if (model->low_confidence_rules != nullptr) {
- for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
- RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
- if (!rule->pattern.empty()) {
- rule->compressed_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(rule->pattern,
- rule->compressed_pattern.get());
- rule->pattern.clear();
- }
- if (!rule->output_pattern.empty()) {
- rule->compressed_output_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(rule->pattern,
- rule->compressed_output_pattern.get());
- rule->output_pattern.clear();
- }
- }
- }
-
- if (!model->lua_actions_script.empty()) {
- model->compressed_lua_actions_script.reset(new CompressedBufferT);
- zlib_compressor->Compress(model->lua_actions_script,
- model->compressed_lua_actions_script.get());
- }
-
- if (model->ranking_options != nullptr &&
- !model->ranking_options->lua_ranking_script.empty()) {
- model->ranking_options->compressed_lua_ranking_script.reset(
- new CompressedBufferT);
- zlib_compressor->Compress(
- model->ranking_options->lua_ranking_script,
- model->ranking_options->compressed_lua_ranking_script.get());
- }
-
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
- // Compress intent generator.
- if (model->android_intent_options != nullptr) {
- CompressIntentModel(model->android_intent_options.get());
- }
-
- return true;
-}
-
-bool DecompressActionsModel(ActionsModelT* model) {
- std::unique_ptr<ZlibDecompressor> zlib_decompressor =
- ZlibDecompressor::Instance();
- if (!zlib_decompressor) {
- TC3_LOG(ERROR) << "Cannot initialize decompressor.";
- return false;
- }
-
- // Decompress regex rules.
- if (model->rules != nullptr) {
- for (int i = 0; i < model->rules->rule.size(); i++) {
- RulesModel_::RuleT* rule = model->rules->rule[i].get();
- if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
- &rule->pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
- return false;
- }
- rule->compressed_pattern.reset(nullptr);
- }
- }
-
- // Decompress low confidence rules.
- if (model->low_confidence_rules != nullptr) {
- for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
- RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
- if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
- &rule->pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
- return false;
- }
- if (!zlib_decompressor->MaybeDecompress(
- rule->compressed_output_pattern.get(), &rule->output_pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
- return false;
- }
- rule->compressed_pattern.reset(nullptr);
- rule->compressed_output_pattern.reset(nullptr);
- }
- }
-
- if (!zlib_decompressor->MaybeDecompress(
- model->compressed_lua_actions_script.get(),
- &model->lua_actions_script)) {
- TC3_LOG(ERROR) << "Cannot decompress actions script.";
- return false;
- }
-
- if (model->ranking_options != nullptr &&
- !zlib_decompressor->MaybeDecompress(
- model->ranking_options->compressed_lua_ranking_script.get(),
- &model->ranking_options->lua_ranking_script)) {
- TC3_LOG(ERROR) << "Cannot decompress actions script.";
- return false;
- }
-
- return true;
-}
-
-std::string CompressSerializedActionsModel(const std::string& model) {
- std::unique_ptr<ActionsModelT> unpacked_model =
- UnPackActionsModel(model.c_str());
- TC3_CHECK(unpacked_model != nullptr);
- TC3_CHECK(CompressActionsModel(unpacked_model.get()));
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder,
- ActionsModel::Pack(builder, unpacked_model.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
- const CompressedBuffer* compressed_buffer,
- ZlibDecompressor* decompressor, std::string* out) {
- if (uncompressed_buffer == nullptr && compressed_buffer == nullptr) {
- out->clear();
- return true;
- }
-
- return decompressor->MaybeDecompressOptionallyCompressedBuffer(
- uncompressed_buffer, compressed_buffer, out);
-}
-
-} // namespace libtextclassifier3
diff --git a/actions/zlib-utils_test.cc b/actions/zlib-utils_test.cc
deleted file mode 100644
index 377f344..0000000
--- a/actions/zlib-utils_test.cc
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "actions/zlib-utils.h"
-
-#include <memory>
-
-#include "actions/actions_model_generated.h"
-#include "utils/zlib/zlib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-
-namespace {
-
-TEST(ZlibUtilsTest, CompressModel) {
- ActionsModelT model;
- constexpr char kTestPattern1[] = "this is a test pattern";
- constexpr char kTestPattern2[] = "this is a second test pattern";
- model.rules.reset(new RulesModelT);
- model.rules->rule.emplace_back(new RulesModel_::RuleT);
- model.rules->rule.back()->pattern = kTestPattern1;
- model.rules->rule.emplace_back(new RulesModel_::RuleT);
- model.rules->rule.back()->pattern = kTestPattern2;
-
- // Compress the model.
- EXPECT_TRUE(CompressActionsModel(&model));
-
- // Sanity check that uncompressed field is removed.
- EXPECT_TRUE(model.rules->rule[0]->pattern.empty());
- EXPECT_TRUE(model.rules->rule[1]->pattern.empty());
- // Pack and load the model.
- flatbuffers::FlatBufferBuilder builder;
- FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model));
- const ActionsModel* compressed_model = GetActionsModel(
- reinterpret_cast<const char*>(builder.GetBufferPointer()));
- ASSERT_TRUE(compressed_model != nullptr);
-
- // Decompress the fields again and check that they match the original.
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- ASSERT_TRUE(decompressor != nullptr);
- std::string uncompressed_pattern;
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->rules()->rule()->Get(0)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, kTestPattern1);
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->rules()->rule()->Get(1)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, kTestPattern2);
- EXPECT_TRUE(DecompressActionsModel(&model));
- EXPECT_EQ(model.rules->rule[0]->pattern, kTestPattern1);
- EXPECT_EQ(model.rules->rule[1]->pattern, kTestPattern2);
-}
-
-} // namespace
-
-} // namespace libtextclassifier3
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
deleted file mode 100644
index 53c8d8a..0000000
--- a/annotator/annotator.cc
+++ /dev/null
@@ -1,2375 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/annotator.h"
-
-#include <algorithm>
-#include <cctype>
-#include <cmath>
-#include <iterator>
-#include <numeric>
-#include <unordered_map>
-
-#include "annotator/collections.h"
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-#include "utils/checksum.h"
-#include "utils/math/softmax.h"
-#include "utils/regex-match.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/zlib/zlib_regex.h"
-
-
-namespace libtextclassifier3 {
-
-using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
-
-const std::string& Annotator::kPhoneCollection =
- *[]() { return new std::string("phone"); }();
-const std::string& Annotator::kAddressCollection =
- *[]() { return new std::string("address"); }();
-const std::string& Annotator::kDateCollection =
- *[]() { return new std::string("date"); }();
-const std::string& Annotator::kUrlCollection =
- *[]() { return new std::string("url"); }();
-const std::string& Annotator::kEmailCollection =
- *[]() { return new std::string("email"); }();
-
-namespace {
-const Model* LoadAndVerifyModel(const void* addr, int size) {
- flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
- if (VerifyModelBuffer(verifier)) {
- return GetModel(addr);
- } else {
- return nullptr;
- }
-}
-
-// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
-// create a new instance, assign ownership to owned_lib, and return it.
-const UniLib* MaybeCreateUnilib(const UniLib* lib,
- std::unique_ptr<UniLib>* owned_lib) {
- if (lib) {
- return lib;
- } else {
- owned_lib->reset(new UniLib);
- return owned_lib->get();
- }
-}
-
-// As above, but for CalendarLib.
-const CalendarLib* MaybeCreateCalendarlib(
- const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
- if (lib) {
- return lib;
- } else {
- owned_lib->reset(new CalendarLib);
- return owned_lib->get();
- }
-}
-
-} // namespace
-
-tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
- if (!selection_interpreter_) {
- TC3_CHECK(selection_executor_);
- selection_interpreter_ = selection_executor_->CreateInterpreter();
- if (!selection_interpreter_) {
- TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
- }
- }
- return selection_interpreter_.get();
-}
-
-tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
- if (!classification_interpreter_) {
- TC3_CHECK(classification_executor_);
- classification_interpreter_ = classification_executor_->CreateInterpreter();
- if (!classification_interpreter_) {
- TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
- }
- }
- return classification_interpreter_.get();
-}
-
-std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
- const char* buffer, int size, const UniLib* unilib,
- const CalendarLib* calendarlib) {
- const Model* model = LoadAndVerifyModel(buffer, size);
- if (model == nullptr) {
- return nullptr;
- }
-
- auto classifier =
- std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
- if (!classifier->IsInitialized()) {
- return nullptr;
- }
-
- return classifier;
-}
-
-
-std::unique_ptr<Annotator> Annotator::FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
- const CalendarLib* calendarlib) {
- if (!(*mmap)->handle().ok()) {
- TC3_VLOG(1) << "Mmap failed.";
- return nullptr;
- }
-
- const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
- (*mmap)->handle().num_bytes());
- if (!model) {
- TC3_LOG(ERROR) << "Model verification failed.";
- return nullptr;
- }
-
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, unilib, calendarlib));
- if (!classifier->IsInitialized()) {
- return nullptr;
- }
-
- return classifier;
-}
-
-std::unique_ptr<Annotator> Annotator::FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib) {
- if (!(*mmap)->handle().ok()) {
- TC3_VLOG(1) << "Mmap failed.";
- return nullptr;
- }
-
- const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
- (*mmap)->handle().num_bytes());
- if (model == nullptr) {
- TC3_LOG(ERROR) << "Model verification failed.";
- return nullptr;
- }
-
- auto classifier = std::unique_ptr<Annotator>(
- new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
- if (!classifier->IsInitialized()) {
- return nullptr;
- }
-
- return classifier;
-}
-
-std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
- int fd, int offset, int size, const UniLib* unilib,
- const CalendarLib* calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
- return FromScopedMmap(&mmap, unilib, calendarlib);
-}
-
-std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
- int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
- return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
-}
-
-std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
- int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
- return FromScopedMmap(&mmap, unilib, calendarlib);
-}
-
-std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
- int fd, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
- return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
-}
-
-std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
- const UniLib* unilib,
- const CalendarLib* calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
- return FromScopedMmap(&mmap, unilib, calendarlib);
-}
-
-std::unique_ptr<Annotator> Annotator::FromPath(
- const std::string& path, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib) {
- std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
- return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
-}
-
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
-
-Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib)
- : model_(model),
- mmap_(std::move(*mmap)),
- owned_unilib_(std::move(unilib)),
- unilib_(owned_unilib_.get()),
- owned_calendarlib_(std::move(calendarlib)),
- calendarlib_(owned_calendarlib_.get()) {
- ValidateAndInitialize();
-}
-
-Annotator::Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib)
- : model_(model),
- owned_unilib_(nullptr),
- unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
- owned_calendarlib_(nullptr),
- calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
- ValidateAndInitialize();
-}
-
-void Annotator::ValidateAndInitialize() {
- initialized_ = false;
-
- if (model_ == nullptr) {
- TC3_LOG(ERROR) << "No model specified.";
- return;
- }
-
- const bool model_enabled_for_annotation =
- (model_->triggering_options() != nullptr &&
- (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
- const bool model_enabled_for_classification =
- (model_->triggering_options() != nullptr &&
- (model_->triggering_options()->enabled_modes() &
- ModeFlag_CLASSIFICATION));
- const bool model_enabled_for_selection =
- (model_->triggering_options() != nullptr &&
- (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
-
- // Annotation requires the selection model.
- if (model_enabled_for_annotation || model_enabled_for_selection) {
- if (!model_->selection_options()) {
- TC3_LOG(ERROR) << "No selection options.";
- return;
- }
- if (!model_->selection_feature_options()) {
- TC3_LOG(ERROR) << "No selection feature options.";
- return;
- }
- if (!model_->selection_feature_options()->bounds_sensitive_features()) {
- TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
- return;
- }
- if (!model_->selection_model()) {
- TC3_LOG(ERROR) << "No selection model.";
- return;
- }
- selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
- if (!selection_executor_) {
- TC3_LOG(ERROR) << "Could not initialize selection executor.";
- return;
- }
- selection_feature_processor_.reset(
- new FeatureProcessor(model_->selection_feature_options(), unilib_));
- }
-
- // Annotation requires the classification model for conflict resolution and
- // scoring.
- // Selection requires the classification model for conflict resolution.
- if (model_enabled_for_annotation || model_enabled_for_classification ||
- model_enabled_for_selection) {
- if (!model_->classification_options()) {
- TC3_LOG(ERROR) << "No classification options.";
- return;
- }
-
- if (!model_->classification_feature_options()) {
- TC3_LOG(ERROR) << "No classification feature options.";
- return;
- }
-
- if (!model_->classification_feature_options()
- ->bounds_sensitive_features()) {
- TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
- return;
- }
- if (!model_->classification_model()) {
- TC3_LOG(ERROR) << "No clf model.";
- return;
- }
-
- classification_executor_ =
- ModelExecutor::FromBuffer(model_->classification_model());
- if (!classification_executor_) {
- TC3_LOG(ERROR) << "Could not initialize classification executor.";
- return;
- }
-
- classification_feature_processor_.reset(new FeatureProcessor(
- model_->classification_feature_options(), unilib_));
- }
-
- // The embeddings need to be specified if the model is to be used for
- // classification or selection.
- if (model_enabled_for_annotation || model_enabled_for_classification ||
- model_enabled_for_selection) {
- if (!model_->embedding_model()) {
- TC3_LOG(ERROR) << "No embedding model.";
- return;
- }
-
- // Check that the embedding size of the selection and classification model
- // matches, as they are using the same embeddings.
- if (model_enabled_for_selection &&
- (model_->selection_feature_options()->embedding_size() !=
- model_->classification_feature_options()->embedding_size() ||
- model_->selection_feature_options()->embedding_quantization_bits() !=
- model_->classification_feature_options()
- ->embedding_quantization_bits())) {
- TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
- return;
- }
-
- embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
- model_->embedding_model(),
- model_->classification_feature_options()->embedding_size(),
- model_->classification_feature_options()->embedding_quantization_bits(),
- model_->embedding_pruning_mask());
- if (!embedding_executor_) {
- TC3_LOG(ERROR) << "Could not initialize embedding executor.";
- return;
- }
- }
-
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- if (model_->regex_model()) {
- if (!InitializeRegexModel(decompressor.get())) {
- TC3_LOG(ERROR) << "Could not initialize regex model.";
- return;
- }
- }
-
- if (model_->datetime_model()) {
- datetime_parser_ = DatetimeParser::Instance(
- model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
- if (!datetime_parser_) {
- TC3_LOG(ERROR) << "Could not initialize datetime parser.";
- return;
- }
- }
-
- if (model_->output_options()) {
- if (model_->output_options()->filtered_collections_annotation()) {
- for (const auto collection :
- *model_->output_options()->filtered_collections_annotation()) {
- filtered_collections_annotation_.insert(collection->str());
- }
- }
- if (model_->output_options()->filtered_collections_classification()) {
- for (const auto collection :
- *model_->output_options()->filtered_collections_classification()) {
- filtered_collections_classification_.insert(collection->str());
- }
- }
- if (model_->output_options()->filtered_collections_selection()) {
- for (const auto collection :
- *model_->output_options()->filtered_collections_selection()) {
- filtered_collections_selection_.insert(collection->str());
- }
- }
- }
-
- if (model_->number_annotator_options() &&
- model_->number_annotator_options()->enabled()) {
- if (selection_feature_processor_ == nullptr) {
- TC3_LOG(ERROR)
- << "Could not initialize NumberAnnotator without a feature processor";
- return;
- }
-
- number_annotator_.reset(
- new NumberAnnotator(model_->number_annotator_options(),
- selection_feature_processor_.get()));
- }
-
- if (model_->duration_annotator_options() &&
- model_->duration_annotator_options()->enabled()) {
- duration_annotator_.reset(
- new DurationAnnotator(model_->duration_annotator_options(),
- selection_feature_processor_.get()));
- }
-
- if (model_->entity_data_schema()) {
- entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
- model_->entity_data_schema()->Data(),
- model_->entity_data_schema()->size());
- if (entity_data_schema_ == nullptr) {
- TC3_LOG(ERROR) << "Could not load entity data schema data.";
- return;
- }
-
- entity_data_builder_.reset(
- new ReflectiveFlatbufferBuilder(entity_data_schema_));
- } else {
- entity_data_schema_ = nullptr;
- entity_data_builder_ = nullptr;
- }
-
- if (model_->triggering_locales() &&
- !ParseLocales(model_->triggering_locales()->c_str(),
- &model_triggering_locales_)) {
- TC3_LOG(ERROR) << "Could not parse model supported locales.";
- return;
- }
-
- if (model_->triggering_options() != nullptr &&
- model_->triggering_options()->locales() != nullptr &&
- !ParseLocales(model_->triggering_options()->locales()->c_str(),
- &ml_model_triggering_locales_)) {
- TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
- return;
- }
-
- if (model_->triggering_options() != nullptr &&
- model_->triggering_options()->dictionary_locales() != nullptr &&
- !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
- &dictionary_locales_)) {
- TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
- return;
- }
-
- initialized_ = true;
-}
-
-bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
- if (!model_->regex_model()->patterns()) {
- return true;
- }
-
- // Initialize pattern recognizers.
- int regex_pattern_id = 0;
- for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
- std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
- UncompressMakeRegexPattern(
- *unilib_, regex_pattern->pattern(),
- regex_pattern->compressed_pattern(),
- model_->regex_model()->lazy_regex_compilation(), decompressor);
- if (!compiled_pattern) {
- TC3_LOG(INFO) << "Failed to load regex pattern";
- return false;
- }
-
- if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
- annotation_regex_patterns_.push_back(regex_pattern_id);
- }
- if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
- classification_regex_patterns_.push_back(regex_pattern_id);
- }
- if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
- selection_regex_patterns_.push_back(regex_pattern_id);
- }
- regex_patterns_.push_back({
- regex_pattern,
- std::move(compiled_pattern),
- });
- ++regex_pattern_id;
- }
-
- return true;
-}
-
-bool Annotator::InitializeKnowledgeEngine(
- const std::string& serialized_config) {
- std::unique_ptr<KnowledgeEngine> knowledge_engine(
- new KnowledgeEngine(unilib_));
- if (!knowledge_engine->Initialize(serialized_config)) {
- TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
- return false;
- }
- knowledge_engine_ = std::move(knowledge_engine);
- return true;
-}
-
-bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
- std::unique_ptr<ContactEngine> contact_engine(
- new ContactEngine(selection_feature_processor_.get(), unilib_));
- if (!contact_engine->Initialize(serialized_config)) {
- TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
- return false;
- }
- contact_engine_ = std::move(contact_engine);
- return true;
-}
-
-bool Annotator::InitializeInstalledAppEngine(
- const std::string& serialized_config) {
- std::unique_ptr<InstalledAppEngine> installed_app_engine(
- new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
- if (!installed_app_engine->Initialize(serialized_config)) {
- TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
- return false;
- }
- installed_app_engine_ = std::move(installed_app_engine);
- return true;
-}
-
-namespace {
-
-int CountDigits(const std::string& str, CodepointSpan selection_indices) {
- int count = 0;
- int i = 0;
- const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
- for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
- if (i >= selection_indices.first && i < selection_indices.second &&
- isdigit(*it)) {
- ++count;
- }
- }
- return count;
-}
-
-} // namespace
-
-namespace internal {
-// Helper function, which if the initial 'span' contains only white-spaces,
-// moves the selection to a single-codepoint selection on a left or right side
-// of this space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
- const UnicodeText& context_unicode,
- const UniLib& unilib) {
- TC3_CHECK(ValidNonEmptySpan(span));
-
- UnicodeText::const_iterator it;
-
- // Check that the current selection is all whitespaces.
- it = context_unicode.begin();
- std::advance(it, span.first);
- for (int i = 0; i < (span.second - span.first); ++i, ++it) {
- if (!unilib.IsWhitespace(*it)) {
- return span;
- }
- }
-
- CodepointSpan result;
-
- // Try moving left.
- result = span;
- it = context_unicode.begin();
- std::advance(it, span.first);
- while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
- --result.first;
- --it;
- }
- result.second = result.first + 1;
- if (!unilib.IsWhitespace(*it)) {
- return result;
- }
-
- // If moving left didn't find a non-whitespace character, just return the
- // original span.
- return span;
-}
-} // namespace internal
-
-bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
- return !span.classification.empty() &&
- filtered_collections_annotation_.find(
- span.classification[0].collection) !=
- filtered_collections_annotation_.end();
-}
-
-bool Annotator::FilteredForClassification(
- const ClassificationResult& classification) const {
- return filtered_collections_classification_.find(classification.collection) !=
- filtered_collections_classification_.end();
-}
-
-bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
- return !span.classification.empty() &&
- filtered_collections_selection_.find(
- span.classification[0].collection) !=
- filtered_collections_selection_.end();
-}
-
-namespace {
-inline bool ClassifiedAsOther(
- const std::vector<ClassificationResult>& classification) {
- return !classification.empty() &&
- classification[0].collection == Collections::Other();
-}
-
-float GetPriorityScore(
- const std::vector<ClassificationResult>& classification) {
- if (!classification.empty() && !ClassifiedAsOther(classification)) {
- return classification[0].priority_score;
- } else {
- return -1.0;
- }
-}
-} // namespace
-
-bool Annotator::VerifyRegexMatchCandidate(
- const std::string& context, const VerificationOptions* verification_options,
- const std::string& match, const UniLib::RegexMatcher* matcher) const {
- if (verification_options == nullptr) {
- return true;
- }
- if (verification_options->verify_luhn_checksum() &&
- !VerifyLuhnChecksum(match)) {
- return false;
- }
- const int lua_verifier = verification_options->lua_verifier();
- if (lua_verifier >= 0) {
- if (model_->regex_model()->lua_verifier() == nullptr ||
- lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
- TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
- return false;
- }
- return VerifyMatch(
- context, matcher,
- model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
- }
- return true;
-}
-
-CodepointSpan Annotator::SuggestSelection(
- const std::string& context, CodepointSpan click_indices,
- const SelectionOptions& options) const {
- CodepointSpan original_click_indices = click_indices;
- if (!initialized_) {
- TC3_LOG(ERROR) << "Not initialized";
- return original_click_indices;
- }
- if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
- return original_click_indices;
- }
-
- std::vector<Locale> detected_text_language_tags;
- if (!ParseLocales(options.detected_text_language_tags,
- &detected_text_language_tags)) {
- TC3_LOG(WARNING)
- << "Failed to parse the detected_text_language_tags in options: "
- << options.detected_text_language_tags;
- }
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- model_triggering_locales_,
- /*default_value=*/true)) {
- return original_click_indices;
- }
-
- const UnicodeText context_unicode = UTF8ToUnicodeText(context,
- /*do_copy=*/false);
-
- if (!context_unicode.is_valid()) {
- return original_click_indices;
- }
-
- const int context_codepoint_size = context_unicode.size_codepoints();
-
- if (click_indices.first < 0 || click_indices.second < 0 ||
- click_indices.first >= context_codepoint_size ||
- click_indices.second > context_codepoint_size ||
- click_indices.first >= click_indices.second) {
- TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
- << click_indices.first << " " << click_indices.second;
- return original_click_indices;
- }
-
- if (model_->snap_whitespace_selections()) {
- // We want to expand a purely white-space selection to a multi-selection it
- // would've been part of. But with this feature disabled we would do a no-
- // op, because no token is found. Therefore, we need to modify the
- // 'click_indices' a bit to include a part of the token, so that the click-
- // finding logic finds the clicked token correctly. This modification is
- // done by the following function. Note, that it's enough to check the left
- // side of the current selection, because if the white-space is a part of a
- // multi-selection, necessarily both tokens - on the left and the right
- // sides need to be selected. Thus snapping only to the left is sufficient
- // (there's a check at the bottom that makes sure that if we snap to the
- // left token but the result does not contain the initial white-space,
- // returns the original indices).
- click_indices = internal::SnapLeftIfWhitespaceSelection(
- click_indices, context_unicode, *unilib_);
- }
-
- std::vector<AnnotatedSpan> candidates;
- InterpreterManager interpreter_manager(selection_executor_.get(),
- classification_executor_.get());
- std::vector<Token> tokens;
- if (!ModelSuggestSelection(context_unicode, click_indices,
- detected_text_language_tags, &interpreter_manager,
- &tokens, &candidates)) {
- TC3_LOG(ERROR) << "Model suggest selection failed.";
- return original_click_indices;
- }
- if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
- /*is_serialized_entity_data_enabled=*/false)) {
- TC3_LOG(ERROR) << "Regex suggest selection failed.";
- return original_click_indices;
- }
- if (!DatetimeChunk(
- UTF8ToUnicodeText(context, /*do_copy=*/false),
- /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
- options.locales, ModeFlag_SELECTION, options.annotation_usecase,
- /*is_serialized_entity_data_enabled=*/false, &candidates)) {
- TC3_LOG(ERROR) << "Datetime suggest selection failed.";
- return original_click_indices;
- }
- if (knowledge_engine_ != nullptr &&
- !knowledge_engine_->Chunk(context, &candidates)) {
- TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
- return original_click_indices;
- }
- if (contact_engine_ != nullptr &&
- !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
- TC3_LOG(ERROR) << "Contact suggest selection failed.";
- return original_click_indices;
- }
- if (installed_app_engine_ != nullptr &&
- !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
- TC3_LOG(ERROR) << "Installed app suggest selection failed.";
- return original_click_indices;
- }
- if (number_annotator_ != nullptr &&
- !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- &candidates)) {
- TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
- return original_click_indices;
- }
- if (duration_annotator_ != nullptr &&
- !duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, &candidates)) {
- TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
- return original_click_indices;
- }
-
- // Sort candidates according to their position in the input, so that the next
- // code can assume that any connected component of overlapping spans forms a
- // contiguous block.
- std::sort(candidates.begin(), candidates.end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- return a.span.first < b.span.first;
- });
-
- std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
- &interpreter_manager, &candidate_indices)) {
- TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
- return original_click_indices;
- }
-
- std::sort(candidate_indices.begin(), candidate_indices.end(),
- [&candidates](int a, int b) {
- return GetPriorityScore(candidates[a].classification) >
- GetPriorityScore(candidates[b].classification);
- });
-
- for (const int i : candidate_indices) {
- if (SpansOverlap(candidates[i].span, click_indices) &&
- SpansOverlap(candidates[i].span, original_click_indices)) {
- // Run model classification if not present but requested and there's a
- // classification collection filter specified.
- if (candidates[i].classification.empty() &&
- model_->selection_options()->always_classify_suggested_selection() &&
- !filtered_collections_selection_.empty()) {
- if (!ModelClassifyText(context, detected_text_language_tags,
- candidates[i].span, &interpreter_manager,
- /*embedding_cache=*/nullptr,
- &candidates[i].classification)) {
- return original_click_indices;
- }
- }
-
- // Ignore if span classification is filtered.
- if (FilteredForSelection(candidates[i])) {
- return original_click_indices;
- }
-
- return candidates[i].span;
- }
- }
-
- return original_click_indices;
-}
-
-namespace {
-// Helper function that returns the index of the first candidate that
-// transitively does not overlap with the candidate on 'start_index'. If the end
-// of 'candidates' is reached, it returns the index that points right behind the
-// array.
-int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
- int start_index) {
- int first_non_overlapping = start_index + 1;
- CodepointSpan conflicting_span = candidates[start_index].span;
- while (
- first_non_overlapping < candidates.size() &&
- SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
- // Grow the span to include the current one.
- conflicting_span.second = std::max(
- conflicting_span.second, candidates[first_non_overlapping].span.second);
-
- ++first_non_overlapping;
- }
- return first_non_overlapping;
-}
-} // namespace
-
-bool Annotator::ResolveConflicts(
- const std::vector<AnnotatedSpan>& candidates, const std::string& context,
- const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager, std::vector<int>* result) const {
- result->clear();
- result->reserve(candidates.size());
- for (int i = 0; i < candidates.size();) {
- int first_non_overlapping =
- FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
-
- const bool conflict_found = first_non_overlapping != (i + 1);
- if (conflict_found) {
- std::vector<int> candidate_indices;
- if (!ResolveConflict(context, cached_tokens, candidates,
- detected_text_language_tags, i,
- first_non_overlapping, annotation_usecase,
- interpreter_manager, &candidate_indices)) {
- return false;
- }
- result->insert(result->end(), candidate_indices.begin(),
- candidate_indices.end());
- } else {
- result->push_back(i);
- }
-
- // Skip over the whole conflicting group/go to next candidate.
- i = first_non_overlapping;
- }
- return true;
-}
-
-namespace {
-// Returns true, if the given two sources do conflict in given annotation
-// usecase.
-// - In SMART usecase, all sources do conflict, because there's only 1 possible
-// annotation for a given span.
-// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
-// and duration), while others not (e.g. duration and number).
-bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
- const AnnotatedSpan::Source source1,
- const AnnotatedSpan::Source source2) {
- uint32 source_mask =
- (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
-
- switch (annotation_usecase) {
- case AnnotationUsecase_ANNOTATION_USECASE_SMART:
- // In the SMART mode, all annotations conflict.
- return true;
-
- case AnnotationUsecase_ANNOTATION_USECASE_RAW:
- // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
- // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
- // hours" (duration).
- if ((source_mask &
- (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
- (source_mask &
- (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
- return false;
- }
-
- // A KNOWLEDGE entity does not conflict with anything.
- if ((source_mask &
- (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
- return false;
- }
-
- // Entities from other sources can conflict.
- return true;
- }
-}
-} // namespace
-
-bool Annotator::ResolveConflict(
- const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<AnnotatedSpan>& candidates,
- const std::vector<Locale>& detected_text_language_tags, int start_index,
- int end_index, AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager,
- std::vector<int>* chosen_indices) const {
- std::vector<int> conflicting_indices;
- std::unordered_map<int, float> scores;
- for (int i = start_index; i < end_index; ++i) {
- conflicting_indices.push_back(i);
- if (!candidates[i].classification.empty()) {
- scores[i] = GetPriorityScore(candidates[i].classification);
- continue;
- }
-
- // OPTIMIZATION: So that we don't have to classify all the ML model
- // spans apriori, we wait until we get here, when they conflict with
- // something and we need the actual classification scores. So if the
- // candidate conflicts and comes from the model, we need to run a
- // classification to determine its priority:
- std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- candidates[i].span, interpreter_manager,
- /*embedding_cache=*/nullptr, &classification)) {
- return false;
- }
-
- if (!classification.empty()) {
- scores[i] = GetPriorityScore(classification);
- }
- }
-
- std::sort(conflicting_indices.begin(), conflicting_indices.end(),
- [&scores](int i, int j) { return scores[i] > scores[j]; });
-
- // Here we keep a set of indices that were chosen, per-source, to enable
- // effective computation.
- std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
- chosen_indices_for_source_map;
-
- // Greedily place the candidates if they don't conflict with the already
- // placed ones.
- for (int i = 0; i < conflicting_indices.size(); ++i) {
- const int considered_candidate = conflicting_indices[i];
-
- // See if there is a conflict between the candidate and all already placed
- // candidates.
- bool conflict = false;
- SortedIntSet* chosen_indices_for_source_ptr = nullptr;
- for (auto& source_set_pair : chosen_indices_for_source_map) {
- if (source_set_pair.first == candidates[considered_candidate].source) {
- chosen_indices_for_source_ptr = &source_set_pair.second;
- }
-
- if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
- candidates[considered_candidate].source) &&
- DoesCandidateConflict(considered_candidate, candidates,
- source_set_pair.second)) {
- conflict = true;
- break;
- }
- }
-
- // Skip the candidate if a conflict was found.
- if (conflict) {
- continue;
- }
-
- // If the set of indices for the current source doesn't exist yet,
- // initialize it.
- if (chosen_indices_for_source_ptr == nullptr) {
- SortedIntSet new_set([&candidates](int a, int b) {
- return candidates[a].span.first < candidates[b].span.first;
- });
- chosen_indices_for_source_map[candidates[considered_candidate].source] =
- std::move(new_set);
- chosen_indices_for_source_ptr =
- &chosen_indices_for_source_map[candidates[considered_candidate]
- .source];
- }
-
- // Place the candidate to the output and to the per-source conflict set.
- chosen_indices->push_back(considered_candidate);
- chosen_indices_for_source_ptr->insert(considered_candidate);
- }
-
- std::sort(chosen_indices->begin(), chosen_indices->end());
-
- return true;
-}
-
-bool Annotator::ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
- const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
- if (model_->triggering_options() == nullptr ||
- !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
- return true;
- }
-
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- ml_model_triggering_locales_,
- /*default_value=*/true)) {
- return true;
- }
-
- int click_pos;
- *tokens = selection_feature_processor_->Tokenize(context_unicode);
- selection_feature_processor_->RetokenizeAndFindClick(
- context_unicode, click_indices,
- selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- tokens, &click_pos);
- if (click_pos == kInvalidIndex) {
- TC3_VLOG(1) << "Could not calculate the click position.";
- return false;
- }
-
- const int symmetry_context_size =
- model_->selection_options()->symmetry_context_size();
- const FeatureProcessorOptions_::BoundsSensitiveFeatures*
- bounds_sensitive_features = selection_feature_processor_->GetOptions()
- ->bounds_sensitive_features();
-
- // The symmetry context span is the clicked token with symmetry_context_size
- // tokens on either side.
- const TokenSpan symmetry_context_span = IntersectTokenSpans(
- ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/symmetry_context_size,
- /*num_tokens_right=*/symmetry_context_size),
- {0, tokens->size()});
-
- // Compute the extraction span based on the model type.
- TokenSpan extraction_span;
- if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
- // The extraction span is the symmetry context span expanded to include
- // max_selection_span tokens on either side, which is how far a selection
- // can stretch from the click, plus a relevant number of tokens outside of
- // the bounds of the selection.
- const int max_selection_span =
- selection_feature_processor_->GetOptions()->max_selection_span();
- extraction_span =
- ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/max_selection_span +
- bounds_sensitive_features->num_tokens_before(),
- /*num_tokens_right=*/max_selection_span +
- bounds_sensitive_features->num_tokens_after());
- } else {
- // The extraction span is the symmetry context span expanded to include
- // context_size tokens on either side.
- const int context_size =
- selection_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(symmetry_context_span,
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
- }
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
-
- if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, extraction_span)) {
- return true;
- }
-
- std::unique_ptr<CachedFeatures> cached_features;
- if (!selection_feature_processor_->ExtractFeatures(
- *tokens, extraction_span,
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- embedding_executor_.get(),
- /*embedding_cache=*/nullptr,
- selection_feature_processor_->EmbeddingSize() +
- selection_feature_processor_->DenseFeaturesCount(),
- &cached_features)) {
- TC3_LOG(ERROR) << "Could not extract features.";
- return false;
- }
-
- // Produce selection model candidates.
- std::vector<TokenSpan> chunks;
- if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
- interpreter_manager->SelectionInterpreter(), *cached_features,
- &chunks)) {
- TC3_LOG(ERROR) << "Could not chunk.";
- return false;
- }
-
- for (const TokenSpan& chunk : chunks) {
- AnnotatedSpan candidate;
- candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
- context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
- if (model_->selection_options()->strip_unpaired_brackets()) {
- candidate.span =
- StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
- }
-
- // Only output non-empty spans.
- if (candidate.span.first != candidate.span.second) {
- result->push_back(candidate);
- }
- }
- return true;
-}
-
-bool Annotator::ModelClassifyText(
- const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const {
- return ModelClassifyText(context, {}, detected_text_language_tags,
- selection_indices, interpreter_manager,
- embedding_cache, classification_results);
-}
-
-namespace internal {
-std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
- TokenSpan tokens_around_selection_to_copy) {
- const auto first_selection_token = std::upper_bound(
- cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
- [](int selection_start, const Token& token) {
- return selection_start < token.end;
- });
- const auto last_selection_token = std::lower_bound(
- cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
- [](const Token& token, int selection_end) {
- return token.start < selection_end;
- });
-
- const int64 first_token = std::max(
- static_cast<int64>(0),
- static_cast<int64>((first_selection_token - cached_tokens.begin()) -
- tokens_around_selection_to_copy.first));
- const int64 last_token = std::min(
- static_cast<int64>(cached_tokens.size()),
- static_cast<int64>((last_selection_token - cached_tokens.begin()) +
- tokens_around_selection_to_copy.second));
-
- std::vector<Token> tokens;
- tokens.reserve(last_token - first_token);
- for (int i = first_token; i < last_token; ++i) {
- tokens.push_back(cached_tokens[i]);
- }
- return tokens;
-}
-} // namespace internal
-
-TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
- const FeatureProcessorOptions_::BoundsSensitiveFeatures*
- bounds_sensitive_features =
- classification_feature_processor_->GetOptions()
- ->bounds_sensitive_features();
- if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
- // The extraction span is the selection span expanded to include a relevant
- // number of tokens outside of the bounds of the selection.
- return {bounds_sensitive_features->num_tokens_before(),
- bounds_sensitive_features->num_tokens_after()};
- } else {
- // The extraction span is the clicked token with context_size tokens on
- // either side.
- const int context_size =
- selection_feature_processor_->GetOptions()->context_size();
- return {context_size, context_size};
- }
-}
-
-namespace {
-// Sorts the classification results from high score to low score.
-void SortClassificationResults(
- std::vector<ClassificationResult>* classification_results) {
- std::sort(classification_results->begin(), classification_results->end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
-}
-} // namespace
-
-bool Annotator::ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const {
- std::vector<Token> tokens;
- return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- selection_indices, interpreter_manager,
- embedding_cache, classification_results, &tokens);
-}
-
-bool Annotator::ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results,
- std::vector<Token>* tokens) const {
- if (model_->triggering_options() == nullptr ||
- !(model_->triggering_options()->enabled_modes() &
- ModeFlag_CLASSIFICATION)) {
- return true;
- }
-
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- ml_model_triggering_locales_,
- /*default_value=*/true)) {
- return true;
- }
-
- if (cached_tokens.empty()) {
- *tokens = classification_feature_processor_->Tokenize(context);
- } else {
- *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
- ClassifyTextUpperBoundNeededTokens());
- }
-
- int click_pos;
- classification_feature_processor_->RetokenizeAndFindClick(
- context, selection_indices,
- classification_feature_processor_->GetOptions()
- ->only_use_line_with_click(),
- tokens, &click_pos);
- const TokenSpan selection_token_span =
- CodepointSpanToTokenSpan(*tokens, selection_indices);
- const int selection_num_tokens = TokenSpanSize(selection_token_span);
- if (model_->classification_options()->max_num_tokens() > 0 &&
- model_->classification_options()->max_num_tokens() <
- selection_num_tokens) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
-
- const FeatureProcessorOptions_::BoundsSensitiveFeatures*
- bounds_sensitive_features =
- classification_feature_processor_->GetOptions()
- ->bounds_sensitive_features();
- if (selection_token_span.first == kInvalidIndex ||
- selection_token_span.second == kInvalidIndex) {
- TC3_LOG(ERROR) << "Could not determine span.";
- return false;
- }
-
- // Compute the extraction span based on the model type.
- TokenSpan extraction_span;
- if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
- // The extraction span is the selection span expanded to include a relevant
- // number of tokens outside of the bounds of the selection.
- extraction_span = ExpandTokenSpan(
- selection_token_span,
- /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
- /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
- } else {
- if (click_pos == kInvalidIndex) {
- TC3_LOG(ERROR) << "Couldn't choose a click position.";
- return false;
- }
- // The extraction span is the clicked token with context_size tokens on
- // either side.
- const int context_size =
- classification_feature_processor_->GetOptions()->context_size();
- extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
- /*num_tokens_left=*/context_size,
- /*num_tokens_right=*/context_size);
- }
- extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
-
- if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, extraction_span)) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
-
- std::unique_ptr<CachedFeatures> cached_features;
- if (!classification_feature_processor_->ExtractFeatures(
- *tokens, extraction_span, selection_indices,
- embedding_executor_.get(), embedding_cache,
- classification_feature_processor_->EmbeddingSize() +
- classification_feature_processor_->DenseFeaturesCount(),
- &cached_features)) {
- TC3_LOG(ERROR) << "Could not extract features.";
- return false;
- }
-
- std::vector<float> features;
- features.reserve(cached_features->OutputFeaturesSize());
- if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
- cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
- &features);
- } else {
- cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
- }
-
- TensorView<float> logits = classification_executor_->ComputeLogits(
- TensorView<float>(features.data(),
- {1, static_cast<int>(features.size())}),
- interpreter_manager->ClassificationInterpreter());
- if (!logits.is_valid()) {
- TC3_LOG(ERROR) << "Couldn't compute logits.";
- return false;
- }
-
- if (logits.dims() != 2 || logits.dim(0) != 1 ||
- logits.dim(1) != classification_feature_processor_->NumCollections()) {
- TC3_LOG(ERROR) << "Mismatching output";
- return false;
- }
-
- const std::vector<float> scores =
- ComputeSoftmax(logits.data(), logits.dim(1));
-
- if (scores.empty()) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
-
- const int best_score_index =
- std::max_element(scores.begin(), scores.end()) - scores.begin();
- const std::string top_collection =
- classification_feature_processor_->LabelToCollection(best_score_index);
-
- // Sanity checks.
- if (top_collection == Collections::Phone()) {
- const int digit_count = CountDigits(context, selection_indices);
- if (digit_count <
- model_->classification_options()->phone_min_num_digits() ||
- digit_count >
- model_->classification_options()->phone_max_num_digits()) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
- } else if (top_collection == Collections::Address()) {
- if (selection_num_tokens <
- model_->classification_options()->address_min_num_tokens()) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
- } else if (top_collection == Collections::Dictionary()) {
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- dictionary_locales_,
- /*default_value=*/false)) {
- *classification_results = {{Collections::Other(), 1.0}};
- return true;
- }
- }
-
- *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
- return true;
-}
-
-bool Annotator::RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- std::vector<ClassificationResult>* classification_result) const {
- const std::string selection_text =
- UTF8ToUnicodeText(context, /*do_copy=*/false)
- .UTF8Substring(selection_indices.first, selection_indices.second);
- const UnicodeText selection_text_unicode(
- UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
-
- // Check whether any of the regular expressions match.
- for (const int pattern_id : classification_regex_patterns_) {
- const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
- const std::unique_ptr<UniLib::RegexMatcher> matcher =
- regex_pattern.pattern->Matcher(selection_text_unicode);
- int status = UniLib::RegexMatcher::kNoError;
- bool matches;
- if (regex_pattern.config->use_approximate_matching()) {
- matches = matcher->ApproximatelyMatches(&status);
- } else {
- matches = matcher->Matches(&status);
- }
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- if (matches && VerifyRegexMatchCandidate(
- context, regex_pattern.config->verification_options(),
- selection_text, matcher.get())) {
- classification_result->push_back(
- {regex_pattern.config->collection_name()->str(),
- regex_pattern.config->target_classification_score(),
- regex_pattern.config->priority_score()});
- if (!SerializedEntityDataFromRegexMatch(
- regex_pattern.config, matcher.get(),
- &classification_result->back().serialized_entity_data)) {
- TC3_LOG(ERROR) << "Could not get entity data.";
- return false;
- }
- }
- }
-
- return true;
-}
-
-namespace {
-std::string PickCollectionForDatetime(
- const DatetimeParseResult& datetime_parse_result) {
- switch (datetime_parse_result.granularity) {
- case GRANULARITY_HOUR:
- case GRANULARITY_MINUTE:
- case GRANULARITY_SECOND:
- return Collections::DateTime();
- default:
- return Collections::Date();
- }
-}
-
-std::string CreateDatetimeSerializedEntityData(
- const DatetimeParseResult& parse_result) {
- EntityDataT entity_data;
- entity_data.datetime.reset(new EntityData_::DatetimeT());
- entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
- entity_data.datetime->granularity =
- static_cast<EntityData_::Datetime_::Granularity>(
- parse_result.granularity);
-
- flatbuffers::FlatBufferBuilder builder;
- FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-} // namespace
-
-bool Annotator::DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options,
- std::vector<ClassificationResult>* classification_results) const {
- if (!datetime_parser_) {
- return false;
- }
-
- const std::string selection_text =
- UTF8ToUnicodeText(context, /*do_copy=*/false)
- .UTF8Substring(selection_indices.first, selection_indices.second);
-
- std::vector<DatetimeParseResultSpan> datetime_spans;
- if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
- options.reference_timezone, options.locales,
- ModeFlag_CLASSIFICATION,
- options.annotation_usecase,
- /*anchor_start_end=*/true, &datetime_spans)) {
- TC3_LOG(ERROR) << "Error during parsing datetime.";
- return false;
- }
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
- // Only consider the result valid if the selection and extracted datetime
- // spans exactly match.
- if (std::make_pair(datetime_span.span.first + selection_indices.first,
- datetime_span.span.second + selection_indices.first) ==
- selection_indices) {
- for (const DatetimeParseResult& parse_result : datetime_span.data) {
- classification_results->emplace_back(
- PickCollectionForDatetime(parse_result),
- datetime_span.target_classification_score);
- classification_results->back().datetime_parse_result = parse_result;
- classification_results->back().serialized_entity_data =
- CreateDatetimeSerializedEntityData(parse_result);
- classification_results->back().priority_score =
- datetime_span.priority_score;
- }
- return true;
- }
- }
- return true;
-}
-
-std::vector<ClassificationResult> Annotator::ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options) const {
- if (!initialized_) {
- TC3_LOG(ERROR) << "Not initialized";
- return {};
- }
-
- if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
- return {};
- }
-
- std::vector<Locale> detected_text_language_tags;
- if (!ParseLocales(options.detected_text_language_tags,
- &detected_text_language_tags)) {
- TC3_LOG(WARNING)
- << "Failed to parse the detected_text_language_tags in options: "
- << options.detected_text_language_tags;
- }
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- model_triggering_locales_,
- /*default_value=*/true)) {
- return {};
- }
-
- if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
- return {};
- }
-
- if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
- TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
- << std::get<0>(selection_indices) << " "
- << std::get<1>(selection_indices);
- return {};
- }
-
- // We'll accumulate a list of candidates, and pick the best candidate in the
- // end.
- std::vector<AnnotatedSpan> candidates;
-
- // Try the knowledge engine.
- // TODO(b/126579108): Propagate error status.
- ClassificationResult knowledge_result;
- if (knowledge_engine_ && knowledge_engine_->ClassifyText(
- context, selection_indices, &knowledge_result)) {
- candidates.push_back({selection_indices, {knowledge_result}});
- candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
- }
-
- // Try the contact engine.
- // TODO(b/126579108): Propagate error status.
- ClassificationResult contact_result;
- if (contact_engine_ && contact_engine_->ClassifyText(
- context, selection_indices, &contact_result)) {
- candidates.push_back({selection_indices, {contact_result}});
- }
-
- // Try the installed app engine.
- // TODO(b/126579108): Propagate error status.
- ClassificationResult installed_app_result;
- if (installed_app_engine_ &&
- installed_app_engine_->ClassifyText(context, selection_indices,
- &installed_app_result)) {
- candidates.push_back({selection_indices, {installed_app_result}});
- }
-
- // Try the regular expression models.
- std::vector<ClassificationResult> regex_results;
- if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
- return {};
- }
- for (const ClassificationResult& result : regex_results) {
- candidates.push_back({selection_indices, {result}});
- }
-
- // Try the date model.
- //
- // DatetimeClassifyText only returns the first result, which can however have
- // more interpretations. They are inserted in the candidates as a single
- // AnnotatedSpan, so that they get treated together by the conflict resolution
- // algorithm.
- std::vector<ClassificationResult> datetime_results;
- if (!DatetimeClassifyText(context, selection_indices, options,
- &datetime_results)) {
- return {};
- }
- if (!datetime_results.empty()) {
- candidates.push_back({selection_indices, std::move(datetime_results)});
- candidates.back().source = AnnotatedSpan::Source::DATETIME;
- }
-
- // Try the number annotator.
- // TODO(b/126579108): Propagate error status.
- ClassificationResult number_annotator_result;
- if (number_annotator_ &&
- number_annotator_->ClassifyText(
- UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
- options.annotation_usecase, &number_annotator_result)) {
- candidates.push_back({selection_indices, {number_annotator_result}});
- }
-
- // Try the duration annotator.
- ClassificationResult duration_annotator_result;
- if (duration_annotator_ &&
- duration_annotator_->ClassifyText(
- UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
- options.annotation_usecase, &duration_annotator_result)) {
- candidates.push_back({selection_indices, {duration_annotator_result}});
- candidates.back().source = AnnotatedSpan::Source::DURATION;
- }
-
- // Try the ML model.
- //
- // The output of the model is considered as an exclusive 1-of-N choice. That's
- // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
- // span for each candidate, like e.g. the regex model.
- InterpreterManager interpreter_manager(selection_executor_.get(),
- classification_executor_.get());
- std::vector<ClassificationResult> model_results;
- std::vector<Token> tokens;
- if (!ModelClassifyText(
- context, /*cached_tokens=*/{}, detected_text_language_tags,
- selection_indices, &interpreter_manager,
- /*embedding_cache=*/nullptr, &model_results, &tokens)) {
- return {};
- }
- if (!model_results.empty()) {
- candidates.push_back({selection_indices, std::move(model_results)});
- }
-
- std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
- &interpreter_manager, &candidate_indices)) {
- TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
- return {};
- }
-
- std::vector<ClassificationResult> results;
- for (const int i : candidate_indices) {
- for (const ClassificationResult& result : candidates[i].classification) {
- if (!FilteredForClassification(result)) {
- results.push_back(result);
- }
- }
- }
-
- // Sort results according to score.
- std::sort(results.begin(), results.end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
-
- if (results.empty()) {
- results = {{Collections::Other(), 1.0}};
- }
- return results;
-}
-
-bool Annotator::ModelAnnotate(
- const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
- if (model_->triggering_options() == nullptr ||
- !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
- return true;
- }
-
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- ml_model_triggering_locales_,
- /*default_value=*/true)) {
- return true;
- }
-
- const UnicodeText context_unicode = UTF8ToUnicodeText(context,
- /*do_copy=*/false);
- std::vector<UnicodeTextRange> lines;
- if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
- lines.push_back({context_unicode.begin(), context_unicode.end()});
- } else {
- lines = selection_feature_processor_->SplitContext(context_unicode);
- }
-
- const float min_annotate_confidence =
- (model_->triggering_options() != nullptr
- ? model_->triggering_options()->min_annotate_confidence()
- : 0.f);
-
- for (const UnicodeTextRange& line : lines) {
- FeatureProcessor::EmbeddingCache embedding_cache;
- const std::string line_str =
- UnicodeText::UTF8Substring(line.first, line.second);
-
- *tokens = selection_feature_processor_->Tokenize(line_str);
- selection_feature_processor_->RetokenizeAndFindClick(
- line_str, {0, std::distance(line.first, line.second)},
- selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- tokens,
- /*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0, tokens->size()};
-
- // TODO(zilka): Add support for greater granularity of this check.
- if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, full_line_span)) {
- continue;
- }
-
- std::unique_ptr<CachedFeatures> cached_features;
- if (!selection_feature_processor_->ExtractFeatures(
- *tokens, full_line_span,
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- embedding_executor_.get(),
- /*embedding_cache=*/nullptr,
- selection_feature_processor_->EmbeddingSize() +
- selection_feature_processor_->DenseFeaturesCount(),
- &cached_features)) {
- TC3_LOG(ERROR) << "Could not extract features.";
- return false;
- }
-
- std::vector<TokenSpan> local_chunks;
- if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
- interpreter_manager->SelectionInterpreter(),
- *cached_features, &local_chunks)) {
- TC3_LOG(ERROR) << "Could not chunk.";
- return false;
- }
-
- const int offset = std::distance(context_unicode.begin(), line.first);
- for (const TokenSpan& chunk : local_chunks) {
- const CodepointSpan codepoint_span =
- selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(*tokens, chunk));
-
- // Skip empty spans.
- if (codepoint_span.first != codepoint_span.second) {
- std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
- codepoint_span, interpreter_manager,
- &embedding_cache, &classification)) {
- TC3_LOG(ERROR) << "Could not classify text: "
- << (codepoint_span.first + offset) << " "
- << (codepoint_span.second + offset);
- return false;
- }
-
- // Do not include the span if it's classified as "other".
- if (!classification.empty() && !ClassifiedAsOther(classification) &&
- classification[0].score >= min_annotate_confidence) {
- AnnotatedSpan result_span;
- result_span.span = {codepoint_span.first + offset,
- codepoint_span.second + offset};
- result_span.classification = std::move(classification);
- result->push_back(std::move(result_span));
- }
- }
- }
- }
- return true;
-}
-
-const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
- return selection_feature_processor_.get();
-}
-
-const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
- const {
- return classification_feature_processor_.get();
-}
-
-const DatetimeParser* Annotator::DatetimeParserForTests() const {
- return datetime_parser_.get();
-}
-
-void Annotator::RemoveNotEnabledEntityTypes(
- const EnabledEntityTypes& is_entity_type_enabled,
- std::vector<AnnotatedSpan>* annotated_spans) const {
- for (AnnotatedSpan& annotated_span : *annotated_spans) {
- std::vector<ClassificationResult>& classifications =
- annotated_span.classification;
- classifications.erase(
- std::remove_if(classifications.begin(), classifications.end(),
- [&is_entity_type_enabled](
- const ClassificationResult& classification_result) {
- return !is_entity_type_enabled(
- classification_result.collection);
- }),
- classifications.end());
- }
- annotated_spans->erase(
- std::remove_if(annotated_spans->begin(), annotated_spans->end(),
- [](const AnnotatedSpan& annotated_span) {
- return annotated_span.classification.empty();
- }),
- annotated_spans->end());
-}
-
-std::vector<AnnotatedSpan> Annotator::Annotate(
- const std::string& context, const AnnotationOptions& options) const {
- std::vector<AnnotatedSpan> candidates;
-
- if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
- return {};
- }
-
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- if (!context_unicode.is_valid()) {
- return {};
- }
-
- std::vector<Locale> detected_text_language_tags;
- if (!ParseLocales(options.detected_text_language_tags,
- &detected_text_language_tags)) {
- TC3_LOG(WARNING)
- << "Failed to parse the detected_text_language_tags in options: "
- << options.detected_text_language_tags;
- }
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
- model_triggering_locales_,
- /*default_value=*/true)) {
- return {};
- }
-
- InterpreterManager interpreter_manager(selection_executor_.get(),
- classification_executor_.get());
-
- // Annotate with the selection model.
- std::vector<Token> tokens;
- if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
- &tokens, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
- return {};
- }
-
- // Annotate with the regular expression models.
- if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- annotation_regex_patterns_, &candidates,
- options.is_serialized_entity_data_enabled)) {
- TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
- return {};
- }
-
- // Annotate with the datetime model.
- const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
- if ((is_entity_type_enabled(Collections::Date()) ||
- is_entity_type_enabled(Collections::DateTime())) &&
- !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
- options.reference_time_ms_utc, options.reference_timezone,
- options.locales, ModeFlag_ANNOTATION,
- options.annotation_usecase,
- options.is_serialized_entity_data_enabled, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
- return {};
- }
-
- // Annotate with the knowledge engine.
- if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
- return {};
- }
-
- // Annotate with the contact engine.
- if (contact_engine_ &&
- !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
- return {};
- }
-
- // Annotate with the installed app engine.
- if (installed_app_engine_ &&
- !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
- return {};
- }
-
- // Annotate with the number annotator.
- if (number_annotator_ != nullptr &&
- !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
- &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
- return {};
- }
-
- // Annotate with the duration annotator.
- if (is_entity_type_enabled(Collections::Duration()) &&
- duration_annotator_ != nullptr &&
- !duration_annotator_->FindAll(context_unicode, tokens,
- options.annotation_usecase, &candidates)) {
- TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
- return {};
- }
-
- // Sort candidates according to their position in the input, so that the next
- // code can assume that any connected component of overlapping spans forms a
- // contiguous block.
- std::sort(candidates.begin(), candidates.end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- return a.span.first < b.span.first;
- });
-
- std::vector<int> candidate_indices;
- if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
- &interpreter_manager, &candidate_indices)) {
- TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
- return {};
- }
-
- std::vector<AnnotatedSpan> result;
- result.reserve(candidate_indices.size());
- AnnotatedSpan aggregated_span;
- for (const int i : candidate_indices) {
- if (candidates[i].span != aggregated_span.span) {
- if (!aggregated_span.classification.empty()) {
- result.push_back(std::move(aggregated_span));
- }
- aggregated_span =
- AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
- }
- if (candidates[i].classification.empty() ||
- ClassifiedAsOther(candidates[i].classification) ||
- FilteredForAnnotation(candidates[i])) {
- continue;
- }
- for (ClassificationResult& classification : candidates[i].classification) {
- aggregated_span.classification.push_back(std::move(classification));
- }
- }
- if (!aggregated_span.classification.empty()) {
- result.push_back(std::move(aggregated_span));
- }
-
- // We generate all candidates and remove them later (with the exception of
- // date/time/duration entities) because there are complex interdependencies
- // between the entity types. E.g., the TLD of an email can be interpreted as a
- // URL, but most likely a user of the API does not want such annotations if
- // "url" is enabled and "email" is not.
- RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
-
- for (AnnotatedSpan& annotated_span : result) {
- SortClassificationResults(&annotated_span.classification);
- }
-
- return result;
-}
-
-CodepointSpan Annotator::ComputeSelectionBoundaries(
- const UniLib::RegexMatcher* match,
- const RegexModel_::Pattern* config) const {
- if (config->capturing_group() == nullptr) {
- // Use first capturing group to specify the selection.
- int status = UniLib::RegexMatcher::kNoError;
- const CodepointSpan result = {match->Start(1, &status),
- match->End(1, &status)};
- if (status != UniLib::RegexMatcher::kNoError) {
- return {kInvalidIndex, kInvalidIndex};
- }
- return result;
- }
-
- CodepointSpan result = {kInvalidIndex, kInvalidIndex};
- const int num_groups = config->capturing_group()->size();
- for (int i = 0; i < num_groups; i++) {
- if (!config->capturing_group()->Get(i)->extend_selection()) {
- continue;
- }
-
- int status = UniLib::RegexMatcher::kNoError;
- // Check match and adjust bounds.
- const int group_start = match->Start(i, &status);
- const int group_end = match->End(i, &status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return {kInvalidIndex, kInvalidIndex};
- }
- if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
- continue;
- }
- if (result.first == kInvalidIndex) {
- result = {group_start, group_end};
- } else {
- result.first = std::min(result.first, group_start);
- result.second = std::max(result.second, group_end);
- }
- }
- return result;
-}
-
-bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
- if (pattern->serialized_entity_data() != nullptr) {
- return true;
- }
- if (pattern->capturing_group() != nullptr) {
- for (const RegexModel_::Pattern_::CapturingGroup* group :
- *pattern->capturing_group()) {
- if (group->entity_field_path() != nullptr) {
- return true;
- }
- }
- }
- return false;
-}
-
-bool Annotator::SerializedEntityDataFromRegexMatch(
- const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
- std::string* serialized_entity_data) const {
- if (!HasEntityData(pattern)) {
- serialized_entity_data->clear();
- return true;
- }
- TC3_CHECK(entity_data_builder_ != nullptr);
-
- std::unique_ptr<ReflectiveFlatbuffer> entity_data =
- entity_data_builder_->NewRoot();
-
- TC3_CHECK(entity_data != nullptr);
-
- // Set static entity data.
- if (pattern->serialized_entity_data() != nullptr) {
- TC3_CHECK(entity_data != nullptr);
- entity_data->MergeFromSerializedFlatbuffer(
- StringPiece(pattern->serialized_entity_data()->c_str(),
- pattern->serialized_entity_data()->size()));
- }
-
- // Add entity data from rule capturing groups.
- if (pattern->capturing_group() != nullptr) {
- const int num_groups = pattern->capturing_group()->size();
- for (int i = 0; i < num_groups; i++) {
- const FlatbufferFieldPath* field_path =
- pattern->capturing_group()->Get(i)->entity_field_path();
- if (field_path == nullptr) {
- continue;
- }
- TC3_CHECK(entity_data != nullptr);
- if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
- entity_data.get())) {
- TC3_LOG(ERROR)
- << "Could not set entity data from rule capturing group.";
- return false;
- }
- }
- }
-
- *serialized_entity_data = entity_data->Serialize();
- return true;
-}
-
-bool Annotator::RegexChunk(const UnicodeText& context_unicode,
- const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const {
- for (int pattern_id : rules) {
- const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
- const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
- if (!matcher) {
- TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
- << pattern_id;
- return false;
- }
-
- int status = UniLib::RegexMatcher::kNoError;
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (regex_pattern.config->verification_options()) {
- if (!VerifyRegexMatchCandidate(
- context_unicode.ToUTF8String(),
- regex_pattern.config->verification_options(),
- matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
- continue;
- }
- }
-
- std::string serialized_entity_data;
- if (is_serialized_entity_data_enabled) {
- if (!SerializedEntityDataFromRegexMatch(
- regex_pattern.config, matcher.get(), &serialized_entity_data)) {
- TC3_LOG(ERROR) << "Could not get entity data.";
- return false;
- }
- }
-
- result->emplace_back();
-
- // Selection/annotation regular expressions need to specify a capturing
- // group specifying the selection.
- result->back().span =
- ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
-
- result->back().classification = {
- {regex_pattern.config->collection_name()->str(),
- regex_pattern.config->target_classification_score(),
- regex_pattern.config->priority_score()}};
-
- result->back().classification[0].serialized_entity_data =
- serialized_entity_data;
- }
- }
- return true;
-}
-
-bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
- tflite::Interpreter* selection_interpreter,
- const CachedFeatures& cached_features,
- std::vector<TokenSpan>* chunks) const {
- const int max_selection_span =
- selection_feature_processor_->GetOptions()->max_selection_span();
- // The inference span is the span of interest expanded to include
- // max_selection_span tokens on either side, which is how far a selection can
- // stretch from the click.
- const TokenSpan inference_span = IntersectTokenSpans(
- ExpandTokenSpan(span_of_interest,
- /*num_tokens_left=*/max_selection_span,
- /*num_tokens_right=*/max_selection_span),
- {0, num_tokens});
-
- std::vector<ScoredChunk> scored_chunks;
- if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
- selection_feature_processor_->GetOptions()
- ->bounds_sensitive_features()
- ->enabled()) {
- if (!ModelBoundsSensitiveScoreChunks(
- num_tokens, span_of_interest, inference_span, cached_features,
- selection_interpreter, &scored_chunks)) {
- return false;
- }
- } else {
- if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
- cached_features, selection_interpreter,
- &scored_chunks)) {
- return false;
- }
- }
- std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
- [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
- return lhs.score < rhs.score;
- });
-
- // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
- // them greedily as long as they do not overlap with any previously picked
- // chunks.
- std::vector<bool> token_used(TokenSpanSize(inference_span));
- chunks->clear();
- for (const ScoredChunk& scored_chunk : scored_chunks) {
- bool feasible = true;
- for (int i = scored_chunk.token_span.first;
- i < scored_chunk.token_span.second; ++i) {
- if (token_used[i - inference_span.first]) {
- feasible = false;
- break;
- }
- }
-
- if (!feasible) {
- continue;
- }
-
- for (int i = scored_chunk.token_span.first;
- i < scored_chunk.token_span.second; ++i) {
- token_used[i - inference_span.first] = true;
- }
-
- chunks->push_back(scored_chunk.token_span);
- }
-
- std::sort(chunks->begin(), chunks->end());
-
- return true;
-}
-
-namespace {
-// Updates the value at the given key in the map to maximum of the current value
-// and the given value, or simply inserts the value if the key is not yet there.
-template <typename Map>
-void UpdateMax(Map* map, typename Map::key_type key,
- typename Map::mapped_type value) {
- const auto it = map->find(key);
- if (it != map->end()) {
- it->second = std::max(it->second, value);
- } else {
- (*map)[key] = value;
- }
-}
-} // namespace
-
-bool Annotator::ModelClickContextScoreChunks(
- int num_tokens, const TokenSpan& span_of_interest,
- const CachedFeatures& cached_features,
- tflite::Interpreter* selection_interpreter,
- std::vector<ScoredChunk>* scored_chunks) const {
- const int max_batch_size = model_->selection_options()->batch_size();
-
- std::vector<float> all_features;
- std::map<TokenSpan, float> chunk_scores;
- for (int batch_start = span_of_interest.first;
- batch_start < span_of_interest.second; batch_start += max_batch_size) {
- const int batch_end =
- std::min(batch_start + max_batch_size, span_of_interest.second);
-
- // Prepare features for the whole batch.
- all_features.clear();
- all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
- for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
- cached_features.AppendClickContextFeaturesForClick(click_pos,
- &all_features);
- }
-
- // Run batched inference.
- const int batch_size = batch_end - batch_start;
- const int features_size = cached_features.OutputFeaturesSize();
- TensorView<float> logits = selection_executor_->ComputeLogits(
- TensorView<float>(all_features.data(), {batch_size, features_size}),
- selection_interpreter);
- if (!logits.is_valid()) {
- TC3_LOG(ERROR) << "Couldn't compute logits.";
- return false;
- }
- if (logits.dims() != 2 || logits.dim(0) != batch_size ||
- logits.dim(1) !=
- selection_feature_processor_->GetSelectionLabelCount()) {
- TC3_LOG(ERROR) << "Mismatching output.";
- return false;
- }
-
- // Save results.
- for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
- const std::vector<float> scores = ComputeSoftmax(
- logits.data() + logits.dim(1) * (click_pos - batch_start),
- logits.dim(1));
- for (int j = 0;
- j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
- TokenSpan relative_token_span;
- if (!selection_feature_processor_->LabelToTokenSpan(
- j, &relative_token_span)) {
- TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
- return false;
- }
- const TokenSpan candidate_span = ExpandTokenSpan(
- SingleTokenSpan(click_pos), relative_token_span.first,
- relative_token_span.second);
- if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
- UpdateMax(&chunk_scores, candidate_span, scores[j]);
- }
- }
- }
- }
-
- scored_chunks->clear();
- scored_chunks->reserve(chunk_scores.size());
- for (const auto& entry : chunk_scores) {
- scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
- }
-
- return true;
-}
-
-bool Annotator::ModelBoundsSensitiveScoreChunks(
- int num_tokens, const TokenSpan& span_of_interest,
- const TokenSpan& inference_span, const CachedFeatures& cached_features,
- tflite::Interpreter* selection_interpreter,
- std::vector<ScoredChunk>* scored_chunks) const {
- const int max_selection_span =
- selection_feature_processor_->GetOptions()->max_selection_span();
- const int max_chunk_length = selection_feature_processor_->GetOptions()
- ->selection_reduced_output_space()
- ? max_selection_span + 1
- : 2 * max_selection_span + 1;
- const bool score_single_token_spans_as_zero =
- selection_feature_processor_->GetOptions()
- ->bounds_sensitive_features()
- ->score_single_token_spans_as_zero();
-
- scored_chunks->clear();
- if (score_single_token_spans_as_zero) {
- scored_chunks->reserve(TokenSpanSize(span_of_interest));
- }
-
- // Prepare all chunk candidates into one batch:
- // - Are contained in the inference span
- // - Have a non-empty intersection with the span of interest
- // - Are at least one token long
- // - Are not longer than the maximum chunk length
- std::vector<TokenSpan> candidate_spans;
- for (int start = inference_span.first; start < span_of_interest.second;
- ++start) {
- const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
- for (int end = leftmost_end_index;
- end <= inference_span.second && end - start <= max_chunk_length;
- ++end) {
- const TokenSpan candidate_span = {start, end};
- if (score_single_token_spans_as_zero &&
- TokenSpanSize(candidate_span) == 1) {
- // Do not include the single token span in the batch, add a zero score
- // for it directly to the output.
- scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
- } else {
- candidate_spans.push_back(candidate_span);
- }
- }
- }
-
- const int max_batch_size = model_->selection_options()->batch_size();
-
- std::vector<float> all_features;
- scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
- for (int batch_start = 0; batch_start < candidate_spans.size();
- batch_start += max_batch_size) {
- const int batch_end = std::min(batch_start + max_batch_size,
- static_cast<int>(candidate_spans.size()));
-
- // Prepare features for the whole batch.
- all_features.clear();
- all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
- for (int i = batch_start; i < batch_end; ++i) {
- cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
- &all_features);
- }
-
- // Run batched inference.
- const int batch_size = batch_end - batch_start;
- const int features_size = cached_features.OutputFeaturesSize();
- TensorView<float> logits = selection_executor_->ComputeLogits(
- TensorView<float>(all_features.data(), {batch_size, features_size}),
- selection_interpreter);
- if (!logits.is_valid()) {
- TC3_LOG(ERROR) << "Couldn't compute logits.";
- return false;
- }
- if (logits.dims() != 2 || logits.dim(0) != batch_size ||
- logits.dim(1) != 1) {
- TC3_LOG(ERROR) << "Mismatching output.";
- return false;
- }
-
- // Save results.
- for (int i = batch_start; i < batch_end; ++i) {
- scored_chunks->push_back(
- ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
- }
- }
-
- return true;
-}
-
-bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& locales, ModeFlag mode,
- AnnotationUsecase annotation_usecase,
- bool is_serialized_entity_data_enabled,
- std::vector<AnnotatedSpan>* result) const {
- if (!datetime_parser_) {
- return true;
- }
-
- std::vector<DatetimeParseResultSpan> datetime_spans;
- if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
- reference_timezone, locales, mode,
- annotation_usecase,
- /*anchor_start_end=*/false, &datetime_spans)) {
- return false;
- }
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
- AnnotatedSpan annotated_span;
- annotated_span.span = datetime_span.span;
- for (const DatetimeParseResult& parse_result : datetime_span.data) {
- annotated_span.classification.emplace_back(
- PickCollectionForDatetime(parse_result),
- datetime_span.target_classification_score,
- datetime_span.priority_score);
- annotated_span.classification.back().datetime_parse_result = parse_result;
- if (is_serialized_entity_data_enabled) {
- annotated_span.classification.back().serialized_entity_data =
- CreateDatetimeSerializedEntityData(parse_result);
- }
- }
- annotated_span.source = AnnotatedSpan::Source::DATETIME;
- result->push_back(std::move(annotated_span));
- }
- return true;
-}
-
-const Model* Annotator::model() const { return model_; }
-const reflection::Schema* Annotator::entity_data_schema() const {
- return entity_data_schema_;
-}
-
-const Model* ViewModel(const void* buffer, int size) {
- if (!buffer) {
- return nullptr;
- }
-
- return LoadAndVerifyModel(buffer, size);
-}
-
-bool Annotator::LookUpKnowledgeEntity(
- const std::string& id, std::string* serialized_knowledge_result) const {
- return knowledge_engine_ &&
- knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/annotator.h b/annotator/annotator.h
deleted file mode 100644
index 0b1c9f9..0000000
--- a/annotator/annotator.h
+++ /dev/null
@@ -1,564 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Inference code for the text classification model.
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
-
-#include <memory>
-#include <set>
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "annotator/contact/contact-engine.h"
-#include "annotator/datetime/parser.h"
-#include "annotator/duration/duration.h"
-#include "annotator/feature-processor.h"
-#include "annotator/installed_app/installed-app-engine.h"
-#include "annotator/knowledge/knowledge-engine.h"
-#include "annotator/model-executor.h"
-#include "annotator/model_generated.h"
-#include "annotator/number/number.h"
-#include "annotator/strip-unpaired-brackets.h"
-#include "annotator/types.h"
-#include "annotator/zlib-utils.h"
-#include "utils/flatbuffers.h"
-#include "utils/i18n/locale.h"
-#include "utils/memory/mmap.h"
-#include "utils/utf8/unilib.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-
-// Aliases for long enum values.
-const AnnotationUsecase ANNOTATION_USECASE_SMART =
- AnnotationUsecase_ANNOTATION_USECASE_SMART;
-const AnnotationUsecase ANNOTATION_USECASE_RAW =
- AnnotationUsecase_ANNOTATION_USECASE_RAW;
-
-struct SelectionOptions {
- // Comma-separated list of locale specification for the input text (BCP 47
- // tags).
- std::string locales;
-
- // Comma-separated list of BCP 47 language tags.
- std::string detected_text_language_tags;
-
- // Tailors the output annotations according to the specified use-case.
- AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
-
- bool operator==(const SelectionOptions& other) const {
- return this->locales == other.locales &&
- this->annotation_usecase == other.annotation_usecase &&
- this->detected_text_language_tags ==
- other.detected_text_language_tags;
- }
-};
-
-struct ClassificationOptions {
- // For parsing relative datetimes, the reference now time against which the
- // relative datetimes get resolved.
- // UTC milliseconds since epoch.
- int64 reference_time_ms_utc = 0;
-
- // Timezone in which the input text was written (format as accepted by ICU).
- std::string reference_timezone;
-
- // Comma-separated list of locale specification for the input text (BCP 47
- // tags).
- std::string locales;
-
- // Comma-separated list of language tags.
- std::string detected_text_language_tags;
-
- // Tailors the output annotations according to the specified use-case.
- AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
-
- bool operator==(const ClassificationOptions& other) const {
- return this->reference_time_ms_utc == other.reference_time_ms_utc &&
- this->reference_timezone == other.reference_timezone &&
- this->locales == other.locales &&
- this->detected_text_language_tags ==
- other.detected_text_language_tags &&
- this->annotation_usecase == other.annotation_usecase;
- }
-};
-
-struct AnnotationOptions {
- // For parsing relative datetimes, the reference now time against which the
- // relative datetimes get resolved.
- // UTC milliseconds since epoch.
- int64 reference_time_ms_utc = 0;
-
- // Timezone in which the input text was written (format as accepted by ICU).
- std::string reference_timezone;
-
- // Comma-separated list of locale specification for the input text (BCP 47
- // tags).
- std::string locales;
-
- // Comma-separated list of language tags.
- std::string detected_text_language_tags;
-
- // List of entity types that should be used for annotation.
- std::unordered_set<std::string> entity_types;
-
- // If true, serialized_entity_data in the results is populated."
- bool is_serialized_entity_data_enabled = false;
-
- // Tailors the output annotations according to the specified use-case.
- AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
-
- bool operator==(const AnnotationOptions& other) const {
- return this->reference_time_ms_utc == other.reference_time_ms_utc &&
- this->reference_timezone == other.reference_timezone &&
- this->locales == other.locales &&
- this->detected_text_language_tags ==
- other.detected_text_language_tags &&
- this->annotation_usecase == other.annotation_usecase &&
- this->is_serialized_entity_data_enabled ==
- other.is_serialized_entity_data_enabled;
- }
-};
-
-// Holds TFLite interpreters for selection and classification models.
-// NOTE: his class is not thread-safe, thus should NOT be re-used across
-// threads.
-class InterpreterManager {
- public:
- // The constructor can be called with nullptr for any of the executors, and is
- // a defined behavior, as long as the corresponding *Interpreter() method is
- // not called when the executor is null.
- InterpreterManager(const ModelExecutor* selection_executor,
- const ModelExecutor* classification_executor)
- : selection_executor_(selection_executor),
- classification_executor_(classification_executor) {}
-
- // Gets or creates and caches an interpreter for the selection model.
- tflite::Interpreter* SelectionInterpreter();
-
- // Gets or creates and caches an interpreter for the classification model.
- tflite::Interpreter* ClassificationInterpreter();
-
- private:
- const ModelExecutor* selection_executor_;
- const ModelExecutor* classification_executor_;
-
- std::unique_ptr<tflite::Interpreter> selection_interpreter_;
- std::unique_ptr<tflite::Interpreter> classification_interpreter_;
-};
-
-// Stores entity types enabled for annotation, and provides operator() for
-// checking whether a given entity type is enabled.
-class EnabledEntityTypes {
- public:
- explicit EnabledEntityTypes(
- const std::unordered_set<std::string>& entity_types)
- : entity_types_(entity_types) {}
-
- bool operator()(const std::string& entity_type) const {
- return entity_types_.empty() ||
- entity_types_.find(entity_type) != entity_types_.cend();
- }
-
- private:
- const std::unordered_set<std::string>& entity_types_;
-};
-
-// A text processing model that provides text classification, annotation,
-// selection suggestion for various types.
-// NOTE: This class is not thread-safe.
-class Annotator {
- public:
- static std::unique_ptr<Annotator> FromUnownedBuffer(
- const char* buffer, int size, const UniLib* unilib = nullptr,
- const CalendarLib* calendarlib = nullptr);
- // Takes ownership of the mmap.
- static std::unique_ptr<Annotator> FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
- const CalendarLib* calendarlib = nullptr);
- static std::unique_ptr<Annotator> FromScopedMmap(
- std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
- static std::unique_ptr<Annotator> FromFileDescriptor(
- int fd, int offset, int size, const UniLib* unilib = nullptr,
- const CalendarLib* calendarlib = nullptr);
- static std::unique_ptr<Annotator> FromFileDescriptor(
- int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
- static std::unique_ptr<Annotator> FromFileDescriptor(
- int fd, const UniLib* unilib = nullptr,
- const CalendarLib* calendarlib = nullptr);
- static std::unique_ptr<Annotator> FromFileDescriptor(
- int fd, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
- static std::unique_ptr<Annotator> FromPath(
- const std::string& path, const UniLib* unilib = nullptr,
- const CalendarLib* calendarlib = nullptr);
- static std::unique_ptr<Annotator> FromPath(
- const std::string& path, std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
-
- // Returns true if the model is ready for use.
- bool IsInitialized() { return initialized_; }
-
- // Initializes the knowledge engine with the given config.
- bool InitializeKnowledgeEngine(const std::string& serialized_config);
-
- // Initializes the contact engine with the given config.
- bool InitializeContactEngine(const std::string& serialized_config);
-
- // Initializes the installed app engine with the given config.
- bool InitializeInstalledAppEngine(const std::string& serialized_config);
-
- // Runs inference for given a context and current selection (i.e. index
- // of the first and one past last selected characters (utf8 codepoint
- // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
- // beginning character and one past selection end character.
- // Returns the original click_indices if an error occurs.
- // NOTE: The selection indices are passed in and returned in terms of
- // UTF8 codepoints (not bytes).
- // Requires that the model is a smart selection model.
- CodepointSpan SuggestSelection(
- const std::string& context, CodepointSpan click_indices,
- const SelectionOptions& options = SelectionOptions()) const;
-
- // Classifies the selected text given the context string.
- // Returns an empty result if an error occurs.
- std::vector<ClassificationResult> ClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options = ClassificationOptions()) const;
-
- // Annotates given input text. The annotations are sorted by their position
- // in the context string and exclude spans classified as 'other'.
- std::vector<AnnotatedSpan> Annotate(
- const std::string& context,
- const AnnotationOptions& options = AnnotationOptions()) const;
-
- // Looks up a knowledge entity by its id. If successful, populates the
- // serialized knowledge result and returns true.
- bool LookUpKnowledgeEntity(const std::string& id,
- std::string* serialized_knowledge_result) const;
-
- const Model* model() const;
- const reflection::Schema* entity_data_schema() const;
-
- // Exposes the feature processor for tests and evaluations.
- const FeatureProcessor* SelectionFeatureProcessorForTests() const;
- const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
-
- // Exposes the date time parser for tests and evaluations.
- const DatetimeParser* DatetimeParserForTests() const;
-
- static const std::string& kPhoneCollection;
- static const std::string& kAddressCollection;
- static const std::string& kDateCollection;
- static const std::string& kUrlCollection;
- static const std::string& kEmailCollection;
-
- protected:
- struct ScoredChunk {
- TokenSpan token_span;
- float score;
- };
-
- // Constructs and initializes text classifier from given model.
- // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- const UniLib* unilib, const CalendarLib* calendarlib);
- Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
- std::unique_ptr<UniLib> unilib,
- std::unique_ptr<CalendarLib> calendarlib);
-
- // Constructs, validates and initializes text classifier from given model.
- // Does not own the buffer that backs 'model'.
- Annotator(const Model* model, const UniLib* unilib,
- const CalendarLib* calendarlib);
-
- // Checks that model contains all required fields, and initializes internal
- // datastructures.
- void ValidateAndInitialize();
-
- // Initializes regular expressions for the regex model.
- bool InitializeRegexModel(ZlibDecompressor* decompressor);
-
- // Resolves conflicts in the list of candidates by removing some overlapping
- // ones. Returns indices of the surviving ones.
- // NOTE: Assumes that the candidates are sorted according to their position in
- // the span.
- bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
- const std::string& context,
- const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager,
- std::vector<int>* result) const;
-
- // Resolves one conflict between candidates on indices 'start_index'
- // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
- // indices to 'chosen_indices'. Returns false if a problem arises.
- bool ResolveConflict(const std::string& context,
- const std::vector<Token>& cached_tokens,
- const std::vector<AnnotatedSpan>& candidates,
- const std::vector<Locale>& detected_text_language_tags,
- int start_index, int end_index,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager,
- std::vector<int>* chosen_indices) const;
-
- // Gets selection candidates from the ML model.
- // Provides the tokens produced during tokenization of the context string for
- // reuse.
- bool ModelSuggestSelection(
- const UnicodeText& context_unicode, CodepointSpan click_indices,
- const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const;
-
- // Classifies the selected text given the context string with the
- // classification model.
- // Returns true if no error occurred.
- bool ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& locales, CodepointSpan selection_indices,
- InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results,
- std::vector<Token>* tokens) const;
-
- // Same as above but doesn't output tokens.
- bool ModelClassifyText(
- const std::string& context, const std::vector<Token>& cached_tokens,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const;
-
- // Same as above but doesn't take cached tokens and doesn't output tokens.
- bool ModelClassifyText(
- const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
- FeatureProcessor::EmbeddingCache* embedding_cache,
- std::vector<ClassificationResult>* classification_results) const;
-
- // Returns a relative token span that represents how many tokens on the left
- // from the selection and right from the selection are needed for the
- // classifier input.
- TokenSpan ClassifyTextUpperBoundNeededTokens() const;
-
- // Classifies the selected text with the regular expressions models.
- // Returns true if no error happened, false otherwise.
- bool RegexClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- std::vector<ClassificationResult>* classification_result) const;
-
- // Classifies the selected text with the date time model.
- // Returns true if no error happened, false otherwise.
- bool DatetimeClassifyText(
- const std::string& context, CodepointSpan selection_indices,
- const ClassificationOptions& options,
- std::vector<ClassificationResult>* classification_results) const;
-
- // Chunks given input text with the selection model and classifies the spans
- // with the classification model.
- // The annotations are sorted by their position in the context string and
- // exclude spans classified as 'other'.
- // Provides the tokens produced during tokenization of the context string for
- // reuse.
- bool ModelAnnotate(const std::string& context,
- const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager,
- std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const;
-
- // Groups the tokens into chunks. A chunk is a token span that should be the
- // suggested selection when any of its contained tokens is clicked. The chunks
- // are non-overlapping and are sorted by their position in the context string.
- // "num_tokens" is the total number of tokens available (as this method does
- // not need the actual vector of tokens).
- // "span_of_interest" is a span of all the tokens that could be clicked.
- // The resulting chunks all have to overlap with it and they cover this span
- // completely. The first and last chunk might extend beyond it.
- // The chunks vector is cleared before filling.
- bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
- tflite::Interpreter* selection_interpreter,
- const CachedFeatures& cached_features,
- std::vector<TokenSpan>* chunks) const;
-
- // A helper method for ModelChunk(). It generates scored chunk candidates for
- // a click context model.
- // NOTE: The returned chunks can (and most likely do) overlap.
- bool ModelClickContextScoreChunks(
- int num_tokens, const TokenSpan& span_of_interest,
- const CachedFeatures& cached_features,
- tflite::Interpreter* selection_interpreter,
- std::vector<ScoredChunk>* scored_chunks) const;
-
- // A helper method for ModelChunk(). It generates scored chunk candidates for
- // a bounds-sensitive model.
- // NOTE: The returned chunks can (and most likely do) overlap.
- bool ModelBoundsSensitiveScoreChunks(
- int num_tokens, const TokenSpan& span_of_interest,
- const TokenSpan& inference_span, const CachedFeatures& cached_features,
- tflite::Interpreter* selection_interpreter,
- std::vector<ScoredChunk>* scored_chunks) const;
-
- // Produces chunks isolated by a set of regular expressions.
- bool RegexChunk(const UnicodeText& context_unicode,
- const std::vector<int>& rules,
- std::vector<AnnotatedSpan>* result,
- bool is_serialized_entity_data_enabled) const;
-
- // Produces chunks from the datetime parser.
- bool DatetimeChunk(const UnicodeText& context_unicode,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& locales, ModeFlag mode,
- AnnotationUsecase annotation_usecase,
- bool is_serialized_entity_data_enabled,
- std::vector<AnnotatedSpan>* result) const;
-
- // Returns whether a classification should be filtered.
- bool FilteredForAnnotation(const AnnotatedSpan& span) const;
- bool FilteredForClassification(
- const ClassificationResult& classification) const;
- bool FilteredForSelection(const AnnotatedSpan& span) const;
-
- // Computes the selection boundaries from a regular expression match.
- CodepointSpan ComputeSelectionBoundaries(
- const UniLib::RegexMatcher* match,
- const RegexModel_::Pattern* config) const;
-
- // Returns whether a regex pattern provides entity data from a match.
- bool HasEntityData(const RegexModel_::Pattern* pattern) const;
-
- // Constructs and serializes entity data from regex matches.
- bool SerializedEntityDataFromRegexMatch(
- const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
- std::string* serialized_entity_data) const;
-
- // Verifies a regex match and returns true if verification was successful.
- bool VerifyRegexMatchCandidate(
- const std::string& context,
- const VerificationOptions* verification_options, const std::string& match,
- const UniLib::RegexMatcher* matcher) const;
-
- const Model* model_;
-
- std::unique_ptr<const ModelExecutor> selection_executor_;
- std::unique_ptr<const ModelExecutor> classification_executor_;
- std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
-
- std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
- std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
-
- std::unique_ptr<const DatetimeParser> datetime_parser_;
-
- private:
- struct CompiledRegexPattern {
- const RegexModel_::Pattern* config;
- std::unique_ptr<UniLib::RegexPattern> pattern;
- };
-
- // Removes annotations the entity type of which is not in the set of enabled
- // entity types.
- void RemoveNotEnabledEntityTypes(
- const EnabledEntityTypes& is_entity_type_enabled,
- std::vector<AnnotatedSpan>* annotated_spans) const;
-
- std::unique_ptr<ScopedMmap> mmap_;
- bool initialized_ = false;
- bool enabled_for_annotation_ = false;
- bool enabled_for_classification_ = false;
- bool enabled_for_selection_ = false;
- std::unordered_set<std::string> filtered_collections_annotation_;
- std::unordered_set<std::string> filtered_collections_classification_;
- std::unordered_set<std::string> filtered_collections_selection_;
-
- std::vector<CompiledRegexPattern> regex_patterns_;
-
- // Indices into regex_patterns_ for the different modes.
- std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
- selection_regex_patterns_;
-
- std::unique_ptr<UniLib> owned_unilib_;
- const UniLib* unilib_;
- std::unique_ptr<CalendarLib> owned_calendarlib_;
- const CalendarLib* calendarlib_;
-
- std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
- std::unique_ptr<const ContactEngine> contact_engine_;
- std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
- std::unique_ptr<const NumberAnnotator> number_annotator_;
- std::unique_ptr<const DurationAnnotator> duration_annotator_;
-
- // Builder for creating extra data.
- const reflection::Schema* entity_data_schema_;
- std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
-
- // Locales for which the entire model triggers.
- std::vector<Locale> model_triggering_locales_;
-
- // Locales for which the ML model triggers.
- std::vector<Locale> ml_model_triggering_locales_;
-
- // Locales that the dictionary classification support.
- std::vector<Locale> dictionary_locales_;
-};
-
-namespace internal {
-
-// Helper function, which if the initial 'span' contains only white-spaces,
-// moves the selection to a single-codepoint selection on the left side
-// of this block of white-space.
-CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
- const UnicodeText& context_unicode,
- const UniLib& unilib);
-
-// Copies tokens from 'cached_tokens' that are
-// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
-// from the tokens that correspond to 'selection_indices'.
-std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
- CodepointSpan selection_indices,
- TokenSpan tokens_around_selection_to_copy);
-} // namespace internal
-
-// Interprets the buffer as a Model flatbuffer and returns it for reading.
-const Model* ViewModel(const void* buffer, int size);
-
-// Opens model from given path and runs a function, passing the loaded Model
-// flatbuffer as an argument.
-//
-// This is mainly useful if we don't want to pay the cost for the model
-// initialization because we'll be only reading some flatbuffer values from the
-// file.
-template <typename ReturnType, typename Func>
-ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
- ScopedMmap mmap(path);
- if (!mmap.handle().ok()) {
- function(/*model=*/nullptr);
- }
- const Model* model =
- ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
- return function(model);
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
deleted file mode 100644
index 9118f30..0000000
--- a/annotator/annotator_jni.cc
+++ /dev/null
@@ -1,694 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// JNI wrapper for the Annotator.
-
-#include "annotator/annotator_jni.h"
-
-#include <jni.h>
-#include <type_traits>
-#include <vector>
-
-#include "annotator/annotator.h"
-#include "annotator/annotator_jni_common.h"
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar.h"
-#include "utils/intents/intent-generator.h"
-#include "utils/intents/jni.h"
-#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/java/string_utils.h"
-#include "utils/memory/mmap.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-#ifdef TC3_UNILIB_JAVAICU
-#ifndef TC3_CALENDAR_JAVAICU
-#error Inconsistent usage of Java ICU components
-#else
-#define TC3_USE_JAVAICU
-#endif
-#endif
-
-using libtextclassifier3::AnnotatedSpan;
-using libtextclassifier3::Annotator;
-using libtextclassifier3::ClassificationResult;
-using libtextclassifier3::CodepointSpan;
-using libtextclassifier3::Model;
-using libtextclassifier3::ScopedLocalRef;
-// When using the Java's ICU, CalendarLib and UniLib need to be instantiated
-// with a JavaVM pointer from JNI. When using a standard ICU the pointer is
-// not needed and the objects are instantiated implicitly.
-#ifdef TC3_USE_JAVAICU
-using libtextclassifier3::CalendarLib;
-using libtextclassifier3::UniLib;
-#endif
-
-namespace libtextclassifier3 {
-
-using libtextclassifier3::CodepointSpan;
-
-namespace {
-class AnnotatorJniContext {
- public:
- static AnnotatorJniContext* Create(
- const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
- std::unique_ptr<Annotator> model) {
- if (jni_cache == nullptr || model == nullptr) {
- return nullptr;
- }
- std::unique_ptr<IntentGenerator> intent_generator =
- IntentGenerator::Create(model->model()->intent_options(),
- model->model()->resources(), jni_cache);
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
- libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
- if (template_handler == nullptr) {
- return nullptr;
- }
- return new AnnotatorJniContext(jni_cache, std::move(model),
- std::move(intent_generator),
- std::move(template_handler));
- }
-
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
- return jni_cache_;
- }
-
- Annotator* model() const { return model_.get(); }
-
- IntentGenerator* intent_generator() const { return intent_generator_.get(); }
-
- RemoteActionTemplatesHandler* template_handler() const {
- return template_handler_.get();
- }
-
- private:
- AnnotatorJniContext(
- const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
- std::unique_ptr<Annotator> model,
- std::unique_ptr<IntentGenerator> intent_generator,
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
- : jni_cache_(jni_cache),
- model_(std::move(model)),
- intent_generator_(std::move(intent_generator)),
- template_handler_(std::move(template_handler)) {}
-
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
- std::unique_ptr<Annotator> model_;
- std::unique_ptr<IntentGenerator> intent_generator_;
- std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
-};
-
-jobject ClassificationResultWithIntentsToJObject(
- JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
- jclass result_class, jmethodID result_class_constructor,
- jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
- const jstring device_locales, const ClassificationOptions* options,
- const std::string& context, const CodepointSpan& selection_indices,
- const ClassificationResult& classification_result, bool generate_intents) {
- jstring row_string =
- env->NewStringUTF(classification_result.collection.c_str());
-
- jobject row_datetime_parse = nullptr;
- if (classification_result.datetime_parse_result.IsSet()) {
- row_datetime_parse =
- env->NewObject(datetime_parse_class, datetime_parse_class_constructor,
- classification_result.datetime_parse_result.time_ms_utc,
- classification_result.datetime_parse_result.granularity);
- }
-
- jbyteArray serialized_knowledge_result = nullptr;
- const std::string& serialized_knowledge_result_string =
- classification_result.serialized_knowledge_result;
- if (!serialized_knowledge_result_string.empty()) {
- serialized_knowledge_result =
- env->NewByteArray(serialized_knowledge_result_string.size());
- env->SetByteArrayRegion(serialized_knowledge_result, 0,
- serialized_knowledge_result_string.size(),
- reinterpret_cast<const jbyte*>(
- serialized_knowledge_result_string.data()));
- }
-
- jstring contact_name = nullptr;
- if (!classification_result.contact_name.empty()) {
- contact_name =
- env->NewStringUTF(classification_result.contact_name.c_str());
- }
-
- jstring contact_given_name = nullptr;
- if (!classification_result.contact_given_name.empty()) {
- contact_given_name =
- env->NewStringUTF(classification_result.contact_given_name.c_str());
- }
-
- jstring contact_nickname = nullptr;
- if (!classification_result.contact_nickname.empty()) {
- contact_nickname =
- env->NewStringUTF(classification_result.contact_nickname.c_str());
- }
-
- jstring contact_email_address = nullptr;
- if (!classification_result.contact_email_address.empty()) {
- contact_email_address =
- env->NewStringUTF(classification_result.contact_email_address.c_str());
- }
-
- jstring contact_phone_number = nullptr;
- if (!classification_result.contact_phone_number.empty()) {
- contact_phone_number =
- env->NewStringUTF(classification_result.contact_phone_number.c_str());
- }
-
- jstring contact_id = nullptr;
- if (!classification_result.contact_id.empty()) {
- contact_id = env->NewStringUTF(classification_result.contact_id.c_str());
- }
-
- jstring app_name = nullptr;
- if (!classification_result.app_name.empty()) {
- app_name = env->NewStringUTF(classification_result.app_name.c_str());
- }
-
- jstring app_package_name = nullptr;
- if (!classification_result.app_package_name.empty()) {
- app_package_name =
- env->NewStringUTF(classification_result.app_package_name.c_str());
- }
-
- jobject extras = nullptr;
- if (model_context->model()->entity_data_schema() != nullptr &&
- !classification_result.serialized_entity_data.empty()) {
- extras = model_context->template_handler()->EntityDataAsNamedVariantArray(
- model_context->model()->entity_data_schema(),
- classification_result.serialized_entity_data);
- }
-
- jbyteArray serialized_entity_data = nullptr;
- if (!classification_result.serialized_entity_data.empty()) {
- serialized_entity_data =
- env->NewByteArray(classification_result.serialized_entity_data.size());
- env->SetByteArrayRegion(
- serialized_entity_data, 0,
- classification_result.serialized_entity_data.size(),
- reinterpret_cast<const jbyte*>(
- classification_result.serialized_entity_data.data()));
- }
-
- jobject remote_action_templates_result = nullptr;
- // Only generate RemoteActionTemplate for the top classification result
- // as classifyText does not need RemoteAction from other results anyway.
- if (generate_intents && model_context->intent_generator() != nullptr) {
- std::vector<RemoteActionTemplate> remote_action_templates;
- if (model_context->intent_generator()->GenerateIntents(
- device_locales, classification_result,
- options->reference_time_ms_utc, context, selection_indices,
- app_context, model_context->model()->entity_data_schema(),
- &remote_action_templates)) {
- remote_action_templates_result =
- model_context->template_handler()
- ->RemoteActionTemplatesToJObjectArray(remote_action_templates);
- }
- }
-
- return env->NewObject(
- result_class, result_class_constructor, row_string,
- static_cast<jfloat>(classification_result.score), row_datetime_parse,
- serialized_knowledge_result, contact_name, contact_given_name,
- contact_nickname, contact_email_address, contact_phone_number, contact_id,
- app_name, app_package_name, extras, serialized_entity_data,
- remote_action_templates_result, classification_result.duration_ms,
- classification_result.numeric_value);
-}
-
-jobjectArray ClassificationResultsWithIntentsToJObjectArray(
- JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
- const jstring device_locales, const ClassificationOptions* options,
- const std::string& context, const CodepointSpan& selection_indices,
- const std::vector<ClassificationResult>& classification_result,
- bool generate_intents) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$ClassificationResult"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find ClassificationResult class.";
- return nullptr;
- }
- const ScopedLocalRef<jclass> datetime_parse_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult"),
- env);
- if (!datetime_parse_class) {
- TC3_LOG(ERROR) << "Couldn't find DatetimeResult class.";
- return nullptr;
- }
-
- const jmethodID result_class_constructor = env->GetMethodID(
- result_class.get(), "<init>",
- "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR
- ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
- ";JJ)V");
- const jmethodID datetime_parse_class_constructor =
- env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
-
- const jobjectArray results = env->NewObjectArray(classification_result.size(),
- result_class.get(), nullptr);
- for (int i = 0; i < classification_result.size(); i++) {
- jobject result = ClassificationResultWithIntentsToJObject(
- env, model_context, app_context, result_class.get(),
- result_class_constructor, datetime_parse_class.get(),
- datetime_parse_class_constructor, device_locales, options, context,
- selection_indices, classification_result[i],
- generate_intents && (i == 0));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
- }
- return results;
-}
-
-jobjectArray ClassificationResultsToJObjectArray(
- JNIEnv* env, const AnnotatorJniContext* model_context,
- const std::vector<ClassificationResult>& classification_result) {
- return ClassificationResultsWithIntentsToJObjectArray(
- env, model_context,
- /*(unused) app_context=*/nullptr,
- /*(unused) devide_locale=*/nullptr,
- /*(unusued) options=*/nullptr,
- /*(unused) selection_text=*/"",
- /*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
- classification_result,
- /*generate_intents=*/false);
-}
-
-CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
- CodepointSpan orig_indices,
- bool from_utf8) {
- const libtextclassifier3::UnicodeText unicode_str =
- libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
-
- int unicode_index = 0;
- int bmp_index = 0;
-
- const int* source_index;
- const int* target_index;
- if (from_utf8) {
- source_index = &unicode_index;
- target_index = &bmp_index;
- } else {
- source_index = &bmp_index;
- target_index = &unicode_index;
- }
-
- CodepointSpan result{-1, -1};
- std::function<void()> assign_indices_fn = [&result, &orig_indices,
- &source_index, &target_index]() {
- if (orig_indices.first == *source_index) {
- result.first = *target_index;
- }
-
- if (orig_indices.second == *source_index) {
- result.second = *target_index;
- }
- };
-
- for (auto it = unicode_str.begin(); it != unicode_str.end();
- ++it, ++unicode_index, ++bmp_index) {
- assign_indices_fn();
-
- // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
- if (*it > 0xFFFF) {
- ++bmp_index;
- }
- }
- assign_indices_fn();
-
- return result;
-}
-
-} // namespace
-
-CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
- CodepointSpan bmp_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
-}
-
-CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
- CodepointSpan utf8_indices) {
- return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
-}
-
-jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
- }
- const Model* model = libtextclassifier3::ViewModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model || !model->locales()) {
- return env->NewStringUTF("");
- }
- return env->NewStringUTF(model->locales()->c_str());
-}
-
-jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return 0;
- }
- const Model* model = libtextclassifier3::ViewModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model) {
- return 0;
- }
- return model->version();
-}
-
-jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
- if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
- }
- const Model* model = libtextclassifier3::ViewModel(
- mmap->handle().start(), mmap->handle().num_bytes());
- if (!model || !model->name()) {
- return env->NewStringUTF("");
- }
- return env->NewStringUTF(model->name()->c_str());
-}
-
-} // namespace libtextclassifier3
-
-using libtextclassifier3::AnnotatorJniContext;
-using libtextclassifier3::ClassificationResultsToJObjectArray;
-using libtextclassifier3::ClassificationResultsWithIntentsToJObjectArray;
-using libtextclassifier3::ConvertIndicesBMPToUTF8;
-using libtextclassifier3::ConvertIndicesUTF8ToBMP;
-using libtextclassifier3::FromJavaAnnotationOptions;
-using libtextclassifier3::FromJavaClassificationOptions;
-using libtextclassifier3::FromJavaSelectionOptions;
-using libtextclassifier3::ToStlString;
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd) {
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
- libtextclassifier3::JniCache::Create(env));
-#ifdef TC3_USE_JAVAICU
- return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
- jni_cache,
- Annotator::FromFileDescriptor(
- fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
-#else
- return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
- jni_cache, Annotator::FromFileDescriptor(fd)));
-#endif
-}
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- const std::string path_str = ToStlString(env, path);
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
- libtextclassifier3::JniCache::Create(env));
-#ifdef TC3_USE_JAVAICU
- return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
- jni_cache,
- Annotator::FromPath(
- path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
-#else
- return reinterpret_cast<jlong>(
- AnnotatorJniContext::Create(jni_cache, Annotator::FromPath(path_str)));
-#endif
-}
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
- nativeNewAnnotatorFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
- libtextclassifier3::JniCache::Create(env));
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
-#ifdef TC3_USE_JAVAICU
- return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
- jni_cache,
- Annotator::FromFileDescriptor(
- fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
- std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
-#else
- return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
- jni_cache, Annotator::FromFileDescriptor(fd, offset, size)));
-#endif
-}
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeKnowledgeEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
- if (!ptr) {
- return false;
- }
-
- Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
-
- std::string serialized_config_string;
- const int length = env->GetArrayLength(serialized_config);
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
-
- return model->InitializeKnowledgeEngine(serialized_config_string);
-}
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeContactEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
- if (!ptr) {
- return false;
- }
-
- Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
-
- std::string serialized_config_string;
- const int length = env->GetArrayLength(serialized_config);
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
-
- return model->InitializeContactEngine(serialized_config_string);
-}
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeInstalledAppEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
- if (!ptr) {
- return false;
- }
-
- Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
-
- std::string serialized_config_string;
- const int length = env->GetArrayLength(serialized_config);
- serialized_config_string.resize(length);
- env->GetByteArrayRegion(serialized_config, 0, length,
- reinterpret_cast<jbyte*>(const_cast<char*>(
- serialized_config_string.data())));
-
- return model->InitializeInstalledAppEngine(serialized_config_string);
-}
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
-(JNIEnv* env, jobject thiz, jlong ptr) {
- if (!ptr) {
- return 0L;
- }
- return reinterpret_cast<jlong>(
- reinterpret_cast<AnnotatorJniContext*>(ptr)->model());
-}
-
-TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options) {
- if (!ptr) {
- return nullptr;
- }
- const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string context_utf8 = ToStlString(env, context);
- CodepointSpan input_indices =
- ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- CodepointSpan selection = model->SuggestSelection(
- context_utf8, input_indices, FromJavaSelectionOptions(env, options));
- selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
-
- jintArray result = env->NewIntArray(2);
- env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
- env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
- return result;
-}
-
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options, jobject app_context,
- jstring device_locales) {
- if (!ptr) {
- return nullptr;
- }
- const AnnotatorJniContext* model_context =
- reinterpret_cast<AnnotatorJniContext*>(ptr);
-
- const std::string context_utf8 = ToStlString(env, context);
- const CodepointSpan input_indices =
- ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- const libtextclassifier3::ClassificationOptions classification_options =
- FromJavaClassificationOptions(env, options);
- const std::vector<ClassificationResult> classification_result =
- model_context->model()->ClassifyText(context_utf8, input_indices,
- classification_options);
- if (app_context != nullptr) {
- return ClassificationResultsWithIntentsToJObjectArray(
- env, model_context, app_context, device_locales,
- &classification_options, context_utf8, input_indices,
- classification_result,
- /*generate_intents=*/true);
- }
- return ClassificationResultsToJObjectArray(env, model_context,
- classification_result);
-}
-
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
- if (!ptr) {
- return nullptr;
- }
- const AnnotatorJniContext* model_context =
- reinterpret_cast<AnnotatorJniContext*>(ptr);
- const std::string context_utf8 = ToStlString(env, context);
- const std::vector<AnnotatedSpan> annotations =
- model_context->model()->Annotate(context_utf8,
- FromJavaAnnotationOptions(env, options));
-
- jclass result_class = env->FindClass(
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find result class: "
- << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotatedSpan";
- return nullptr;
- }
-
- jmethodID result_class_constructor =
- env->GetMethodID(result_class, "<init>",
- "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$ClassificationResult;)V");
-
- jobjectArray results =
- env->NewObjectArray(annotations.size(), result_class, nullptr);
-
- for (int i = 0; i < annotations.size(); ++i) {
- CodepointSpan span_bmp =
- ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
- jobject result = env->NewObject(
- result_class, result_class_constructor,
- static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
- ClassificationResultsToJObjectArray(env, model_context,
- annotations[i].classification));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
- }
- env->DeleteLocalRef(result_class);
- return results;
-}
-
-TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeLookUpKnowledgeEntity)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring id) {
- if (!ptr) {
- return nullptr;
- }
- const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string id_utf8 = ToStlString(env, id);
- std::string serialized_knowledge_result;
- if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
- return nullptr;
- }
- jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
- env->SetByteArrayRegion(
- result, 0, serialized_knowledge_result.size(),
- reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
- return result;
-}
-
-TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
-(JNIEnv* env, jobject thiz, jlong ptr) {
- const AnnotatorJniContext* context =
- reinterpret_cast<AnnotatorJniContext*>(ptr);
- delete context;
-}
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
-(JNIEnv* env, jobject clazz, jint fd) {
- TC3_LOG(WARNING) << "Using deprecated getLanguage().";
- return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)(
- env, clazz, fd);
-}
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return GetLocalesFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetLocalesFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetLocalesFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return GetVersionFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetVersionFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetVersionFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
-(JNIEnv* env, jobject clazz, jint fd) {
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd));
- return GetNameFromMmap(env, mmap.get());
-}
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetNameFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
- const jint fd = libtextclassifier3::GetFdFromAssetFileDescriptor(env, afd);
- const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
- new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetNameFromMmap(env, mmap.get());
-}
diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h
deleted file mode 100644
index bca1dcd..0000000
--- a/annotator/annotator_jni.h
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
-
-#include <jni.h>
-#include <string>
-#include "annotator/annotator_jni_common.h"
-#include "annotator/types.h"
-#include "utils/java/jni-base.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-// SmartSelection.
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
-(JNIEnv* env, jobject thiz, jint fd);
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
-(JNIEnv* env, jobject thiz, jstring path);
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME,
- nativeNewAnnotatorFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeKnowledgeEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeContactEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
-
-TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
- nativeInitializeInstalledAppEngine)
-(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
-
-TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
-(JNIEnv* env, jobject thiz, jlong ptr);
-
-TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options);
-
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options, jobject app_context,
- jstring device_locales);
-
-TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
-
-TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
- nativeLookUpKnowledgeEntity)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring id);
-
-TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
-(JNIEnv* env, jobject thiz, jlong ptr);
-
-// DEPRECATED. Use nativeGetLocales instead.
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetLocalesFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetVersionFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME,
- nativeGetNameFromAssetFileDescriptor)
-(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size);
-
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-
-// Given a utf8 string and a span expressed in Java BMP (basic multilingual
-// plane) codepoints, converts it to a span expressed in utf8 codepoints.
-libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8(
- const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices);
-
-// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
-// span expressed in Java BMP (basic multilingual plane) codepoints.
-libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP(
- const std::string& utf8_str,
- libtextclassifier3::CodepointSpan utf8_indices);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
diff --git a/annotator/annotator_jni_common.cc b/annotator/annotator_jni_common.cc
deleted file mode 100644
index 55f14e6..0000000
--- a/annotator/annotator_jni_common.cc
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/annotator_jni_common.h"
-
-#include "utils/java/jni-base.h"
-#include "utils/java/scoped_local_ref.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env,
- const jobject& jobject) {
- std::unordered_set<std::string> entity_types;
- jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
- const int size = env->GetArrayLength(jentity_types);
- for (int i = 0; i < size; ++i) {
- jstring jentity_type =
- reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i));
- entity_types.insert(ToStlString(env, jentity_type));
- }
- return entity_types;
-}
-
-template <typename T>
-T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
- const std::string& class_name) {
- if (!joptions) {
- return {};
- }
-
- const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
- env);
- if (!options_class) {
- return {};
- }
-
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocale", "Ljava/lang/String;");
- const std::pair<bool, jobject> status_or_reference_timezone =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getReferenceTimezone",
- "Ljava/lang/String;");
- const std::pair<bool, int64> status_or_reference_time_ms_utc =
- CallJniMethod0<int64>(env, joptions, options_class.get(),
- &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
- "J");
- const std::pair<bool, jobject> status_or_detected_text_language_tags =
- CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getDetectedTextLanguageTags", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
-
- if (!status_or_locales.first || !status_or_reference_timezone.first ||
- !status_or_reference_time_ms_utc.first ||
- !status_or_detected_text_language_tags.first ||
- !status_or_annotation_usecase.first) {
- return {};
- }
-
- T options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
- options.reference_timezone = ToStlString(
- env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
- options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
- options.detected_text_language_tags = ToStlString(
- env,
- reinterpret_cast<jstring>(status_or_detected_text_language_tags.second));
- options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
- return options;
-}
-} // namespace
-
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
- if (!joptions) {
- return {};
- }
-
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$SelectionOptions"),
- env);
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocales", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
- if (!status_or_locales.first || !status_or_annotation_usecase.first) {
- return {};
- }
-
- SelectionOptions options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
- options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
-
- return options;
-}
-
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
- jobject joptions) {
- return FromJavaOptionsInternal<ClassificationOptions>(
- env, joptions,
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions");
-}
-
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
- if (!joptions) return {};
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotationOptions"),
- env);
- if (!options_class) return {};
- const std::pair<bool, jobject> status_or_entity_types =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getEntityTypes",
- "[Ljava/lang/String;");
- if (!status_or_entity_types.first) return {};
- const std::pair<bool, bool> status_or_enable_serialized_entity_data =
- CallJniMethod0<bool>(env, joptions, options_class.get(),
- &JNIEnv::CallBooleanMethod,
- "isSerializedEntityDataEnabled", "Z");
- if (!status_or_enable_serialized_entity_data.first) return {};
- AnnotationOptions annotation_options =
- FromJavaOptionsInternal<AnnotationOptions>(
- env, joptions,
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
- annotation_options.entity_types =
- EntityTypesFromJObject(env, status_or_entity_types.second);
- annotation_options.is_serialized_entity_data_enabled =
- status_or_enable_serialized_entity_data.second;
- return annotation_options;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/annotator_jni_common.h b/annotator/annotator_jni_common.h
deleted file mode 100644
index b62bb21..0000000
--- a/annotator/annotator_jni_common.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
-
-#include <jni.h>
-
-#include "annotator/annotator.h"
-
-#ifndef TC3_ANNOTATOR_CLASS_NAME
-#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel
-#endif
-
-#define TC3_ANNOTATOR_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ANNOTATOR_CLASS_NAME)
-
-namespace libtextclassifier3 {
-
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions);
-
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
- jobject joptions);
-
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
diff --git a/annotator/collections.h b/annotator/collections.h
deleted file mode 100644
index a23623e..0000000
--- a/annotator/collections.h
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
-
-#include <string>
-
-namespace libtextclassifier3 {
-
-// String collection names for various classes.
-class Collections {
- public:
- static const std::string& Address() {
- static const std::string& value =
- *[]() { return new std::string("address"); }();
- return value;
- }
- static const std::string& App() {
- static const std::string& value =
- *[]() { return new std::string("app"); }();
- return value;
- }
- static const std::string& Contact() {
- static const std::string& value =
- *[]() { return new std::string("contact"); }();
- return value;
- }
- static const std::string& Date() {
- static const std::string& value =
- *[]() { return new std::string("date"); }();
- return value;
- }
- static const std::string& DateTime() {
- static const std::string& value =
- *[]() { return new std::string("datetime"); }();
- return value;
- }
- static const std::string& Dictionary() {
- static const std::string& value =
- *[]() { return new std::string("dictionary"); }();
- return value;
- }
- static const std::string& Duration() {
- static const std::string& value =
- *[]() { return new std::string("duration"); }();
- return value;
- }
- static const std::string& Email() {
- static const std::string& value =
- *[]() { return new std::string("email"); }();
- return value;
- }
- static const std::string& Entity() {
- static const std::string& value =
- *[]() { return new std::string("entity"); }();
- return value;
- }
- static const std::string& Flight() {
- static const std::string& value =
- *[]() { return new std::string("flight"); }();
- return value;
- }
- static const std::string& Iban() {
- static const std::string& value =
- *[]() { return new std::string("iban"); }();
- return value;
- }
- static const std::string& Isbn() {
- static const std::string& value =
- *[]() { return new std::string("isbn"); }();
- return value;
- }
- static const std::string& Money() {
- static const std::string& value =
- *[]() { return new std::string("money"); }();
- return value;
- }
- static const std::string& Number() {
- static const std::string& value =
- *[]() { return new std::string("number"); }();
- return value;
- }
- static const std::string& Other() {
- static const std::string& value =
- *[]() { return new std::string("other"); }();
- return value;
- }
- static const std::string& PaymentCard() {
- static const std::string& value =
- *[]() { return new std::string("payment_card"); }();
- return value;
- }
- static const std::string& Phone() {
- static const std::string& value =
- *[]() { return new std::string("phone"); }();
- return value;
- }
- static const std::string& TrackingNumber() {
- static const std::string& value =
- *[]() { return new std::string("tracking_number"); }();
- return value;
- }
- static const std::string& Url() {
- static const std::string& value =
- *[]() { return new std::string("url"); }();
- return value;
- }
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
diff --git a/annotator/contact/contact-engine-dummy.h b/annotator/contact/contact-engine-dummy.h
deleted file mode 100644
index c7a389d..0000000
--- a/annotator/contact/contact-engine-dummy.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
-
-#include <string>
-#include <vector>
-
-#include "annotator/feature-processor.h"
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-// A dummy implementation of the contact engine.
-class ContactEngine {
- public:
- explicit ContactEngine(const FeatureProcessor* feature_processor,
- const UniLib* unilib) {}
-
- bool Initialize(const std::string& serialized_config) {
- TC3_LOG(ERROR) << "No contact engine to initialize.";
- return false;
- }
-
- bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
- ClassificationResult* classification_result) const {
- return false;
- }
-
- bool Chunk(const UnicodeText& context_unicode,
- const std::vector<Token>& tokens,
- std::vector<AnnotatedSpan>* result) const {
- return true;
- }
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc
deleted file mode 100644
index b9d0c30..0000000
--- a/annotator/datetime/extractor.cc
+++ /dev/null
@@ -1,444 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/datetime/extractor.h"
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool DatetimeExtractor::Extract(DateParseData* result,
- CodepointSpan* result_span) const {
- result->field_set_mask = 0;
- *result_span = {kInvalidIndex, kInvalidIndex};
-
- if (rule_.regex->groups() == nullptr) {
- return false;
- }
-
- for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
- UnicodeText group_text;
- const int group_type = rule_.regex->groups()->Get(group_id);
- if (group_type == DatetimeGroupType_GROUP_UNUSED) {
- continue;
- }
- if (!GroupTextFromMatch(group_id, &group_text)) {
- TC3_LOG(ERROR) << "Couldn't retrieve group.";
- return false;
- }
- // The pattern can have a group defined in a part that was not matched,
- // e.g. an optional part. In this case we'll get an empty content here.
- if (group_text.empty()) {
- continue;
- }
- switch (group_type) {
- case DatetimeGroupType_GROUP_YEAR: {
- if (!ParseYear(group_text, &(result->year))) {
- TC3_LOG(ERROR) << "Couldn't extract YEAR.";
- return false;
- }
- result->field_set_mask |= DateParseData::YEAR_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_MONTH: {
- if (!ParseMonth(group_text, &(result->month))) {
- TC3_LOG(ERROR) << "Couldn't extract MONTH.";
- return false;
- }
- result->field_set_mask |= DateParseData::MONTH_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_DAY: {
- if (!ParseDigits(group_text, &(result->day_of_month))) {
- TC3_LOG(ERROR) << "Couldn't extract DAY.";
- return false;
- }
- result->field_set_mask |= DateParseData::DAY_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_HOUR: {
- if (!ParseDigits(group_text, &(result->hour))) {
- TC3_LOG(ERROR) << "Couldn't extract HOUR.";
- return false;
- }
- result->field_set_mask |= DateParseData::HOUR_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_MINUTE: {
- if (!ParseDigits(group_text, &(result->minute))) {
- TC3_LOG(ERROR) << "Couldn't extract MINUTE.";
- return false;
- }
- result->field_set_mask |= DateParseData::MINUTE_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_SECOND: {
- if (!ParseDigits(group_text, &(result->second))) {
- TC3_LOG(ERROR) << "Couldn't extract SECOND.";
- return false;
- }
- result->field_set_mask |= DateParseData::SECOND_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_AMPM: {
- if (!ParseAMPM(group_text, &(result->ampm))) {
- TC3_LOG(ERROR) << "Couldn't extract AMPM.";
- return false;
- }
- result->field_set_mask |= DateParseData::AMPM_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
- if (!ParseRelationDistance(group_text, &(result->relation_distance))) {
- TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
- return false;
- }
- result->field_set_mask |= DateParseData::RELATION_DISTANCE_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_RELATION: {
- if (!ParseRelation(group_text, &(result->relation))) {
- TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
- return false;
- }
- result->field_set_mask |= DateParseData::RELATION_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_RELATIONTYPE: {
- if (!ParseRelationType(group_text, &(result->relation_type))) {
- TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
- return false;
- }
- result->field_set_mask |= DateParseData::RELATION_TYPE_FIELD;
- break;
- }
- case DatetimeGroupType_GROUP_DUMMY1:
- case DatetimeGroupType_GROUP_DUMMY2:
- break;
- default:
- TC3_LOG(INFO) << "Unknown group type.";
- continue;
- }
- if (!UpdateMatchSpan(group_id, result_span)) {
- TC3_LOG(ERROR) << "Couldn't update span.";
- return false;
- }
- }
-
- if (result_span->first == kInvalidIndex ||
- result_span->second == kInvalidIndex) {
- *result_span = {kInvalidIndex, kInvalidIndex};
- }
-
- return true;
-}
-
-bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type,
- int* rule_id) const {
- auto type_it = type_and_locale_to_rule_.find(type);
- if (type_it == type_and_locale_to_rule_.end()) {
- return false;
- }
-
- auto locale_it = type_it->second.find(locale_id_);
- if (locale_it == type_it->second.end()) {
- return false;
- }
- *rule_id = locale_it->second;
- return true;
-}
-
-bool DatetimeExtractor::ExtractType(const UnicodeText& input,
- DatetimeExtractorType extractor_type,
- UnicodeText* match_result) const {
- int rule_id;
- if (!RuleIdForType(extractor_type, &rule_id)) {
- return false;
- }
-
- std::unique_ptr<UniLib::RegexMatcher> matcher =
- rules_[rule_id]->Matcher(input);
- if (!matcher) {
- return false;
- }
-
- int status;
- if (!matcher->Find(&status)) {
- return false;
- }
-
- if (match_result != nullptr) {
- *match_result = matcher->Group(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- }
- return true;
-}
-
-bool DatetimeExtractor::GroupTextFromMatch(int group_id,
- UnicodeText* result) const {
- int status;
- *result = matcher_.Group(group_id, &status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- return true;
-}
-
-bool DatetimeExtractor::UpdateMatchSpan(int group_id,
- CodepointSpan* span) const {
- int status;
- const int match_start = matcher_.Start(group_id, &status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- const int match_end = matcher_.End(group_id, &status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- if (span->first == kInvalidIndex || span->first > match_start) {
- span->first = match_start;
- }
- if (span->second == kInvalidIndex || span->second < match_end) {
- span->second = match_end;
- }
-
- return true;
-}
-
-template <typename T>
-bool DatetimeExtractor::MapInput(
- const UnicodeText& input,
- const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
- T* result) const {
- for (const auto& type_value_pair : mapping) {
- if (ExtractType(input, type_value_pair.first)) {
- *result = type_value_pair.second;
- return true;
- }
- }
- return false;
-}
-
-bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
- int* parsed_number) const {
- std::vector<std::pair<int, int>> found_numbers;
- for (const auto& type_value_pair :
- std::vector<std::pair<DatetimeExtractorType, int>>{
- {DatetimeExtractorType_ZERO, 0},
- {DatetimeExtractorType_ONE, 1},
- {DatetimeExtractorType_TWO, 2},
- {DatetimeExtractorType_THREE, 3},
- {DatetimeExtractorType_FOUR, 4},
- {DatetimeExtractorType_FIVE, 5},
- {DatetimeExtractorType_SIX, 6},
- {DatetimeExtractorType_SEVEN, 7},
- {DatetimeExtractorType_EIGHT, 8},
- {DatetimeExtractorType_NINE, 9},
- {DatetimeExtractorType_TEN, 10},
- {DatetimeExtractorType_ELEVEN, 11},
- {DatetimeExtractorType_TWELVE, 12},
- {DatetimeExtractorType_THIRTEEN, 13},
- {DatetimeExtractorType_FOURTEEN, 14},
- {DatetimeExtractorType_FIFTEEN, 15},
- {DatetimeExtractorType_SIXTEEN, 16},
- {DatetimeExtractorType_SEVENTEEN, 17},
- {DatetimeExtractorType_EIGHTEEN, 18},
- {DatetimeExtractorType_NINETEEN, 19},
- {DatetimeExtractorType_TWENTY, 20},
- {DatetimeExtractorType_THIRTY, 30},
- {DatetimeExtractorType_FORTY, 40},
- {DatetimeExtractorType_FIFTY, 50},
- {DatetimeExtractorType_SIXTY, 60},
- {DatetimeExtractorType_SEVENTY, 70},
- {DatetimeExtractorType_EIGHTY, 80},
- {DatetimeExtractorType_NINETY, 90},
- {DatetimeExtractorType_HUNDRED, 100},
- {DatetimeExtractorType_THOUSAND, 1000},
- }) {
- int rule_id;
- if (!RuleIdForType(type_value_pair.first, &rule_id)) {
- return false;
- }
-
- std::unique_ptr<UniLib::RegexMatcher> matcher =
- rules_[rule_id]->Matcher(input);
- if (!matcher) {
- return false;
- }
-
- int status;
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- int span_start = matcher->Start(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
- found_numbers.push_back({span_start, type_value_pair.second});
- }
- }
-
- std::sort(found_numbers.begin(), found_numbers.end(),
- [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
- return a.first < b.first;
- });
-
- int sum = 0;
- int running_value = -1;
- // Simple math to make sure we handle written numerical modifiers correctly
- // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
- for (const std::pair<int, int> position_number_pair : found_numbers) {
- if (running_value >= 0) {
- if (running_value > position_number_pair.second) {
- sum += running_value;
- running_value = position_number_pair.second;
- } else {
- running_value *= position_number_pair.second;
- }
- } else {
- running_value = position_number_pair.second;
- }
- }
- sum += running_value;
- *parsed_number = sum;
- return true;
-}
-
-bool DatetimeExtractor::ParseDigits(const UnicodeText& input,
- int* parsed_digits) const {
- UnicodeText digit;
- if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) {
- return false;
- }
-
- if (!unilib_.ParseInt32(digit, parsed_digits)) {
- return false;
- }
- return true;
-}
-
-bool DatetimeExtractor::ParseYear(const UnicodeText& input,
- int* parsed_year) const {
- if (!ParseDigits(input, parsed_year)) {
- return false;
- }
-
- if (*parsed_year < 100) {
- if (*parsed_year < 50) {
- *parsed_year += 2000;
- } else {
- *parsed_year += 1900;
- }
- }
-
- return true;
-}
-
-bool DatetimeExtractor::ParseMonth(const UnicodeText& input,
- int* parsed_month) const {
- if (ParseDigits(input, parsed_month)) {
- return true;
- }
-
- if (MapInput(input,
- {
- {DatetimeExtractorType_JANUARY, 1},
- {DatetimeExtractorType_FEBRUARY, 2},
- {DatetimeExtractorType_MARCH, 3},
- {DatetimeExtractorType_APRIL, 4},
- {DatetimeExtractorType_MAY, 5},
- {DatetimeExtractorType_JUNE, 6},
- {DatetimeExtractorType_JULY, 7},
- {DatetimeExtractorType_AUGUST, 8},
- {DatetimeExtractorType_SEPTEMBER, 9},
- {DatetimeExtractorType_OCTOBER, 10},
- {DatetimeExtractorType_NOVEMBER, 11},
- {DatetimeExtractorType_DECEMBER, 12},
- },
- parsed_month)) {
- return true;
- }
-
- return false;
-}
-
-bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
- DateParseData::AMPM* parsed_ampm) const {
- return MapInput(input,
- {
- {DatetimeExtractorType_AM, DateParseData::AMPM::AM},
- {DatetimeExtractorType_PM, DateParseData::AMPM::PM},
- },
- parsed_ampm);
-}
-
-bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input,
- int* parsed_distance) const {
- if (ParseDigits(input, parsed_distance)) {
- return true;
- }
- if (ParseWrittenNumber(input, parsed_distance)) {
- return true;
- }
- return false;
-}
-
-bool DatetimeExtractor::ParseRelation(
- const UnicodeText& input, DateParseData::Relation* parsed_relation) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_NOW, DateParseData::Relation::NOW},
- {DatetimeExtractorType_YESTERDAY, DateParseData::Relation::YESTERDAY},
- {DatetimeExtractorType_TOMORROW, DateParseData::Relation::TOMORROW},
- {DatetimeExtractorType_NEXT, DateParseData::Relation::NEXT},
- {DatetimeExtractorType_NEXT_OR_SAME,
- DateParseData::Relation::NEXT_OR_SAME},
- {DatetimeExtractorType_LAST, DateParseData::Relation::LAST},
- {DatetimeExtractorType_PAST, DateParseData::Relation::PAST},
- {DatetimeExtractorType_FUTURE, DateParseData::Relation::FUTURE},
- },
- parsed_relation);
-}
-
-bool DatetimeExtractor::ParseRelationType(
- const UnicodeText& input,
- DateParseData::RelationType* parsed_relation_type) const {
- return MapInput(
- input,
- {
- {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY,
- DateParseData::RelationType::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY,
- DateParseData::RelationType::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
- {DatetimeExtractorType_SATURDAY,
- DateParseData::RelationType::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
- {DatetimeExtractorType_SECONDS, DateParseData::RelationType::SECOND},
- {DatetimeExtractorType_MINUTES, DateParseData::RelationType::MINUTE},
- {DatetimeExtractorType_HOURS, DateParseData::RelationType::HOUR},
- {DatetimeExtractorType_DAY, DateParseData::RelationType::DAY},
- {DatetimeExtractorType_WEEK, DateParseData::RelationType::WEEK},
- {DatetimeExtractorType_MONTH, DateParseData::RelationType::MONTH},
- {DatetimeExtractorType_YEAR, DateParseData::RelationType::YEAR},
- },
- parsed_relation_type);
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/datetime/extractor.h b/annotator/datetime/extractor.h
deleted file mode 100644
index 95e7f7c..0000000
--- a/annotator/datetime/extractor.h
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
-
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-struct CompiledRule {
- // The compiled regular expression.
- std::unique_ptr<const UniLib::RegexPattern> compiled_regex;
-
- // The uncompiled pattern and information about the pattern groups.
- const DatetimeModelPattern_::Regex* regex;
-
- // DatetimeModelPattern which 'regex' is part of and comes from.
- const DatetimeModelPattern* pattern;
-};
-
-// A helper class for DatetimeParser that extracts structured data
-// (DateParseDate) from the current match of the passed RegexMatcher.
-class DatetimeExtractor {
- public:
- DatetimeExtractor(
- const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
- int locale_id, const UniLib& unilib,
- const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
- extractor_rules,
- const std::unordered_map<DatetimeExtractorType,
- std::unordered_map<int, int>>&
- type_and_locale_to_extractor_rule)
- : rule_(rule),
- matcher_(matcher),
- locale_id_(locale_id),
- unilib_(unilib),
- rules_(extractor_rules),
- type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
- bool Extract(DateParseData* result, CodepointSpan* result_span) const;
-
- private:
- bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
-
- // Returns true if the rule for given extractor matched. If it matched,
- // match_result will contain the first group of the rule (if match_result not
- // nullptr).
- bool ExtractType(const UnicodeText& input,
- DatetimeExtractorType extractor_type,
- UnicodeText* match_result = nullptr) const;
-
- bool GroupTextFromMatch(int group_id, UnicodeText* result) const;
-
- // Updates the span to include the current match for the given group.
- bool UpdateMatchSpan(int group_id, CodepointSpan* span) const;
-
- // Returns true if any of the extractors from 'mapping' matched. If it did,
- // will fill 'result' with the associated value from 'mapping'.
- template <typename T>
- bool MapInput(const UnicodeText& input,
- const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
- T* result) const;
-
- bool ParseDigits(const UnicodeText& input, int* parsed_digits) const;
- bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
- bool ParseYear(const UnicodeText& input, int* parsed_year) const;
- bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
- bool ParseAMPM(const UnicodeText& input,
- DateParseData::AMPM* parsed_ampm) const;
- bool ParseRelation(const UnicodeText& input,
- DateParseData::Relation* parsed_relation) const;
- bool ParseRelationDistance(const UnicodeText& input,
- int* parsed_distance) const;
- bool ParseTimeUnit(const UnicodeText& input,
- DateParseData::TimeUnit* parsed_time_unit) const;
- bool ParseRelationType(
- const UnicodeText& input,
- DateParseData::RelationType* parsed_relation_type) const;
- bool ParseWeekday(const UnicodeText& input,
- DateParseData::RelationType* parsed_weekday) const;
-
- const CompiledRule& rule_;
- const UniLib::RegexMatcher& matcher_;
- int locale_id_;
- const UniLib& unilib_;
- const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_;
- const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>&
- type_and_locale_to_rule_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc
deleted file mode 100644
index 6d844f4..0000000
--- a/annotator/datetime/parser.cc
+++ /dev/null
@@ -1,425 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/datetime/parser.h"
-
-#include <set>
-#include <unordered_set>
-
-#include "annotator/datetime/extractor.h"
-#include "utils/calendar/calendar.h"
-#include "utils/i18n/locale.h"
-#include "utils/strings/split.h"
-#include "utils/zlib/zlib_regex.h"
-
-namespace libtextclassifier3 {
-std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
- const DatetimeModel* model, const UniLib& unilib,
- const CalendarLib& calendarlib, ZlibDecompressor* decompressor) {
- std::unique_ptr<DatetimeParser> result(
- new DatetimeParser(model, unilib, calendarlib, decompressor));
- if (!result->initialized_) {
- result.reset();
- }
- return result;
-}
-
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
- const CalendarLib& calendarlib,
- ZlibDecompressor* decompressor)
- : unilib_(unilib), calendarlib_(calendarlib) {
- initialized_ = false;
-
- if (model == nullptr) {
- return;
- }
-
- if (model->patterns() != nullptr) {
- for (const DatetimeModelPattern* pattern : *model->patterns()) {
- if (pattern->regexes()) {
- for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
- std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- UncompressMakeRegexPattern(
- unilib, regex->pattern(), regex->compressed_pattern(),
- model->lazy_regex_compilation(), decompressor);
- if (!regex_pattern) {
- TC3_LOG(ERROR) << "Couldn't create rule pattern.";
- return;
- }
- rules_.push_back({std::move(regex_pattern), regex, pattern});
- if (pattern->locales()) {
- for (int locale : *pattern->locales()) {
- locale_to_rules_[locale].push_back(rules_.size() - 1);
- }
- }
- }
- }
- }
- }
-
- if (model->extractors() != nullptr) {
- for (const DatetimeModelExtractor* extractor : *model->extractors()) {
- std::unique_ptr<UniLib::RegexPattern> regex_pattern =
- UncompressMakeRegexPattern(
- unilib, extractor->pattern(), extractor->compressed_pattern(),
- model->lazy_regex_compilation(), decompressor);
- if (!regex_pattern) {
- TC3_LOG(ERROR) << "Couldn't create extractor pattern";
- return;
- }
- extractor_rules_.push_back(std::move(regex_pattern));
-
- if (extractor->locales()) {
- for (int locale : *extractor->locales()) {
- type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
- extractor_rules_.size() - 1;
- }
- }
- }
- }
-
- if (model->locales() != nullptr) {
- for (int i = 0; i < model->locales()->Length(); ++i) {
- locale_string_to_id_[model->locales()->Get(i)->str()] = i;
- }
- }
-
- if (model->default_locales() != nullptr) {
- for (const int locale : *model->default_locales()) {
- default_locale_ids_.push_back(locale);
- }
- }
-
- use_extractors_for_locating_ = model->use_extractors_for_locating();
- generate_alternative_interpretations_when_ambiguous_ =
- model->generate_alternative_interpretations_when_ambiguous();
-
- initialized_ = true;
-}
-
-bool DatetimeParser::Parse(
- const std::string& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
- reference_time_ms_utc, reference_timezone, locales, mode,
- annotation_usecase, anchor_start_end, results);
-}
-
-bool DatetimeParser::FindSpansUsingLocales(
- const std::vector<int>& locale_ids, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const {
- for (const int locale_id : locale_ids) {
- auto rules_it = locale_to_rules_.find(locale_id);
- if (rules_it == locale_to_rules_.end()) {
- continue;
- }
-
- for (const int rule_id : rules_it->second) {
- // Skip rules that were already executed in previous locales.
- if (executed_rules->find(rule_id) != executed_rules->end()) {
- continue;
- }
-
- if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
- (1 << annotation_usecase)) == 0) {
- continue;
- }
-
- if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
- continue;
- }
-
- executed_rules->insert(rule_id);
-
- if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- anchor_start_end, found_spans)) {
- return false;
- }
- }
- }
- return true;
-}
-
-bool DatetimeParser::Parse(
- const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> found_spans;
- std::unordered_set<int> executed_rules;
- std::string reference_locale;
- const std::vector<int> requested_locales =
- ParseAndExpandLocales(locales, &reference_locale);
- if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, annotation_usecase,
- anchor_start_end, reference_locale,
- &executed_rules, &found_spans)) {
- return false;
- }
-
- std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
- indexed_found_spans.reserve(found_spans.size());
- for (int i = 0; i < found_spans.size(); i++) {
- indexed_found_spans.push_back({found_spans[i], i});
- }
-
- // Resolve conflicts by always picking the longer span and breaking ties by
- // selecting the earlier entry in the list for a given locale.
- std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
- [](const std::pair<DatetimeParseResultSpan, int>& a,
- const std::pair<DatetimeParseResultSpan, int>& b) {
- if ((a.first.span.second - a.first.span.first) !=
- (b.first.span.second - b.first.span.first)) {
- return (a.first.span.second - a.first.span.first) >
- (b.first.span.second - b.first.span.first);
- } else {
- return a.second < b.second;
- }
- });
-
- found_spans.clear();
- for (auto& span_index_pair : indexed_found_spans) {
- found_spans.push_back(span_index_pair.first);
- }
-
- std::set<int, std::function<bool(int, int)>> chosen_indices_set(
- [&found_spans](int a, int b) {
- return found_spans[a].span.first < found_spans[b].span.first;
- });
- for (int i = 0; i < found_spans.size(); ++i) {
- if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
- chosen_indices_set.insert(i);
- results->push_back(found_spans[i]);
- }
- }
-
- return true;
-}
-
-bool DatetimeParser::HandleParseMatch(
- const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
- int status = UniLib::RegexMatcher::kNoError;
- const int start = matcher.Start(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
-
- const int end = matcher.End(&status);
- if (status != UniLib::RegexMatcher::kNoError) {
- return false;
- }
-
- DatetimeParseResultSpan parse_result;
- std::vector<DatetimeParseResult> alternatives;
- if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
- reference_locale, locale_id, &alternatives,
- &parse_result.span)) {
- return false;
- }
-
- if (!use_extractors_for_locating_) {
- parse_result.span = {start, end};
- }
-
- if (parse_result.span.first != kInvalidIndex &&
- parse_result.span.second != kInvalidIndex) {
- parse_result.target_classification_score =
- rule.pattern->target_classification_score();
- parse_result.priority_score = rule.pattern->priority_score();
-
- for (DatetimeParseResult& alternative : alternatives) {
- parse_result.data.push_back(alternative);
- }
- }
- result->push_back(parse_result);
- return true;
-}
-
-bool DatetimeParser::ParseWithRule(
- const CompiledRule& rule, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
- std::unique_ptr<UniLib::RegexMatcher> matcher =
- rule.compiled_regex->Matcher(input);
- int status = UniLib::RegexMatcher::kNoError;
- if (anchor_start_end) {
- if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
- }
- } else {
- while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
- }
- }
- return true;
-}
-
-std::vector<int> DatetimeParser::ParseAndExpandLocales(
- const std::string& locales, std::string* reference_locale) const {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- *reference_locale = split_locales[0].ToString();
- } else {
- *reference_locale = "";
- }
-
- std::vector<int> result;
- for (const StringPiece& locale_str : split_locales) {
- auto locale_it = locale_string_to_id_.find(locale_str.ToString());
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
-
- const Locale locale = Locale::FromBCP47(locale_str.ToString());
- if (!locale.IsValid()) {
- continue;
- }
-
- const std::string language = locale.Language();
- const std::string script = locale.Script();
- const std::string region = locale.Region();
-
- // First, try adding *-region locale.
- if (!region.empty()) {
- locale_it = locale_string_to_id_.find("*-" + region);
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
- }
- // Second, try adding language-script-* locale.
- if (!script.empty()) {
- locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
- }
- // Third, try adding language-* locale.
- if (!language.empty()) {
- locale_it = locale_string_to_id_.find(language + "-*");
- if (locale_it != locale_string_to_id_.end()) {
- result.push_back(locale_it->second);
- }
- }
- }
-
- // Add the default locales if they haven't been added already.
- const std::unordered_set<int> result_set(result.begin(), result.end());
- for (const int default_locale_id : default_locale_ids_) {
- if (result_set.find(default_locale_id) == result_set.end()) {
- result.push_back(default_locale_id);
- }
- }
-
- return result;
-}
-
-void DatetimeParser::FillInterpretations(
- const DateParseData& parse,
- std::vector<DateParseData>* interpretations) const {
- DatetimeGranularity granularity = calendarlib_.GetGranularity(parse);
-
- DateParseData modified_parse(parse);
- // If the relation field is not set, but relation_type field *is*, assume
- // the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
- // "monday 3pm" (otherwise only "this monday 3pm" would work).
- if (!(modified_parse.field_set_mask &
- DateParseData::Fields::RELATION_FIELD) &&
- (modified_parse.field_set_mask &
- DateParseData::Fields::RELATION_TYPE_FIELD)) {
- modified_parse.relation = DateParseData::Relation::NEXT_OR_SAME;
- modified_parse.field_set_mask |= DateParseData::Fields::RELATION_FIELD;
- }
-
- // Multiple interpretations of ambiguous datetime expressions are generated
- // here.
- if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
- (modified_parse.field_set_mask & DateParseData::Fields::HOUR_FIELD) &&
- modified_parse.hour <= 12 &&
- !(modified_parse.field_set_mask & DateParseData::Fields::AMPM_FIELD)) {
- // If it's not clear if the time is AM or PM, generate all variants.
- interpretations->push_back(modified_parse);
- interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
- interpretations->back().ampm = DateParseData::AMPM::AM;
-
- interpretations->push_back(modified_parse);
- interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
- interpretations->back().ampm = DateParseData::AMPM::PM;
- } else {
- // Otherwise just generate 1 variant.
- interpretations->push_back(modified_parse);
- }
- // TODO(zilka): Add support for generating alternatives for "monday" -> "this
- // monday", "next monday", "last monday". The previous implementation did not
- // work as expected, because didn't work correctly for this/previous day of
- // week, and resulted sometimes results in the same date being proposed.
-}
-
-bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- const int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const {
- DateParseData parse;
- DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
- extractor_rules_,
- type_and_locale_to_extractor_rule_);
- if (!extractor.Extract(&parse, result_span)) {
- return false;
- }
-
- std::vector<DateParseData> interpretations;
- if (generate_alternative_interpretations_when_ambiguous_) {
- FillInterpretations(parse, &interpretations);
- } else {
- interpretations.push_back(parse);
- }
-
- results->reserve(results->size() + interpretations.size());
- for (const DateParseData& interpretation : interpretations) {
- DatetimeParseResult result;
- if (!calendarlib_.InterpretParseData(
- interpretation, reference_time_ms_utc, reference_timezone,
- reference_locale, &(result.time_ms_utc), &(result.granularity))) {
- return false;
- }
- results->push_back(result);
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h
deleted file mode 100644
index 3f0c143..0000000
--- a/annotator/datetime/parser.h
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
-
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-#include "annotator/datetime/extractor.h"
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar.h"
-#include "utils/utf8/unilib.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-
-// Parses datetime expressions in the input and resolves them to actual absolute
-// time.
-class DatetimeParser {
- public:
- static std::unique_ptr<DatetimeParser> Instance(
- const DatetimeModel* model, const UniLib& unilib,
- const CalendarLib& calendarlib, ZlibDecompressor* decompressor);
-
- // Parses the dates in 'input' and fills result. Makes sure that the results
- // do not overlap.
- // If 'anchor_start_end' is true the extracted results need to start at the
- // beginning of 'input' and end at the end of it.
- bool Parse(const std::string& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- // Same as above but takes UnicodeText.
- bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
-
-#ifdef TC3_TEST_ONLY
- void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) {
- generate_alternative_interpretations_when_ambiguous_ = value;
- }
-#endif // TC3_TEST_ONLY
-
- protected:
- DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
- const CalendarLib& calendarlib,
- ZlibDecompressor* decompressor);
-
- // Returns a list of locale ids for given locale spec string (comma-separated
- // locale names). Assigns the first parsed locale to reference_locale.
- std::vector<int> ParseAndExpandLocales(const std::string& locales,
- std::string* reference_locale) const;
-
- // Helper function that finds datetime spans, only using the rules associated
- // with the given locales.
- bool FindSpansUsingLocales(
- const std::vector<int>& locale_ids, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end, const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const;
-
- bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- void FillInterpretations(const DateParseData& parse,
- std::vector<DateParseData>* interpretations) const;
-
- // Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const;
-
- // Parse and extract information from current match in 'matcher'.
- bool HandleParseMatch(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- private:
- bool initialized_;
- const UniLib& unilib_;
- const CalendarLib& calendarlib_;
- std::vector<CompiledRule> rules_;
- std::unordered_map<int, std::vector<int>> locale_to_rules_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
- std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
- type_and_locale_to_extractor_rule_;
- std::unordered_map<std::string, int> locale_string_to_id_;
- std::vector<int> default_locale_ids_;
- bool use_extractors_for_locating_;
- bool generate_alternative_interpretations_when_ambiguous_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
deleted file mode 100644
index 8196fa7..0000000
--- a/annotator/datetime/parser_test.cc
+++ /dev/null
@@ -1,538 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <time.h>
-#include <fstream>
-#include <iostream>
-#include <memory>
-#include <string>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "annotator/annotator.h"
-#include "annotator/datetime/parser.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "utils/testing/annotator.h"
-
-using testing::ElementsAreArray;
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetModelPath() {
- return TC3_TEST_DATA_DIR;
-}
-
-std::string ReadFile(const std::string& file_name) {
- std::ifstream file_stream(file_name);
- return std::string(std::istreambuf_iterator<char>(file_stream), {});
-}
-
-class ParserTest : public testing::Test {
- public:
- void SetUp() override {
- // Loads default unmodified model. Individual tests can call LoadModel to
- // make changes.
- LoadModel([](ModelT* model) {});
- }
-
- template <typename Fn>
- void LoadModel(Fn model_visitor_fn) {
- std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
- model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
- classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(),
- model_buffer_.size(), &unilib_);
- TC3_CHECK(classifier_);
- parser_ = classifier_->DatetimeParserForTests();
- TC3_CHECK(parser_);
- }
-
- bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- return results.empty();
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const std::vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- const UnicodeText marked_text_unicode =
- UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
- auto brace_open_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
- auto brace_end_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
- TC3_CHECK(brace_open_it != marked_text_unicode.end());
- TC3_CHECK(brace_end_it != marked_text_unicode.end());
-
- std::string text;
- text +=
- UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
- text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
- text += UnicodeText::UTF8Substring(std::next(brace_end_it),
- marked_text_unicode.end());
-
- std::vector<DatetimeParseResultSpan> results;
-
- if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- if (results.empty()) {
- TC3_LOG(ERROR) << "No results.";
- return false;
- }
-
- const int expected_start_index =
- std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 bellow is to account for the opening bracket character.
- const int expected_end_index =
- std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
-
- std::vector<DatetimeParseResultSpan> filtered_results;
- for (const DatetimeParseResultSpan& result : results) {
- if (SpansOverlap(result.span,
- {expected_start_index, expected_end_index})) {
- filtered_results.push_back(result);
- }
- }
-
- std::vector<DatetimeParseResultSpan> expected{
- {{expected_start_index, expected_end_index},
- {},
- /*target_classification_score=*/1.0,
- /*priority_score=*/0.1}};
- expected[0].data.resize(expected_ms_utcs.size());
- for (int i = 0; i < expected_ms_utcs.size(); i++) {
- expected[0].data[i] = {expected_ms_utcs[i], expected_granularity};
- }
-
- const bool matches =
- testing::Matches(ElementsAreArray(expected))(filtered_results);
- if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0];
- if (filtered_results.empty()) {
- TC3_LOG(ERROR) << "But got no results.";
- }
- TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
- }
-
- return matches;
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- return ParsesCorrectly(marked_text, std::vector<int64>{expected_ms_utc},
- expected_granularity, anchor_start_end, timezone,
- locales, annotation_usecase);
- }
-
- bool ParsesCorrectlyGerman(const std::string& marked_text,
- const std::vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity) {
- return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyGerman(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- protected:
- std::string model_buffer_;
- std::unique_ptr<Annotator> classifier_;
- const DatetimeParser* parser_;
- UniLib unilib_;
-};
-
-// Test with just a few cases to make debugging of general failures easier.
-TEST_F(ParserTest, ParseShort) {
- EXPECT_TRUE(
- ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
-}
-
-TEST_F(ParserTest, Parse) {
- EXPECT_TRUE(
- ParsesCorrectly("{January 1, 1988}", 567990000000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectly("{january 31 2018}", 1517353200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
- GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectly("{09/Mar/2004 22:02:40}", 1078866160000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Dec 2, 2010 2:39:58 AM}", 1291253998000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{Mar 16 08:12:04}", {6419524000, 6462724000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}",
- {1277512289000, 1277555489000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}",
- {1137899465000, 1137942665000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr 11:42:35}", {9715355000, 9758555000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly(
- "Are sentiments apartments decisively the especially alteration. "
- "Thrown shy denote ten ladies though ask saw. Or by to he going "
- "think order event music. Incommode so intention defective at "
- "convinced. Led income months itself and houses you. After nor "
- "you leave might share court balls. {19/apr/2010 06:36:15} Are "
- "sentiments apartments decisively the especially alteration. "
- "Thrown shy denote ten ladies though ask saw. Or by to he going "
- "think order event music. Incommode so intention defective at "
- "convinced. Led income months itself and houses you. After nor "
- "you leave might share court balls. ",
- {1271651775000, 1271694975000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
- GRANULARITY_HOUR));
-
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", {-3600000, 39600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly(
- "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
- /*anchor_start_end=*/false, "America/Los_Angeles"));
- EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4:00}", {97200000, 140400000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR));
- EXPECT_TRUE(
- ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("last seen {today at 9:01 PM}", 72060000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("set an alarm for {7am tomorrow}", 108000000,
- GRANULARITY_HOUR));
- EXPECT_TRUE(
- ParsesCorrectly("set an alarm for {7 a.m}", 21600000, GRANULARITY_HOUR));
-}
-
-TEST_F(ParserTest, ParseWithAnchor) {
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
- GRANULARITY_DAY, /*anchor_start_end=*/false));
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988}", 567990000000,
- GRANULARITY_DAY, /*anchor_start_end=*/true));
- EXPECT_TRUE(ParsesCorrectly("lorem {1 january 2018} ipsum", 1514761200000,
- GRANULARITY_DAY, /*anchor_start_end=*/false));
- EXPECT_TRUE(HasNoResult("lorem 1 january 2018 ipsum",
- /*anchor_start_end=*/true));
-}
-
-TEST_F(ParserTest, ParseWithRawUsecase) {
- // Annotated for RAW usecase.
- EXPECT_TRUE(ParsesCorrectly(
- "{tomorrow}", 82800000, GRANULARITY_DAY, /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me {in two hours}", 7200000, GRANULARITY_HOUR,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me {next month}", 2674800000, GRANULARITY_MONTH,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
- EXPECT_TRUE(ParsesCorrectly(
- "what's the time {now}", -3600000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- EXPECT_TRUE(ParsesCorrectly(
- "call me on {Saturday}", 169200000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en-US",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
-
- // Not annotated for Smart usecase.
- EXPECT_TRUE(HasNoResult(
- "{tomorrow}", /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
-}
-
-TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30am}", 567991800000,
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{January 1, 1988 12:30pm}", 568035000000,
- GRANULARITY_MINUTE));
-}
-
-TEST_F(ParserTest, ParseGerman) {
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{Januar 1 2018}", 1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
- 1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}",
- {1271651775000, 1271694975000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}",
- {1291253998000, 1291297198000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}",
- {1277512289000, 1277555489000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}",
- {1137899465000, 1137942665000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{11:42:35}", {38555000, 81755000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman(
- "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}",
- {1429782155000, 1429825355000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}",
- {1271651775000, 1271694975000},
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
- 1514820600000, GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4 nachm}", 1514818800000,
- GRANULARITY_HOUR));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{morgen 0:00}", {82800000, 126000000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectlyGerman("{morgen um 4:00}", {97200000, 140400000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR));
-}
-
-TEST_F(ParserTest, ParseNonUs) {
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"en-GB"));
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1430431200000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"en"));
-}
-
-TEST_F(ParserTest, ParseUs) {
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"en-US"));
- EXPECT_TRUE(ParsesCorrectly("{1/5/15}", 1420412400000, GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich",
- /*locales=*/"es-US"));
-}
-
-TEST_F(ParserTest, ParseUnknownLanguage) {
- EXPECT_TRUE(ParsesCorrectly("bylo to {31. 12. 2015} v 6 hodin", 1451516400000,
- GRANULARITY_DAY,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
-}
-
-TEST_F(ParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
- LoadModel([](ModelT* model) {
- model->datetime_model->generate_alternative_interpretations_when_ambiguous =
- true;
- });
-
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
- {1514777400000, 1514820600000},
- GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{monday 3pm}", 396000000, GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("{monday 3:00}", {352800000, 396000000},
- GRANULARITY_MINUTE));
-}
-
-TEST_F(ParserTest, WhenAlternativesDisabledDoesNotGenerateAlternatives) {
- LoadModel([](ModelT* model) {
- model->datetime_model->generate_alternative_interpretations_when_ambiguous =
- false;
- });
-
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
- GRANULARITY_MINUTE));
-}
-
-class ParserLocaleTest : public testing::Test {
- public:
- void SetUp() override;
- bool HasResult(const std::string& input, const std::string& locales);
-
- protected:
- UniLib unilib_;
- CalendarLib calendarlib_;
- flatbuffers::FlatBufferBuilder builder_;
- std::unique_ptr<DatetimeParser> parser_;
-};
-
-void AddPattern(const std::string& regex, int locale,
- std::vector<std::unique_ptr<DatetimeModelPatternT>>* patterns) {
- patterns->emplace_back(new DatetimeModelPatternT);
- patterns->back()->regexes.emplace_back(new DatetimeModelPattern_::RegexT);
- patterns->back()->regexes.back()->pattern = regex;
- patterns->back()->regexes.back()->groups.push_back(
- DatetimeGroupType_GROUP_UNUSED);
- patterns->back()->locales.push_back(locale);
-}
-
-void ParserLocaleTest::SetUp() {
- DatetimeModelT model;
- model.use_extractors_for_locating = false;
- model.locales.clear();
- model.locales.push_back("en-US");
- model.locales.push_back("en-CH");
- model.locales.push_back("zh-Hant");
- model.locales.push_back("en-*");
- model.locales.push_back("zh-Hant-*");
- model.locales.push_back("*-CH");
- model.locales.push_back("default");
- model.default_locales.push_back(6);
-
- AddPattern(/*regex=*/"en-US", /*locale=*/0, &model.patterns);
- AddPattern(/*regex=*/"en-CH", /*locale=*/1, &model.patterns);
- AddPattern(/*regex=*/"zh-Hant", /*locale=*/2, &model.patterns);
- AddPattern(/*regex=*/"en-all", /*locale=*/3, &model.patterns);
- AddPattern(/*regex=*/"zh-Hant-all", /*locale=*/4, &model.patterns);
- AddPattern(/*regex=*/"all-CH", /*locale=*/5, &model.patterns);
- AddPattern(/*regex=*/"default", /*locale=*/6, &model.patterns);
-
- builder_.Finish(DatetimeModel::Pack(builder_, &model));
- const DatetimeModel* model_fb =
- flatbuffers::GetRoot<DatetimeModel>(builder_.GetBufferPointer());
- ASSERT_TRUE(model_fb);
-
- parser_ = DatetimeParser::Instance(model_fb, unilib_, calendarlib_,
- /*decompressor=*/nullptr);
- ASSERT_TRUE(parser_);
-}
-
-bool ParserLocaleTest::HasResult(const std::string& input,
- const std::string& locales) {
- std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(
- input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
- AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
- return results.size() == 1;
-}
-
-TEST_F(ParserLocaleTest, English) {
- EXPECT_TRUE(HasResult("en-US", /*locales=*/"en-US"));
- EXPECT_FALSE(HasResult("en-CH", /*locales=*/"en-US"));
- EXPECT_FALSE(HasResult("en-US", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("en-CH", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
-}
-
-TEST_F(ParserLocaleTest, TraditionalChinese) {
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant"));
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-TW"));
- EXPECT_TRUE(HasResult("zh-Hant-all", /*locales=*/"zh-Hant-SG"));
- EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh-SG"));
- EXPECT_FALSE(HasResult("zh-Hant-all", /*locales=*/"zh"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"zh"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"zh-Hant-SG"));
-}
-
-TEST_F(ParserLocaleTest, SwissEnglish) {
- EXPECT_TRUE(HasResult("all-CH", /*locales=*/"de-CH"));
- EXPECT_TRUE(HasResult("all-CH", /*locales=*/"en-CH"));
- EXPECT_TRUE(HasResult("en-all", /*locales=*/"en-CH"));
- EXPECT_FALSE(HasResult("all-CH", /*locales=*/"de-DE"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"de-CH"));
- EXPECT_TRUE(HasResult("default", /*locales=*/"en-CH"));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/duration/duration.cc b/annotator/duration/duration.cc
deleted file mode 100644
index d442dc6..0000000
--- a/annotator/duration/duration.cc
+++ /dev/null
@@ -1,290 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/duration/duration.h"
-
-#include <climits>
-#include <cstdlib>
-
-#include "annotator/collections.h"
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-#include "utils/strings/numbers.h"
-
-namespace libtextclassifier3 {
-
-using DurationUnit = internal::DurationUnit;
-
-namespace internal {
-
-namespace {
-void FillDurationUnitMap(
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- expressions,
- DurationUnit duration_unit,
- std::unordered_map<std::string, DurationUnit>* target_map) {
- if (expressions == nullptr) {
- return;
- }
-
- for (const flatbuffers::String* expression_string : *expressions) {
- (*target_map)[expression_string->c_str()] = duration_unit;
- }
-}
-} // namespace
-
-std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
- const DurationAnnotatorOptions* options) {
- std::unordered_map<std::string, DurationUnit> mapping;
- FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK,
- &mapping);
- FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping);
- FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR,
- &mapping);
- FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
- &mapping);
- FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
- &mapping);
- return mapping;
-}
-
-std::unordered_set<std::string> BuildStringSet(
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- strings) {
- std::unordered_set<std::string> result;
- if (strings == nullptr) {
- return result;
- }
-
- for (const flatbuffers::String* string_value : *strings) {
- result.insert(string_value->c_str());
- }
-
- return result;
-}
-
-} // namespace internal
-
-bool DurationAnnotator::ClassifyText(
- const UnicodeText& context, CodepointSpan selection_indices,
- AnnotationUsecase annotation_usecase,
- ClassificationResult* classification_result) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
- return false;
- }
-
- const UnicodeText selection =
- UnicodeText::Substring(context, selection_indices.first,
- selection_indices.second, /*do_copy=*/false);
- const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
-
- AnnotatedSpan annotated_span;
- if (FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
- tokens.size()) {
- return false;
- }
-
- TC3_CHECK(!annotated_span.classification.empty());
-
- *classification_result = annotated_span.classification[0];
- return true;
-}
-
-bool DurationAnnotator::FindAll(const UnicodeText& context,
- const std::vector<Token>& tokens,
- AnnotationUsecase annotation_usecase,
- std::vector<AnnotatedSpan>* results) const {
- if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
- (1 << annotation_usecase))) == 0) {
- return true;
- }
-
- for (int i = 0; i < tokens.size();) {
- AnnotatedSpan span;
- const int next_i = FindDurationStartingAt(context, tokens, i, &span);
- if (next_i != i) {
- results->push_back(span);
- i = next_i;
- } else {
- i++;
- }
- }
- return true;
-}
-
-int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
- const std::vector<Token>& tokens,
- int start_token_index,
- AnnotatedSpan* result) const {
- CodepointIndex start_index = kInvalidIndex;
- CodepointIndex end_index = kInvalidIndex;
-
- bool has_quantity = false;
- ParsedDurationAtom parsed_duration;
-
- std::vector<ParsedDurationAtom> parsed_duration_atoms;
-
- // This is the core algorithm for finding the duration expressions. It
- // basically iterates over tokens and changes the state variables above as it
- // goes.
- int token_index;
- for (token_index = start_token_index; token_index < tokens.size();
- token_index++) {
- const Token& token = tokens[token_index];
-
- if (ParseQuantityToken(token, &parsed_duration)) {
- has_quantity = true;
- if (start_index == kInvalidIndex) {
- start_index = token.start;
- }
- end_index = token.end;
- } else if (ParseDurationUnitToken(token, &parsed_duration.unit)) {
- if (start_index == kInvalidIndex) {
- start_index = token.start;
- }
- end_index = token.end;
- parsed_duration_atoms.push_back(parsed_duration);
- has_quantity = false;
- parsed_duration = ParsedDurationAtom();
- } else if (ParseFillerToken(token)) {
- } else {
- break;
- }
- }
-
- if (parsed_duration_atoms.empty()) {
- return start_token_index;
- }
-
- const bool parse_ended_without_unit_for_last_mentioned_quantity =
- has_quantity;
-
- ClassificationResult classification{Collections::Duration(),
- options_->score()};
- classification.priority_score = options_->priority_score();
- classification.duration_ms =
- ParsedDurationAtomsToMillis(parsed_duration_atoms);
-
- // Process suffix expressions like "and half" that don't have the
- // duration_unit explicitly mentioned.
- if (parse_ended_without_unit_for_last_mentioned_quantity &&
- parsed_duration.plus_half) {
- ParsedDurationAtom atom = ParsedDurationAtom::Half();
- atom.unit = parsed_duration_atoms.rbegin()->unit;
- classification.duration_ms += ParsedDurationAtomsToMillis({atom});
- }
-
- result->span = feature_processor_->StripBoundaryCodepoints(
- context, {start_index, end_index});
- result->classification.push_back(classification);
- result->source = AnnotatedSpan::Source::DURATION;
-
- return token_index;
-}
-
-int64 DurationAnnotator::ParsedDurationAtomsToMillis(
- const std::vector<ParsedDurationAtom>& atoms) const {
- int64 result = 0;
- for (auto atom : atoms) {
- int multiplier;
- switch (atom.unit) {
- case DurationUnit::WEEK:
- multiplier = 7 * 24 * 60 * 60 * 1000;
- break;
- case DurationUnit::DAY:
- multiplier = 24 * 60 * 60 * 1000;
- break;
- case DurationUnit::HOUR:
- multiplier = 60 * 60 * 1000;
- break;
- case DurationUnit::MINUTE:
- multiplier = 60 * 1000;
- break;
- case DurationUnit::SECOND:
- multiplier = 1000;
- break;
- case DurationUnit::UNKNOWN:
- TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
- return -1;
- break;
- }
-
- int value = atom.value;
- // This condition handles expressions like "an hour", where the quantity is
- // not specified. In this case we assume quantity 1. Except for cases like
- // "half hour".
- if (value == 0 && !atom.plus_half) {
- value = 1;
- }
- result += value * multiplier;
- result += atom.plus_half * multiplier / 2;
- }
- return result;
-}
-
-bool DurationAnnotator::ParseQuantityToken(const Token& token,
- ParsedDurationAtom* value) const {
- if (token.value.empty()) {
- return false;
- }
-
- std::string token_value_buffer;
- const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
- token.value, &token_value_buffer);
-
- if (half_expressions_.find(token_value) != half_expressions_.end()) {
- value->plus_half = true;
- return true;
- }
-
- int32 parsed_value;
- if (ParseInt32(token_value.c_str(), &parsed_value)) {
- value->value = parsed_value;
- return true;
- }
-
- return false;
-}
-
-bool DurationAnnotator::ParseDurationUnitToken(
- const Token& token, DurationUnit* duration_unit) const {
- std::string token_value_buffer;
- const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
- token.value, &token_value_buffer);
-
- const auto it = token_value_to_duration_unit_.find(token_value);
- if (it == token_value_to_duration_unit_.end()) {
- return false;
- }
-
- *duration_unit = it->second;
- return true;
-}
-
-bool DurationAnnotator::ParseFillerToken(const Token& token) const {
- std::string token_value_buffer;
- const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
- token.value, &token_value_buffer);
-
- if (filler_expressions_.find(token_value) == filler_expressions_.end()) {
- return false;
- }
-
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/duration/duration.h b/annotator/duration/duration.h
deleted file mode 100644
index 4311afc..0000000
--- a/annotator/duration/duration.h
+++ /dev/null
@@ -1,128 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
-
-#include <string>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-#include "annotator/feature-processor.h"
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-namespace internal {
-enum class DurationUnit {
- UNKNOWN = -1,
- WEEK = 0,
- DAY = 1,
- HOUR = 2,
- MINUTE = 3,
- SECOND = 4
-
- // NOTE: If we want to add MONTH and YEAR we'll have to think of different
- // parsing format, because MONTH and YEAR don't have a fixed number of
- // milliseconds, unlike week/day/hour/minute/second. We ignore the daylight
- // savings time and assume the day is always 24 hours.
-};
-
-// Prepares the mapping between token values and duration unit types.
-std::unordered_map<std::string, internal::DurationUnit>
-BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options);
-
-// Creates a set of strings from a flatbuffer string vector.
-std::unordered_set<std::string> BuildStringSet(
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*);
-
-} // namespace internal
-
-// Annotator of duration expressions like "3 minutes 30 seconds".
-class DurationAnnotator {
- public:
- explicit DurationAnnotator(const DurationAnnotatorOptions* options,
- const FeatureProcessor* feature_processor)
- : options_(options),
- feature_processor_(feature_processor),
- token_value_to_duration_unit_(
- internal::BuildTokenToDurationUnitMapping(options)),
- filler_expressions_(
- internal::BuildStringSet(options->filler_expressions())),
- half_expressions_(
- internal::BuildStringSet(options->half_expressions())) {}
-
- // Classifies given text, and if it is a duration, it passes the result in
- // 'classification_result' and returns true, otherwise returns false.
- bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
- AnnotationUsecase annotation_usecase,
- ClassificationResult* classification_result) const;
-
- // Finds all duration instances in the input text.
- bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens,
- AnnotationUsecase annotation_usecase,
- std::vector<AnnotatedSpan>* results) const;
-
- private:
- // Represents a component of duration parsed from text (e.g. "3 hours" from
- // the expression "3 hours and 20 minutes").
- struct ParsedDurationAtom {
- // Unit of the duration.
- internal::DurationUnit unit = internal::DurationUnit::UNKNOWN;
-
- // Quantity of the duration unit.
- int value = 0;
-
- // True, if half an unit was specified (either in addition, or exclusively).
- // E.g. "hour and a half".
- // NOTE: Quarter, three-quarters etc. is not supported.
- bool plus_half = false;
-
- static ParsedDurationAtom Half() {
- ParsedDurationAtom result;
- result.plus_half = true;
- return result;
- }
- };
-
- // Starts consuming tokens and returns the index past the last consumed token.
- int FindDurationStartingAt(const UnicodeText& context,
- const std::vector<Token>& tokens,
- int start_token_index,
- AnnotatedSpan* result) const;
-
- bool ParseQuantityToken(const Token& token, ParsedDurationAtom* value) const;
- bool ParseDurationUnitToken(const Token& token,
- internal::DurationUnit* duration_unit) const;
- bool ParseFillerToken(const Token& token) const;
-
- int64 ParsedDurationAtomsToMillis(
- const std::vector<ParsedDurationAtom>& atoms) const;
-
- const DurationAnnotatorOptions* options_;
- const FeatureProcessor* feature_processor_;
- const std::unordered_map<std::string, internal::DurationUnit>
- token_value_to_duration_unit_;
- const std::unordered_set<std::string> filler_expressions_;
- const std::unordered_set<std::string> half_expressions_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
diff --git a/annotator/duration/duration_test.cc b/annotator/duration/duration_test.cc
deleted file mode 100644
index 78548fe..0000000
--- a/annotator/duration/duration_test.cc
+++ /dev/null
@@ -1,320 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/duration/duration.h"
-
-#include <string>
-#include <vector>
-
-#include "annotator/collections.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "annotator/types.h"
-#include "utils/test-utils.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::AllOf;
-using testing::ElementsAre;
-using testing::Field;
-
-const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
-
- options.week_expressions.push_back("week");
- options.week_expressions.push_back("weeks");
-
- options.day_expressions.push_back("day");
- options.day_expressions.push_back("days");
-
- options.hour_expressions.push_back("hour");
- options.hour_expressions.push_back("hours");
-
- options.minute_expressions.push_back("minute");
- options.minute_expressions.push_back("minutes");
-
- options.second_expressions.push_back("second");
- options.second_expressions.push_back("seconds");
-
- options.filler_expressions.push_back("and");
- options.filler_expressions.push_back("a");
- options.filler_expressions.push_back("an");
- options.filler_expressions.push_back("one");
-
- options.half_expressions.push_back("half");
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
-}
-
-FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.ignored_span_boundary_codepoints.push_back(',');
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- const FeatureProcessorOptions* feature_processor_options =
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
-
- return FeatureProcessor(feature_processor_options, unilib);
-}
-
-class DurationAnnotatorTest : public ::testing::Test {
- protected:
- DurationAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingDurationAnnotatorOptions(),
- &feature_processor_) {}
-
- std::vector<Token> Tokenize(const UnicodeText& text) {
- return feature_processor_.Tokenize(text);
- }
-
- UniLib unilib_;
- FeatureProcessor feature_processor_;
- DurationAnnotator duration_annotator_;
-};
-
-TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
- ClassificationResult classification;
- EXPECT_TRUE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
-}
-
-TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
- ClassificationResult classification;
- EXPECT_TRUE(duration_annotator_.ClassifyText(
- UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
-
- EXPECT_THAT(classification,
- AllOf(Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
-}
-
-TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
- const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 15 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3.5 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
- const UnicodeText text =
- UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 3 * 60 * 60 * 1000 + 5 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
- const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 0.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 1 hour and a half");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 1.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for an hour and a half");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 1.5 * 60 * 60 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest,
- FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 1 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
- const UnicodeText text = UTF8ToUnicodeText(
- "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 2 * 1000)))))));
-}
-
-TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
- const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
- const UnicodeText text =
- UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
- std::vector<Token> tokens = Tokenize(text);
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "duration"),
- Field(&ClassificationResult::duration_ms,
- 10 * 60 * 1000 + 2 * 1000)))))));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/entity-data.fbs b/annotator/entity-data.fbs
deleted file mode 100755
index 2143e28..0000000
--- a/annotator/entity-data.fbs
+++ /dev/null
@@ -1,69 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-namespace libtextclassifier3.EntityData_.Datetime_;
-enum Granularity : int {
- GRANULARITY_UNKNOWN = -1,
- GRANULARITY_YEAR = 0,
- GRANULARITY_MONTH = 1,
- GRANULARITY_WEEK = 2,
- GRANULARITY_DAY = 3,
- GRANULARITY_HOUR = 4,
- GRANULARITY_MINUTE = 5,
- GRANULARITY_SECOND = 6,
-}
-
-namespace libtextclassifier3.EntityData_;
-table Datetime {
- time_ms_utc:long;
- granularity:Datetime_.Granularity = GRANULARITY_UNKNOWN;
-}
-
-namespace libtextclassifier3.EntityData_;
-table Contact {
- name:string;
- given_name:string;
- nickname:string;
- email_address:string;
- phone_number:string;
- contact_id:string;
-}
-
-namespace libtextclassifier3.EntityData_;
-table App {
- name:string;
- package_name:string;
-}
-
-// Represents an entity annotated in text.
-namespace libtextclassifier3;
-table EntityData {
- // Codepoint indices of the annotation, start is inclusive, end is
- // exclusive.
- start:int;
-
- end:int;
-
- // The entity type, as in the TextClassifier APIs.
- type:string;
-
- datetime:EntityData_.Datetime;
- reserved_5:int (deprecated);
- contact:EntityData_.Contact;
- app:EntityData_.App;
-}
-
-root_type libtextclassifier3.EntityData;
diff --git a/annotator/feature-processor.cc b/annotator/feature-processor.cc
deleted file mode 100644
index c0f5c82..0000000
--- a/annotator/feature-processor.cc
+++ /dev/null
@@ -1,863 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/feature-processor.h"
-
-#include <iterator>
-#include <set>
-#include <vector>
-
-#include "utils/base/logging.h"
-#include "utils/strings/utf8.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-namespace internal {
-
-Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
- const UniLib* unilib) {
- std::vector<const TokenizationCodepointRange*> codepoint_config;
- if (options->tokenization_codepoint_config() != nullptr) {
- codepoint_config.insert(codepoint_config.end(),
- options->tokenization_codepoint_config()->begin(),
- options->tokenization_codepoint_config()->end());
- }
- std::vector<const CodepointRange*> internal_codepoint_config;
- if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
- internal_codepoint_config.insert(
- internal_codepoint_config.end(),
- options->internal_tokenizer_codepoint_ranges()->begin(),
- options->internal_tokenizer_codepoint_ranges()->end());
- }
- const bool tokenize_on_script_change =
- options->tokenization_codepoint_config() != nullptr &&
- options->tokenize_on_script_change();
- return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
- internal_codepoint_config, tokenize_on_script_change,
- options->icu_preserve_whitespace_tokens());
-}
-
-TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
- const FeatureProcessorOptions* const options) {
- TokenFeatureExtractorOptions extractor_options;
-
- extractor_options.num_buckets = options->num_buckets();
- if (options->chargram_orders() != nullptr) {
- for (int order : *options->chargram_orders()) {
- extractor_options.chargram_orders.push_back(order);
- }
- }
- extractor_options.max_word_length = options->max_word_length();
- extractor_options.extract_case_feature = options->extract_case_feature();
- extractor_options.unicode_aware_features = options->unicode_aware_features();
- extractor_options.extract_selection_mask_feature =
- options->extract_selection_mask_feature();
- if (options->regexp_feature() != nullptr) {
- for (const auto& regexp_feauture : *options->regexp_feature()) {
- extractor_options.regexp_features.push_back(regexp_feauture->str());
- }
- }
- extractor_options.remap_digits = options->remap_digits();
- extractor_options.lowercase_tokens = options->lowercase_tokens();
-
- if (options->allowed_chargrams() != nullptr) {
- for (const auto& chargram : *options->allowed_chargrams()) {
- extractor_options.allowed_chargrams.insert(chargram->str());
- }
- }
- return extractor_options;
-}
-
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
- std::vector<Token>* tokens) {
- for (auto it = tokens->begin(); it != tokens->end(); ++it) {
- const UnicodeText token_word =
- UTF8ToUnicodeText(it->value, /*do_copy=*/false);
-
- auto last_start = token_word.begin();
- int last_start_index = it->start;
- std::vector<UnicodeText::const_iterator> split_points;
-
- // Selection start split point.
- if (selection.first > it->start && selection.first < it->end) {
- std::advance(last_start, selection.first - last_start_index);
- split_points.push_back(last_start);
- last_start_index = selection.first;
- }
-
- // Selection end split point.
- if (selection.second > it->start && selection.second < it->end) {
- std::advance(last_start, selection.second - last_start_index);
- split_points.push_back(last_start);
- }
-
- if (!split_points.empty()) {
- // Add a final split for the rest of the token unless it's been all
- // consumed already.
- if (split_points.back() != token_word.end()) {
- split_points.push_back(token_word.end());
- }
-
- std::vector<Token> replacement_tokens;
- last_start = token_word.begin();
- int current_pos = it->start;
- for (const auto& split_point : split_points) {
- Token new_token(token_word.UTF8Substring(last_start, split_point),
- current_pos,
- current_pos + std::distance(last_start, split_point));
-
- last_start = split_point;
- current_pos = new_token.end;
-
- replacement_tokens.push_back(new_token);
- }
-
- it = tokens->erase(it);
- it = tokens->insert(it, replacement_tokens.begin(),
- replacement_tokens.end());
- std::advance(it, replacement_tokens.size() - 1);
- }
- }
-}
-
-} // namespace internal
-
-void FeatureProcessor::StripTokensFromOtherLines(
- const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens) const {
- const UnicodeText context_unicode = UTF8ToUnicodeText(context,
- /*do_copy=*/false);
- StripTokensFromOtherLines(context_unicode, span, tokens);
-}
-
-void FeatureProcessor::StripTokensFromOtherLines(
- const UnicodeText& context_unicode, CodepointSpan span,
- std::vector<Token>* tokens) const {
- std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
-
- auto span_start = context_unicode.begin();
- if (span.first > 0) {
- std::advance(span_start, span.first);
- }
- auto span_end = context_unicode.begin();
- if (span.second > 0) {
- std::advance(span_end, span.second);
- }
- for (const UnicodeTextRange& line : lines) {
- // Find the line that completely contains the span.
- if (line.first <= span_start && line.second >= span_end) {
- const CodepointIndex last_line_begin_index =
- std::distance(context_unicode.begin(), line.first);
- const CodepointIndex last_line_end_index =
- last_line_begin_index + std::distance(line.first, line.second);
-
- for (auto token = tokens->begin(); token != tokens->end();) {
- if (token->start >= last_line_begin_index &&
- token->end <= last_line_end_index) {
- ++token;
- } else {
- token = tokens->erase(token);
- }
- }
- }
- }
-}
-
-std::string FeatureProcessor::GetDefaultCollection() const {
- if (options_->default_collection() < 0 ||
- options_->collections() == nullptr ||
- options_->default_collection() >= options_->collections()->size()) {
- TC3_LOG(ERROR)
- << "Invalid or missing default collection. Returning empty string.";
- return "";
- }
- return (*options_->collections())[options_->default_collection()]->str();
-}
-
-std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
- return tokenizer_.Tokenize(text);
-}
-
-std::vector<Token> FeatureProcessor::Tokenize(
- const UnicodeText& text_unicode) const {
- return tokenizer_.Tokenize(text_unicode);
-}
-
-bool FeatureProcessor::LabelToSpan(
- const int label, const VectorSpan<Token>& tokens,
- std::pair<CodepointIndex, CodepointIndex>* span) const {
- if (tokens.size() != GetNumContextTokens()) {
- return false;
- }
-
- TokenSpan token_span;
- if (!LabelToTokenSpan(label, &token_span)) {
- return false;
- }
-
- const int result_begin_token_index = token_span.first;
- const Token& result_begin_token =
- tokens[options_->context_size() - result_begin_token_index];
- const int result_begin_codepoint = result_begin_token.start;
- const int result_end_token_index = token_span.second;
- const Token& result_end_token =
- tokens[options_->context_size() + result_end_token_index];
- const int result_end_codepoint = result_end_token.end;
-
- if (result_begin_codepoint == kInvalidIndex ||
- result_end_codepoint == kInvalidIndex) {
- *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
- } else {
- const UnicodeText token_begin_unicode =
- UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
- UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
- const UnicodeText token_end_unicode =
- UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
- UnicodeText::const_iterator token_end = token_end_unicode.end();
-
- const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
- token_begin, token_begin_unicode.end(),
- /*count_from_beginning=*/true);
- const int end_ignored =
- CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
- /*count_from_beginning=*/false);
- // In case everything would be stripped, set the span to the original
- // beginning and zero length.
- if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
- *span = {result_begin_codepoint, result_begin_codepoint};
- } else {
- *span = CodepointSpan({result_begin_codepoint + begin_ignored,
- result_end_codepoint - end_ignored});
- }
- }
- return true;
-}
-
-bool FeatureProcessor::LabelToTokenSpan(const int label,
- TokenSpan* token_span) const {
- if (label >= 0 && label < label_to_selection_.size()) {
- *token_span = label_to_selection_[label];
- return true;
- } else {
- return false;
- }
-}
-
-bool FeatureProcessor::SpanToLabel(
- const std::pair<CodepointIndex, CodepointIndex>& span,
- const std::vector<Token>& tokens, int* label) const {
- if (tokens.size() != GetNumContextTokens()) {
- return false;
- }
-
- const int click_position =
- options_->context_size(); // Click is always in the middle.
- const int padding = options_->context_size() - options_->max_selection_span();
-
- int span_left = 0;
- for (int i = click_position - 1; i >= padding; i--) {
- if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
- ++span_left;
- } else {
- break;
- }
- }
-
- int span_right = 0;
- for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
- if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
- ++span_right;
- } else {
- break;
- }
- }
-
- // Check that the spanned tokens cover the whole span.
- bool tokens_match_span;
- const CodepointIndex tokens_start = tokens[click_position - span_left].start;
- const CodepointIndex tokens_end = tokens[click_position + span_right].end;
- if (options_->snap_label_span_boundaries_to_containing_tokens()) {
- tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
- } else {
- const UnicodeText token_left_unicode = UTF8ToUnicodeText(
- tokens[click_position - span_left].value, /*do_copy=*/false);
- const UnicodeText token_right_unicode = UTF8ToUnicodeText(
- tokens[click_position + span_right].value, /*do_copy=*/false);
-
- UnicodeText::const_iterator span_begin = token_left_unicode.begin();
- UnicodeText::const_iterator span_end = token_right_unicode.end();
-
- const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
- span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
- const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
- token_right_unicode.begin(), span_end,
- /*count_from_beginning=*/false);
-
- tokens_match_span = tokens_start <= span.first &&
- tokens_start + num_punctuation_start >= span.first &&
- tokens_end >= span.second &&
- tokens_end - num_punctuation_end <= span.second;
- }
-
- if (tokens_match_span) {
- *label = TokenSpanToLabel({span_left, span_right});
- } else {
- *label = kInvalidLabel;
- }
-
- return true;
-}
-
-int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
- auto it = selection_to_label_.find(span);
- if (it != selection_to_label_.end()) {
- return it->second;
- } else {
- return kInvalidLabel;
- }
-}
-
-TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span,
- bool snap_boundaries_to_containing_tokens) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
-
- TokenIndex start_token = kInvalidIndex;
- TokenIndex end_token = kInvalidIndex;
- for (int i = 0; i < selectable_tokens.size(); ++i) {
- bool is_token_in_span;
- if (snap_boundaries_to_containing_tokens) {
- is_token_in_span = codepoint_start < selectable_tokens[i].end &&
- codepoint_end > selectable_tokens[i].start;
- } else {
- is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
- codepoint_end >= selectable_tokens[i].end;
- }
- if (is_token_in_span && !selectable_tokens[i].is_padding) {
- if (start_token == kInvalidIndex) {
- start_token = i;
- }
- end_token = i + 1;
- }
- }
- return {start_token, end_token};
-}
-
-CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
- return {selectable_tokens[token_span.first].start,
- selectable_tokens[token_span.second - 1].end};
-}
-
-namespace {
-
-// Finds a single token that completely contains the given span.
-int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
- CodepointSpan codepoint_span) {
- const int codepoint_start = std::get<0>(codepoint_span);
- const int codepoint_end = std::get<1>(codepoint_span);
-
- for (int i = 0; i < selectable_tokens.size(); ++i) {
- if (codepoint_start >= selectable_tokens[i].start &&
- codepoint_end <= selectable_tokens[i].end) {
- return i;
- }
- }
- return kInvalidIndex;
-}
-
-} // namespace
-
-namespace internal {
-
-int CenterTokenFromClick(CodepointSpan span,
- const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
- CodepointSpanToTokenSpan(selectable_tokens, span);
-
- // If no exact match was found, try finding a token that completely contains
- // the click span. This is useful e.g. when Android builds the selection
- // using ICU tokenization, and ends up with only a portion of our space-
- // separated token. E.g. for "(857)" Android would select "857".
- if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
- int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
- if (token_index != kInvalidIndex) {
- range_begin = token_index;
- range_end = token_index + 1;
- }
- }
-
- // We only allow clicks that are exactly 1 selectable token.
- if (range_end - range_begin == 1) {
- return range_begin;
- } else {
- return kInvalidIndex;
- }
-}
-
-int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens) {
- int range_begin;
- int range_end;
- std::tie(range_begin, range_end) =
- CodepointSpanToTokenSpan(selectable_tokens, span);
-
- // Center the clicked token in the selection range.
- if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
- return (range_begin + range_end - 1) / 2;
- } else {
- return kInvalidIndex;
- }
-}
-
-} // namespace internal
-
-int FeatureProcessor::FindCenterToken(CodepointSpan span,
- const std::vector<Token>& tokens) const {
- if (options_->center_token_selection_method() ==
- FeatureProcessorOptions_::
- CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
- return internal::CenterTokenFromClick(span, tokens);
- } else if (options_->center_token_selection_method() ==
- FeatureProcessorOptions_::
- CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
- return internal::CenterTokenFromMiddleOfSelection(span, tokens);
- } else if (options_->center_token_selection_method() ==
- FeatureProcessorOptions_::
- CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
- // TODO(zilka): Remove once we have new models on the device.
- // It uses the fact that sharing model use
- // split_tokens_on_selection_boundaries and selection not. So depending on
- // this we select the right way of finding the click location.
- if (!options_->split_tokens_on_selection_boundaries()) {
- // SmartSelection model.
- return internal::CenterTokenFromClick(span, tokens);
- } else {
- // SmartSharing model.
- return internal::CenterTokenFromMiddleOfSelection(span, tokens);
- }
- } else {
- TC3_LOG(ERROR) << "Invalid center token selection method.";
- return kInvalidIndex;
- }
-}
-
-bool FeatureProcessor::SelectionLabelSpans(
- const VectorSpan<Token> tokens,
- std::vector<CodepointSpan>* selection_label_spans) const {
- for (int i = 0; i < label_to_selection_.size(); ++i) {
- CodepointSpan span;
- if (!LabelToSpan(i, tokens, &span)) {
- TC3_LOG(ERROR) << "Could not convert label to span: " << i;
- return false;
- }
- selection_label_spans->push_back(span);
- }
- return true;
-}
-
-void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
- if (options_->ignored_span_boundary_codepoints() != nullptr) {
- for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
- ignored_span_boundary_codepoints_.insert(codepoint);
- }
- }
-}
-
-int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end,
- bool count_from_beginning) const {
- if (span_start == span_end) {
- return 0;
- }
-
- UnicodeText::const_iterator it;
- UnicodeText::const_iterator it_last;
- if (count_from_beginning) {
- it = span_start;
- it_last = span_end;
- // We can assume that the string is non-zero length because of the check
- // above, thus the decrement is always valid here.
- --it_last;
- } else {
- it = span_end;
- it_last = span_start;
- // We can assume that the string is non-zero length because of the check
- // above, thus the decrement is always valid here.
- --it;
- }
-
- // Move until we encounter a non-ignored character.
- int num_ignored = 0;
- while (ignored_span_boundary_codepoints_.find(*it) !=
- ignored_span_boundary_codepoints_.end()) {
- ++num_ignored;
-
- if (it == it_last) {
- break;
- }
-
- if (count_from_beginning) {
- ++it;
- } else {
- --it;
- }
- }
-
- return num_ignored;
-}
-
-namespace {
-
-void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
- std::vector<UnicodeTextRange>* ranges) {
- UnicodeText::const_iterator start = t.begin();
- UnicodeText::const_iterator curr = start;
- UnicodeText::const_iterator end = t.end();
- for (; curr != end; ++curr) {
- if (codepoints.find(*curr) != codepoints.end()) {
- if (start != curr) {
- ranges->push_back(std::make_pair(start, curr));
- }
- start = curr;
- ++start;
- }
- }
- if (start != end) {
- ranges->push_back(std::make_pair(start, end));
- }
-}
-
-} // namespace
-
-std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
- const UnicodeText& context_unicode) const {
- std::vector<UnicodeTextRange> lines;
- const std::set<char32> codepoints{{'\n', '|'}};
- FindSubstrings(context_unicode, codepoints, &lines);
- return lines;
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const std::string& context, CodepointSpan span) const {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripBoundaryCodepoints(context_unicode, span);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText& context_unicode, CodepointSpan span) const {
- if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
- return span;
- }
-
- UnicodeText::const_iterator span_begin = context_unicode.begin();
- std::advance(span_begin, span.first);
- UnicodeText::const_iterator span_end = context_unicode.begin();
- std::advance(span_end, span.second);
-
- return StripBoundaryCodepoints(span_begin, span_end, span);
-}
-
-CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
- if (!ValidNonEmptySpan(span) || span_begin == span_end) {
- return span;
- }
-
- const int start_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/true);
- const int end_offset = CountIgnoredSpanBoundaryCodepoints(
- span_begin, span_end, /*count_from_beginning=*/false);
-
- if (span.first + start_offset < span.second - end_offset) {
- return {span.first + start_offset, span.second - end_offset};
- } else {
- return {span.first, span.first};
- }
-}
-
-float FeatureProcessor::SupportedCodepointsRatio(
- const TokenSpan& token_span, const std::vector<Token>& tokens) const {
- int num_supported = 0;
- int num_total = 0;
- for (int i = token_span.first; i < token_span.second; ++i) {
- const UnicodeText value =
- UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
- for (auto codepoint : value) {
- if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
- ++num_supported;
- }
- ++num_total;
- }
- }
- return static_cast<float>(num_supported) / static_cast<float>(num_total);
-}
-
-const std::string& FeatureProcessor::StripBoundaryCodepoints(
- const std::string& value, std::string* buffer) const {
- const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
- const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
- const CodepointSpan stripped_span =
- StripBoundaryCodepoints(value_unicode, initial_span);
-
- if (initial_span != stripped_span) {
- const UnicodeText stripped_token_value =
- UnicodeText::Substring(value_unicode, stripped_span.first,
- stripped_span.second, /*do_copy=*/false);
- *buffer = stripped_token_value.ToUTF8String();
- return *buffer;
- }
- return value;
-}
-
-int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
- const auto it = collection_to_label_.find(collection);
- if (it == collection_to_label_.end()) {
- return options_->default_collection();
- } else {
- return it->second;
- }
-}
-
-std::string FeatureProcessor::LabelToCollection(int label) const {
- if (label >= 0 && label < collection_to_label_.size()) {
- return (*options_->collections())[label]->str();
- } else {
- return GetDefaultCollection();
- }
-}
-
-void FeatureProcessor::MakeLabelMaps() {
- if (options_->collections() != nullptr) {
- for (int i = 0; i < options_->collections()->size(); ++i) {
- collection_to_label_[(*options_->collections())[i]->str()] = i;
- }
- }
-
- int selection_label_id = 0;
- for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
- for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
- if (!options_->selection_reduced_output_space() ||
- r + l <= options_->max_selection_span()) {
- TokenSpan token_span{l, r};
- selection_to_label_[token_span] = selection_label_id;
- label_to_selection_.push_back(token_span);
- ++selection_label_id;
- }
- }
- }
-}
-
-void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens,
- int* click_pos) const {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
- tokens, click_pos);
-}
-
-void FeatureProcessor::RetokenizeAndFindClick(
- const UnicodeText& context_unicode, CodepointSpan input_span,
- bool only_use_line_with_click, std::vector<Token>* tokens,
- int* click_pos) const {
- TC3_CHECK(tokens != nullptr);
-
- if (options_->split_tokens_on_selection_boundaries()) {
- internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
- }
-
- if (only_use_line_with_click) {
- StripTokensFromOtherLines(context_unicode, input_span, tokens);
- }
-
- int local_click_pos;
- if (click_pos == nullptr) {
- click_pos = &local_click_pos;
- }
- *click_pos = FindCenterToken(input_span, *tokens);
- if (*click_pos == kInvalidIndex) {
- // If the default click method failed, let's try to do sub-token matching
- // before we fail.
- *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
- }
-}
-
-namespace internal {
-
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
- std::vector<Token>* tokens, int* click_pos) {
- int right_context_needed = relative_click_span.second + context_size;
- if (*click_pos + right_context_needed + 1 >= tokens->size()) {
- // Pad max the context size.
- const int num_pad_tokens = std::min(
- context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
- tokens->size()));
- std::vector<Token> pad_tokens(num_pad_tokens);
- tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
- } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
- // Strip unused tokens.
- auto it = tokens->begin();
- std::advance(it, *click_pos + right_context_needed + 1);
- tokens->erase(it, tokens->end());
- }
-
- int left_context_needed = relative_click_span.first + context_size;
- if (*click_pos < left_context_needed) {
- // Pad max the context size.
- const int num_pad_tokens =
- std::min(context_size, left_context_needed - *click_pos);
- std::vector<Token> pad_tokens(num_pad_tokens);
- tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
- *click_pos += num_pad_tokens;
- } else if (*click_pos > left_context_needed) {
- // Strip unused tokens.
- auto it = tokens->begin();
- std::advance(it, *click_pos - left_context_needed);
- *click_pos -= it - tokens->begin();
- tokens->erase(tokens->begin(), it);
- }
-}
-
-} // namespace internal
-
-bool FeatureProcessor::HasEnoughSupportedCodepoints(
- const std::vector<Token>& tokens, TokenSpan token_span) const {
- if (options_->min_supported_codepoint_ratio() > 0) {
- const float supported_codepoint_ratio =
- SupportedCodepointsRatio(token_span, tokens);
- if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
- TC3_VLOG(1) << "Not enough supported codepoints in the context: "
- << supported_codepoint_ratio;
- return false;
- }
- }
- return true;
-}
-
-bool FeatureProcessor::ExtractFeatures(
- const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache, int feature_vector_size,
- std::unique_ptr<CachedFeatures>* cached_features) const {
- std::unique_ptr<std::vector<float>> features(new std::vector<float>());
- features->reserve(feature_vector_size * TokenSpanSize(token_span));
- for (int i = token_span.first; i < token_span.second; ++i) {
- if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
- embedding_executor, embedding_cache,
- features.get())) {
- TC3_LOG(ERROR) << "Could not get token features.";
- return false;
- }
- }
-
- std::unique_ptr<std::vector<float>> padding_features(
- new std::vector<float>());
- padding_features->reserve(feature_vector_size);
- if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
- embedding_executor, embedding_cache,
- padding_features.get())) {
- TC3_LOG(ERROR) << "Count not get padding token features.";
- return false;
- }
-
- *cached_features = CachedFeatures::Create(token_span, std::move(features),
- std::move(padding_features),
- options_, feature_vector_size);
- if (!*cached_features) {
- TC3_LOG(ERROR) << "Cound not create cached features.";
- return false;
- }
-
- return true;
-}
-
-bool FeatureProcessor::AppendTokenFeaturesWithCache(
- const Token& token, CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache,
- std::vector<float>* output_features) const {
- // Look for the embedded features for the token in the cache, if there is one.
- if (embedding_cache) {
- const auto it = embedding_cache->find({token.start, token.end});
- if (it != embedding_cache->end()) {
- // The embedded features were found in the cache, extract only the dense
- // features.
- std::vector<float> dense_features;
- if (!feature_extractor_.Extract(
- token, token.IsContainedInSpan(selection_span_for_feature),
- /*sparse_features=*/nullptr, &dense_features)) {
- TC3_LOG(ERROR) << "Could not extract token's dense features.";
- return false;
- }
-
- // Append both embedded and dense features to the output and return.
- output_features->insert(output_features->end(), it->second.begin(),
- it->second.end());
- output_features->insert(output_features->end(), dense_features.begin(),
- dense_features.end());
- return true;
- }
- }
-
- // Extract the sparse and dense features.
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- if (!feature_extractor_.Extract(
- token, token.IsContainedInSpan(selection_span_for_feature),
- &sparse_features, &dense_features)) {
- TC3_LOG(ERROR) << "Could not extract token's features.";
- return false;
- }
-
- // Embed the sparse features, appending them directly to the output.
- const int embedding_size = GetOptions()->embedding_size();
- output_features->resize(output_features->size() + embedding_size);
- float* output_features_end =
- output_features->data() + output_features->size();
- if (!embedding_executor->AddEmbedding(
- TensorView<int>(sparse_features.data(),
- {static_cast<int>(sparse_features.size())}),
- /*dest=*/output_features_end - embedding_size,
- /*dest_size=*/embedding_size)) {
- TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
- return false;
- }
-
- // If there is a cache, the embedded features for the token were not in it,
- // so insert them.
- if (embedding_cache) {
- (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
- output_features_end - embedding_size, output_features_end);
- }
-
- // Append the dense features to the output.
- output_features->insert(output_features->end(), dense_features.begin(),
- dense_features.end());
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/feature-processor.h b/annotator/feature-processor.h
deleted file mode 100644
index 4a753b0..0000000
--- a/annotator/feature-processor.h
+++ /dev/null
@@ -1,290 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Feature processing for FFModel (feed-forward SmartSelection model).
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
-
-#include <map>
-#include <memory>
-#include <set>
-#include <string>
-#include <vector>
-
-#include "annotator/cached-features.h"
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/token-feature-extractor.h"
-#include "utils/tokenizer.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-constexpr int kInvalidLabel = -1;
-
-namespace internal {
-
-Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
- const UniLib* unilib);
-
-TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
- const FeatureProcessorOptions* options);
-
-// Splits tokens that contain the selection boundary inside them.
-// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
-void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
- std::vector<Token>* tokens);
-
-// Returns the index of token that corresponds to the codepoint span.
-int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
-
-// Returns the index of token that corresponds to the middle of the codepoint
-// span.
-int CenterTokenFromMiddleOfSelection(
- CodepointSpan span, const std::vector<Token>& selectable_tokens);
-
-// Strips the tokens from the tokens vector that are not used for feature
-// extraction because they are out of scope, or pads them so that there is
-// enough tokens in the required context_size for all inferences with a click
-// in relative_click_span.
-void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
- std::vector<Token>* tokens, int* click_pos);
-
-} // namespace internal
-
-// Converts a codepoint span to a token span in the given list of tokens.
-// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
-// token to overlap with the codepoint range to be considered part of it.
-// Otherwise it must be fully included in the range.
-TokenSpan CodepointSpanToTokenSpan(
- const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
- bool snap_boundaries_to_containing_tokens = false);
-
-// Converts a token span to a codepoint span in the given list of tokens.
-CodepointSpan TokenSpanToCodepointSpan(
- const std::vector<Token>& selectable_tokens, TokenSpan token_span);
-
-// Takes care of preparing features for the span prediction model.
-class FeatureProcessor {
- public:
- // A cache mapping codepoint spans to embedded tokens features. An instance
- // can be provided to multiple calls to ExtractFeatures() operating on the
- // same context (the same codepoint spans corresponding to the same tokens),
- // as an optimization. Note that the tokenizations do not have to be
- // identical.
- typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
-
- FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
- : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
- *unilib),
- options_(options),
- tokenizer_(internal::BuildTokenizer(options, unilib)) {
- MakeLabelMaps();
- if (options->supported_codepoint_ranges() != nullptr) {
- SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
- options->supported_codepoint_ranges()->end()},
- &supported_codepoint_ranges_);
- }
- PrepareIgnoredSpanBoundaryCodepoints();
- }
-
- // Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& text) const;
-
- // Same as above but takes UnicodeText.
- std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
-
- // Converts a label into a token span.
- bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
-
- // Gets the total number of selection labels.
- int GetSelectionLabelCount() const { return label_to_selection_.size(); }
-
- // Gets the string value for given collection label.
- std::string LabelToCollection(int label) const;
-
- // Gets the total number of collections of the model.
- int NumCollections() const { return collection_to_label_.size(); }
-
- // Gets the name of the default collection.
- std::string GetDefaultCollection() const;
-
- const FeatureProcessorOptions* GetOptions() const { return options_; }
-
- // Retokenizes the context and input span, and finds the click position.
- // Depending on the options, might modify tokens (split them or remove them).
- void RetokenizeAndFindClick(const std::string& context,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens, int* click_pos) const;
-
- // Same as above but takes UnicodeText.
- void RetokenizeAndFindClick(const UnicodeText& context_unicode,
- CodepointSpan input_span,
- bool only_use_line_with_click,
- std::vector<Token>* tokens, int* click_pos) const;
-
- // Returns true if the token span has enough supported codepoints (as defined
- // in the model config) or not and model should not run.
- bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
- TokenSpan token_span) const;
-
- // Extracts features as a CachedFeatures object that can be used for repeated
- // inference over token spans in the given context.
- bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache, int feature_vector_size,
- std::unique_ptr<CachedFeatures>* cached_features) const;
-
- // Fills selection_label_spans with CodepointSpans that correspond to the
- // selection labels. The CodepointSpans are based on the codepoint ranges of
- // given tokens.
- bool SelectionLabelSpans(
- VectorSpan<Token> tokens,
- std::vector<CodepointSpan>* selection_label_spans) const;
-
- int DenseFeaturesCount() const {
- return feature_extractor_.DenseFeaturesCount();
- }
-
- int EmbeddingSize() const { return options_->embedding_size(); }
-
- // Splits context to several segments.
- std::vector<UnicodeTextRange> SplitContext(
- const UnicodeText& context_unicode) const;
-
- // Strips boundary codepoints from the span in context and returns the new
- // start and end indices. If the span comprises entirely of boundary
- // codepoints, the first index of span is returned for both indices.
- CodepointSpan StripBoundaryCodepoints(const std::string& context,
- CodepointSpan span) const;
-
- // Same as above but takes UnicodeText.
- CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
- CodepointSpan span) const;
-
- // Same as above but takes a pair of iterators for the span, for efficiency.
- CodepointSpan StripBoundaryCodepoints(
- const UnicodeText::const_iterator& span_begin,
- const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
-
- // Same as above, but takes an optional buffer for saving the modified value.
- // As an optimization, returns pointer to 'value' if nothing was stripped, or
- // pointer to 'buffer' if something was stripped.
- const std::string& StripBoundaryCodepoints(const std::string& value,
- std::string* buffer) const;
-
- protected:
- // Returns the class id corresponding to the given string collection
- // identifier. There is a catch-all class id that the function returns for
- // unknown collections.
- int CollectionToLabel(const std::string& collection) const;
-
- // Prepares mapping from collection names to labels.
- void MakeLabelMaps();
-
- // Gets the number of spannable tokens for the model.
- //
- // Spannable tokens are those tokens of context, which the model predicts
- // selection spans over (i.e., there is 1:1 correspondence between the output
- // classes of the model and each of the spannable tokens).
- int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
-
- // Converts a label into a span of codepoint indices corresponding to it
- // given output_tokens.
- bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
- CodepointSpan* span) const;
-
- // Converts a span to the corresponding label given output_tokens.
- bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
- const std::vector<Token>& output_tokens, int* label) const;
-
- // Converts a token span to the corresponding label.
- int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
-
- // Returns the ratio of supported codepoints to total number of codepoints in
- // the given token span.
- float SupportedCodepointsRatio(const TokenSpan& token_span,
- const std::vector<Token>& tokens) const;
-
- void PrepareIgnoredSpanBoundaryCodepoints();
-
- // Counts the number of span boundary codepoints. If count_from_beginning is
- // True, the counting will start at the span_start iterator (inclusive) and at
- // maximum end at span_end (exclusive). If count_from_beginning is True, the
- // counting will start from span_end (exclusive) and end at span_start
- // (inclusive).
- int CountIgnoredSpanBoundaryCodepoints(
- const UnicodeText::const_iterator& span_start,
- const UnicodeText::const_iterator& span_end,
- bool count_from_beginning) const;
-
- // Finds the center token index in tokens vector, using the method defined
- // in options_.
- int FindCenterToken(CodepointSpan span,
- const std::vector<Token>& tokens) const;
-
- // Removes all tokens from tokens that are not on a line (defined by calling
- // SplitContext on the context) to which span points.
- void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
- std::vector<Token>* tokens) const;
-
- // Same as above but takes UnicodeText.
- void StripTokensFromOtherLines(const UnicodeText& context_unicode,
- CodepointSpan span,
- std::vector<Token>* tokens) const;
-
- // Extracts the features of a token and appends them to the output vector.
- // Uses the embedding cache to to avoid re-extracting the re-embedding the
- // sparse features for the same token.
- bool AppendTokenFeaturesWithCache(const Token& token,
- CodepointSpan selection_span_for_feature,
- const EmbeddingExecutor* embedding_executor,
- EmbeddingCache* embedding_cache,
- std::vector<float>* output_features) const;
-
- protected:
- const TokenFeatureExtractor feature_extractor_;
-
- // Codepoint ranges that define what codepoints are supported by the model.
- // NOTE: Must be sorted.
- std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
-
- private:
- // Set of codepoints that will be stripped from beginning and end of
- // predicted spans.
- std::set<int32> ignored_span_boundary_codepoints_;
-
- const FeatureProcessorOptions* const options_;
-
- // Mapping between token selection spans and labels ids.
- std::map<TokenSpan, int> selection_to_label_;
- std::vector<TokenSpan> label_to_selection_;
-
- // Mapping between collections and labels.
- std::map<std::string, int> collection_to_label_;
-
- Tokenizer tokenizer_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
diff --git a/annotator/feature-processor_test.cc b/annotator/feature-processor_test.cc
deleted file mode 100644
index 5337776..0000000
--- a/annotator/feature-processor_test.cc
+++ /dev/null
@@ -1,975 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/feature-processor.h"
-
-#include "annotator/model-executor.h"
-#include "utils/tensor-view.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-using testing::FloatEq;
-using testing::Matcher;
-
-flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
- const FeatureProcessorOptionsT& options) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateFeatureProcessorOptions(builder, &options));
- return builder.Release();
-}
-
-template <typename T>
-std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) {
- return std::vector<T>(vector.begin() + start, vector.begin() + end);
-}
-
-Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
- std::vector<Matcher<float>> matchers;
- for (const float value : values) {
- matchers.push_back(FloatEq(value));
- }
- return ElementsAreArray(matchers);
-}
-
-class TestingFeatureProcessor : public FeatureProcessor {
- public:
- using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints;
- using FeatureProcessor::FeatureProcessor;
- using FeatureProcessor::SpanToLabel;
- using FeatureProcessor::StripTokensFromOtherLines;
- using FeatureProcessor::supported_codepoint_ranges_;
- using FeatureProcessor::SupportedCodepointsRatio;
-};
-
-// EmbeddingExecutor that always returns features based on
-class FakeEmbeddingExecutor : public EmbeddingExecutor {
- public:
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) const override {
- TC3_CHECK_GE(dest_size, 4);
- EXPECT_EQ(sparse_features.size(), 1);
- dest[0] = sparse_features.data()[0];
- dest[1] = sparse_features.data()[0];
- dest[2] = -sparse_features.data()[0];
- dest[3] = -sparse_features.data()[0];
- return true;
- }
-
- private:
- std::vector<float> storage_;
-};
-
-class FeatureProcessorTest : public ::testing::Test {
- protected:
- FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěě", 6, 9),
- Token("bař", 9, 12),
- Token("@google.com", 12, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěěbař", 6, 12),
- Token("@google.com", 12, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěě", 6, 9),
- Token("bař@google.com", 9, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
- std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
-
- // clang-format off
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hě", 0, 2),
- Token("lló", 2, 5),
- Token("fěě", 6, 9),
- Token("bař@google.com", 9, 23),
- Token("heře!", 24, 29)}));
- // clang-format on
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {0, 5};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {18, 22};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickThird) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {24, 33};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
- const CodepointSpan span = {18, 22};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 11, 17),
- Token("Lině", 18, 22),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
-}
-
-TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) {
- FeatureProcessorOptionsT options;
- options.only_use_line_with_click = true;
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
- const CodepointSpan span = {5, 23};
- // clang-format off
- std::vector<Token> tokens = {Token("Fiřst", 0, 5),
- Token("Lině", 6, 10),
- Token("Sěcond", 18, 23),
- Token("Lině", 19, 23),
- Token("Thiřd", 23, 28),
- Token("Lině", 29, 33)};
- // clang-format on
-
- // Keeps the first line.
- feature_processor.StripTokensFromOtherLines(context, span, &tokens);
- EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
- Token("Sěcond", 18, 23), Token("Lině", 19, 23),
- Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
-}
-
-TEST_F(FeatureProcessorTest, SpanToLabel) {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
- ASSERT_EQ(3, tokens.size());
- int label;
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
- EXPECT_EQ(kInvalidLabel, label);
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
- EXPECT_NE(kInvalidLabel, label);
- TokenSpan token_span;
- feature_processor.LabelToTokenSpan(label, &token_span);
- EXPECT_EQ(0, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- // Reconfigure with snapping enabled.
- options.snap_label_span_boundaries_to_containing_tokens = true;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- int label2;
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
-
- // Cross a token boundary.
- ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
-
- // Multiple tokens.
- options.context_size = 2;
- options.max_selection_span = 2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- tokens = feature_processor3.Tokenize("zero, one, two, three, four");
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
- EXPECT_NE(kInvalidLabel, label2);
- feature_processor3.LabelToTokenSpan(label2, &token_span);
- EXPECT_EQ(1, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- int label3;
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
-}
-
-TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- std::vector<Token> tokens = feature_processor.Tokenize("one, two, three");
- ASSERT_EQ(3, tokens.size());
- int label;
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label));
- EXPECT_EQ(kInvalidLabel, label);
- ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label));
- EXPECT_NE(kInvalidLabel, label);
- TokenSpan token_span;
- feature_processor.LabelToTokenSpan(label, &token_span);
- EXPECT_EQ(0, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- // Reconfigure with snapping enabled.
- options.snap_label_span_boundaries_to_containing_tokens = true;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- int label2;
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2));
- EXPECT_EQ(label, label2);
-
- // Cross a token boundary.
- ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
- ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2));
- EXPECT_EQ(kInvalidLabel, label2);
-
- // Multiple tokens.
- options.context_size = 2;
- options.max_selection_span = 2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- tokens = feature_processor3.Tokenize("zero, one, two, three, four");
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2));
- EXPECT_NE(kInvalidLabel, label2);
- feature_processor3.LabelToTokenSpan(label2, &token_span);
- EXPECT_EQ(1, token_span.first);
- EXPECT_EQ(0, token_span.second);
-
- int label3;
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
- ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3));
- EXPECT_EQ(label2, label3);
-}
-
-TEST_F(FeatureProcessorTest, CenterTokenFromClick) {
- int token_index;
-
- // Exactly aligned indices.
- token_index = internal::CenterTokenFromClick(
- {6, 11},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, 1);
-
- // Click is contained in a token.
- token_index = internal::CenterTokenFromClick(
- {13, 17},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, 2);
-
- // Click spans two tokens.
- token_index = internal::CenterTokenFromClick(
- {6, 17},
- {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
- EXPECT_EQ(token_index, kInvalidIndex);
-}
-
-TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
- int token_index;
-
- // Selection of length 3. Exactly aligned indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {7, 27},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 2);
-
- // Selection of length 1 token. Exactly aligned indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {21, 27},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 3);
-
- // Selection marks sub-token range, with no tokens in it.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {29, 33},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, kInvalidIndex);
-
- // Selection of length 2. Sub-token indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {3, 25},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 1);
-
- // Selection of length 1. Sub-token indices.
- token_index = internal::CenterTokenFromMiddleOfSelection(
- {22, 34},
- {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
- Token("Token4", 21, 27), Token("Token5", 28, 34)});
- EXPECT_EQ(token_index, 4);
-
- // Some invalid ones.
- token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {});
- EXPECT_EQ(token_index, -1);
-}
-
-TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.bounds_sensitive_features.reset(
- new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
- options.bounds_sensitive_features->enabled = true;
- options.bounds_sensitive_features->num_tokens_before = 5;
- options.bounds_sensitive_features->num_tokens_inside_left = 3;
- options.bounds_sensitive_features->num_tokens_inside_right = 3;
- options.bounds_sensitive_features->num_tokens_after = 5;
- options.bounds_sensitive_features->include_inside_bag = true;
- options.bounds_sensitive_features->include_inside_length = true;
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 0;
- range->end = 128;
- }
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 10000;
- range->end = 10001;
- }
-
- {
- options.supported_codepoint_ranges.emplace_back(new CodepointRangeT());
- auto& range = options.supported_codepoint_ranges.back();
- range->start = 20000;
- range->end = 30000;
- }
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("aaa bbb ccc")),
- FloatEq(1.0));
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("aaa bbb ěěě")),
- FloatEq(2.0 / 3));
- EXPECT_THAT(feature_processor.SupportedCodepointsRatio(
- {0, 3}, feature_processor.Tokenize("ěěě řřř ěěě")),
- FloatEq(0.0));
- EXPECT_FALSE(
- IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(
- IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(
- IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(
- IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(IsCodepointInRanges(
- 10000, feature_processor.supported_codepoint_ranges_));
- EXPECT_FALSE(IsCodepointInRanges(
- 10001, feature_processor.supported_codepoint_ranges_));
- EXPECT_TRUE(IsCodepointInRanges(
- 25000, feature_processor.supported_codepoint_ranges_));
-
- const std::vector<Token> tokens = {Token("ěěě", 0, 3), Token("řřř", 4, 7),
- Token("eee", 8, 11)};
-
- options.min_supported_codepoint_ratio = 0.0;
- flatbuffers::DetachedBuffer options2_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor2(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()),
- &unilib_);
- EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-
- options.min_supported_codepoint_ratio = 0.2;
- flatbuffers::DetachedBuffer options3_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor3(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()),
- &unilib_);
- EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-
- options.min_supported_codepoint_ratio = 0.5;
- flatbuffers::DetachedBuffer options4_fb =
- PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor4(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()),
- &unilib_);
- EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints(
- tokens, /*token_span=*/{0, 3}));
-}
-
-TEST_F(FeatureProcessorTest, InSpanFeature) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.extract_selection_mask_feature = true;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- std::unique_ptr<CachedFeatures> cached_features;
-
- FakeEmbeddingExecutor embedding_executor;
-
- const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7),
- Token("ccc", 8, 11), Token("ddd", 12, 15)};
-
- EXPECT_TRUE(feature_processor.ExtractFeatures(
- tokens, /*token_span=*/{0, 4},
- /*selection_span_for_feature=*/{4, 11}, &embedding_executor,
- /*embedding_cache=*/nullptr, /*feature_vector_size=*/5,
- &cached_features));
- std::vector<float> features;
- cached_features->AppendClickContextFeaturesForClick(1, &features);
- ASSERT_EQ(features.size(), 25);
- EXPECT_THAT(features[4], FloatEq(0.0));
- EXPECT_THAT(features[9], FloatEq(0.0));
- EXPECT_THAT(features[14], FloatEq(1.0));
- EXPECT_THAT(features[19], FloatEq(1.0));
- EXPECT_THAT(features[24], FloatEq(0.0));
-}
-
-TEST_F(FeatureProcessorTest, EmbeddingCache) {
- FeatureProcessorOptionsT options;
- options.context_size = 2;
- options.max_selection_span = 2;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.feature_version = 2;
- options.embedding_size = 4;
- options.bounds_sensitive_features.reset(
- new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
- options.bounds_sensitive_features->enabled = true;
- options.bounds_sensitive_features->num_tokens_before = 3;
- options.bounds_sensitive_features->num_tokens_inside_left = 2;
- options.bounds_sensitive_features->num_tokens_inside_right = 2;
- options.bounds_sensitive_features->num_tokens_after = 3;
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- std::unique_ptr<CachedFeatures> cached_features;
-
- FakeEmbeddingExecutor embedding_executor;
-
- const std::vector<Token> tokens = {
- Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11),
- Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)};
-
- // We pre-populate the cache with dummy embeddings, to make sure they are
- // used when populating the features vector.
- const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0};
- const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0};
- const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0};
- FeatureProcessor::EmbeddingCache embedding_cache = {
- {{kInvalidIndex, kInvalidIndex}, cached_padding_features},
- {{4, 7}, cached_features1},
- {{12, 15}, cached_features2},
- };
-
- EXPECT_TRUE(feature_processor.ExtractFeatures(
- tokens, /*token_span=*/{0, 6},
- /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
- &embedding_executor, &embedding_cache, /*feature_vector_size=*/4,
- &cached_features));
- std::vector<float> features;
- cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features);
- ASSERT_EQ(features.size(), 40);
- // Check that the dummy embeddings were used.
- EXPECT_THAT(Subvector(features, 0, 4),
- ElementsAreFloat(cached_padding_features));
- EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1));
- EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2));
- EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2));
- EXPECT_THAT(Subvector(features, 36, 40),
- ElementsAreFloat(cached_padding_features));
- // Check that the real embeddings were cached.
- EXPECT_EQ(embedding_cache.size(), 7);
- EXPECT_THAT(Subvector(features, 4, 8),
- ElementsAreFloat(embedding_cache.at({0, 3})));
- EXPECT_THAT(Subvector(features, 12, 16),
- ElementsAreFloat(embedding_cache.at({8, 11})));
- EXPECT_THAT(Subvector(features, 20, 24),
- ElementsAreFloat(embedding_cache.at({8, 11})));
- EXPECT_THAT(Subvector(features, 28, 32),
- ElementsAreFloat(embedding_cache.at({16, 19})));
- EXPECT_THAT(Subvector(features, 32, 36),
- ElementsAreFloat(embedding_cache.at({20, 23})));
-}
-
-TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
- std::vector<Token> tokens_orig{
- Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
- Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
- Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
- Token("12", 0, 0)};
-
- std::vector<Token> tokens;
- int click_index;
-
- // Try to click first token and see if it gets padded from left.
- tokens = tokens_orig;
- click_index = 0;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token(),
- Token(),
- Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // When we click the second token nothing should get padded.
- tokens = tokens_orig;
- click_index = 2;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // When we click the last token tokens should get padded from the right.
- tokens = tokens_orig;
- click_index = 12;
- internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
- Token("11", 0, 0),
- Token("12", 0, 0),
- Token(),
- Token()}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-}
-
-TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
- std::vector<Token> tokens_orig{
- Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
- Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
- Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
- Token("12", 0, 0)};
-
- std::vector<Token> tokens;
- int click_index;
-
- // Try to click first token and see if it gets padded from left to maximum
- // context_size.
- tokens = tokens_orig;
- click_index = 0;
- internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token(),
- Token(),
- Token("0", 0, 0),
- Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0),
- Token("5", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 2);
-
- // Clicking to the middle with enough context should not produce any padding.
- tokens = tokens_orig;
- click_index = 6;
- internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
- Token("2", 0, 0),
- Token("3", 0, 0),
- Token("4", 0, 0),
- Token("5", 0, 0),
- Token("6", 0, 0),
- Token("7", 0, 0),
- Token("8", 0, 0),
- Token("9", 0, 0)}));
- // clang-format on
- EXPECT_EQ(click_index, 5);
-
- // Clicking at the end should pad right to maximum context_size.
- tokens = tokens_orig;
- click_index = 11;
- internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
- // clang-format off
- EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
- Token("7", 0, 0),
- Token("8", 0, 0),
- Token("9", 0, 0),
- Token("10", 0, 0),
- Token("11", 0, 0),
- Token("12", 0, 0),
- Token(),
- Token()}));
- // clang-format on
- EXPECT_EQ(click_index, 5);
-}
-
-TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) {
- FeatureProcessorOptionsT options;
- options.ignored_span_boundary_codepoints.push_back('.');
- options.ignored_span_boundary_codepoints.push_back(',');
- options.ignored_span_boundary_codepoints.push_back('[');
- options.ignored_span_boundary_codepoints.push_back(']');
-
- flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options);
- TestingFeatureProcessor feature_processor(
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
- &unilib_);
-
- const std::string text1_utf8 = "ěščř";
- const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text1.begin(), text1.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text1.begin(), text1.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text2_utf8 = ".,abčd";
- const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text2.begin(), text2.end(),
- /*count_from_beginning=*/true),
- 2);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text2.begin(), text2.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text3_utf8 = ".,abčd[]";
- const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text3.begin(), text3.end(),
- /*count_from_beginning=*/true),
- 2);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text3.begin(), text3.end(),
- /*count_from_beginning=*/false),
- 2);
-
- const std::string text4_utf8 = "[abčd]";
- const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text4.begin(), text4.end(),
- /*count_from_beginning=*/true),
- 1);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text4.begin(), text4.end(),
- /*count_from_beginning=*/false),
- 1);
-
- const std::string text5_utf8 = "";
- const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text5.begin(), text5.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text5.begin(), text5.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text6_utf8 = "012345ěščř";
- const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false);
- UnicodeText::const_iterator text6_begin = text6.begin();
- std::advance(text6_begin, 6);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text6_begin, text6.end(),
- /*count_from_beginning=*/true),
- 0);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text6_begin, text6.end(),
- /*count_from_beginning=*/false),
- 0);
-
- const std::string text7_utf8 = "012345.,ěščř";
- const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false);
- UnicodeText::const_iterator text7_begin = text7.begin();
- std::advance(text7_begin, 6);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text7_begin, text7.end(),
- /*count_from_beginning=*/true),
- 2);
- UnicodeText::const_iterator text7_end = text7.begin();
- std::advance(text7_end, 8);
- EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints(
- text7.begin(), text7_end,
- /*count_from_beginning=*/false),
- 2);
-
- // Test not stripping.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
- "Hello [[[Wořld]] or not?", {0, 24}),
- std::make_pair(0, 24));
- // Test basic stripping.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints(
- "Hello [[[Wořld]] or not?", {6, 16}),
- std::make_pair(9, 14));
- // Test stripping when everything is stripped.
- EXPECT_EQ(
- feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}),
- std::make_pair(6, 6));
- // Test stripping empty string.
- EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}),
- std::make_pair(0, 0));
-}
-
-TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) {
- const std::vector<Token> tokens{Token("Hělló", 0, 5),
- Token("fěěbař@google.com", 6, 23),
- Token("heře!", 24, 29)};
-
- // Spans matching the tokens exactly.
- EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}));
- EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}));
- EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}));
- EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}));
-
- // Snapping to containing tokens has no effect.
- EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true));
- EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true));
- EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true));
- EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true));
-
- // Span boundaries inside tokens.
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28}));
- EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true));
-
- // Tokens adjacent to the span, but not overlapping.
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}));
- EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h
deleted file mode 100644
index 96d77c5..0000000
--- a/annotator/knowledge/knowledge-engine-dummy.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
-
-#include <string>
-
-#include "annotator/types.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-// A dummy implementation of the knowledge engine.
-class KnowledgeEngine {
- public:
- explicit KnowledgeEngine(const UniLib* unilib) {}
-
- bool Initialize(const std::string& serialized_config) { return true; }
-
- bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
- ClassificationResult* classification_result) const {
- return false;
- }
-
- bool Chunk(const std::string& context,
- std::vector<AnnotatedSpan>* result) const {
- return true;
- }
-
- bool LookUpEntity(const std::string& id,
- std::string* serialized_knowledge_result) const {
- return false;
- }
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
diff --git a/annotator/model-executor.h b/annotator/model-executor.h
deleted file mode 100644
index bcc318b..0000000
--- a/annotator/model-executor.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Contains classes that can execute different models/parts of a model.
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
-
-#include <memory>
-
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-#include "utils/tensor-view.h"
-#include "utils/tflite-model-executor.h"
-
-namespace libtextclassifier3 {
-
-// Executor for the text selection prediction and classification models.
-class ModelExecutor : public TfLiteModelExecutor {
- public:
- static std::unique_ptr<ModelExecutor> FromModelSpec(
- const tflite::Model* model_spec) {
- auto model = TfLiteModelFromModelSpec(model_spec);
- if (!model) {
- return nullptr;
- }
- return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
- }
-
- static std::unique_ptr<ModelExecutor> FromBuffer(
- const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
- auto model = TfLiteModelFromBuffer(model_spec_buffer);
- if (!model) {
- return nullptr;
- }
- return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
- }
-
- TensorView<float> ComputeLogits(const TensorView<float>& features,
- tflite::Interpreter* interpreter) const;
-
- protected:
- explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
- : TfLiteModelExecutor(std::move(model)) {}
-
- static const int kInputIndexFeatures = 0;
- static const int kOutputIndexLogits = 0;
-};
-
-// Executor for embedding sparse features into a dense vector.
-class EmbeddingExecutor {
- public:
- virtual ~EmbeddingExecutor() {}
-
- // Embeds the sparse_features into a dense embedding and adds (+) it
- // element-wise to the dest vector.
- virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) const = 0;
-
- // Returns true when the model is ready to be used, false otherwise.
- virtual bool IsReady() const { return true; }
-};
-
-class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
- public:
- static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
- const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
- int quantization_bits,
- const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
-
- // Embeds the sparse_features into a dense embedding and adds (+) it
- // element-wise to the dest vector.
- bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
- int dest_size) const;
-
- // Auxiliary function for computing prefixes used in implementation of
- // efficient mask indexing data structure.
- void ComputePrefixCounts();
-
- // Function implementing mask indexing based on efficient data structure
- int PruneBucketId(int bucket_id) const;
-
- protected:
- explicit TFLiteEmbeddingExecutor(
- std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
- int num_buckets, int bytes_per_embedding, int output_embedding_size,
- const TfLiteTensor* scales, const TfLiteTensor* embeddings,
- std::unique_ptr<tflite::Interpreter> interpreter,
- const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
-
- std::unique_ptr<TfLiteModelExecutor> executor_;
-
- int quantization_bits_;
- int num_buckets_ = -1;
- int bytes_per_embedding_ = -1;
- int output_embedding_size_ = -1;
- const TfLiteTensor* scales_ = nullptr;
- const TfLiteTensor* embeddings_ = nullptr;
-
- // NOTE: This interpreter is used in a read-only way (as a storage for the
- // model params), thus is still thread-safe.
- std::unique_ptr<tflite::Interpreter> interpreter_;
-
- std::vector<uint64> pruning_mask_;
- std::vector<uint16> prefix_counts_;
- int full_num_buckets_ = -1;
-
- // Index of row of embedding table corresponding to all pruned buckets.
- int pruned_row_bucket_id_ = -1;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
diff --git a/annotator/model.fbs b/annotator/model.fbs
deleted file mode 100755
index 9d18779..0000000
--- a/annotator/model.fbs
+++ /dev/null
@@ -1,689 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "utils/codepoint-range.fbs";
-include "utils/flatbuffers.fbs";
-include "utils/intents/intent-config.fbs";
-include "utils/resources.fbs";
-include "utils/tokenizer.fbs";
-include "utils/zlib/buffer.fbs";
-
-file_identifier "TC2 ";
-
-// The possible model modes, represents a bit field.
-namespace libtextclassifier3;
-enum ModeFlag : int {
- NONE = 0,
- ANNOTATION = 1,
- CLASSIFICATION = 2,
- ANNOTATION_AND_CLASSIFICATION = 3,
- SELECTION = 4,
- ANNOTATION_AND_SELECTION = 5,
- CLASSIFICATION_AND_SELECTION = 6,
- ALL = 7,
-}
-
-// Enum for specifying the annotation usecase.
-namespace libtextclassifier3;
-enum AnnotationUsecase : int {
- // Results are optimized for Smart{Select,Share,Linkify}.
- ANNOTATION_USECASE_SMART = 0,
-
- // Results are optimized for using TextClassifier as an infrastructure that
- // annotates as much as possible.
- ANNOTATION_USECASE_RAW = 1,
-}
-
-namespace libtextclassifier3;
-enum DatetimeExtractorType : int {
- UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
- AM = 1,
- PM = 2,
- JANUARY = 3,
- FEBRUARY = 4,
- MARCH = 5,
- APRIL = 6,
- MAY = 7,
- JUNE = 8,
- JULY = 9,
- AUGUST = 10,
- SEPTEMBER = 11,
- OCTOBER = 12,
- NOVEMBER = 13,
- DECEMBER = 14,
- NEXT = 15,
- NEXT_OR_SAME = 16,
- LAST = 17,
- NOW = 18,
- TOMORROW = 19,
- YESTERDAY = 20,
- PAST = 21,
- FUTURE = 22,
- DAY = 23,
- WEEK = 24,
- MONTH = 25,
- YEAR = 26,
- MONDAY = 27,
- TUESDAY = 28,
- WEDNESDAY = 29,
- THURSDAY = 30,
- FRIDAY = 31,
- SATURDAY = 32,
- SUNDAY = 33,
- DAYS = 34,
- WEEKS = 35,
- MONTHS = 36,
-
- // TODO(zilka): Make the following 3 values singular for consistency.
- HOURS = 37,
-
- MINUTES = 38,
- SECONDS = 39,
- YEARS = 40,
- DIGITS = 41,
- SIGNEDDIGITS = 42,
- ZERO = 43,
- ONE = 44,
- TWO = 45,
- THREE = 46,
- FOUR = 47,
- FIVE = 48,
- SIX = 49,
- SEVEN = 50,
- EIGHT = 51,
- NINE = 52,
- TEN = 53,
- ELEVEN = 54,
- TWELVE = 55,
- THIRTEEN = 56,
- FOURTEEN = 57,
- FIFTEEN = 58,
- SIXTEEN = 59,
- SEVENTEEN = 60,
- EIGHTEEN = 61,
- NINETEEN = 62,
- TWENTY = 63,
- THIRTY = 64,
- FORTY = 65,
- FIFTY = 66,
- SIXTY = 67,
- SEVENTY = 68,
- EIGHTY = 69,
- NINETY = 70,
- HUNDRED = 71,
- THOUSAND = 72,
-}
-
-namespace libtextclassifier3;
-enum DatetimeGroupType : int {
- GROUP_UNKNOWN = 0,
- GROUP_UNUSED = 1,
- GROUP_YEAR = 2,
- GROUP_MONTH = 3,
- GROUP_DAY = 4,
- GROUP_HOUR = 5,
- GROUP_MINUTE = 6,
- GROUP_SECOND = 7,
- GROUP_AMPM = 8,
- GROUP_RELATIONDISTANCE = 9,
- GROUP_RELATION = 10,
- GROUP_RELATIONTYPE = 11,
-
- // Dummy groups serve just as an inflator of the selection. E.g. we might want
- // to select more text than was contained in an envelope of all extractor
- // spans.
- GROUP_DUMMY1 = 12,
-
- GROUP_DUMMY2 = 13,
-}
-
-// Options for the model that predicts text selection.
-namespace libtextclassifier3;
-table SelectionModelOptions {
- // If true, before the selection is returned, the unpaired brackets contained
- // in the predicted selection are stripped from the both selection ends.
- // The bracket codepoints are defined in the Unicode standard:
- // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
- strip_unpaired_brackets:bool = true;
-
- // Number of hypothetical click positions on either side of the actual click
- // to consider in order to enforce symmetry.
- symmetry_context_size:int;
-
- // Number of examples to bundle in one batch for inference.
- batch_size:int = 1024;
-
- // Whether to always classify a suggested selection or only on demand.
- always_classify_suggested_selection:bool = false;
-}
-
-// Options for the model that classifies a text selection.
-namespace libtextclassifier3;
-table ClassificationModelOptions {
- // Limits for phone numbers.
- phone_min_num_digits:int = 7;
-
- phone_max_num_digits:int = 15;
-
- // Limits for addresses.
- address_min_num_tokens:int;
-
- // Maximum number of tokens to attempt a classification (-1 is unlimited).
- max_num_tokens:int = -1;
-}
-
-// Options for post-checks, checksums and verification to apply on a match.
-namespace libtextclassifier3;
-table VerificationOptions {
- verify_luhn_checksum:bool = false;
-
- // Lua verifier to use.
- // Index of the lua verifier in the model.
- lua_verifier:int = -1;
-}
-
-// Behaviour of capturing groups.
-namespace libtextclassifier3.RegexModel_.Pattern_;
-table CapturingGroup {
- // If true, the span of the capturing group will be used to
- // extend the selection.
- extend_selection:bool = true;
-
- // If set, the text of the capturing group will be used to set a field in
- // the classfication result entity data.
- entity_field_path:FlatbufferFieldPath;
-}
-
-// List of regular expression matchers to check.
-namespace libtextclassifier3.RegexModel_;
-table Pattern {
- // The name of the collection of a match.
- collection_name:string;
-
- // The pattern to check.
- pattern:string;
-
- // The modes for which to apply the patterns.
- enabled_modes:ModeFlag = ALL;
-
- // The final score to assign to the results of this pattern.
- target_classification_score:float = 1;
-
- // Priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // If true, will use an approximate matching implementation implemented
- // using Find() instead of the true Match(). This approximate matching will
- // use the first Find() result and then check that it spans the whole input.
- use_approximate_matching:bool = false;
-
- compressed_pattern:CompressedBuffer;
-
- // Verification to apply on a match.
- verification_options:VerificationOptions;
-
- capturing_group:[Pattern_.CapturingGroup];
-
- // Serialized entity data to set for a match.
- serialized_entity_data:string;
-}
-
-namespace libtextclassifier3;
-table RegexModel {
- patterns:[RegexModel_.Pattern];
-
- // If true, will compile the regexes only on first use.
- lazy_regex_compilation:bool = true;
-
- // Lua scripts for match verification.
- // The verifier can access:
- // * `context`: The context as a string.
- // * `match`: The groups of the regex match as an array, each group gives
- // * `begin`: span start
- // * `end`: span end
- // * `text`: the text
- // The verifier is expected to return a boolean, indicating whether the
- // verification succeeded or not.
- lua_verifier:[string];
-}
-
-// List of regex patterns.
-namespace libtextclassifier3.DatetimeModelPattern_;
-table Regex {
- pattern:string;
-
- // The ith entry specifies the type of the ith capturing group.
- // This is used to decide how the matched content has to be parsed.
- groups:[DatetimeGroupType];
-
- compressed_pattern:CompressedBuffer;
-}
-
-namespace libtextclassifier3;
-table DatetimeModelPattern {
- regexes:[DatetimeModelPattern_.Regex];
-
- // List of locale indices in DatetimeModel that represent the locales that
- // these patterns should be used for. If empty, can be used for all locales.
- locales:[int];
-
- // The final score to assign to the results of this pattern.
- target_classification_score:float = 1;
-
- // Priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // The modes for which to apply the patterns.
- enabled_modes:ModeFlag = ALL;
-
- // The annotation usecases for which to apply the patterns.
- // This is a flag field for values of AnnotationUsecase.
- enabled_annotation_usecases:uint = 4294967295;
-}
-
-namespace libtextclassifier3;
-table DatetimeModelExtractor {
- extractor:DatetimeExtractorType;
- pattern:string;
- locales:[int];
- compressed_pattern:CompressedBuffer;
-}
-
-namespace libtextclassifier3;
-table DatetimeModel {
- // List of BCP 47 locale strings representing all locales supported by the
- // model. The individual patterns refer back to them using an index.
- locales:[string];
-
- patterns:[DatetimeModelPattern];
- extractors:[DatetimeModelExtractor];
-
- // If true, will use the extractors for determining the match location as
- // opposed to using the location where the global pattern matched.
- use_extractors_for_locating:bool = true;
-
- // List of locale ids, rules of whose are always run, after the requested
- // ones.
- default_locales:[int];
-
- // If true, will generate the alternative interpretations for ambiguous
- // datetime expressions.
- generate_alternative_interpretations_when_ambiguous:bool = false;
-
- // If true, will compile the regexes only on first use.
- lazy_regex_compilation:bool = true;
-}
-
-namespace libtextclassifier3.DatetimeModelLibrary_;
-table Item {
- key:string;
- value:DatetimeModel;
-}
-
-// A set of named DateTime models.
-namespace libtextclassifier3;
-table DatetimeModelLibrary {
- models:[DatetimeModelLibrary_.Item];
-}
-
-// Options controlling the output of the Tensorflow Lite models.
-namespace libtextclassifier3;
-table ModelTriggeringOptions {
- // Lower bound threshold for filtering annotation model outputs.
- min_annotate_confidence:float = 0;
-
- // The modes for which to enable the models.
- enabled_modes:ModeFlag = ALL;
-
- // Comma-separated list of locales (BCP 47 tags) that dictionary
- // classification supports.
- dictionary_locales:string;
-
- // Comma-separated list of locales (BCP 47 tags) that the model supports, that
- // are used to prevent triggering on input in unsupported languages. If
- // empty, the model will trigger on all inputs.
- locales:string;
-}
-
-// Options controlling the output of the classifier.
-namespace libtextclassifier3;
-table OutputOptions {
- // Lists of collection names that will be filtered out at the output:
- // - For annotation, the spans of given collection are simply dropped.
- // - For classification, the result is mapped to the class "other".
- // - For selection, the spans of given class are returned as
- // single-selection.
- filtered_collections_annotation:[string];
-
- filtered_collections_classification:[string];
- filtered_collections_selection:[string];
-}
-
-namespace libtextclassifier3.Model_;
-table EmbeddingPruningMask {
- // If true, use pruning mask. In this case, we use mask
- // pruning_mask to determine the mapping of hashed-charactergrams.
- enabled:bool;
-
- // Packing of the binary pruning mask into uint64 values.
- pruning_mask:[ulong] (force_align: 16);
-
- // Number of buckets before pruning.
- full_num_buckets:int;
-
- // Index of row of compressed embedding matrix to which all pruned buckets
- // are mapped.
- pruned_row_bucket_id:int;
-}
-
-namespace libtextclassifier3;
-table Model {
- // Comma-separated list of locales supported by the model as BCP 47 tags.
- locales:string;
-
- version:int;
-
- // A name for the model that can be used for e.g. logging.
- name:string;
-
- selection_feature_options:FeatureProcessorOptions;
- classification_feature_options:FeatureProcessorOptions;
-
- // Tensorflow Lite models.
- selection_model:[ubyte] (force_align: 16);
-
- classification_model:[ubyte] (force_align: 16);
- embedding_model:[ubyte] (force_align: 16);
-
- // Options for the different models.
- selection_options:SelectionModelOptions;
-
- classification_options:ClassificationModelOptions;
- regex_model:RegexModel;
- datetime_model:DatetimeModel;
-
- // Options controlling the output of the models.
- triggering_options:ModelTriggeringOptions;
-
- // Global switch that controls if SuggestSelection(), ClassifyText() and
- // Annotate() will run. If a mode is disabled it returns empty/no-op results.
- enabled_modes:ModeFlag = ALL;
-
- // If true, will snap the selections that consist only of whitespaces to the
- // containing suggested span. Otherwise, no suggestion is proposed, since the
- // selections are not part of any token.
- snap_whitespace_selections:bool = true;
-
- // Global configuration for the output of SuggestSelection(), ClassifyText()
- // and Annotate().
- output_options:OutputOptions;
-
- // Configures how Intents should be generated on Android.
- android_intent_options:AndroidIntentFactoryOptions;
-
- intent_options:IntentFactoryModel;
-
- // Model resources.
- resources:ResourcePool;
-
- // Schema data for handling entity data.
- entity_data_schema:[ubyte];
-
- number_annotator_options:NumberAnnotatorOptions;
- duration_annotator_options:DurationAnnotatorOptions;
-
- // Comma-separated list of locales (BCP 47 tags) that the model supports, that
- // are used to prevent triggering on input in unsupported languages. If
- // empty, the model will trigger on all inputs.
- triggering_locales:string;
-
- embedding_pruning_mask:Model_.EmbeddingPruningMask;
-}
-
-// Method for selecting the center token.
-namespace libtextclassifier3.FeatureProcessorOptions_;
-enum CenterTokenSelectionMethod : int {
- DEFAULT_CENTER_TOKEN_METHOD = 0,
-
- // Use click indices to determine the center token.
- CENTER_TOKEN_FROM_CLICK = 1,
-
- // Use selection indices to get a token range, and select the middle of it
- // as the center token.
- CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
-}
-
-// Bounds-sensitive feature extraction configuration.
-namespace libtextclassifier3.FeatureProcessorOptions_;
-table BoundsSensitiveFeatures {
- // Enables the extraction of bounds-sensitive features, instead of the click
- // context features.
- enabled:bool;
-
- // The numbers of tokens to extract in specific locations relative to the
- // bounds.
- // Immediately before the span.
- num_tokens_before:int;
-
- // Inside the span, aligned with the beginning.
- num_tokens_inside_left:int;
-
- // Inside the span, aligned with the end.
- num_tokens_inside_right:int;
-
- // Immediately after the span.
- num_tokens_after:int;
-
- // If true, also extracts the tokens of the entire span and adds up their
- // features forming one "token" to include in the extracted features.
- include_inside_bag:bool;
-
- // If true, includes the selection length (in the number of tokens) as a
- // feature.
- include_inside_length:bool;
-
- // If true, for selection, single token spans are not run through the model
- // and their score is assumed to be zero.
- score_single_token_spans_as_zero:bool;
-}
-
-namespace libtextclassifier3;
-table FeatureProcessorOptions {
- // Number of buckets used for hashing charactergrams.
- num_buckets:int = -1;
-
- // Size of the embedding.
- embedding_size:int = -1;
-
- // Number of bits for quantization for embeddings.
- embedding_quantization_bits:int = 8;
-
- // Context size defines the number of words to the left and to the right of
- // the selected word to be used as context. For example, if context size is
- // N, then we take N words to the left and N words to the right of the
- // selected word as its context.
- context_size:int = -1;
-
- // Maximum number of words of the context to select in total.
- max_selection_span:int = -1;
-
- // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
- // character trigrams etc.
- chargram_orders:[int];
-
- // Maximum length of a word, in codepoints.
- max_word_length:int = 20;
-
- // If true, will use the unicode-aware functionality for extracting features.
- unicode_aware_features:bool = false;
-
- // Whether to extract the token case feature.
- extract_case_feature:bool = false;
-
- // Whether to extract the selection mask feature.
- extract_selection_mask_feature:bool = false;
-
- // List of regexps to run over each token. For each regexp, if there is a
- // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used.
- regexp_feature:[string];
-
- // Whether to remap all digits to a single number.
- remap_digits:bool = false;
-
- // Whether to lower-case each token before generating hashgrams.
- lowercase_tokens:bool;
-
- // If true, the selection classifier output will contain only the selections
- // that are feasible (e.g., those that are shorter than max_selection_span),
- // if false, the output will be a complete cross-product of possible
- // selections to the left and possible selections to the right, including the
- // infeasible ones.
- // NOTE: Exists mainly for compatibility with older models that were trained
- // with the non-reduced output space.
- selection_reduced_output_space:bool = true;
-
- // Collection names.
- collections:[string];
-
- // An index of collection in collections to be used if a collection name can't
- // be mapped to an id.
- default_collection:int = -1;
-
- // If true, will split the input by lines, and only use the line that contains
- // the clicked token.
- only_use_line_with_click:bool = false;
-
- // If true, will split tokens that contain the selection boundary, at the
- // position of the boundary.
- // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
- split_tokens_on_selection_boundaries:bool = false;
-
- // Codepoint ranges that determine how different codepoints are tokenized.
- // The ranges must not overlap.
- tokenization_codepoint_config:[TokenizationCodepointRange];
-
- center_token_selection_method:FeatureProcessorOptions_.CenterTokenSelectionMethod;
-
- // If true, span boundaries will be snapped to containing tokens and not
- // required to exactly match token boundaries.
- snap_label_span_boundaries_to_containing_tokens:bool;
-
- // A set of codepoint ranges supported by the model.
- supported_codepoint_ranges:[CodepointRange];
-
- // A set of codepoint ranges to use in the mixed tokenization mode to identify
- // stretches of tokens to re-tokenize using the internal tokenizer.
- internal_tokenizer_codepoint_ranges:[CodepointRange];
-
- // Minimum ratio of supported codepoints in the input context. If the ratio
- // is lower than this, the feature computation will fail.
- min_supported_codepoint_ratio:float = 0;
-
- // Used for versioning the format of features the model expects.
- // - feature_version == 0:
- // For each token the features consist of:
- // - chargram embeddings
- // - dense features
- // Chargram embeddings for tokens are concatenated first together,
- // and at the end, the dense features for the tokens are concatenated
- // to it. So the resulting feature vector has two regions.
- feature_version:int = 0;
-
- tokenization_type:TokenizationType = INTERNAL_TOKENIZER;
- icu_preserve_whitespace_tokens:bool = false;
-
- // List of codepoints that will be stripped from beginning and end of
- // predicted spans.
- ignored_span_boundary_codepoints:[int];
-
- bounds_sensitive_features:FeatureProcessorOptions_.BoundsSensitiveFeatures;
-
- // List of allowed charactergrams. The extracted charactergrams are filtered
- // using this list, and charactergrams that are not present are interpreted as
- // out-of-vocabulary.
- // If no allowed_chargrams are specified, all charactergrams are allowed.
- // The field is typed as bytes type to allow non-UTF8 chargrams.
- allowed_chargrams:[string];
-
- // If true, tokens will be also split when the codepoint's script_id changes
- // as defined in TokenizationCodepointRange.
- tokenize_on_script_change:bool = false;
-}
-
-namespace libtextclassifier3;
-table NumberAnnotatorOptions {
- // If true, number annotations will be produced.
- enabled:bool = false;
-
- // Score to assign to the annotated numbers from the annotator.
- score:float = 1;
-
- // Priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // The modes in which to enable number annotations.
- enabled_modes:ModeFlag = ALL;
-
- // The annotation usecases for which to produce number annotations.
- // This is a flag field for values of AnnotationUsecase.
- enabled_annotation_usecases:uint = 4294967295;
-
- // A list of codepoints that can form a prefix of a valid number.
- allowed_prefix_codepoints:[int];
-
- // A list of codepoints that can form a suffix of a valid number.
- allowed_suffix_codepoints:[int];
-}
-
-// DurationAnnotator is so far tailored for English only.
-namespace libtextclassifier3;
-table DurationAnnotatorOptions {
- // If true, duration annotations will be produced.
- enabled:bool = false;
-
- // Score to assign to the annotated durations from the annotator.
- score:float = 1;
-
- // Priority score used for conflict resolution with the other models.
- priority_score:float = 0;
-
- // The modes in which to enable duration annotations.
- enabled_modes:ModeFlag = ALL;
-
- // The annotation usecases for which to produce duration annotations.
- enabled_annotation_usecases:uint = 4294967295;
-
- // Durations typically look like XX hours and XX minutes etc... The list of
- // strings below enumerate variants of "hours", "minutes", etc. in these
- // expressions. These are verbatim strings that are matched against tokens in
- // the input.
- week_expressions:[string];
-
- day_expressions:[string];
- hour_expressions:[string];
- minute_expressions:[string];
- second_expressions:[string];
-
- // List of expressions that doesn't break a duration expression (can become
- // a part of it) but has not semantic meaning.
- filler_expressions:[string];
-
- // List of expressions that mean half of a unit of duration (e.g. "half an
- // hour").
- half_expressions:[string];
-}
-
-root_type libtextclassifier3.Model;
diff --git a/annotator/number/number.cc b/annotator/number/number.cc
deleted file mode 100644
index bc3a2fe..0000000
--- a/annotator/number/number.cc
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/number/number.h"
-
-#include <climits>
-#include <cstdlib>
-
-#include "annotator/collections.h"
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool NumberAnnotator::ClassifyText(
- const UnicodeText& context, CodepointSpan selection_indices,
- AnnotationUsecase annotation_usecase,
- ClassificationResult* classification_result) const {
- int64 parsed_value;
- int num_prefix_codepoints;
- int num_suffix_codepoints;
- if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
- selection_indices.second),
- &parsed_value, &num_prefix_codepoints,
- &num_suffix_codepoints)) {
- ClassificationResult classification{Collections::Number(), 1.0};
- TC3_CHECK(classification_result != nullptr);
- classification_result->collection = Collections::Number();
- classification_result->score = options_->score();
- classification_result->priority_score = options_->priority_score();
- classification_result->numeric_value = parsed_value;
- return true;
- }
- return false;
-}
-
-bool NumberAnnotator::FindAll(const UnicodeText& context,
- AnnotationUsecase annotation_usecase,
- std::vector<AnnotatedSpan>* result) const {
- if (!options_->enabled() || ((1 << annotation_usecase) &
- options_->enabled_annotation_usecases()) == 0) {
- return true;
- }
-
- const std::vector<Token> tokens = feature_processor_->Tokenize(context);
- for (const Token& token : tokens) {
- const UnicodeText token_text =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- int64 parsed_value;
- int num_prefix_codepoints;
- int num_suffix_codepoints;
- if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
- &num_suffix_codepoints)) {
- ClassificationResult classification{Collections::Number(),
- options_->score()};
- classification.numeric_value = parsed_value;
- classification.priority_score = options_->priority_score();
-
- AnnotatedSpan annotated_span;
- annotated_span.span = {token.start + num_prefix_codepoints,
- token.end - num_suffix_codepoints};
- annotated_span.classification.push_back(classification);
-
- result->push_back(annotated_span);
- }
- }
-
- return true;
-}
-
-std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
- const flatbuffers::Vector<int32_t>* codepoints) {
- if (codepoints == nullptr) {
- return std::unordered_set<int>{};
- }
-
- std::unordered_set<int> result;
- for (const int codepoint : *codepoints) {
- result.insert(codepoint);
- }
- return result;
-}
-
-namespace {
-UnicodeText::const_iterator ConsumeAndParseNumber(
- const UnicodeText::const_iterator& it_begin,
- const UnicodeText::const_iterator& it_end, int64* result) {
- *result = 0;
-
- // See if there's a sign in the beginning of the number.
- int sign = 1;
- auto it = it_begin;
- if (it != it_end) {
- if (*it == '-') {
- ++it;
- sign = -1;
- } else if (*it == '+') {
- ++it;
- sign = 1;
- }
- }
-
- while (it != it_end) {
- if (*it >= '0' && *it <= '9') {
- // When overflow is imminent we'll fail to parse the number.
- if (*result > INT64_MAX / 10) {
- return it_begin;
- }
- *result *= 10;
- *result += *it - '0';
- } else {
- *result *= sign;
- return it;
- }
-
- ++it;
- }
-
- *result *= sign;
- return it_end;
-}
-} // namespace
-
-bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
- int* num_prefix_codepoints,
- int* num_suffix_codepoints) const {
- TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
- num_suffix_codepoints != nullptr);
- auto it = text.begin();
- auto it_end = text.end();
-
- // Strip boundary codepoints from both ends.
- const CodepointSpan original_span{0, text.size_codepoints()};
- const CodepointSpan stripped_span =
- feature_processor_->StripBoundaryCodepoints(text, original_span);
- const int num_stripped_end = (original_span.second - stripped_span.second);
- std::advance(it, stripped_span.first);
- std::advance(it_end, -num_stripped_end);
-
- // Consume prefix codepoints.
- *num_prefix_codepoints = stripped_span.first;
- while (it != text.end()) {
- if (allowed_prefix_codepoints_.find(*it) ==
- allowed_prefix_codepoints_.end()) {
- break;
- }
-
- ++it;
- ++(*num_prefix_codepoints);
- }
-
- auto it_start = it;
- it = ConsumeAndParseNumber(it, text.end(), result);
- if (it == it_start) {
- return false;
- }
-
- // Consume suffix codepoints.
- bool valid_suffix = true;
- *num_suffix_codepoints = 0;
- while (it != it_end) {
- if (allowed_suffix_codepoints_.find(*it) ==
- allowed_suffix_codepoints_.end()) {
- valid_suffix = false;
- break;
- }
-
- ++it;
- ++(*num_suffix_codepoints);
- }
- *num_suffix_codepoints += num_stripped_end;
- return valid_suffix;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/number/number.h b/annotator/number/number.h
deleted file mode 100644
index 488f5ea..0000000
--- a/annotator/number/number.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
-
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "annotator/feature-processor.h"
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-// Annotator of numbers in text.
-//
-// Only supports values in range [-999 999 999, 999 999 999] (inclusive).
-//
-// TODO(zilka): Add support for non-ASCII digits.
-// TODO(zilka): Add support for written-out numbers.
-class NumberAnnotator {
- public:
- explicit NumberAnnotator(const NumberAnnotatorOptions* options,
- const FeatureProcessor* feature_processor)
- : options_(options),
- feature_processor_(feature_processor),
- allowed_prefix_codepoints_(
- FlatbuffersVectorToSet(options->allowed_prefix_codepoints())),
- allowed_suffix_codepoints_(
- FlatbuffersVectorToSet(options->allowed_suffix_codepoints())) {}
-
- // Classifies given text, and if it is a number, it passes the result in
- // 'classification_result' and returns true, otherwise returns false.
- bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
- AnnotationUsecase annotation_usecase,
- ClassificationResult* classification_result) const;
-
- // Finds all number instances in the input text.
- bool FindAll(const UnicodeText& context_unicode,
- AnnotationUsecase annotation_usecase,
- std::vector<AnnotatedSpan>* result) const;
-
- private:
- static std::unordered_set<int> FlatbuffersVectorToSet(
- const flatbuffers::Vector<int32_t>* codepoints);
-
- // Parses the text to an int64 value and returns true if succeeded, otherwise
- // false. Also returns the number of prefix/suffix codepoints that were
- // stripped from the number.
- bool ParseNumber(const UnicodeText& text, int64* result,
- int* num_prefix_codepoints,
- int* num_suffix_codepoints) const;
-
- const NumberAnnotatorOptions* options_;
- const FeatureProcessor* feature_processor_;
- const std::unordered_set<int> allowed_prefix_codepoints_;
- const std::unordered_set<int> allowed_suffix_codepoints_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
diff --git a/annotator/number/number_test.cc b/annotator/number/number_test.cc
deleted file mode 100644
index d3b2e8c..0000000
--- a/annotator/number/number_test.cc
+++ /dev/null
@@ -1,258 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/number/number.h"
-
-#include <string>
-#include <vector>
-
-#include "annotator/collections.h"
-#include "annotator/model_generated.h"
-#include "annotator/types-test-util.h"
-#include "annotator/types.h"
-#include "utils/test-utils.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::AllOf;
-using testing::ElementsAre;
-using testing::Field;
-
-const NumberAnnotatorOptions* TestingNumberAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- NumberAnnotatorOptionsT options;
- options.enabled = true;
- options.allowed_prefix_codepoints.push_back('$');
- options.allowed_suffix_codepoints.push_back('%');
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(NumberAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- return flatbuffers::GetRoot<NumberAnnotatorOptions>(options_data->data());
-}
-
-FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- FeatureProcessorOptionsT options;
- options.context_size = 1;
- options.max_selection_span = 1;
- options.snap_label_span_boundaries_to_containing_tokens = false;
- options.ignored_span_boundary_codepoints.push_back(',');
-
- options.tokenization_codepoint_config.emplace_back(
- new TokenizationCodepointRangeT());
- auto& config = options.tokenization_codepoint_config.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
-
- const FeatureProcessorOptions* feature_processor_options =
- flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
-
- return FeatureProcessor(feature_processor_options, unilib);
-}
-
-class NumberAnnotatorTest : public ::testing::Test {
- protected:
- NumberAnnotatorTest()
- : INIT_UNILIB_FOR_TESTING(unilib_),
- feature_processor_(BuildFeatureProcessor(&unilib_)),
- number_annotator_(TestingNumberAnnotatorOptions(),
- &feature_processor_) {}
-
- UniLib unilib_;
- FeatureProcessor feature_processor_;
- NumberAnnotator number_annotator_;
-};
-
-TEST_F(NumberAnnotatorTest, ClassifiesAndParsesNumberCorrectly) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("... 12345 ..."), {4, 9},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_EQ(classification_result.collection, "number");
- EXPECT_EQ(classification_result.numeric_value, 12345);
-}
-
-TEST_F(NumberAnnotatorTest, ClassifiesNonNumberCorrectly) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("... 123a45 ..."), {4, 10},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, FindsAllNumbersInText) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... 12345 ... 9 is my number and I paid $99 and "
- "sometimes 27% but not 68# nor #68"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- ASSERT_EQ(result.size(), 4);
- ASSERT_EQ(result[0].classification.size(), 1);
- EXPECT_EQ(result[0].classification[0].collection, "number");
- EXPECT_EQ(result[0].classification[0].numeric_value, 12345);
- ASSERT_EQ(result[1].classification.size(), 1);
- EXPECT_EQ(result[1].classification[0].collection, "number");
- EXPECT_EQ(result[1].classification[0].numeric_value, 9);
- ASSERT_EQ(result[2].classification.size(), 1);
- EXPECT_EQ(result[2].classification[0].collection, "number");
- EXPECT_EQ(result[2].classification[0].numeric_value, 99);
- ASSERT_EQ(result[3].classification.size(), 1);
- EXPECT_EQ(result[3].classification[0].collection, "number");
- EXPECT_EQ(result[3].classification[0].numeric_value, 27);
-}
-
-TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("Come at 9, ok?"),
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(8, 9)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, 9)))))));
-}
-
-TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- EXPECT_THAT(
- result,
- ElementsAre(
- AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 2)),
- Field(&AnnotatedSpan::classification,
- ElementsAre(AllOf(
- Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, -5)))))));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("-999999999999999999"), {0, 19},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_THAT(
- classification_result,
- AllOf(Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, -999999999999999999L)));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLargestSupportedNumberParsesIt) {
- ClassificationResult classification_result;
- EXPECT_TRUE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("999999999999999999"), {0, 18},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-
- EXPECT_THAT(
- classification_result,
- AllOf(Field(&ClassificationResult::collection, "number"),
- Field(&ClassificationResult::numeric_value, 999999999999999999L)));
-}
-
-TEST_F(NumberAnnotatorTest, WhenFirstLowestNonSupportedNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("-10000000000000000000"), {0, 21},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenFirstLargestNonSupportedNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("10000000000000000000"), {0, 20},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenLargeNumberDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("1234567890123456789012345678901234567890"), {0, 40},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMultipleMinusSignsDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("--10"), {0, 4},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMinusSignSuffixDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("10-"), {0, 3},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenMinusInTheMiddleDoesNotParseIt) {
- ClassificationResult classification_result;
- EXPECT_FALSE(number_annotator_.ClassifyText(
- UTF8ToUnicodeText("2016-2017"), {0, 9},
- AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification_result));
-}
-
-TEST_F(NumberAnnotatorTest, WhenSuffixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... % ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(NumberAnnotatorTest, WhenPrefixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... $ ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-TEST_F(NumberAnnotatorTest, WhenPrefixAndSuffixWithoutNumberDoesNotParseIt) {
- std::vector<AnnotatedSpan> result;
- EXPECT_TRUE(number_annotator_.FindAll(
- UTF8ToUnicodeText("... $% ..."), AnnotationUsecase_ANNOTATION_USECASE_RAW,
- &result));
-
- ASSERT_EQ(result.size(), 0);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
deleted file mode 100644
index 0f2ec16..0000000
--- a/annotator/test_data/test_model.fb
+++ /dev/null
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
deleted file mode 100644
index 5439623..0000000
--- a/annotator/test_data/wrong_embeddings.fb
+++ /dev/null
Binary files differ
diff --git a/annotator/types-test-util.h b/annotator/types-test-util.h
deleted file mode 100644
index c0b0980..0000000
--- a/annotator/types-test-util.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
-
-#include <ostream>
-
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-#define TC3_DECLARE_PRINT_OPERATOR(TYPE_NAME) \
- inline std::ostream& operator<<(std::ostream& stream, \
- const TYPE_NAME& value) { \
- logging::LoggingStringStream tmp_stream; \
- tmp_stream << value; \
- return stream << tmp_stream.message; \
- }
-
-TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
-TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
-TC3_DECLARE_PRINT_OPERATOR(DateParseData)
-TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
-TC3_DECLARE_PRINT_OPERATOR(Token)
-
-#undef TC3_DECLARE_PRINT_OPERATOR
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
diff --git a/annotator/types.cc b/annotator/types.cc
deleted file mode 100644
index ee150c8..0000000
--- a/annotator/types.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/types.h"
-
-namespace libtextclassifier3 {
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const Token& token) {
- if (!token.is_padding) {
- return stream << "Token(\"" << token.value << "\", " << token.start << ", "
- << token.end << ")";
- } else {
- return stream << "Token()";
- }
-}
-
-namespace {
-std::string FormatMillis(int64 time_ms_utc) {
- long time_seconds = time_ms_utc / 1000; // NOLINT
- char buffer[512];
- strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
- localtime(&time_seconds));
- return std::string(buffer);
-}
-} // namespace
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DatetimeParseResultSpan& value) {
- stream << "DatetimeParseResultSpan({" << value.span.first << ", "
- << value.span.second << "}, {";
- for (const DatetimeParseResult& data : value.data) {
- stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
- << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
- << data.granularity << "}, ";
- }
- stream << "})";
- return stream;
-}
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const ClassificationResult& result) {
- return stream << "ClassificationResult(" << result.collection
- << ", /*score=*/ " << result.score << ", /*priority_score=*/ "
- << result.priority_score << ")";
-}
-
-logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream,
- const std::vector<ClassificationResult>& results) {
- stream = stream << "{\n";
- for (const ClassificationResult& result : results) {
- stream = stream << " " << result << "\n";
- }
- stream = stream << "}";
- return stream;
-}
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const AnnotatedSpan& span) {
- std::string best_class;
- float best_score = -1;
- if (!span.classification.empty()) {
- best_class = span.classification[0].collection;
- best_score = span.classification[0].score;
- }
- return stream << "Span(" << span.span.first << ", " << span.span.second
- << ", " << best_class << ", " << best_score << ")";
-}
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DateParseData& data) {
- // TODO(zilka): Add human-readable form of field_set_mask and the enum fields.
- stream = stream << "DateParseData {\n";
- stream = stream << " field_set_mask: " << data.field_set_mask << "\n";
- stream = stream << " year: " << data.year << "\n";
- stream = stream << " month: " << data.month << "\n";
- stream = stream << " day_of_month: " << data.day_of_month << "\n";
- stream = stream << " hour: " << data.hour << "\n";
- stream = stream << " minute: " << data.minute << "\n";
- stream = stream << " second: " << data.second << "\n";
- stream = stream << " ampm: " << static_cast<int>(data.ampm) << "\n";
- stream = stream << " zone_offset: " << data.zone_offset << "\n";
- stream = stream << " dst_offset: " << data.dst_offset << "\n";
- stream = stream << " relation: " << static_cast<int>(data.relation) << "\n";
- stream = stream << " relation_type: " << static_cast<int>(data.relation_type)
- << "\n";
- stream = stream << " relation_distance: " << data.relation_distance << "\n";
- stream = stream << "}";
- return stream;
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/types.h b/annotator/types.h
deleted file mode 100644
index 48fefe4..0000000
--- a/annotator/types.h
+++ /dev/null
@@ -1,423 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
-
-#include <time.h>
-#include <algorithm>
-#include <cmath>
-#include <functional>
-#include <map>
-#include <set>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "annotator/entity-data_generated.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
-#include "utils/variant.h"
-
-namespace libtextclassifier3 {
-
-constexpr int kInvalidIndex = -1;
-
-// Index for a 0-based array of tokens.
-using TokenIndex = int;
-
-// Index for a 0-based array of codepoints.
-using CodepointIndex = int;
-
-// Marks a span in a sequence of codepoints. The first element is the index of
-// the first codepoint of the span, and the second element is the index of the
-// codepoint one past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
-
-inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
- return a.first < b.second && b.first < a.second;
-}
-
-inline bool ValidNonEmptySpan(const CodepointSpan& span) {
- return span.first < span.second && span.first >= 0 && span.second >= 0;
-}
-
-template <typename T>
-bool DoesCandidateConflict(
- const int considered_candidate, const std::vector<T>& candidates,
- const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
- if (chosen_indices_set.empty()) {
- return false;
- }
-
- auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
- // Check conflict on the right.
- if (conflicting_it != chosen_indices_set.end() &&
- SpansOverlap(candidates[considered_candidate].span,
- candidates[*conflicting_it].span)) {
- return true;
- }
-
- // Check conflict on the left.
- // If we can't go more left, there can't be a conflict:
- if (conflicting_it == chosen_indices_set.begin()) {
- return false;
- }
- // Otherwise move one span left and insert if it doesn't overlap with the
- // candidate.
- --conflicting_it;
- if (!SpansOverlap(candidates[considered_candidate].span,
- candidates[*conflicting_it].span)) {
- return false;
- }
-
- return true;
-}
-
-// Marks a span in a sequence of tokens. The first element is the index of the
-// first token in the span, and the second element is the index of the token one
-// past the end of the span.
-// TODO(b/71982294): Make it a struct.
-using TokenSpan = std::pair<TokenIndex, TokenIndex>;
-
-// Returns the size of the token span. Assumes that the span is valid.
-inline int TokenSpanSize(const TokenSpan& token_span) {
- return token_span.second - token_span.first;
-}
-
-// Returns a token span consisting of one token.
-inline TokenSpan SingleTokenSpan(int token_index) {
- return {token_index, token_index + 1};
-}
-
-// Returns an intersection of two token spans. Assumes that both spans are valid
-// and overlapping.
-inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
- const TokenSpan& token_span2) {
- return {std::max(token_span1.first, token_span2.first),
- std::min(token_span1.second, token_span2.second)};
-}
-
-// Returns and expanded token span by adding a certain number of tokens on its
-// left and on its right.
-inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
- int num_tokens_left, int num_tokens_right) {
- return {token_span.first - num_tokens_left,
- token_span.second + num_tokens_right};
-}
-
-// Token holds a token, its position in the original string and whether it was
-// part of the input span.
-struct Token {
- std::string value;
- CodepointIndex start;
- CodepointIndex end;
-
- // Whether the token is a padding token.
- bool is_padding;
-
- // Default constructor constructs the padding-token.
- Token()
- : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
-
- Token(const std::string& arg_value, CodepointIndex arg_start,
- CodepointIndex arg_end)
- : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
-
- bool operator==(const Token& other) const {
- return value == other.value && start == other.start && end == other.end &&
- is_padding == other.is_padding;
- }
-
- bool IsContainedInSpan(CodepointSpan span) const {
- return start >= span.first && end <= span.second;
- }
-};
-
-// Pretty-printing function for Token.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const Token& token);
-
-enum DatetimeGranularity {
- GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
- // structure being uninitialized.
- GRANULARITY_YEAR = 0,
- GRANULARITY_MONTH = 1,
- GRANULARITY_WEEK = 2,
- GRANULARITY_DAY = 3,
- GRANULARITY_HOUR = 4,
- GRANULARITY_MINUTE = 5,
- GRANULARITY_SECOND = 6
-};
-
-struct DatetimeParseResult {
- // The absolute time in milliseconds since the epoch in UTC.
- int64 time_ms_utc;
-
- // The precision of the estimate then in to calculating the milliseconds
- DatetimeGranularity granularity;
-
- DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
-
- DatetimeParseResult(int64 arg_time_ms_utc,
- DatetimeGranularity arg_granularity)
- : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
-
- bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
-
- bool operator==(const DatetimeParseResult& other) const {
- return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
- }
-};
-
-const float kFloatCompareEpsilon = 1e-5;
-
-struct DatetimeParseResultSpan {
- CodepointSpan span;
- std::vector<DatetimeParseResult> data;
- float target_classification_score;
- float priority_score;
-
- bool operator==(const DatetimeParseResultSpan& other) const {
- return span == other.span && data == other.data &&
- std::abs(target_classification_score -
- other.target_classification_score) < kFloatCompareEpsilon &&
- std::abs(priority_score - other.priority_score) <
- kFloatCompareEpsilon;
- }
-};
-
-// Pretty-printing function for DatetimeParseResultSpan.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DatetimeParseResultSpan& value);
-
-struct ClassificationResult {
- std::string collection;
- float score;
- DatetimeParseResult datetime_parse_result;
- std::string serialized_knowledge_result;
- std::string contact_name, contact_given_name, contact_nickname,
- contact_email_address, contact_phone_number, contact_id;
- std::string app_name, app_package_name;
- int64 numeric_value;
-
- // Length of the parsed duration in milliseconds.
- int64 duration_ms;
-
- // Internal score used for conflict resolution.
- float priority_score;
-
-
- // Entity data information.
- std::string serialized_entity_data;
- const EntityData* entity_data() {
- return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
- serialized_entity_data.size());
- }
-
- explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
-
- ClassificationResult(const std::string& arg_collection, float arg_score)
- : collection(arg_collection),
- score(arg_score),
- priority_score(arg_score) {}
-
- ClassificationResult(const std::string& arg_collection, float arg_score,
- float arg_priority_score)
- : collection(arg_collection),
- score(arg_score),
- priority_score(arg_priority_score) {}
-};
-
-// Pretty-printing function for ClassificationResult.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const ClassificationResult& result);
-
-// Pretty-printing function for std::vector<ClassificationResult>.
-logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream,
- const std::vector<ClassificationResult>& results);
-
-// Represents a result of Annotate call.
-struct AnnotatedSpan {
- enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME };
-
- // Unicode codepoint indices in the input string.
- CodepointSpan span = {kInvalidIndex, kInvalidIndex};
-
- // Classification result for the span.
- std::vector<ClassificationResult> classification;
-
- // The source of the annotation, used in conflict resolution.
- Source source = Source::OTHER;
-
- AnnotatedSpan() = default;
-
- AnnotatedSpan(CodepointSpan arg_span,
- std::vector<ClassificationResult> arg_classification)
- : span(arg_span), classification(std::move(arg_classification)) {}
-};
-
-// Pretty-printing function for AnnotatedSpan.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const AnnotatedSpan& span);
-
-// StringPiece analogue for std::vector<T>.
-template <class T>
-class VectorSpan {
- public:
- VectorSpan() : begin_(), end_() {}
- VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
- : begin_(v.begin()), end_(v.end()) {}
- VectorSpan(typename std::vector<T>::const_iterator begin,
- typename std::vector<T>::const_iterator end)
- : begin_(begin), end_(end) {}
-
- const T& operator[](typename std::vector<T>::size_type i) const {
- return *(begin_ + i);
- }
-
- int size() const { return end_ - begin_; }
- typename std::vector<T>::const_iterator begin() const { return begin_; }
- typename std::vector<T>::const_iterator end() const { return end_; }
- const float* data() const { return &(*begin_); }
-
- private:
- typename std::vector<T>::const_iterator begin_;
- typename std::vector<T>::const_iterator end_;
-};
-
-struct DateParseData {
- enum class Relation {
- UNSPECIFIED = 0,
- NEXT = 1,
- NEXT_OR_SAME = 2,
- LAST = 3,
- NOW = 4,
- TOMORROW = 5,
- YESTERDAY = 6,
- PAST = 7,
- FUTURE = 8
- };
-
- enum class RelationType {
- UNSPECIFIED = 0,
- SUNDAY = 1,
- MONDAY = 2,
- TUESDAY = 3,
- WEDNESDAY = 4,
- THURSDAY = 5,
- FRIDAY = 6,
- SATURDAY = 7,
- DAY = 8,
- WEEK = 9,
- MONTH = 10,
- YEAR = 11,
- HOUR = 12,
- MINUTE = 13,
- SECOND = 14,
- };
-
- enum Fields {
- YEAR_FIELD = 1 << 0,
- MONTH_FIELD = 1 << 1,
- DAY_FIELD = 1 << 2,
- HOUR_FIELD = 1 << 3,
- MINUTE_FIELD = 1 << 4,
- SECOND_FIELD = 1 << 5,
- AMPM_FIELD = 1 << 6,
- ZONE_OFFSET_FIELD = 1 << 7,
- DST_OFFSET_FIELD = 1 << 8,
- RELATION_FIELD = 1 << 9,
- RELATION_TYPE_FIELD = 1 << 10,
- RELATION_DISTANCE_FIELD = 1 << 11
- };
-
- enum class AMPM { AM = 0, PM = 1 };
-
- enum class TimeUnit {
- DAYS = 1,
- WEEKS = 2,
- MONTHS = 3,
- HOURS = 4,
- MINUTES = 5,
- SECONDS = 6,
- YEARS = 7
- };
-
- // Bit mask of fields which have been set on the struct
- int field_set_mask = 0;
-
- // Fields describing absolute date fields.
- // Year of the date seen in the text match.
- int year = 0;
- // Month of the year starting with January = 1.
- int month = 0;
- // Day of the month starting with 1.
- int day_of_month = 0;
- // Hour of the day with a range of 0-23,
- // values less than 12 need the AMPM field below or heuristics
- // to definitively determine the time.
- int hour = 0;
- // Hour of the day with a range of 0-59.
- int minute = 0;
- // Hour of the day with a range of 0-59.
- int second = 0;
- // 0 == AM, 1 == PM
- AMPM ampm = AMPM::AM;
- // Number of hours offset from UTC this date time is in.
- int zone_offset = 0;
- // Number of hours offest for DST
- int dst_offset = 0;
-
- // The permutation from now that was made to find the date time.
- Relation relation = Relation::UNSPECIFIED;
- // The unit of measure of the change to the date time.
- RelationType relation_type = RelationType::UNSPECIFIED;
- // The number of units of change that were made.
- int relation_distance = 0;
-
- DateParseData() = default;
-
- DateParseData(int field_set_mask, int year, int month, int day_of_month,
- int hour, int minute, int second, AMPM ampm, int zone_offset,
- int dst_offset, Relation relation, RelationType relation_type,
- int relation_distance) {
- this->field_set_mask = field_set_mask;
- this->year = year;
- this->month = month;
- this->day_of_month = day_of_month;
- this->hour = hour;
- this->minute = minute;
- this->second = second;
- this->ampm = ampm;
- this->zone_offset = zone_offset;
- this->dst_offset = dst_offset;
- this->relation = relation;
- this->relation_type = relation_type;
- this->relation_distance = relation_distance;
- }
-};
-
-// Pretty-printing function for DateParseData.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const DateParseData& data);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc
deleted file mode 100644
index ec2392b..0000000
--- a/annotator/zlib-utils.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/zlib-utils.h"
-
-#include <memory>
-
-#include "utils/base/logging.h"
-#include "utils/intents/zlib-utils.h"
-#include "utils/resources.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-
-// Compress rule fields in the model.
-bool CompressModel(ModelT* model) {
- std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
- if (!zlib_compressor) {
- TC3_LOG(ERROR) << "Cannot compress model.";
- return false;
- }
-
- // Compress regex rules.
- if (model->regex_model != nullptr) {
- for (int i = 0; i < model->regex_model->patterns.size(); i++) {
- RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
- pattern->compressed_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(pattern->pattern,
- pattern->compressed_pattern.get());
- pattern->pattern.clear();
- }
- }
-
- // Compress date-time rules.
- if (model->datetime_model != nullptr) {
- for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
- DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
- for (int j = 0; j < pattern->regexes.size(); j++) {
- DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
- regex->compressed_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(regex->pattern,
- regex->compressed_pattern.get());
- regex->pattern.clear();
- }
- }
- for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
- DatetimeModelExtractorT* extractor =
- model->datetime_model->extractors[i].get();
- extractor->compressed_pattern.reset(new CompressedBufferT);
- zlib_compressor->Compress(extractor->pattern,
- extractor->compressed_pattern.get());
- extractor->pattern.clear();
- }
- }
-
- // Compress resources.
- if (model->resources != nullptr) {
- CompressResources(model->resources.get());
- }
-
- // Compress intent generator.
- if (model->intent_options != nullptr) {
- CompressIntentModel(model->intent_options.get());
- }
-
- return true;
-}
-
-bool DecompressModel(ModelT* model) {
- std::unique_ptr<ZlibDecompressor> zlib_decompressor =
- ZlibDecompressor::Instance();
- if (!zlib_decompressor) {
- TC3_LOG(ERROR) << "Cannot initialize decompressor.";
- return false;
- }
-
- // Decompress regex rules.
- if (model->regex_model != nullptr) {
- for (int i = 0; i < model->regex_model->patterns.size(); i++) {
- RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
- if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(),
- &pattern->pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
- return false;
- }
- pattern->compressed_pattern.reset(nullptr);
- }
- }
-
- // Decompress date-time rules.
- if (model->datetime_model != nullptr) {
- for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
- DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
- for (int j = 0; j < pattern->regexes.size(); j++) {
- DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
- if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(),
- ®ex->pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
- return false;
- }
- regex->compressed_pattern.reset(nullptr);
- }
- }
- for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
- DatetimeModelExtractorT* extractor =
- model->datetime_model->extractors[i].get();
- if (!zlib_decompressor->MaybeDecompress(
- extractor->compressed_pattern.get(), &extractor->pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
- return false;
- }
- extractor->compressed_pattern.reset(nullptr);
- }
- }
- return true;
-}
-
-std::string CompressSerializedModel(const std::string& model) {
- std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
- TC3_CHECK(unpacked_model != nullptr);
- TC3_CHECK(CompressModel(unpacked_model.get()));
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-} // namespace libtextclassifier3
diff --git a/annotator/zlib-utils_test.cc b/annotator/zlib-utils_test.cc
deleted file mode 100644
index 7a8d775..0000000
--- a/annotator/zlib-utils_test.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/zlib-utils.h"
-
-#include <memory>
-
-#include "annotator/model_generated.h"
-#include "utils/zlib/zlib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-
-TEST(ZlibUtilsTest, CompressModel) {
- ModelT model;
- model.regex_model.reset(new RegexModelT);
- model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
- model.regex_model->patterns.back()->pattern = "this is a test pattern";
- model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
- model.regex_model->patterns.back()->pattern = "this is a second test pattern";
-
- model.datetime_model.reset(new DatetimeModelT);
- model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT);
- model.datetime_model->patterns.back()->regexes.emplace_back(
- new DatetimeModelPattern_::RegexT);
- model.datetime_model->patterns.back()->regexes.back()->pattern =
- "an example datetime pattern";
- model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT);
- model.datetime_model->extractors.back()->pattern =
- "an example datetime extractor";
-
- // Compress the model.
- EXPECT_TRUE(CompressModel(&model));
-
- // Sanity check that uncompressed field is removed.
- EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
- EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
- EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
-
- // Pack and load the model.
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, &model));
- const Model* compressed_model =
- GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer()));
- ASSERT_TRUE(compressed_model != nullptr);
-
- // Decompress the fields again and check that they match the original.
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
- ASSERT_TRUE(decompressor != nullptr);
- std::string uncompressed_pattern;
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "this is a test pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(
- compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "this is a second test pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
- ->patterns()
- ->Get(0)
- ->regexes()
- ->Get(0)
- ->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "an example datetime pattern");
- EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
- ->extractors()
- ->Get(0)
- ->compressed_pattern(),
- &uncompressed_pattern));
- EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
-
- EXPECT_TRUE(DecompressModel(&model));
- EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern");
- EXPECT_EQ(model.regex_model->patterns[1]->pattern,
- "this is a second test pattern");
- EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern,
- "an example datetime pattern");
- EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
- "an example datetime extractor");
-}
-
-} // namespace libtextclassifier3
diff --git a/java/Android.bp b/java/Android.bp
new file mode 100644
index 0000000..26efacd
--- /dev/null
+++ b/java/Android.bp
@@ -0,0 +1,70 @@
+//
+// Copyright (C) 2019 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// A standalone TextClassifierService app for testing.
+android_app {
+ name: "TextClassifierService",
+ static_libs: ["TextClassifierServiceLib"],
+ jni_libs: ["libtextclassifier"],
+ sdk_version: "system_current",
+ min_sdk_version: "28",
+ certificate: "platform",
+ optimize: {
+ proguard_flags_files: ["proguard.flags"],
+ },
+ use_embedded_native_libs: true,
+}
+
+// A library that contains all java classes with the AndroidManifest.
+android_library {
+ name: "TextClassifierServiceLib",
+ static_libs: ["TextClassifierServiceLibNoManifest"],
+ sdk_version: "system_current",
+ min_sdk_version: "28",
+ manifest: "AndroidManifest.xml",
+}
+
+// Similar to TextClassifierServiceLib, but without the AndroidManifest.
+android_library {
+ name: "TextClassifierServiceLibNoManifest",
+ srcs: ["src/**/*.java"],
+ manifest: "LibNoManifest_AndroidManifest.xml",
+ static_libs: [
+ "androidx.core_core",
+ "libtextclassifier-java",
+ "androidx.annotation_annotation",
+ "guava",
+ "textclassifier-statsd",
+ "error_prone_annotations",
+ ],
+ sdk_version: "system_current",
+ min_sdk_version: "28",
+}
+
+java_library {
+ name: "textclassifier-statsd",
+ sdk_version: "system_current",
+ srcs: [
+ ":statslog-textclassifier-java-gen",
+ ],
+}
+
+genrule {
+ name: "statslog-textclassifier-java-gen",
+ tools: ["stats-log-api-gen"],
+ cmd: "$(location stats-log-api-gen) --java $(out) --module textclassifier --javaPackage com.android.textclassifier --javaClass TextClassifierStatsLog",
+ out: ["com/android/textclassifier/TextClassifierStatsLog.java"],
+}
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
new file mode 100644
index 0000000..9f02689
--- /dev/null
+++ b/java/AndroidManifest.xml
@@ -0,0 +1,47 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-->
+
+<!--
+ This manifest file is for the standalone TCS used for testing.
+ The TCS is typically shipped as part of ExtServices and is configured
+ in ExtServices's manifest.
+-->
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier"
+ android:versionCode="1"
+ android:versionName="1.0.0">
+
+ <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+
+ <uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
+
+ <application android:label="@string/tcs_app_name"
+ android:icon="@drawable/tcs_app_icon"
+ android:extractNativeLibs="false">
+ <service
+ android:exported="true"
+ android:name=".DefaultTextClassifierService"
+ android:permission="android.permission.BIND_TEXTCLASSIFIER_SERVICE">
+ <intent-filter>
+ <action android:name="android.service.textclassifier.TextClassifierService"/>
+ </intent-filter>
+ </service>
+
+ </application>
+</manifest>
diff --git a/java/LibNoManifest_AndroidManifest.xml b/java/LibNoManifest_AndroidManifest.xml
new file mode 100644
index 0000000..184bbf0
--- /dev/null
+++ b/java/LibNoManifest_AndroidManifest.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-->
+
+<!--
+ This is for the AndroidManifest.xml for the TextClassifierServiceLibNoManifest library.
+ The user of this library should explicitly put the necssary components in their own
+ AndroidManifest.xml, see AndroidManifest.xml under the same folder.
+-->
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier">
+
+ <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+
+</manifest>
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
deleted file mode 100644
index 9132b1f..0000000
--- a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ /dev/null
@@ -1,265 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier;
-
-import java.util.concurrent.atomic.AtomicBoolean;
-
-/**
- * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
- * actions and replies in a given conversation.
- *
- * @hide
- */
-public final class ActionsSuggestionsModel implements AutoCloseable {
- private final AtomicBoolean isClosed = new AtomicBoolean(false);
-
- static {
- System.loadLibrary("textclassifier");
- }
-
- private long actionsModelPtr;
-
- /**
- * Creates a new instance of Actions predictor, using the provided model image, given as a file
- * descriptor.
- */
- public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) {
- actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
- if (actionsModelPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
- }
- }
-
- public ActionsSuggestionsModel(int fileDescriptor) {
- this(fileDescriptor, /* serializedPreconditions= */ null);
- }
-
- /**
- * Creates a new instance of Actions predictor, using the provided model image, given as a file
- * path.
- */
- public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) {
- actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
- if (actionsModelPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
- }
- }
-
- public ActionsSuggestionsModel(String path) {
- this(path, /* serializedPreconditions= */ null);
- }
-
- /** Suggests actions / replies to the given conversation. */
- public ActionSuggestion[] suggestActions(
- Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
- return nativeSuggestActions(
- actionsModelPtr,
- conversation,
- options,
- (annotator != null ? annotator.getNativeAnnotator() : 0),
- /* appContext= */ null,
- /* deviceLocales= */ null,
- /* generateAndroidIntents= */ false);
- }
-
- public ActionSuggestion[] suggestActionsWithIntents(
- Conversation conversation,
- ActionSuggestionOptions options,
- Object appContext,
- String deviceLocales,
- AnnotatorModel annotator) {
- return nativeSuggestActions(
- actionsModelPtr,
- conversation,
- options,
- (annotator != null ? annotator.getNativeAnnotator() : 0),
- appContext,
- deviceLocales,
- /* generateAndroidIntents= */ true);
- }
-
- /** Frees up the allocated memory. */
- @Override
- public void close() {
- if (isClosed.compareAndSet(false, true)) {
- nativeCloseActionsModel(actionsModelPtr);
- actionsModelPtr = 0L;
- }
- }
-
- @Override
- protected void finalize() throws Throwable {
- try {
- close();
- } finally {
- super.finalize();
- }
- }
-
- /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
- public static String getLocales(int fd) {
- return nativeGetLocales(fd);
- }
-
- /** Returns the version of the model. */
- public static int getVersion(int fd) {
- return nativeGetVersion(fd);
- }
-
- /** Returns the name of the model. */
- public static String getName(int fd) {
- return nativeGetName(fd);
- }
-
- /** Action suggestion that contains a response text and the type of the response. */
- public static final class ActionSuggestion {
- private final String responseText;
- private final String actionType;
- private final float score;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
-
- public ActionSuggestion(
- String responseText,
- String actionType,
- float score,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates) {
- this.responseText = responseText;
- this.actionType = actionType;
- this.score = score;
- this.entityData = entityData;
- this.serializedEntityData = serializedEntityData;
- this.remoteActionTemplates = remoteActionTemplates;
- }
-
- public String getResponseText() {
- return responseText;
- }
-
- public String getActionType() {
- return actionType;
- }
-
- /** Confidence score between 0 and 1 */
- public float getScore() {
- return score;
- }
-
- public NamedVariant[] getEntityData() {
- return entityData;
- }
-
- public byte[] getSerializedEntityData() {
- return serializedEntityData;
- }
-
- public RemoteActionTemplate[] getRemoteActionTemplates() {
- return remoteActionTemplates;
- }
- }
-
- /** Represents a single message in the conversation. */
- public static final class ConversationMessage {
- private final int userId;
- private final String text;
- private final long referenceTimeMsUtc;
- private final String referenceTimezone;
- private final String detectedTextLanguageTags;
-
- public ConversationMessage(
- int userId,
- String text,
- long referenceTimeMsUtc,
- String referenceTimezone,
- String detectedTextLanguageTags) {
- this.userId = userId;
- this.text = text;
- this.referenceTimeMsUtc = referenceTimeMsUtc;
- this.referenceTimezone = referenceTimezone;
- this.detectedTextLanguageTags = detectedTextLanguageTags;
- }
-
- /** The identifier of the sender */
- public int getUserId() {
- return userId;
- }
-
- public String getText() {
- return text;
- }
-
- /**
- * Return the reference time of the message, for example, it could be compose time or send time.
- * {@code 0} means unspecified.
- */
- public long getReferenceTimeMsUtc() {
- return referenceTimeMsUtc;
- }
-
- public String getReferenceTimezone() {
- return referenceTimezone;
- }
-
- /** Returns a comma separated list of BCP 47 language tags. */
- public String getDetectedTextLanguageTags() {
- return detectedTextLanguageTags;
- }
- }
-
- /** Represents conversation between multiple users. */
- public static final class Conversation {
- public final ConversationMessage[] conversationMessages;
-
- public Conversation(ConversationMessage[] conversationMessages) {
- this.conversationMessages = conversationMessages;
- }
-
- public ConversationMessage[] getConversationMessages() {
- return conversationMessages;
- }
- }
-
- /** Represents options for the SuggestActions call. */
- public static final class ActionSuggestionOptions {
- public ActionSuggestionOptions() {}
- }
-
- private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
-
- private static native long nativeNewActionsModelFromPath(
- String path, byte[] preconditionsOverwrite);
-
- private static native String nativeGetLocales(int fd);
-
- private static native int nativeGetVersion(int fd);
-
- private static native String nativeGetName(int fd);
-
- private native ActionSuggestion[] nativeSuggestActions(
- long context,
- Conversation conversation,
- ActionSuggestionOptions options,
- long annotatorPtr,
- Object appContext,
- String deviceLocales,
- boolean generateAndroidIntents);
-
- private native void nativeCloseActionsModel(long ptr);
-}
diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java
deleted file mode 100644
index 5f99f74..0000000
--- a/java/com/google/android/textclassifier/AnnotatorModel.java
+++ /dev/null
@@ -1,591 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier;
-
-import java.util.Collection;
-import java.util.concurrent.atomic.AtomicBoolean;
-
-/**
- * Java wrapper for Annotator native library interface. This library is used for detecting entities
- * in text.
- *
- * @hide
- */
-public final class AnnotatorModel implements AutoCloseable {
- private final AtomicBoolean isClosed = new AtomicBoolean(false);
-
- static {
- System.loadLibrary("textclassifier");
- }
-
- // Keep these in sync with the constants defined in AOSP.
- static final String TYPE_UNKNOWN = "";
- static final String TYPE_OTHER = "other";
- static final String TYPE_EMAIL = "email";
- static final String TYPE_PHONE = "phone";
- static final String TYPE_ADDRESS = "address";
- static final String TYPE_URL = "url";
- static final String TYPE_DATE = "date";
- static final String TYPE_DATE_TIME = "datetime";
- static final String TYPE_FLIGHT_NUMBER = "flight";
-
- private long annotatorPtr;
-
- /** Enumeration for specifying the usecase of the annotations. */
- public static enum AnnotationUsecase {
- /** Results are optimized for Smart{Select,Share,Linkify}. */
- SMART(0),
-
- /**
- * Results are optimized for using TextClassifier as an infrastructure that annotates as much as
- * possible.
- */
- RAW(1);
-
- private final int value;
-
- AnnotationUsecase(int value) {
- this.value = value;
- }
-
- public int getValue() {
- return value;
- }
- };
-
- /**
- * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
- * file descriptor.
- */
- public AnnotatorModel(int fileDescriptor) {
- annotatorPtr = nativeNewAnnotator(fileDescriptor);
- if (annotatorPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
- }
- }
-
- /**
- * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
- * file path.
- */
- public AnnotatorModel(String path) {
- annotatorPtr = nativeNewAnnotatorFromPath(path);
- if (annotatorPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize TC from given file.");
- }
- }
-
- /** Initializes the knowledge engine, passing the given serialized config to it. */
- public void initializeKnowledgeEngine(byte[] serializedConfig) {
- if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) {
- throw new IllegalArgumentException("Couldn't initialize the KG engine");
- }
- }
-
- /** Initializes the contact engine, passing the given serialized config to it. */
- public void initializeContactEngine(byte[] serializedConfig) {
- if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) {
- throw new IllegalArgumentException("Couldn't initialize the contact engine");
- }
- }
-
- /** Initializes the installed app engine, passing the given serialized config to it. */
- public void initializeInstalledAppEngine(byte[] serializedConfig) {
- if (!nativeInitializeInstalledAppEngine(annotatorPtr, serializedConfig)) {
- throw new IllegalArgumentException("Couldn't initialize the installed app engine");
- }
- }
-
- /**
- * Given a string context and current selection, computes the selection suggestion.
- *
- * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the
- * character index where the selection begins, and selectionEnd is the index of one character past
- * the selection span.
- *
- * <p>The return value is an array of two ints: suggested selection beginning and end, with the
- * same semantics as the input selectionBeginning and selectionEnd.
- */
- public int[] suggestSelection(
- String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
- return nativeSuggestSelection(annotatorPtr, context, selectionBegin, selectionEnd, options);
- }
-
- /**
- * Given a string context and current selection, classifies the type of the selected text.
- *
- * <p>The begin and end params are character indices in the context string.
- *
- * <p>Returns an array of ClassificationResult objects with the probability scores for different
- * collections.
- */
- public ClassificationResult[] classifyText(
- String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
- return classifyText(
- context,
- selectionBegin,
- selectionEnd,
- options,
- /*appContext=*/ null,
- /*deviceLocales=*/ null);
- }
-
- public ClassificationResult[] classifyText(
- String context,
- int selectionBegin,
- int selectionEnd,
- ClassificationOptions options,
- Object appContext,
- String deviceLocales) {
- return nativeClassifyText(
- annotatorPtr, context, selectionBegin, selectionEnd, options, appContext, deviceLocales);
- }
-
- /**
- * Annotates given input text. The annotations should cover the whole input context except for
- * whitespaces, and are sorted by their position in the context string.
- */
- public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
- return nativeAnnotate(annotatorPtr, text, options);
- }
-
- /**
- * Looks up a knowledge entity by its identifier. Returns null if the entity is not found or on
- * error.
- */
- public byte[] lookUpKnowledgeEntity(String id) {
- return nativeLookUpKnowledgeEntity(annotatorPtr, id);
- }
-
- /** Frees up the allocated memory. */
- @Override
- public void close() {
- if (isClosed.compareAndSet(false, true)) {
- nativeCloseAnnotator(annotatorPtr);
- annotatorPtr = 0L;
- }
- }
-
- @Override
- protected void finalize() throws Throwable {
- try {
- close();
- } finally {
- super.finalize();
- }
- }
-
- /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
- public static String getLocales(int fd) {
- return nativeGetLocales(fd);
- }
-
- /** Returns the version of the model. */
- public static int getVersion(int fd) {
- return nativeGetVersion(fd);
- }
-
- /** Returns the name of the model. */
- public static String getName(int fd) {
- return nativeGetName(fd);
- }
-
- /** Information about a parsed time/date. */
- public static final class DatetimeResult {
-
- public static final int GRANULARITY_YEAR = 0;
- public static final int GRANULARITY_MONTH = 1;
- public static final int GRANULARITY_WEEK = 2;
- public static final int GRANULARITY_DAY = 3;
- public static final int GRANULARITY_HOUR = 4;
- public static final int GRANULARITY_MINUTE = 5;
- public static final int GRANULARITY_SECOND = 6;
-
- private final long timeMsUtc;
- private final int granularity;
-
- public DatetimeResult(long timeMsUtc, int granularity) {
- this.timeMsUtc = timeMsUtc;
- this.granularity = granularity;
- }
-
- public long getTimeMsUtc() {
- return timeMsUtc;
- }
-
- public int getGranularity() {
- return granularity;
- }
- }
-
- /** Classification result for classifyText method. */
- public static final class ClassificationResult {
- private final String collection;
- private final float score;
- private final DatetimeResult datetimeResult;
- private final byte[] serializedKnowledgeResult;
- private final String contactName;
- private final String contactGivenName;
- private final String contactNickname;
- private final String contactEmailAddress;
- private final String contactPhoneNumber;
- private final String contactId;
- private final String appName;
- private final String appPackageName;
- private final NamedVariant[] entityData;
- private final byte[] serializedEntityData;
- private final RemoteActionTemplate[] remoteActionTemplates;
- private final long durationMs;
- private final long numericValue;
-
- public ClassificationResult(
- String collection,
- float score,
- DatetimeResult datetimeResult,
- byte[] serializedKnowledgeResult,
- String contactName,
- String contactGivenName,
- String contactNickname,
- String contactEmailAddress,
- String contactPhoneNumber,
- String contactId,
- String appName,
- String appPackageName,
- NamedVariant[] entityData,
- byte[] serializedEntityData,
- RemoteActionTemplate[] remoteActionTemplates,
- long durationMs,
- long numericValue) {
- this.collection = collection;
- this.score = score;
- this.datetimeResult = datetimeResult;
- this.serializedKnowledgeResult = serializedKnowledgeResult;
- this.contactName = contactName;
- this.contactGivenName = contactGivenName;
- this.contactNickname = contactNickname;
- this.contactEmailAddress = contactEmailAddress;
- this.contactPhoneNumber = contactPhoneNumber;
- this.contactId = contactId;
- this.appName = appName;
- this.appPackageName = appPackageName;
- this.entityData = entityData;
- this.serializedEntityData = serializedEntityData;
- this.remoteActionTemplates = remoteActionTemplates;
- this.durationMs = durationMs;
- this.numericValue = numericValue;
- }
-
- /** Returns the classified entity type. */
- public String getCollection() {
- return collection;
- }
-
- /** Confidence score between 0 and 1. */
- public float getScore() {
- return score;
- }
-
- public DatetimeResult getDatetimeResult() {
- return datetimeResult;
- }
-
- public byte[] getSerializedKnowledgeResult() {
- return serializedKnowledgeResult;
- }
-
- public String getContactName() {
- return contactName;
- }
-
- public String getContactGivenName() {
- return contactGivenName;
- }
-
- public String getContactNickname() {
- return contactNickname;
- }
-
- public String getContactEmailAddress() {
- return contactEmailAddress;
- }
-
- public String getContactPhoneNumber() {
- return contactPhoneNumber;
- }
-
- public String getContactId() {
- return contactId;
- }
-
- public String getAppName() {
- return appName;
- }
-
- public String getAppPackageName() {
- return appPackageName;
- }
-
- public NamedVariant[] getEntityData() {
- return entityData;
- }
-
- public byte[] getSerializedEntityData() {
- return serializedEntityData;
- }
-
- public RemoteActionTemplate[] getRemoteActionTemplates() {
- return remoteActionTemplates;
- }
-
- public long getDurationMs() {
- return durationMs;
- }
-
- public long getNumericValue() {
- return numericValue;
- }
- }
-
- /** Represents a result of Annotate call. */
- public static final class AnnotatedSpan {
- private final int startIndex;
- private final int endIndex;
- private final ClassificationResult[] classification;
-
- AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) {
- this.startIndex = startIndex;
- this.endIndex = endIndex;
- this.classification = classification;
- }
-
- public int getStartIndex() {
- return startIndex;
- }
-
- public int getEndIndex() {
- return endIndex;
- }
-
- public ClassificationResult[] getClassification() {
- return classification;
- }
- }
-
- /** Represents options for the suggestSelection call. */
- public static final class SelectionOptions {
- private final String locales;
- private final String detectedTextLanguageTags;
- private final int annotationUsecase;
-
- public SelectionOptions(
- String locales, String detectedTextLanguageTags, int annotationUsecase) {
- this.locales = locales;
- this.detectedTextLanguageTags = detectedTextLanguageTags;
- this.annotationUsecase = annotationUsecase;
- }
-
- public SelectionOptions(String locales, String detectedTextLanguageTags) {
- this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
- }
-
- public String getLocales() {
- return locales;
- }
-
- /** Returns a comma separated list of BCP 47 language tags. */
- public String getDetectedTextLanguageTags() {
- return detectedTextLanguageTags;
- }
-
- public int getAnnotationUsecase() {
- return annotationUsecase;
- }
- }
-
- /** Represents options for the classifyText call. */
- public static final class ClassificationOptions {
- private final long referenceTimeMsUtc;
- private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
- private final int annotationUsecase;
-
- public ClassificationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- int annotationUsecase) {
- this.referenceTimeMsUtc = referenceTimeMsUtc;
- this.referenceTimezone = referenceTimezone;
- this.locales = locales;
- this.detectedTextLanguageTags = detectedTextLanguageTags;
- this.annotationUsecase = annotationUsecase;
- }
-
- public ClassificationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- AnnotationUsecase.SMART.getValue());
- }
-
- public long getReferenceTimeMsUtc() {
- return referenceTimeMsUtc;
- }
-
- public String getReferenceTimezone() {
- return referenceTimezone;
- }
-
- public String getLocale() {
- return locales;
- }
-
- /** Returns a comma separated list of BCP 47 language tags. */
- public String getDetectedTextLanguageTags() {
- return detectedTextLanguageTags;
- }
-
- public int getAnnotationUsecase() {
- return annotationUsecase;
- }
- }
-
- /** Represents options for the annotate call. */
- public static final class AnnotationOptions {
- private final long referenceTimeMsUtc;
- private final String referenceTimezone;
- private final String locales;
- private final String detectedTextLanguageTags;
- private final String[] entityTypes;
- private final int annotationUsecase;
- private final boolean isSerializedEntityDataEnabled;
-
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags,
- Collection<String> entityTypes,
- int annotationUsecase,
- boolean isSerializedEntityDataEnabled) {
- this.referenceTimeMsUtc = referenceTimeMsUtc;
- this.referenceTimezone = referenceTimezone;
- this.locales = locales;
- this.detectedTextLanguageTags = detectedTextLanguageTags;
- this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
- this.annotationUsecase = annotationUsecase;
- this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
- }
-
- public AnnotationOptions(
- long referenceTimeMsUtc,
- String referenceTimezone,
- String locales,
- String detectedTextLanguageTags) {
- this(
- referenceTimeMsUtc,
- referenceTimezone,
- locales,
- detectedTextLanguageTags,
- null,
- AnnotationUsecase.SMART.getValue(),
- /* isSerializedEntityDataEnabled */ false);
- }
-
- public long getReferenceTimeMsUtc() {
- return referenceTimeMsUtc;
- }
-
- public String getReferenceTimezone() {
- return referenceTimezone;
- }
-
- public String getLocale() {
- return locales;
- }
-
- /** Returns a comma separated list of BCP 47 language tags. */
- public String getDetectedTextLanguageTags() {
- return detectedTextLanguageTags;
- }
-
- public String[] getEntityTypes() {
- return entityTypes;
- }
-
- public int getAnnotationUsecase() {
- return annotationUsecase;
- }
-
- public boolean isSerializedEntityDataEnabled() {
- return isSerializedEntityDataEnabled;
- }
- }
-
- /**
- * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long
- * as the pointer is used.
- */
- long getNativeAnnotator() {
- return nativeGetNativeModelPtr(annotatorPtr);
- }
-
- private static native long nativeNewAnnotator(int fd);
-
- private static native long nativeNewAnnotatorFromPath(String path);
-
- private static native String nativeGetLocales(int fd);
-
- private static native int nativeGetVersion(int fd);
-
- private static native String nativeGetName(int fd);
-
- private native long nativeGetNativeModelPtr(long context);
-
- private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
-
- private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig);
-
- private native boolean nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig);
-
- private native int[] nativeSuggestSelection(
- long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
-
- private native ClassificationResult[] nativeClassifyText(
- long context,
- String text,
- int selectionBegin,
- int selectionEnd,
- ClassificationOptions options,
- Object appContext,
- String deviceLocales);
-
- private native AnnotatedSpan[] nativeAnnotate(
- long context, String text, AnnotationOptions options);
-
- private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
-
- private native void nativeCloseAnnotator(long context);
-}
diff --git a/java/com/google/android/textclassifier/LangIdModel.java b/java/com/google/android/textclassifier/LangIdModel.java
deleted file mode 100644
index d3e166f..0000000
--- a/java/com/google/android/textclassifier/LangIdModel.java
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier;
-
-import java.util.concurrent.atomic.AtomicBoolean;
-
-/**
- * Java wrapper for LangId native library interface. This class is used to detect languages in text.
- *
- * @hide
- */
-public final class LangIdModel implements AutoCloseable {
- private final AtomicBoolean isClosed = new AtomicBoolean(false);
-
- static {
- System.loadLibrary("textclassifier");
- }
-
- private long modelPtr;
-
- /** Creates a new instance of LangId predictor, using the provided model image. */
- public LangIdModel(int fd) {
- modelPtr = nativeNew(fd);
- if (modelPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
- }
- }
-
- /** Creates a new instance of LangId predictor, using the provided model image. */
- public LangIdModel(String modelPath) {
- modelPtr = nativeNewFromPath(modelPath);
- if (modelPtr == 0L) {
- throw new IllegalArgumentException("Couldn't initialize LangId from given file.");
- }
- }
-
- /** Detects the languages for given text. */
- public LanguageResult[] detectLanguages(String text) {
- return nativeDetectLanguages(modelPtr, text);
- }
-
- /** Frees up the allocated memory. */
- @Override
- public void close() {
- if (isClosed.compareAndSet(false, true)) {
- nativeClose(modelPtr);
- modelPtr = 0L;
- }
- }
-
- @Override
- protected void finalize() throws Throwable {
- try {
- close();
- } finally {
- super.finalize();
- }
- }
-
- /** Result for detectLanguages method. */
- public static final class LanguageResult {
- final String mLanguage;
- final float mScore;
-
- LanguageResult(String language, float score) {
- mLanguage = language;
- mScore = score;
- }
-
- public final String getLanguage() {
- return mLanguage;
- }
-
- public final float getScore() {
- return mScore;
- }
- }
-
- /** Returns the version of the LangId model used. */
- public int getVersion() {
- return nativeGetVersion(modelPtr);
- }
-
- public float getLangIdThreshold() {
- return nativeGetLangIdThreshold(modelPtr);
- }
-
- public static int getVersion(int fd) {
- return nativeGetVersionFromFd(fd);
- }
-
- private static native long nativeNew(int fd);
-
- private static native long nativeNewFromPath(String path);
-
- private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
-
- private native void nativeClose(long nativePtr);
-
- private native int nativeGetVersion(long nativePtr);
-
- private static native int nativeGetVersionFromFd(int fd);
-
- private native float nativeGetLangIdThreshold(long nativePtr);
-}
diff --git a/java/com/google/android/textclassifier/NamedVariant.java b/java/com/google/android/textclassifier/NamedVariant.java
deleted file mode 100644
index d04bb11..0000000
--- a/java/com/google/android/textclassifier/NamedVariant.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier;
-
-/**
- * Represents a union of different basic types.
- *
- * @hide
- */
-public final class NamedVariant {
- public static final int TYPE_EMPTY = 0;
- public static final int TYPE_INT = 1;
- public static final int TYPE_LONG = 2;
- public static final int TYPE_FLOAT = 3;
- public static final int TYPE_DOUBLE = 4;
- public static final int TYPE_BOOL = 5;
- public static final int TYPE_STRING = 6;
-
- public NamedVariant(String name, int value) {
- this.name = name;
- this.intValue = value;
- this.type = TYPE_INT;
- }
-
- public NamedVariant(String name, long value) {
- this.name = name;
- this.longValue = value;
- this.type = TYPE_LONG;
- }
-
- public NamedVariant(String name, float value) {
- this.name = name;
- this.floatValue = value;
- this.type = TYPE_FLOAT;
- }
-
- public NamedVariant(String name, double value) {
- this.name = name;
- this.doubleValue = value;
- this.type = TYPE_DOUBLE;
- }
-
- public NamedVariant(String name, boolean value) {
- this.name = name;
- this.boolValue = value;
- this.type = TYPE_BOOL;
- }
-
- public NamedVariant(String name, String value) {
- this.name = name;
- this.stringValue = value;
- this.type = TYPE_STRING;
- }
-
- public String getName() {
- return name;
- }
-
- public int getType() {
- return type;
- }
-
- public int getInt() {
- assert (type == TYPE_INT);
- return intValue;
- }
-
- public long getLong() {
- assert (type == TYPE_LONG);
- return longValue;
- }
-
- public float getFloat() {
- assert (type == TYPE_FLOAT);
- return floatValue;
- }
-
- public double getDouble() {
- assert (type == TYPE_DOUBLE);
- return doubleValue;
- }
-
- public boolean getBool() {
- assert (type == TYPE_BOOL);
- return boolValue;
- }
-
- public String getString() {
- assert (type == TYPE_STRING);
- return stringValue;
- }
-
- private final String name;
- private final int type;
- private int intValue;
- private long longValue;
- private float floatValue;
- private double doubleValue;
- private boolean boolValue;
- private String stringValue;
-}
diff --git a/java/proguard.flags b/java/proguard.flags
new file mode 100644
index 0000000..fd6c544
--- /dev/null
+++ b/java/proguard.flags
@@ -0,0 +1,6 @@
+# Jni classes
+-keep class com.google.android.textclassifier.** { *; }
+
+# compileOnly dependency, optional in runtime.
+-dontwarn androidx.paging.PositionalDataSource
+-dontwarn androidx.lifecycle.LiveData
\ No newline at end of file
diff --git a/java/res/drawable/tcs_app_icon.xml b/java/res/drawable/tcs_app_icon.xml
new file mode 100644
index 0000000..8cce7ca
--- /dev/null
+++ b/java/res/drawable/tcs_app_icon.xml
@@ -0,0 +1,11 @@
+<?xml version="1.0" encoding="utf-8"?>
+<vector xmlns:android="http://schemas.android.com/apk/res/android"
+ android:width="24dp"
+ android:height="24dp"
+ android:viewportWidth="24"
+ android:viewportHeight="24">
+
+ <path
+ android:fillColor="#000000"
+ android:pathData="M2.5 4v3h5v12h3V7h5V4h-13zm19 5h-9v3h3v7h3v-7h3V9z" />
+</vector>
\ No newline at end of file
diff --git a/java/res/values-af/strings.xml b/java/res/values-af/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-af/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-am/strings.xml b/java/res/values-am/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-am/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ar/strings.xml b/java/res/values-ar/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ar/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-as/strings.xml b/java/res/values-as/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-as/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-az/strings.xml b/java/res/values-az/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-az/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-b+es+419/strings.xml b/java/res/values-b+es+419/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-b+es+419/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-b+sr+Latn/strings.xml b/java/res/values-b+sr+Latn/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-b+sr+Latn/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-be/strings.xml b/java/res/values-be/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-be/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-bg/strings.xml b/java/res/values-bg/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-bg/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-bn/strings.xml b/java/res/values-bn/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-bn/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-bs/strings.xml b/java/res/values-bs/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-bs/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ca/strings.xml b/java/res/values-ca/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ca/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-cs/strings.xml b/java/res/values-cs/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-cs/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-da/strings.xml b/java/res/values-da/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-da/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-de/strings.xml b/java/res/values-de/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-de/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-el/strings.xml b/java/res/values-el/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-el/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-en-rGB/strings.xml b/java/res/values-en-rGB/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-en-rGB/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-es/strings.xml b/java/res/values-es/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-es/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-et/strings.xml b/java/res/values-et/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-et/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-eu/strings.xml b/java/res/values-eu/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-eu/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-fa/strings.xml b/java/res/values-fa/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-fa/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-fi/strings.xml b/java/res/values-fi/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-fi/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-fr-rCA/strings.xml b/java/res/values-fr-rCA/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-fr-rCA/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-fr/strings.xml b/java/res/values-fr/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-fr/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-gl/strings.xml b/java/res/values-gl/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-gl/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-gu/strings.xml b/java/res/values-gu/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-gu/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-hi/strings.xml b/java/res/values-hi/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-hi/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-hr/strings.xml b/java/res/values-hr/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-hr/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-hu/strings.xml b/java/res/values-hu/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-hu/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-hy/strings.xml b/java/res/values-hy/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-hy/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-id/strings.xml b/java/res/values-id/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-id/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-is/strings.xml b/java/res/values-is/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-is/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-it/strings.xml b/java/res/values-it/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-it/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-iw/strings.xml b/java/res/values-iw/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-iw/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ja/strings.xml b/java/res/values-ja/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ja/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ka/strings.xml b/java/res/values-ka/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ka/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-kk/strings.xml b/java/res/values-kk/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-kk/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-km/strings.xml b/java/res/values-km/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-km/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-kn/strings.xml b/java/res/values-kn/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-kn/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ko/strings.xml b/java/res/values-ko/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ko/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ky/strings.xml b/java/res/values-ky/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ky/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-lo/strings.xml b/java/res/values-lo/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-lo/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-lt/strings.xml b/java/res/values-lt/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-lt/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-lv/strings.xml b/java/res/values-lv/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-lv/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-mk/strings.xml b/java/res/values-mk/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-mk/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ml/strings.xml b/java/res/values-ml/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ml/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-mn/strings.xml b/java/res/values-mn/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-mn/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-mr/strings.xml b/java/res/values-mr/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-mr/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ms/strings.xml b/java/res/values-ms/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ms/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-my/strings.xml b/java/res/values-my/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-my/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ne/strings.xml b/java/res/values-ne/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ne/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-nl/strings.xml b/java/res/values-nl/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-nl/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-no/strings.xml b/java/res/values-no/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-no/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-or/strings.xml b/java/res/values-or/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-or/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-pa/strings.xml b/java/res/values-pa/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-pa/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-pl/strings.xml b/java/res/values-pl/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-pl/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-pt-rBR/strings.xml b/java/res/values-pt-rBR/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-pt-rBR/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-pt-rPT/strings.xml b/java/res/values-pt-rPT/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-pt-rPT/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ro/strings.xml b/java/res/values-ro/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ro/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ru/strings.xml b/java/res/values-ru/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ru/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-si/strings.xml b/java/res/values-si/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-si/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sk/strings.xml b/java/res/values-sk/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sk/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sl/strings.xml b/java/res/values-sl/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sl/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sq/strings.xml b/java/res/values-sq/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sq/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sr/strings.xml b/java/res/values-sr/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sr/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sv/strings.xml b/java/res/values-sv/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sv/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-sw/strings.xml b/java/res/values-sw/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-sw/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ta/strings.xml b/java/res/values-ta/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ta/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-te/strings.xml b/java/res/values-te/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-te/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-th/strings.xml b/java/res/values-th/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-th/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-tl/strings.xml b/java/res/values-tl/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-tl/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-tr/strings.xml b/java/res/values-tr/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-tr/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-uk/strings.xml b/java/res/values-uk/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-uk/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-ur/strings.xml b/java/res/values-ur/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-ur/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-uz/strings.xml b/java/res/values-uz/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-uz/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-vi/strings.xml b/java/res/values-vi/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-vi/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-zh-rHK/strings.xml b/java/res/values-zh-rHK/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-zh-rHK/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-zh-rTW/strings.xml b/java/res/values-zh-rTW/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-zh-rTW/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-zh/strings.xml b/java/res/values-zh/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-zh/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values-zu/strings.xml b/java/res/values-zu/strings.xml
new file mode 100755
index 0000000..56d9f67
--- /dev/null
+++ b/java/res/values-zu/strings.xml
@@ -0,0 +1,3 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+</resources>
diff --git a/java/res/values/strings.xml b/java/res/values/strings.xml
new file mode 100644
index 0000000..f681d9d
--- /dev/null
+++ b/java/res/values/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <!-- Label for this app [CHAR LIMIT=30] -->
+ <string name="tcs_app_name" translatable="false">Text classifier</string>
+</resources>
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
new file mode 100644
index 0000000..a51c95d
--- /dev/null
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -0,0 +1,244 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.app.Person;
+import android.app.RemoteAction;
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.text.TextUtils;
+import android.util.ArrayMap;
+import android.util.Pair;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.ConversationActions.Message;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
+import com.android.textclassifier.common.logging.ResultIdUtils;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import com.google.common.base.Equivalence;
+import com.google.common.base.Equivalence.Wrapper;
+import com.google.common.base.Optional;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** Helper class for action suggestions. */
+final class ActionsSuggestionsHelper {
+ private static final String TAG = "ActionsSuggestions";
+ private static final int USER_LOCAL = 0;
+ private static final int FIRST_NON_LOCAL_USER = 1;
+
+ private ActionsSuggestionsHelper() {}
+
+ /**
+ * Converts the messages to a list of native messages object that the model can understand.
+ *
+ * <p>User id encoding - local user is represented as 0, Other users are numbered according to how
+ * far before they spoke last time in the conversation. For example, considering this
+ * conversation:
+ *
+ * <ul>
+ * <li>User A: xxx
+ * <li>Local user: yyy
+ * <li>User B: zzz
+ * </ul>
+ *
+ * User A will be encoded as 2, user B will be encoded as 1 and local user will be encoded as 0.
+ */
+ public static ActionsSuggestionsModel.ConversationMessage[] toNativeMessages(
+ List<ConversationActions.Message> messages,
+ Function<CharSequence, List<String>> languageDetector) {
+ List<ConversationActions.Message> messagesWithText =
+ messages.stream()
+ .filter(message -> !TextUtils.isEmpty(message.getText()))
+ .collect(Collectors.toCollection(ArrayList::new));
+ if (messagesWithText.isEmpty()) {
+ return new ActionsSuggestionsModel.ConversationMessage[0];
+ }
+ Deque<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayDeque<>();
+ PersonEncoder personEncoder = new PersonEncoder();
+ int size = messagesWithText.size();
+ for (int i = size - 1; i >= 0; i--) {
+ ConversationActions.Message message = messagesWithText.get(i);
+ long referenceTime =
+ message.getReferenceTime() == null
+ ? 0
+ : message.getReferenceTime().toInstant().toEpochMilli();
+ String timeZone =
+ message.getReferenceTime() == null ? null : message.getReferenceTime().getZone().getId();
+ nativeMessages.push(
+ new ActionsSuggestionsModel.ConversationMessage(
+ personEncoder.encode(message.getAuthor()),
+ message.getText().toString(),
+ referenceTime,
+ timeZone,
+ String.join(",", languageDetector.apply(message.getText()))));
+ }
+ return nativeMessages.toArray(
+ new ActionsSuggestionsModel.ConversationMessage[nativeMessages.size()]);
+ }
+
+ /** Returns the result id for logging. */
+ public static String createResultId(
+ Context context,
+ List<ConversationActions.Message> messages,
+ Optional<ModelFile> actionsModel,
+ Optional<ModelFile> annotatorModel,
+ Optional<ModelFile> langIdModel) {
+ int hash =
+ Objects.hash(
+ messages.stream().mapToInt(ActionsSuggestionsHelper::hashMessage),
+ context.getPackageName(),
+ System.currentTimeMillis());
+ return ResultIdUtils.createId(
+ hash, ModelFile.toModelInfos(actionsModel, annotatorModel, langIdModel));
+ }
+
+ /** Generated labeled intent from an action suggestion and return the resolved result. */
+ @Nullable
+ public static LabeledIntent.Result createLabeledIntentResult(
+ Context context,
+ TemplateIntentFactory templateIntentFactory,
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion) {
+ RemoteActionTemplate[] remoteActionTemplates = nativeSuggestion.getRemoteActionTemplates();
+ if (remoteActionTemplates == null) {
+ TcLog.w(
+ TAG, "createRemoteAction: Missing template for type " + nativeSuggestion.getActionType());
+ return null;
+ }
+ List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ if (labeledIntents.isEmpty()) {
+ return null;
+ }
+ // Given that we only support implicit intent here, we should expect there is just one
+ // intent for each action type.
+ LabeledIntent.TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(nativeSuggestion.getActionType());
+ return labeledIntents.get(0).resolve(context, titleChooser);
+ }
+
+ /** Returns a {@link LabeledIntent.TitleChooser} for conversation actions use case. */
+ @Nullable
+ public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
+ if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
+ return (labeledIntent, resolveInfo) -> {
+ if (resolveInfo.handleAllWebDataURI) {
+ return labeledIntent.titleWithEntity;
+ }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return labeledIntent.titleWithEntity;
+ }
+ return labeledIntent.titleWithoutEntity;
+ };
+ }
+ return null;
+ }
+
+ /**
+ * Returns a list of {@link ConversationAction}s that have 0 duplicates. Two actions are
+ * duplicates if they may look the same to users. This function assumes every ConversationActions
+ * with a non-null RemoteAction also have a non-null intent in the extras.
+ */
+ public static List<ConversationAction> removeActionsWithDuplicates(
+ List<ConversationAction> conversationActions) {
+ // Ideally, we should compare title and icon here, but comparing icon is expensive and thus
+ // we use the component name of the target handler as the heuristic.
+ Map<Pair<String, String>, Integer> counter = new ArrayMap<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null) {
+ continue;
+ }
+ Integer existingCount = counter.getOrDefault(representation, 0);
+ counter.put(representation, existingCount + 1);
+ }
+ List<ConversationAction> result = new ArrayList<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null || counter.getOrDefault(representation, 0) == 1) {
+ result.add(conversationAction);
+ }
+ }
+ return result;
+ }
+
+ @Nullable
+ private static Pair<String, String> getRepresentation(ConversationAction conversationAction) {
+ RemoteAction remoteAction = conversationAction.getAction();
+ if (remoteAction == null) {
+ return null;
+ }
+ Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
+ ComponentName componentName = actionIntent.getComponent();
+ // Action without a component name will be considered as from the same app.
+ String packageName = componentName == null ? null : componentName.getPackageName();
+ return new Pair<>(conversationAction.getAction().getTitle().toString(), packageName);
+ }
+
+ private static final class PersonEncoder {
+ private static final Equivalence<Person> EQUIVALENCE = new PersonEquivalence();
+ private static final Equivalence.Wrapper<Person> PERSON_USER_SELF =
+ EQUIVALENCE.wrap(Message.PERSON_USER_SELF);
+
+ private final Map<Equivalence.Wrapper<Person>, Integer> personToUserIdMap = new ArrayMap<>();
+ private int nextUserId = FIRST_NON_LOCAL_USER;
+
+ private int encode(Person person) {
+ Wrapper<Person> personWrapper = EQUIVALENCE.wrap(person);
+ if (PERSON_USER_SELF.equals(personWrapper)) {
+ return USER_LOCAL;
+ }
+ Integer result = personToUserIdMap.get(personWrapper);
+ if (result == null) {
+ personToUserIdMap.put(personWrapper, nextUserId);
+ result = nextUserId;
+ nextUserId++;
+ }
+ return result;
+ }
+
+ private static final class PersonEquivalence extends Equivalence<Person> {
+
+ @Override
+ protected boolean doEquivalent(Person a, Person b) {
+ return Objects.equals(a.getKey(), b.getKey())
+ && TextUtils.equals(a.getName(), b.getName())
+ && Objects.equals(a.getUri(), b.getUri());
+ }
+
+ @Override
+ protected int doHash(Person person) {
+ return Objects.hash(person.getKey(), person.getName(), person.getUri());
+ }
+ }
+ }
+
+ private static int hashMessage(ConversationActions.Message message) {
+ return Objects.hash(message.getAuthor(), message.getText(), message.getReferenceTime());
+ }
+}
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
new file mode 100644
index 0000000..d2c1e38
--- /dev/null
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.CancellationSignal;
+import android.service.textclassifier.TextClassifierService;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationSessionId;
+import android.view.textclassifier.TextClassifierEvent;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextSelection;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.io.FileDescriptor;
+import java.io.PrintWriter;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+
+/** An implementation of a TextClassifierService. */
+public final class DefaultTextClassifierService extends TextClassifierService {
+ private static final String TAG = "default_tcs";
+
+ // TODO: Figure out do we need more concurrency.
+ private final ListeningExecutorService normPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newFixedThreadPool(
+ /* nThreads= */ 2,
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-norm-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY)
+ .build()));
+
+ private final ListeningExecutorService lowPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newSingleThreadExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-low-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY - 1)
+ .build()));
+
+ private TextClassifierImpl textClassifier;
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ textClassifier = new TextClassifierImpl(this, new TextClassifierSettings());
+ }
+
+ @Override
+ public void onSuggestSelection(
+ TextClassificationSessionId sessionId,
+ TextSelection.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextSelection> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestSelection(request), callback, cancellationSignal);
+ }
+
+ @Override
+ public void onClassifyText(
+ TextClassificationSessionId sessionId,
+ TextClassification.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextClassification> callback) {
+ handleRequestAsync(() -> textClassifier.classifyText(request), callback, cancellationSignal);
+ }
+
+ @Override
+ public void onGenerateLinks(
+ TextClassificationSessionId sessionId,
+ TextLinks.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLinks> callback) {
+ handleRequestAsync(() -> textClassifier.generateLinks(request), callback, cancellationSignal);
+ }
+
+ @Override
+ public void onSuggestConversationActions(
+ TextClassificationSessionId sessionId,
+ ConversationActions.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<ConversationActions> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestConversationActions(request), callback, cancellationSignal);
+ }
+
+ @Override
+ public void onDetectLanguage(
+ TextClassificationSessionId sessionId,
+ TextLanguage.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLanguage> callback) {
+ handleRequestAsync(() -> textClassifier.detectLanguage(request), callback, cancellationSignal);
+ }
+
+ @Override
+ public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
+ handleEvent(() -> textClassifier.onSelectionEvent(event));
+ }
+
+ @Override
+ public void onTextClassifierEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event));
+ }
+
+ @Override
+ protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
+ IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
+ textClassifier.dump(indentingPrintWriter);
+ indentingPrintWriter.flush();
+ }
+
+ private <T> void handleRequestAsync(
+ Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
+ ListenableFuture<T> result = normPriorityExecutor.submit(callable);
+ Futures.addCallback(
+ result,
+ new FutureCallback<T>() {
+ @Override
+ public void onSuccess(T result) {
+ callback.onSuccess(result);
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ callback.onFailure(t.getMessage());
+ }
+ },
+ MoreExecutors.directExecutor());
+ cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true));
+ }
+
+ private void handleEvent(Runnable runnable) {
+ ListenableFuture<Void> result =
+ lowPriorityExecutor.submit(
+ () -> {
+ runnable.run();
+ return null;
+ });
+ Futures.addCallback(
+ result,
+ new FutureCallback<Void>() {
+ @Override
+ public void onSuccess(Void result) {}
+
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ }
+ },
+ MoreExecutors.directExecutor());
+ }
+}
diff --git a/java/src/com/android/textclassifier/Entity.java b/java/src/com/android/textclassifier/Entity.java
new file mode 100644
index 0000000..6410a3e
--- /dev/null
+++ b/java/src/com/android/textclassifier/Entity.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import androidx.annotation.FloatRange;
+import com.google.common.base.Objects;
+import com.google.common.base.Preconditions;
+
+/** A representation of an identified entity with the confidence score */
+public final class Entity implements Comparable<Entity> {
+
+ private final String entityType;
+ private final float score;
+
+ public Entity(String entityType, float score) {
+ this.entityType = Preconditions.checkNotNull(entityType);
+ this.score = score;
+ }
+
+ public String getEntityType() {
+ return entityType;
+ }
+
+ /**
+ * Returns the confidence score of the entity, which ranged from 0.0 (low confidence) to 1.0 (high
+ * confidence).
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public Float getScore() {
+ return score;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(entityType, score);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Entity entity = (Entity) o;
+ return Float.compare(entity.score, score) == 0
+ && java.util.Objects.equals(entityType, entity.entityType);
+ }
+
+ @Override
+ public String toString() {
+ return "Entity{" + entityType + ": " + score + "}";
+ }
+
+ @Override
+ public int compareTo(Entity entity) {
+ // This method is implemented for sorting Entity. Sort the entities by the confidence score
+ // in descending order firstly. If the scores are the same, then sort them by the entity
+ // type in ascending order.
+ int result = Float.compare(entity.getScore(), score);
+ if (result == 0) {
+ return entityType.compareTo(entity.getEntityType());
+ }
+ return result;
+ }
+}
diff --git a/java/src/com/android/textclassifier/EntityConfidence.java b/java/src/com/android/textclassifier/EntityConfidence.java
new file mode 100644
index 0000000..ef8ff05
--- /dev/null
+++ b/java/src/com/android/textclassifier/EntityConfidence.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import androidx.annotation.FloatRange;
+import androidx.collection.ArrayMap;
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/** Helper object for setting and getting entity scores for classified text. */
+final class EntityConfidence {
+
+ static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
+
+ private final ArrayMap<String, Float> entityConfidence = new ArrayMap<>();
+ private final ArrayList<String> sortedEntities = new ArrayList<>();
+
+ /**
+ * Constructs an EntityConfidence from a map of entity to confidence.
+ *
+ * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
+ *
+ * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
+ * (high confidence).
+ */
+ EntityConfidence(Map<String, Float> source) {
+ Preconditions.checkNotNull(source);
+
+ // Prune non-existent entities and clamp to 1.
+ entityConfidence.ensureCapacity(source.size());
+ for (Map.Entry<String, Float> it : source.entrySet()) {
+ if (it.getValue() <= 0) {
+ continue;
+ }
+ entityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
+ }
+ resetSortedEntitiesFromMap();
+ }
+
+ /**
+ * Returns an immutable list of entities found in the classified text ordered from high confidence
+ * to low confidence.
+ */
+ public List<String> getEntities() {
+ return Collections.unmodifiableList(sortedEntities);
+ }
+
+ /**
+ * Returns the confidence score for the specified entity. The value ranges from 0 (low confidence)
+ * to 1 (high confidence). 0 indicates that the entity was not found for the classified text.
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public float getConfidenceScore(String entity) {
+ return entityConfidence.getOrDefault(entity, 0f);
+ }
+
+ @Override
+ public String toString() {
+ return entityConfidence.toString();
+ }
+
+ private void resetSortedEntitiesFromMap() {
+ sortedEntities.clear();
+ sortedEntities.ensureCapacity(entityConfidence.size());
+ sortedEntities.addAll(entityConfidence.keySet());
+ sortedEntities.sort(
+ (e1, e2) -> {
+ float score1 = entityConfidence.get(e1);
+ float score2 = entityConfidence.get(e2);
+ return Float.compare(score2, score1);
+ });
+ }
+}
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
new file mode 100644
index 0000000..fd64581
--- /dev/null
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -0,0 +1,308 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.app.RemoteAction;
+import android.content.Intent;
+import android.os.Bundle;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLinks;
+import androidx.core.util.Pair;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.common.annotations.VisibleForTesting;
+import java.util.ArrayList;
+import java.util.List;
+import javax.annotation.Nullable;
+
+/** Utility class for inserting and retrieving data in TextClassifier request/response extras. */
+// TODO: Make this a TestApi for CTS testing.
+public final class ExtrasUtils {
+
+ // Keys for response objects.
+ private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
+ private static final String ENTITIES_EXTRAS = "entities-extras";
+ private static final String ACTION_INTENT = "action-intent";
+ private static final String ACTIONS_INTENTS = "actions-intents";
+ private static final String FOREIGN_LANGUAGE = "foreign-language";
+ private static final String ENTITY_TYPE = "entity-type";
+ private static final String SCORE = "score";
+ private static final String MODEL_VERSION = "model-version";
+ private static final String MODEL_NAME = "model-name";
+ private static final String TEXT_LANGUAGES = "text-languages";
+ private static final String ENTITIES = "entities";
+
+ // Keys for request objects.
+ private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
+ "is-serialized-entity-data-enabled";
+
+ private ExtrasUtils() {}
+
+ /** Bundles and returns foreign language detection information for TextClassifier responses. */
+ static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
+ final Bundle bundle = new Bundle();
+ bundle.putString(ENTITY_TYPE, language);
+ bundle.putFloat(SCORE, score);
+ bundle.putInt(MODEL_VERSION, modelVersion);
+ bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
+ return bundle;
+ }
+
+ /**
+ * Stores {@code extra} as foreign language information in TextClassifier response object's extras
+ * {@code container}.
+ *
+ * @see #getForeignLanguageExtra(TextClassification)
+ */
+ static void putForeignLanguageExtra(Bundle container, Bundle extra) {
+ container.putParcelable(FOREIGN_LANGUAGE, extra);
+ }
+
+ /**
+ * Returns foreign language detection information contained in the TextClassification object.
+ * responses.
+ *
+ * @see #putForeignLanguageExtra(Bundle, Bundle)
+ */
+ @Nullable
+ @VisibleForTesting
+ public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
+ }
+ return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
+ }
+
+ /** @see #getTopLanguage(Intent) */
+ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
+ final int maxSize = Math.min(3, languageScores.getEntities().size());
+ final String[] languages =
+ languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
+ final float[] scores = new float[languages.length];
+ for (int i = 0; i < languages.length; i++) {
+ scores[i] = languageScores.getConfidenceScore(languages[i]);
+ }
+ container.putStringArray(ENTITY_TYPE, languages);
+ container.putFloatArray(SCORE, scores);
+ }
+
+ /** See {@link #putTopLanguageScores(Bundle, EntityConfidence)}. */
+ @Nullable
+ static Pair<String, Float> getTopLanguage(@Nullable Intent intent) {
+ if (intent == null) {
+ return null;
+ }
+ final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
+ if (tcBundle == null) {
+ return null;
+ }
+ final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
+ if (textLanguagesExtra == null) {
+ return null;
+ }
+ final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
+ final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
+ if (languages == null
+ || scores == null
+ || languages.length == 0
+ || languages.length != scores.length) {
+ return null;
+ }
+ int highestScoringIndex = 0;
+ for (int i = 1; i < languages.length; i++) {
+ if (scores[highestScoringIndex] < scores[i]) {
+ highestScoringIndex = i;
+ }
+ }
+ return Pair.create(languages[highestScoringIndex], scores[highestScoringIndex]);
+ }
+
+ public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
+ container.putBundle(TEXT_LANGUAGES, extra);
+ }
+
+ /**
+ * Stores {@code actionsIntents} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
+ container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
+ }
+
+ /**
+ * Stores {@code actionIntent} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
+ container.putParcelable(ACTION_INTENT, actionIntent);
+ }
+
+ /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
+ @Nullable
+ public static Intent getActionIntent(Bundle container) {
+ return container.getParcelable(ACTION_INTENT);
+ }
+
+ /**
+ * Stores serialized entity data information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putSerializedEntityData(
+ Bundle container, @Nullable byte[] serializedEntityData) {
+ container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
+ }
+
+ /** Returns serialized entity data information contained in a TextClassifier response object. */
+ @Nullable
+ public static byte[] getSerializedEntityData(Bundle container) {
+ return container.getByteArray(SERIALIZED_ENTITIES_DATA);
+ }
+
+ /**
+ * Stores {@code entities} information in TextClassifier response object's extras {@code
+ * container}.
+ *
+ * @see {@link #getCopyText(Bundle)}
+ */
+ public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
+ container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
+ }
+
+ /**
+ * Returns {@code entities} information contained in a TextClassifier response object.
+ *
+ * @see {@link #putEntitiesExtras(Bundle, Bundle)}
+ */
+ @Nullable
+ public static String getCopyText(Bundle container) {
+ Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
+ if (entitiesExtras == null) {
+ return null;
+ }
+ return entitiesExtras.getString("text");
+ }
+
+ /** Returns {@code actionIntents} information contained in the TextClassification object. */
+ @Nullable
+ public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
+ }
+ return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
+ }
+
+ /**
+ * Returns the first action found in the {@code classification} object with an intent action
+ * string, {@code intentAction}.
+ */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findAction(
+ @Nullable TextClassification classification, @Nullable String intentAction) {
+ if (classification == null || intentAction == null) {
+ return null;
+ }
+ final ArrayList<Intent> actionIntents = getActionsIntents(classification);
+ if (actionIntents != null) {
+ final int size = actionIntents.size();
+ for (int i = 0; i < size; i++) {
+ final Intent intent = actionIntents.get(i);
+ if (intent != null && intentAction.equals(intent.getAction())) {
+ return classification.getActions().get(i);
+ }
+ }
+ }
+ return null;
+ }
+
+ /** Returns the first "translate" action found in the {@code classification} object. */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
+ return findAction(classification, Intent.ACTION_TRANSLATE);
+ }
+
+ /** Returns the entity type contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static String getEntityType(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
+ }
+ return extra.getString(ENTITY_TYPE);
+ }
+
+ /** Returns the score contained in the {@code extra}. */
+ @VisibleForTesting
+ public static float getScore(Bundle extra) {
+ final int defaultValue = -1;
+ if (extra == null) {
+ return defaultValue;
+ }
+ return extra.getFloat(SCORE, defaultValue);
+ }
+
+ /** Returns the model name contained in the {@code extra}. */
+ @Nullable
+ public static String getModelName(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
+ }
+ return extra.getString(MODEL_NAME);
+ }
+
+ /** Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}. */
+ public static void putEntities(
+ Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
+ if (classifications == null || classifications.length == 0) {
+ return;
+ }
+ ArrayList<Bundle> entitiesBundle = new ArrayList<>();
+ for (AnnotatorModel.ClassificationResult classification : classifications) {
+ if (classification == null) {
+ continue;
+ }
+ Bundle entityBundle = new Bundle();
+ entityBundle.putString(ENTITY_TYPE, classification.getCollection());
+ entityBundle.putByteArray(SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
+ entitiesBundle.add(entityBundle);
+ }
+ if (!entitiesBundle.isEmpty()) {
+ container.putParcelableArrayList(ENTITIES, entitiesBundle);
+ }
+ }
+
+ /** Returns a list of entities contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static List<Bundle> getEntities(Bundle container) {
+ return container.getParcelableArrayList(ENTITIES);
+ }
+
+ /** Whether the annotator should populate serialized entity data into the result object. */
+ public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
+ return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
+ }
+
+ /**
+ * To indicate whether the annotator should populate serialized entity data in the result object.
+ */
+ @VisibleForTesting
+ public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
+ bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
new file mode 100644
index 0000000..a6f64d8
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -0,0 +1,311 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.LocaleList;
+import android.os.ParcelFileDescriptor;
+import android.text.TextUtils;
+import androidx.annotation.GuardedBy;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** Manages model files that are listed by the model files supplier. */
+final class ModelFileManager {
+ private static final String TAG = "ModelFileManager";
+
+ private final Supplier<ImmutableList<ModelFile>> modelFileSupplier;
+
+ public ModelFileManager(Supplier<ImmutableList<ModelFile>> modelFileSupplier) {
+ this.modelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ }
+
+ /** Returns an immutable list of model files listed by the given model files supplier. */
+ public ImmutableList<ModelFile> listModelFiles() {
+ return modelFileSupplier.get();
+ }
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param localeList the required locales, use {@code null} if there is no preference.
+ */
+ public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
+ final String languages =
+ localeList == null || localeList.isEmpty()
+ ? LocaleList.getDefault().toLanguageTags()
+ : localeList.toLanguageTags();
+ final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles()) {
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /** Default implementation of the model file supplier. */
+ public static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
+ private final File updatedModelFile;
+ private final File factoryModelDir;
+ private final Pattern modelFilenamePattern;
+ private final Function<Integer, Integer> versionSupplier;
+ private final Function<Integer, String> supportedLocalesSupplier;
+ private final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private ImmutableList<ModelFile> factoryModels;
+
+ public ModelFileSupplierImpl(
+ File factoryModelDir,
+ String factoryModelFileNameRegex,
+ File updatedModelFile,
+ Function<Integer, Integer> versionSupplier,
+ Function<Integer, String> supportedLocalesSupplier) {
+ this.updatedModelFile = Preconditions.checkNotNull(updatedModelFile);
+ this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
+ modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
+ this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
+ this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> get() {
+ final List<ModelFile> modelFiles = new ArrayList<>();
+ // The update model has the highest precedence.
+ if (updatedModelFile.exists()) {
+ final ModelFile updatedModel = createModelFile(updatedModelFile);
+ if (updatedModel != null) {
+ modelFiles.add(updatedModel);
+ }
+ }
+ // Factory models should never have overlapping locales, so the order doesn't matter.
+ synchronized (lock) {
+ if (factoryModels == null) {
+ factoryModels = getFactoryModels();
+ }
+ modelFiles.addAll(factoryModels);
+ }
+ return ImmutableList.copyOf(modelFiles);
+ }
+
+ private ImmutableList<ModelFile> getFactoryModels() {
+ List<ModelFile> factoryModelFiles = new ArrayList<>();
+ if (factoryModelDir.exists() && factoryModelDir.isDirectory()) {
+ final File[] files = factoryModelDir.listFiles();
+ for (File file : files) {
+ final Matcher matcher = modelFilenamePattern.matcher(file.getName());
+ if (matcher.matches() && file.isFile()) {
+ final ModelFile model = createModelFile(file);
+ if (model != null) {
+ factoryModelFiles.add(model);
+ }
+ }
+ }
+ }
+ return ImmutableList.copyOf(factoryModelFiles);
+ }
+
+ /** Returns null if the path did not point to a compatible model. */
+ @Nullable
+ private ModelFile createModelFile(File file) {
+ if (!file.exists()) {
+ return null;
+ }
+ ParcelFileDescriptor modelFd = null;
+ try {
+ modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ if (modelFd == null) {
+ return null;
+ }
+ final int modelFdInt = modelFd.getFd();
+ final int version = versionSupplier.apply(modelFdInt);
+ final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
+ if (supportedLocalesStr.isEmpty()) {
+ TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
+ return null;
+ }
+ final List<Locale> supportedLocales = new ArrayList<>();
+ for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
+ supportedLocales.add(Locale.forLanguageTag(langTag));
+ }
+ return new ModelFile(
+ file,
+ version,
+ supportedLocales,
+ supportedLocalesStr,
+ ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
+ return null;
+ } finally {
+ maybeCloseAndLogError(modelFd);
+ }
+ }
+
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
+ }
+ }
+ }
+
+ /** Describes TextClassifier model files on disk. */
+ public static final class ModelFile {
+ public static final String LANGUAGE_INDEPENDENT = "*";
+
+ private final File file;
+ private final int version;
+ private final List<Locale> supportedLocales;
+ private final String supportedLocalesStr;
+ private final boolean languageIndependent;
+
+ public ModelFile(
+ File file,
+ int version,
+ List<Locale> supportedLocales,
+ String supportedLocalesStr,
+ boolean languageIndependent) {
+ this.file = Preconditions.checkNotNull(file);
+ this.version = version;
+ this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
+ this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
+ this.languageIndependent = languageIndependent;
+ }
+
+ /** Returns the absolute path to the model file. */
+ public String getPath() {
+ return file.getAbsolutePath();
+ }
+
+ /** Returns a name to use for id generation, effectively the name of the model file. */
+ public String getName() {
+ return file.getName();
+ }
+
+ /** Returns the version tag in the model's metadata. */
+ public int getVersion() {
+ return version;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ }
+
+ /** Returns an immutable lists of supported locales. */
+ public List<Locale> getSupportedLocales() {
+ return Collections.unmodifiableList(supportedLocales);
+ }
+
+ /** Returns the original supported locals string read from the model file. */
+ public String getSupportedLocalesStr() {
+ return supportedLocalesStr;
+ }
+
+ /** Returns if this model file is preferred to the given one. */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!languageIndependent && model.languageIndependent) {
+ return true;
+ }
+ if (languageIndependent && !model.languageIndependent) {
+ return false;
+ }
+
+ // A higher-version model is preferred.
+ if (version > model.getVersion()) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getPath());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other instanceof ModelFile) {
+ final ModelFile otherModel = (ModelFile) other;
+ return TextUtils.equals(getPath(), otherModel.getPath());
+ }
+ return false;
+ }
+
+ public ModelInfo toModelInfo() {
+ return new ModelInfo(getVersion(), supportedLocalesStr);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ Locale.US,
+ "ModelFile { path=%s name=%s version=%d locales=%s }",
+ getPath(),
+ getName(),
+ version,
+ supportedLocalesStr);
+ }
+
+ public static ImmutableList<Optional<ModelInfo>> toModelInfos(
+ Optional<ModelFile>... modelFiles) {
+ return Arrays.stream(modelFiles)
+ .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
+ .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
new file mode 100644
index 0000000..5c028ef
--- /dev/null
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -0,0 +1,834 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static java.util.stream.Collectors.toCollection;
+
+import android.app.PendingIntent;
+import android.app.RemoteAction;
+import android.content.Context;
+import android.content.Intent;
+import android.icu.util.ULocale;
+import android.os.Bundle;
+import android.os.LocaleList;
+import android.os.Looper;
+import android.os.ParcelFileDescriptor;
+import android.util.ArrayMap;
+import android.view.View.OnClickListener;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationSessionId;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextSelection;
+import androidx.annotation.GuardedBy;
+import androidx.annotation.WorkerThread;
+import androidx.core.util.Pair;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
+import com.android.textclassifier.common.logging.ResultIdUtils;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.common.statsd.GenerateLinksLogger;
+import com.android.textclassifier.common.statsd.SelectionEventConverter;
+import com.android.textclassifier.common.statsd.TextClassificationSessionIdConverter;
+import com.android.textclassifier.common.statsd.TextClassifierEventConverter;
+import com.android.textclassifier.common.statsd.TextClassifierEventLogger;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.LangIdModel;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.FluentIterable;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.time.ZoneId;
+import java.time.ZonedDateTime;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nullable;
+
+/**
+ * A text classifier that is running locally.
+ *
+ * <p>This class uses machine learning to recognize entities in text. Unless otherwise stated,
+ * methods of this class are blocking operations and should most likely not be called on the UI
+ * thread.
+ */
+final class TextClassifierImpl {
+
+ private static final String TAG = "TextClassifierImpl";
+
+ private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
+ // Annotator
+ private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
+ "textclassifier\\.(.*)\\.model";
+ private static final File ANNOTATOR_UPDATED_MODEL_FILE =
+ new File("/data/misc/textclassifier/textclassifier.model");
+
+ // LangIdModel
+ private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
+ private static final File UPDATED_LANG_ID_MODEL_FILE =
+ new File("/data/misc/textclassifier/lang_id.model");
+
+ // Actions
+ private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
+ "actions_suggestions\\.(.*)\\.model";
+ private static final File UPDATED_ACTIONS_MODEL =
+ new File("/data/misc/textclassifier/actions_suggestions.model");
+
+ private final Context context;
+ private final TextClassifier fallback;
+ private final GenerateLinksLogger generateLinksLogger;
+
+ private final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile annotatorModelInUse;
+
+ @GuardedBy("lock")
+ private AnnotatorModel annotatorImpl;
+
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile langIdModelInUse;
+
+ @GuardedBy("lock")
+ private LangIdModel langIdImpl;
+
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile actionModelInUse;
+
+ @GuardedBy("lock")
+ private ActionsSuggestionsModel actionsImpl;
+
+ private final TextClassifierEventLogger textClassifierEventLogger =
+ new TextClassifierEventLogger();
+
+ private final TextClassifierSettings settings;
+
+ private final ModelFileManager annotatorModelFileManager;
+ private final ModelFileManager langIdModelFileManager;
+ private final ModelFileManager actionsModelFileManager;
+ private final TemplateIntentFactory templateIntentFactory;
+
+ TextClassifierImpl(Context context, TextClassifierSettings settings, TextClassifier fallback) {
+ this.context = Preconditions.checkNotNull(context);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ this.settings = Preconditions.checkNotNull(settings);
+ generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
+ annotatorModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
+ ANNOTATOR_UPDATED_MODEL_FILE,
+ AnnotatorModel::getVersion,
+ AnnotatorModel::getLocales));
+ langIdModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_LANG_ID_MODEL_FILE,
+ LangIdModel::getVersion,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
+ actionsModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_ACTIONS_MODEL,
+ ActionsSuggestionsModel::getVersion,
+ ActionsSuggestionsModel::getLocales));
+
+ templateIntentFactory = new TemplateIntentFactory();
+ }
+
+ TextClassifierImpl(Context context, TextClassifierSettings settings) {
+ this(context, settings, TextClassifier.NO_OP);
+ }
+
+ @WorkerThread
+ TextSelection suggestSelection(TextSelection.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final Optional<LangIdModel> langIdModel = getLangIdImpl();
+ final String detectLanguageTags =
+ String.join(",", detectLanguageTags(langIdModel, request.getText()));
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final int[] startEnd =
+ annotatorImpl.suggestSelection(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
+ final int start = startEnd[0];
+ final int end = startEnd[1];
+ if (start < end
+ && start >= 0
+ && end <= string.length()
+ && start <= request.getStartIndex()
+ && end >= request.getEndIndex()) {
+ final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final AnnotatorModel.ClassificationResult[] results =
+ annotatorImpl.classifyText(
+ string,
+ start,
+ end,
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ detectLanguageTags),
+ // Passing null here to suppress intent generation
+ // TODO: Use an explicit flag to suppress it.
+ /* appContext */ null,
+ /* deviceLocales */ null);
+ final int size = results.length;
+ for (int i = 0; i < size; i++) {
+ tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
+ }
+ final String resultId =
+ createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
+ return tsBuilder.setId(resultId).build();
+ } else {
+ // We can not trust the result. Log the issue and ignore the result.
+ TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
+ }
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t);
+ }
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.suggestSelection(request);
+ }
+
+ @WorkerThread
+ TextClassification classifyText(TextClassification.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ Optional<LangIdModel> langId = getLangIdImpl();
+ List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final ZonedDateTime refTime =
+ request.getReferenceTime() != null
+ ? request.getReferenceTime()
+ : ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel.ClassificationResult[] results =
+ getAnnotatorImpl(request.getDefaultLocales())
+ .classifyText(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ String.join(",", detectLanguageTags),
+ AnnotatorModel.AnnotationUsecase.SMART.getValue(),
+ LocaleList.getDefault().toLanguageTags()),
+ context,
+ getResourceLocalesString());
+ if (results.length > 0) {
+ return createClassificationResult(
+ results, string, request.getStartIndex(), request.getEndIndex(), langId);
+ }
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting text classification info.", t);
+ }
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.classifyText(request);
+ }
+
+ @WorkerThread
+ TextLinks generateLinks(TextLinks.Request request) {
+ Preconditions.checkNotNull(request);
+ Preconditions.checkArgument(
+ request.getText().length() <= getMaxGenerateLinksTextLength(),
+ "text.length() cannot be greater than %s",
+ getMaxGenerateLinksTextLength());
+ checkMainThread();
+
+ final String textString = request.getText().toString();
+ final TextLinks.Builder builder = new TextLinks.Builder(textString);
+
+ try {
+ final long startTimeMs = System.currentTimeMillis();
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final Collection<String> entitiesToIdentify =
+ request.getEntityConfig() != null
+ ? request
+ .getEntityConfig()
+ .resolveEntityListModifications(
+ getEntitiesForHints(request.getEntityConfig().getHints()))
+ : settings.getEntityListDefault();
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ Optional<LangIdModel> langId = getLangIdImpl();
+ ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final boolean isSerializedEntityDataEnabled =
+ ExtrasUtils.isSerializedEntityDataEnabled(request);
+ final AnnotatorModel.AnnotatedSpan[] annotations =
+ annotatorImpl.annotate(
+ textString,
+ new AnnotatorModel.AnnotationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ String.join(",", detectLanguageTags),
+ entitiesToIdentify,
+ AnnotatorModel.AnnotationUsecase.SMART.getValue(),
+ isSerializedEntityDataEnabled));
+ for (AnnotatorModel.AnnotatedSpan span : annotations) {
+ final AnnotatorModel.ClassificationResult[] results = span.getClassification();
+ if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
+ continue;
+ }
+ final Map<String, Float> entityScores = new ArrayMap<>();
+ for (int i = 0; i < results.length; i++) {
+ entityScores.put(results[i].getCollection(), results[i].getScore());
+ }
+ Bundle extras = new Bundle();
+ if (isSerializedEntityDataEnabled) {
+ ExtrasUtils.putEntities(extras, results);
+ }
+ builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
+ }
+ final TextLinks links = builder.build();
+ final long endTimeMs = System.currentTimeMillis();
+ final String callingPackageName =
+ request.getCallingPackageName() == null
+ ? context.getPackageName() // local (in process) TC.
+ : request.getCallingPackageName();
+ Optional<ModelInfo> annotatorModelInfo;
+ Optional<ModelInfo> langIdModelInfo;
+ synchronized (lock) {
+ annotatorModelInfo =
+ Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
+ langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
+ }
+ generateLinksLogger.logGenerateLinks(
+ request.getText(),
+ links,
+ callingPackageName,
+ endTimeMs - startTimeMs,
+ annotatorModelInfo,
+ langIdModelInfo);
+ return links;
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting links info.", t);
+ }
+ return fallback.generateLinks(request);
+ }
+
+ int getMaxGenerateLinksTextLength() {
+ return settings.getGenerateLinksMaxTextLength();
+ }
+
+ private Collection<String> getEntitiesForHints(Collection<String> hints) {
+ final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
+ final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
+
+ // Use the default if there is no hint, or conflicting ones.
+ final boolean useDefault = editable == notEditable;
+ if (useDefault) {
+ return settings.getEntityListDefault();
+ } else if (editable) {
+ return settings.getEntityListEditable();
+ } else { // notEditable
+ return settings.getEntityListNotEditable();
+ }
+ }
+
+ void onSelectionEvent(SelectionEvent event) {
+ TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
+ if (textClassifierEvent == null) {
+ return;
+ }
+ onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
+ }
+
+ void onTextClassifierEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ textClassifierEventLogger.writeEvent(
+ TextClassificationSessionIdConverter.fromPlatform(sessionId),
+ TextClassifierEventConverter.fromPlatform(event));
+ }
+
+ TextLanguage detectLanguage(TextLanguage.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ Optional<LangIdModel> langIdImpl = getLangIdImpl();
+ if (langIdImpl.isPresent()) {
+ final LangIdModel.LanguageResult[] langResults =
+ langIdImpl.get().detectLanguages(request.getText().toString());
+ for (int i = 0; i < langResults.length; i++) {
+ builder.putLocale(
+ ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore());
+ }
+ return builder.build();
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error detecting text language.", t);
+ }
+ return fallback.detectLanguage(request);
+ }
+
+ ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ if (actionsImpl == null) {
+ // Actions model is optional, fallback if it is not available.
+ return fallback.suggestConversationActions(request);
+ }
+ Optional<LangIdModel> langId = getLangIdImpl();
+ ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ request.getConversation(), text -> detectLanguageTags(langId, text));
+ if (nativeMessages.length == 0) {
+ return fallback.suggestConversationActions(request);
+ }
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages);
+
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ actionsImpl.suggestActionsWithIntents(
+ nativeConversation,
+ null,
+ context,
+ getResourceLocalesString(),
+ getAnnotatorImpl(LocaleList.getDefault()));
+ return createConversationActionResult(request, nativeSuggestions);
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting conversation actions.", t);
+ }
+ return fallback.suggestConversationActions(request);
+ }
+
+ /**
+ * Returns the {@link ConversationAction} result, with a non-null extras.
+ *
+ * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
+ * non-null component name is in the extras.
+ */
+ private ConversationActions createConversationActionResult(
+ ConversationActions.Request request,
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
+ Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
+ List<ConversationAction> conversationActions = new ArrayList<>();
+ for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
+ String actionType = nativeSuggestion.getActionType();
+ if (!expectedTypes.contains(actionType)) {
+ continue;
+ }
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ context, templateIntentFactory, nativeSuggestion);
+ RemoteAction remoteAction = null;
+ Bundle extras = new Bundle();
+ if (labeledIntentResult != null) {
+ remoteAction = labeledIntentResult.remoteAction.toRemoteAction();
+ ExtrasUtils.putActionIntent(
+ extras, stripPackageInfoFromIntent(labeledIntentResult.resolvedIntent));
+ }
+ ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
+ ExtrasUtils.putEntitiesExtras(
+ extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
+ conversationActions.add(
+ new ConversationAction.Builder(actionType)
+ .setConfidenceScore(nativeSuggestion.getScore())
+ .setTextReply(nativeSuggestion.getResponseText())
+ .setAction(remoteAction)
+ .setExtras(extras)
+ .build());
+ }
+ conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
+ if (request.getMaxSuggestions() >= 0
+ && conversationActions.size() > request.getMaxSuggestions()) {
+ conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
+ }
+ synchronized (lock) {
+ String resultId =
+ ActionsSuggestionsHelper.createResultId(
+ context,
+ request.getConversation(),
+ Optional.fromNullable(actionModelInUse),
+ Optional.fromNullable(annotatorModelInUse),
+ Optional.fromNullable(langIdModelInUse));
+ return new ConversationActions(conversationActions, resultId);
+ }
+ }
+
+ private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
+ List<String> defaultActionTypes =
+ request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
+ ? settings.getNotificationConversationActionTypes()
+ : settings.getInAppConversationActionTypes();
+ return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
+ }
+
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
+ synchronized (lock) {
+ localeList = localeList == null ? LocaleList.getDefault() : localeList;
+ final ModelFileManager.ModelFile bestModel =
+ annotatorModelFileManager.findBestModelFile(localeList);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
+ }
+ if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ // The current annotator model may be still used by another thread / model.
+ // Do not call close() here, and let the GC to clean it up when no one else
+ // is using it.
+ annotatorImpl = new AnnotatorModel(pfd.getFd());
+ Optional<LangIdModel> langIdModel = getLangIdImpl();
+ if (langIdModel.isPresent()) {
+ annotatorImpl.setLangIdModel(langIdModel.get());
+ }
+ annotatorModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return annotatorImpl;
+ }
+ }
+
+ private Optional<LangIdModel> getLangIdImpl() {
+ synchronized (lock) {
+ final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null);
+ if (bestModel == null) {
+ return Optional.absent();
+ }
+ if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd;
+ try {
+ pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "Failed to open the LangID model file", e);
+ return Optional.absent();
+ }
+ try {
+ if (pfd != null) {
+ langIdImpl = new LangIdModel(pfd.getFd());
+ langIdModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return Optional.of(langIdImpl);
+ }
+ }
+
+ @Nullable
+ private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ synchronized (lock) {
+ // TODO: Use LangID to determine the locale we should use here?
+ final ModelFileManager.ModelFile bestModel =
+ actionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ if (bestModel == null) {
+ return null;
+ }
+ if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd == null) {
+ TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
+ return null;
+ }
+ actionsImpl = new ActionsSuggestionsModel(pfd.getFd());
+ actionModelInUse = bestModel;
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return actionsImpl;
+ }
+ }
+
+ private String createAnnotatorId(String text, int start, int end) {
+ synchronized (lock) {
+ return ResultIdUtils.createId(
+ context,
+ text,
+ start,
+ end,
+ ModelFile.toModelInfos(
+ Optional.fromNullable(annotatorModelInUse), Optional.fromNullable(langIdModelInUse)));
+ }
+ }
+
+ private static String concatenateLocales(@Nullable LocaleList locales) {
+ return (locales == null) ? "" : locales.toLanguageTags();
+ }
+
+ private TextClassification createClassificationResult(
+ AnnotatorModel.ClassificationResult[] classifications,
+ String text,
+ int start,
+ int end,
+ Optional<LangIdModel> langId) {
+ final String classifiedText = text.substring(start, end);
+ final TextClassification.Builder builder =
+ new TextClassification.Builder().setText(classifiedText);
+
+ final int typeCount = classifications.length;
+ AnnotatorModel.ClassificationResult highestScoringResult =
+ typeCount > 0 ? classifications[0] : null;
+ for (int i = 0; i < typeCount; i++) {
+ builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore());
+ if (classifications[i].getScore() > highestScoringResult.getScore()) {
+ highestScoringResult = classifications[i];
+ }
+ }
+
+ boolean isPrimaryAction = true;
+ final ImmutableList<LabeledIntent> labeledIntents =
+ highestScoringResult == null
+ ? ImmutableList.of()
+ : templateIntentFactory.create(highestScoringResult.getRemoteActionTemplates());
+ final LabeledIntent.TitleChooser titleChooser =
+ (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
+
+ ArrayList<Intent> actionIntents = new ArrayList<>();
+ for (LabeledIntent labeledIntent : labeledIntents) {
+ final LabeledIntent.Result result = labeledIntent.resolve(context, titleChooser);
+ if (result == null) {
+ continue;
+ }
+
+ final Intent intent = result.resolvedIntent;
+ final RemoteAction action = result.remoteAction.toRemoteAction();
+ if (isPrimaryAction) {
+ // For O backwards compatibility, the first RemoteAction is also written to the
+ // legacy API fields.
+ builder.setIcon(action.getIcon().loadDrawable(context));
+ builder.setLabel(action.getTitle().toString());
+ builder.setIntent(intent);
+ builder.setOnClickListener(
+ createIntentOnClickListener(
+ createPendingIntent(context, intent, labeledIntent.requestCode)));
+ isPrimaryAction = false;
+ }
+ builder.addAction(action);
+ actionIntents.add(intent);
+ }
+ Bundle extras = new Bundle();
+ Optional<Bundle> foreignLanguageExtra =
+ langId
+ .transform(model -> maybeCreateExtrasForTranslate(actionIntents, model))
+ .or(Optional.<Bundle>absent());
+ if (foreignLanguageExtra.isPresent()) {
+ ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get());
+ }
+ if (actionIntents.stream().anyMatch(Objects::nonNull)) {
+ ArrayList<Intent> strippedIntents =
+ actionIntents.stream()
+ .map(TextClassifierImpl::stripPackageInfoFromIntent)
+ .collect(toCollection(ArrayList::new));
+ ExtrasUtils.putActionsIntents(extras, strippedIntents);
+ }
+ ExtrasUtils.putEntities(extras, classifications);
+ builder.setExtras(extras);
+ String resultId = createAnnotatorId(text, start, end);
+ return builder.setId(resultId).build();
+ }
+
+ private static OnClickListener createIntentOnClickListener(final PendingIntent intent) {
+ Preconditions.checkNotNull(intent);
+ return v -> {
+ try {
+ intent.send();
+ } catch (PendingIntent.CanceledException e) {
+ TcLog.e(TAG, "Error sending PendingIntent", e);
+ }
+ };
+ }
+
+ private static Optional<Bundle> maybeCreateExtrasForTranslate(
+ List<Intent> intents, LangIdModel langId) {
+ Optional<Intent> translateIntent =
+ FluentIterable.from(intents)
+ .filter(Objects::nonNull)
+ .filter(intent -> Intent.ACTION_TRANSLATE.equals(intent.getAction()))
+ .first();
+ if (!translateIntent.isPresent()) {
+ return Optional.absent();
+ }
+ Pair<String, Float> topLanguageWithScore = ExtrasUtils.getTopLanguage(translateIntent.get());
+ if (topLanguageWithScore == null) {
+ return Optional.absent();
+ }
+ return Optional.of(
+ ExtrasUtils.createForeignLanguageExtra(
+ topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion()));
+ }
+
+ private ImmutableList<String> detectLanguageTags(
+ Optional<LangIdModel> langId, CharSequence text) {
+ return langId
+ .transform(
+ model -> {
+ float threshold = getLangIdThreshold(model);
+ EntityConfidence languagesConfidence = detectLanguages(model, text, threshold);
+ return ImmutableList.copyOf(languagesConfidence.getEntities());
+ })
+ .or(ImmutableList.of());
+ }
+
+ /**
+ * Detects languages for the specified text. Only returns languages with score that is higher than
+ * or equal to the specified threshold.
+ */
+ private static EntityConfidence detectLanguages(
+ LangIdModel langId, CharSequence text, float threshold) {
+ final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
+ final Map<String, Float> languagesMap = new ArrayMap<>();
+ for (LangIdModel.LanguageResult langResult : langResults) {
+ if (langResult.getScore() >= threshold) {
+ languagesMap.put(langResult.getLanguage(), langResult.getScore());
+ }
+ }
+ return new EntityConfidence(languagesMap);
+ }
+
+ private float getLangIdThreshold(LangIdModel langId) {
+ return settings.getLangIdThresholdOverride() >= 0
+ ? settings.getLangIdThresholdOverride()
+ : langId.getLangIdThreshold();
+ }
+
+ void dump(IndentingPrintWriter printWriter) {
+ synchronized (lock) {
+ printWriter.println("TextClassifierImpl:");
+ printWriter.increaseIndent();
+ printWriter.println("Annotator model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("LangID model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("Actions model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.printPair("mFallback", fallback);
+ printWriter.decreaseIndent();
+ printWriter.println();
+ settings.dump(printWriter);
+ }
+ }
+
+ /** Returns the locales string for the current resources configuration. */
+ private String getResourceLocalesString() {
+ try {
+ return context.getResources().getConfiguration().getLocales().toLanguageTags();
+ } catch (NullPointerException e) {
+
+ // NPE is unexpected. Erring on the side of caution.
+ return LocaleList.getDefault().toLanguageTags();
+ }
+ }
+
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
+ }
+ }
+
+ private static void checkMainThread() {
+ if (Looper.myLooper() == Looper.getMainLooper()) {
+ TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
+ }
+ }
+
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
+
+ @Nullable
+ private static Intent stripPackageInfoFromIntent(@Nullable Intent intent) {
+ if (intent == null) {
+ return null;
+ }
+ Intent strippedIntent = new Intent(intent);
+ strippedIntent.setPackage(null);
+ strippedIntent.setComponent(null);
+ return strippedIntent;
+ }
+}
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
new file mode 100644
index 0000000..3decd38
--- /dev/null
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -0,0 +1,324 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.provider.DeviceConfig;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.TextClassifier;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import javax.annotation.Nullable;
+
+/**
+ * TextClassifier specific settings.
+ *
+ * <p>Currently, this class does not guarantee co-diverted flags are updated atomically.
+ *
+ * <p>Example of setting the values for testing.
+ *
+ * <pre>
+ * adb shell cmd device_config put textclassifier system_textclassifier_enabled true
+ * </pre>
+ *
+ * @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
+ */
+public final class TextClassifierSettings {
+ private static final String DELIMITER = ":";
+
+ /** Whether the user language profile feature is enabled. */
+ private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
+ /** Max length of text that suggestSelection can accept. */
+ @VisibleForTesting
+ static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH = "suggest_selection_max_range_length";
+ /** Max length of text that classifyText can accept. */
+ private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
+ /** Max length of text that generateLinks can accept. */
+ private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
+ /** Sampling rate for generateLinks logging. */
+ private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
+ /**
+ * Extra count that is added to some languages, e.g. system languages, when deducing the frequent
+ * languages in {@link
+ * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
+ */
+
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * hint is not given.
+ */
+ @VisibleForTesting static final String ENTITY_LIST_DEFAULT = "entity_list_default";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in a not editable UI widget.
+ */
+ private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in an editable UI widget.
+ */
+ private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in an app.
+ */
+ private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "in_app_conversation_action_types_default";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in a notification.
+ */
+ private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "notification_conversation_action_types_default";
+ /** Threshold to accept a suggested language from LangID model. */
+ @VisibleForTesting static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
+ /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
+ @VisibleForTesting
+ static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
+ /** Whether to enable "translate" action in classifyText. */
+ private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
+ "translate_in_classification_enabled";
+ /**
+ * Whether to detect the languages of the text in request by using langId for the native model.
+ */
+ private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
+ "detect_languages_from_text_enabled";
+ /**
+ * A colon(:) separated string that specifies the configuration to use when including surrounding
+ * context text in language detection queries.
+ *
+ * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
+ *
+ * <p>e.g. 20:1.0:0.4
+ *
+ * <p>Accept all text lengths with minimumTextSize=0
+ *
+ * <p>Reject all text less than minimumTextSize with penalizeRatio=0
+ *
+ * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
+ */
+ @VisibleForTesting static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
+ /** Default threshold to translate the language of the context the user selects */
+ private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
+
+ // Sync this with ConversationAction.TYPE_ADD_CONTACT;
+ public static final String TYPE_ADD_CONTACT = "add_contact";
+ // Sync this with ConversationAction.COPY;
+ public static final String TYPE_COPY = "copy";
+
+ private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
+ private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
+
+ private static final ImmutableList<String> ENTITY_LIST_DEFAULT_VALUE =
+ ImmutableList.of(
+ TextClassifier.TYPE_ADDRESS,
+ TextClassifier.TYPE_EMAIL,
+ TextClassifier.TYPE_PHONE,
+ TextClassifier.TYPE_URL,
+ TextClassifier.TYPE_DATE,
+ TextClassifier.TYPE_DATE_TIME,
+ TextClassifier.TYPE_FLIGHT_NUMBER);
+ private static final ImmutableList<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
+ ImmutableList.of(
+ ConversationAction.TYPE_TEXT_REPLY,
+ ConversationAction.TYPE_CREATE_REMINDER,
+ ConversationAction.TYPE_CALL_PHONE,
+ ConversationAction.TYPE_OPEN_URL,
+ ConversationAction.TYPE_SEND_EMAIL,
+ ConversationAction.TYPE_SEND_SMS,
+ ConversationAction.TYPE_TRACK_FLIGHT,
+ ConversationAction.TYPE_VIEW_CALENDAR,
+ ConversationAction.TYPE_VIEW_MAP,
+ TYPE_ADD_CONTACT,
+ TYPE_COPY);
+ /**
+ * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
+ *
+ * @see EntityConfidence
+ */
+ private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
+
+ private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
+
+ private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
+ private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
+ private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
+ private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+
+ public int getSuggestSelectionMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getClassifyTextMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksMaxTextLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_MAX_TEXT_LENGTH,
+ GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksLogSampleRate() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_LOG_SAMPLE_RATE,
+ GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
+ public List<String> getEntityListDefault() {
+ return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListNotEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getInAppConversationActionTypes() {
+ return getDeviceConfigStringList(
+ IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public List<String> getNotificationConversationActionTypes() {
+ return getDeviceConfigStringList(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public float getLangIdThresholdOverride() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ LANG_ID_THRESHOLD_OVERRIDE,
+ LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
+ }
+
+ public float getTranslateActionThreshold() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_ACTION_THRESHOLD,
+ TRANSLATE_ACTION_THRESHOLD_DEFAULT);
+ }
+
+ public boolean isUserLanguageProfileEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ USER_LANGUAGE_PROFILE_ENABLED,
+ USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
+ }
+
+ public boolean isTemplateIntentFactoryEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TEMPLATE_INTENT_FACTORY_ENABLED,
+ TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+ }
+
+ public boolean isTranslateInClassificationEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
+ }
+
+ public boolean isDetectLanguagesFromTextEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
+ }
+
+ public float[] getLangIdContextSettings() {
+ return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
+ }
+
+ void dump(IndentingPrintWriter pw) {
+ pw.println("TextClassifierSettings:");
+ pw.increaseIndent();
+ pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
+ pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
+ pw.printPair("entity_list_default", getEntityListDefault());
+ pw.printPair("entity_list_editable", getEntityListEditable());
+ pw.printPair("entity_list_not_editable", getEntityListNotEditable());
+ pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
+ pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
+ pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
+ pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
+ pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
+ pw.printPair("translate_action_threshold", getTranslateActionThreshold());
+ pw.printPair(
+ "notification_conversation_action_types_default", getNotificationConversationActionTypes());
+ pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
+ pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
+ pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
+ pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
+ pw.decreaseIndent();
+ }
+
+ private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
+ if (listStr != null) {
+ return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
+ }
+ return defaultValue;
+ }
+
+ private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
+ if (arrayStr != null) {
+ final List<String> split = Splitter.onPattern(DELIMITER).splitToList(arrayStr);
+ if (split.size() != defaultValue.length) {
+ return defaultValue;
+ }
+ final float[] result = new float[split.size()];
+ for (int i = 0; i < split.size(); i++) {
+ try {
+ result[i] = Float.parseFloat(split.get(i));
+ } catch (NumberFormatException e) {
+ return defaultValue;
+ }
+ }
+ return result;
+ } else {
+ return defaultValue;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/base/LocaleCompat.java b/java/src/com/android/textclassifier/common/base/LocaleCompat.java
new file mode 100644
index 0000000..baaaf67
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/base/LocaleCompat.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.base;
+
+import android.content.Context;
+import android.os.Build;
+import java.util.Locale;
+
+/** Helper for accessing locale related stuff that works across different platform versions. */
+public final class LocaleCompat {
+
+ private LocaleCompat() {}
+
+ /**
+ * Returns a well-formed IETF BCP 47 language tag representing this locale. In older platforms,
+ * only the ISO 639 language code will be returned.
+ *
+ * @see Locale#toLanguageTag()
+ */
+ public static String toLanguageTag(Locale locale) {
+ if (Build.VERSION.SDK_INT >= 24) {
+ return Api24Impl.toLanguageTag(locale);
+ }
+ return ApiBaseImpl.toLanguageTag(locale);
+ }
+
+ /** Returns the language tags in string for the current resources configuration. */
+ public static String getResourceLanguageTags(Context context) {
+ if (Build.VERSION.SDK_INT >= 24) {
+ return Api24Impl.getResourceLanguageTags(context);
+ } else if (Build.VERSION.SDK_INT >= 21) {
+ return Api21Impl.getResourceLanguageTags(context);
+ }
+ return ApiBaseImpl.getResourceLanguageTags(context);
+ }
+
+ private static class Api24Impl {
+ private Api24Impl() {}
+
+ static String toLanguageTag(Locale locale) {
+ return locale.toLanguageTag();
+ }
+
+ static String getResourceLanguageTags(Context context) {
+ return context.getResources().getConfiguration().getLocales().toLanguageTags();
+ }
+ }
+
+ private static class Api21Impl {
+ private Api21Impl() {}
+
+ static String getResourceLanguageTags(Context context) {
+ return context.getResources().getConfiguration().locale.toLanguageTag();
+ }
+ }
+
+ private static class ApiBaseImpl {
+ private ApiBaseImpl() {}
+
+ static String toLanguageTag(Locale locale) {
+ return locale.getLanguage();
+ }
+
+ static String getResourceLanguageTags(Context context) {
+ return context.getResources().getConfiguration().locale.getLanguage();
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/base/TcLog.java b/java/src/com/android/textclassifier/common/base/TcLog.java
new file mode 100644
index 0000000..87f1187
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/base/TcLog.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.base;
+
+/**
+ * Logging for android.view.textclassifier package.
+ *
+ * <p>To enable full log:
+ *
+ * <ul>
+ * <li>adb shell setprop log.tag.androidtc VERBOSE
+ * <li>adb shell stop && adb shell start
+ * </ul>
+ */
+public final class TcLog {
+ private static final boolean USE_TC_TAG = true;
+ public static final String TAG = "androidtc";
+
+ /** true: Enables full logging. false: Limits logging to debug level. */
+ public static final boolean ENABLE_FULL_LOGGING =
+ android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+
+ private TcLog() {}
+
+ public static void v(String tag, String msg) {
+ if (ENABLE_FULL_LOGGING) {
+ android.util.Log.v(getTag(tag), msg);
+ }
+ }
+
+ public static void d(String tag, String msg) {
+ android.util.Log.d(getTag(tag), msg);
+ }
+
+ public static void w(String tag, String msg) {
+ android.util.Log.w(getTag(tag), msg);
+ }
+
+ public static void e(String tag, String msg, Throwable tr) {
+ android.util.Log.e(getTag(tag), msg, tr);
+ }
+
+ private static String getTag(String customTag) {
+ return USE_TC_TAG ? TAG : customTag;
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/intent/LabeledIntent.java b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
new file mode 100644
index 0000000..b56d0bb
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/intent/LabeledIntent.java
@@ -0,0 +1,219 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.intent;
+
+import android.app.PendingIntent;
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.pm.ActivityInfo;
+import android.content.pm.PackageManager;
+import android.content.pm.ResolveInfo;
+import android.text.TextUtils;
+import androidx.annotation.DrawableRes;
+import androidx.core.app.RemoteActionCompat;
+import androidx.core.content.ContextCompat;
+import androidx.core.graphics.drawable.IconCompat;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.base.Preconditions;
+import javax.annotation.Nullable;
+
+/** Helper class to store the information from which RemoteActions are built. */
+public final class LabeledIntent {
+ private static final String TAG = "LabeledIntent";
+ public static final int DEFAULT_REQUEST_CODE = 0;
+ private static final TitleChooser DEFAULT_TITLE_CHOOSER =
+ (labeledIntent, resolveInfo) -> {
+ if (!TextUtils.isEmpty(labeledIntent.titleWithEntity)) {
+ return labeledIntent.titleWithEntity;
+ }
+ return labeledIntent.titleWithoutEntity;
+ };
+
+ @Nullable public final String titleWithoutEntity;
+ @Nullable public final String titleWithEntity;
+ public final String description;
+ @Nullable public final String descriptionWithAppName;
+ // Do not update this intent.
+ public final Intent intent;
+ public final int requestCode;
+
+ /**
+ * Initializes a LabeledIntent.
+ *
+ * <p>NOTE: {@code requestCode} is required to not be {@link #DEFAULT_REQUEST_CODE} if
+ * distinguishing info (e.g. the classified text) is represented in intent extras only. In such
+ * circumstances, the request code should represent the distinguishing info (e.g. by generating a
+ * hashcode) so that the generated PendingIntent is (somewhat) unique. To be correct, the
+ * PendingIntent should be definitely unique but we try a best effort approach that avoids
+ * spamming the system with PendingIntents.
+ */
+ // TODO: Fix the issue mentioned above so the behaviour is correct.
+ public LabeledIntent(
+ @Nullable String titleWithoutEntity,
+ @Nullable String titleWithEntity,
+ String description,
+ @Nullable String descriptionWithAppName,
+ Intent intent,
+ int requestCode) {
+ if (TextUtils.isEmpty(titleWithEntity) && TextUtils.isEmpty(titleWithoutEntity)) {
+ throw new IllegalArgumentException(
+ "titleWithEntity and titleWithoutEntity should not be both null");
+ }
+ this.titleWithoutEntity = titleWithoutEntity;
+ this.titleWithEntity = titleWithEntity;
+ this.description = Preconditions.checkNotNull(description);
+ this.descriptionWithAppName = descriptionWithAppName;
+ this.intent = Preconditions.checkNotNull(intent);
+ this.requestCode = requestCode;
+ }
+
+ /**
+ * Return the resolved result.
+ *
+ * @param context the context to resolve the result's intent and action
+ * @param titleChooser for choosing an action title
+ */
+ @Nullable
+ public Result resolve(Context context, @Nullable TitleChooser titleChooser) {
+ final PackageManager pm = context.getPackageManager();
+ final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
+
+ if (resolveInfo == null || resolveInfo.activityInfo == null) {
+ TcLog.w(TAG, "resolveInfo or activityInfo is null");
+ return null;
+ }
+ if (!hasPermission(context, resolveInfo.activityInfo)) {
+ TcLog.d(TAG, "No permission to access: " + resolveInfo.activityInfo);
+ return null;
+ }
+
+ final String packageName = resolveInfo.activityInfo.packageName;
+ final String className = resolveInfo.activityInfo.name;
+ if (packageName == null || className == null) {
+ TcLog.w(TAG, "packageName or className is null");
+ return null;
+ }
+ Intent resolvedIntent = new Intent(intent);
+ boolean shouldShowIcon = false;
+ IconCompat icon = null;
+ if (!"android".equals(packageName)) {
+ // We only set the component name when the package name is not resolved to "android"
+ // to workaround a bug that explicit intent with component name == ResolverActivity
+ // can't be launched on keyguard.
+ resolvedIntent.setComponent(new ComponentName(packageName, className));
+ if (resolveInfo.activityInfo.getIconResource() != 0) {
+ icon =
+ createIconFromPackage(context, packageName, resolveInfo.activityInfo.getIconResource());
+ shouldShowIcon = true;
+ }
+ }
+ if (icon == null) {
+ // RemoteAction requires that there be an icon.
+ icon = IconCompat.createWithResource(context, android.R.drawable.ic_menu_more);
+ }
+ final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
+ titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
+ CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
+ if (TextUtils.isEmpty(title)) {
+ TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
+ title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
+ }
+ final RemoteActionCompat action =
+ new RemoteActionCompat(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ action.setShouldShowIcon(shouldShowIcon);
+ return new Result(resolvedIntent, action);
+ }
+
+ private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (!TextUtils.isEmpty(descriptionWithAppName)) {
+ // Example string format of descriptionWithAppName: "Use %1$s to open map".
+ String applicationName = getApplicationName(resolveInfo, packageManager);
+ if (!TextUtils.isEmpty(applicationName)) {
+ return String.format(descriptionWithAppName, applicationName);
+ }
+ }
+ return description;
+ }
+
+ @Nullable
+ private static IconCompat createIconFromPackage(
+ Context context, String packageName, @DrawableRes int iconRes) {
+ try {
+ Context packageContext = context.createPackageContext(packageName, 0);
+ return IconCompat.createWithResource(packageContext, iconRes);
+ } catch (PackageManager.NameNotFoundException e) {
+ TcLog.e(TAG, "createIconFromPackage: failed to create package context", e);
+ }
+ return null;
+ }
+
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
+
+ @Nullable
+ private static String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (resolveInfo.activityInfo == null) {
+ return null;
+ }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return null;
+ }
+ if (resolveInfo.activityInfo.applicationInfo == null) {
+ return null;
+ }
+ return packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo).toString();
+ }
+
+ private static boolean hasPermission(Context context, ActivityInfo info) {
+ if (!info.exported) {
+ return false;
+ }
+ if (info.permission == null) {
+ return true;
+ }
+ return ContextCompat.checkSelfPermission(context, info.permission)
+ == PackageManager.PERMISSION_GRANTED;
+ }
+
+ /** Data class that holds the result. */
+ public static final class Result {
+ public final Intent resolvedIntent;
+ public final RemoteActionCompat remoteAction;
+
+ public Result(Intent resolvedIntent, RemoteActionCompat remoteAction) {
+ this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
+ this.remoteAction = Preconditions.checkNotNull(remoteAction);
+ }
+ }
+
+ /**
+ * An object to choose a title from resolved info. If {@code null} is returned, {@link
+ * #titleWithEntity} will be used if it exists, {@link #titleWithoutEntity} otherwise.
+ */
+ public interface TitleChooser {
+ /**
+ * Picks a title from a {@link LabeledIntent} by looking into resolved info. {@code resolveInfo}
+ * is guaranteed to have a non-null {@code activityInfo}.
+ */
+ @Nullable
+ CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java b/java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java
new file mode 100644
index 0000000..b4f361a
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/intent/TemplateIntentFactory.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.intent;
+
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Bundle;
+import android.text.TextUtils;
+import android.view.textclassifier.TextClassifier;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.android.textclassifier.NamedVariant;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import com.google.common.collect.ImmutableList;
+import javax.annotation.Nullable;
+
+/** Creates intents based on {@link RemoteActionTemplate} objects. */
+public final class TemplateIntentFactory {
+ private static final String TAG = "TemplateIntentFactory";
+
+ /** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
+ public ImmutableList<LabeledIntent> create(
+ @Nullable RemoteActionTemplate[] remoteActionTemplates) {
+ if (remoteActionTemplates == null || remoteActionTemplates.length == 0) {
+ return ImmutableList.of();
+ }
+ final ImmutableList.Builder<LabeledIntent> labeledIntents = ImmutableList.builder();
+ for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
+ if (!isValidTemplate(remoteActionTemplate)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
+ continue;
+ }
+ labeledIntents.add(
+ new LabeledIntent(
+ remoteActionTemplate.titleWithoutEntity,
+ remoteActionTemplate.titleWithEntity,
+ remoteActionTemplate.description,
+ remoteActionTemplate.descriptionWithAppName,
+ createIntent(remoteActionTemplate),
+ remoteActionTemplate.requestCode == null
+ ? LabeledIntent.DEFAULT_REQUEST_CODE
+ : remoteActionTemplate.requestCode));
+ }
+ return labeledIntents.build();
+ }
+
+ private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
+ if (remoteActionTemplate == null) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: is null");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.titleWithEntity)
+ && TextUtils.isEmpty(remoteActionTemplate.titleWithoutEntity)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: title is null");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.description)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: description is null");
+ return false;
+ }
+ if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: package name is set");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.action)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: intent action not set");
+ return false;
+ }
+ return true;
+ }
+
+ private static Intent createIntent(RemoteActionTemplate remoteActionTemplate) {
+ final Intent intent = new Intent(remoteActionTemplate.action);
+ final Uri uri =
+ TextUtils.isEmpty(remoteActionTemplate.data)
+ ? null
+ : Uri.parse(remoteActionTemplate.data).normalizeScheme();
+ final String type =
+ TextUtils.isEmpty(remoteActionTemplate.type)
+ ? null
+ : Intent.normalizeMimeType(remoteActionTemplate.type);
+ intent.setDataAndType(uri, type);
+ intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
+ if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
+ intent.setPackage(remoteActionTemplate.packageName);
+ }
+ if (remoteActionTemplate.category != null) {
+ for (String category : remoteActionTemplate.category) {
+ if (category != null) {
+ intent.addCategory(category);
+ }
+ }
+ }
+ intent.putExtras(nameVariantsToBundle(remoteActionTemplate.extras));
+ // If the template does not have EXTRA_FROM_TEXT_CLASSIFIER, create one to indicate the result
+ // is from the text classifier, so that client can handle the intent differently.
+ if (!intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)) {
+ intent.putExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER, Bundle.EMPTY);
+ }
+ return intent;
+ }
+
+ /** Converts an array of {@link NamedVariant} to a Bundle and returns it. */
+ public static Bundle nameVariantsToBundle(@Nullable NamedVariant[] namedVariants) {
+ if (namedVariants == null) {
+ return Bundle.EMPTY;
+ }
+ Bundle bundle = new Bundle();
+ for (NamedVariant namedVariant : namedVariants) {
+ if (namedVariant == null) {
+ continue;
+ }
+ switch (namedVariant.getType()) {
+ case NamedVariant.TYPE_INT:
+ bundle.putInt(namedVariant.getName(), namedVariant.getInt());
+ break;
+ case NamedVariant.TYPE_LONG:
+ bundle.putLong(namedVariant.getName(), namedVariant.getLong());
+ break;
+ case NamedVariant.TYPE_FLOAT:
+ bundle.putFloat(namedVariant.getName(), namedVariant.getFloat());
+ break;
+ case NamedVariant.TYPE_DOUBLE:
+ bundle.putDouble(namedVariant.getName(), namedVariant.getDouble());
+ break;
+ case NamedVariant.TYPE_BOOL:
+ bundle.putBoolean(namedVariant.getName(), namedVariant.getBool());
+ break;
+ case NamedVariant.TYPE_STRING:
+ bundle.putString(namedVariant.getName(), namedVariant.getString());
+ break;
+ case NamedVariant.TYPE_STRING_ARRAY:
+ bundle.putStringArray(namedVariant.getName(), namedVariant.getStringArray());
+ break;
+ case NamedVariant.TYPE_FLOAT_ARRAY:
+ bundle.putFloatArray(namedVariant.getName(), namedVariant.getFloatArray());
+ break;
+ case NamedVariant.TYPE_INT_ARRAY:
+ bundle.putIntArray(namedVariant.getName(), namedVariant.getIntArray());
+ break;
+ case NamedVariant.TYPE_NAMED_VARIANT_ARRAY:
+ bundle.putBundle(
+ namedVariant.getName(), nameVariantsToBundle(namedVariant.getNamedVariantArray()));
+ break;
+ default:
+ TcLog.w(
+ TAG, "Unsupported type found in nameVariantsToBundle : " + namedVariant.getType());
+ }
+ }
+ return bundle;
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
new file mode 100644
index 0000000..dae0442
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import android.content.Context;
+import android.text.TextUtils;
+import com.android.textclassifier.common.base.LocaleCompat;
+import com.google.common.base.Joiner;
+import com.google.common.base.Objects;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import javax.annotation.Nullable;
+
+/** Provide utils to generate and parse the result id. */
+public final class ResultIdUtils {
+ private static final String CLASSIFIER_ID = "androidtc";
+ private static final String SEPARATOR_MODEL_NAME = ";";
+ private static final String SEPARATOR_LOCALES = ",";
+ private static final Pattern EXTRACT_MODEL_NAME_FROM_RESULT_ID =
+ Pattern.compile("^[^|]*\\|([^|]*)\\|[^|]*$");
+
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(
+ Context context, String text, int start, int end, List<Optional<ModelInfo>> modelInfos) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(modelInfos);
+ final int hash = Objects.hashCode(text, start, end, context.getPackageName());
+ return createId(hash, modelInfos);
+ }
+
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(int hash, List<Optional<ModelInfo>> modelInfos) {
+ Preconditions.checkNotNull(modelInfos);
+ final List<String> modelNames = new ArrayList<>();
+ for (Optional<ModelInfo> modelInfo : modelInfos) {
+ modelNames.add(modelInfo.transform(ModelInfo::toModelName).or(""));
+ }
+ return String.format(
+ Locale.US,
+ "%s|%s|%d",
+ CLASSIFIER_ID,
+ Joiner.on(SEPARATOR_MODEL_NAME).join(modelNames),
+ hash);
+ }
+
+ /** Returns if the result id was generated from the default text classifier. */
+ public static boolean isFromDefaultTextClassifier(String resultId) {
+ return resultId.startsWith(CLASSIFIER_ID + '|');
+ }
+
+ /** Returns all the model names encoded in the signature. */
+ public static ImmutableList<String> getModelNames(@Nullable String signature) {
+ if (TextUtils.isEmpty(signature)) {
+ return ImmutableList.of();
+ }
+ Matcher matcher = EXTRACT_MODEL_NAME_FROM_RESULT_ID.matcher(signature);
+ if (!matcher.find()) {
+ return ImmutableList.of();
+ }
+ return ImmutableList.copyOf(Splitter.on(SEPARATOR_MODEL_NAME).splitToList(matcher.group(1)));
+ }
+
+ private ResultIdUtils() {}
+
+ /** Model information of a model file. */
+ public static class ModelInfo {
+ private final String modelName;
+
+ public ModelInfo(int version, List<Locale> locales) {
+ this(version, createSupportedLanguageTagsString(locales));
+ }
+
+ /**
+ * Creates a {@link ModelInfo} object.
+ *
+ * @param version model version
+ * @param supportedLanguageTags a comma-separated string of bcp47 language tags of supported
+ * languages
+ */
+ public ModelInfo(int version, String supportedLanguageTags) {
+ this.modelName = createModelName(version, supportedLanguageTags);
+ }
+
+ private static String createSupportedLanguageTagsString(List<Locale> locales) {
+ List<String> languageTags = new ArrayList<>();
+ for (Locale locale : locales) {
+ languageTags.add(LocaleCompat.toLanguageTag(locale));
+ }
+ return Joiner.on(SEPARATOR_LOCALES).join(languageTags);
+ }
+
+ private static String createModelName(int version, String supportedLanguageTags) {
+ return String.format(Locale.US, "%s_v%d", supportedLanguageTags, version);
+ }
+
+ /** Returns a string representation of the model info. */
+ public String toModelName() {
+ return modelName;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java
new file mode 100644
index 0000000..e729201
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/logging/TextClassificationContext.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import androidx.annotation.NonNull;
+import com.google.common.base.Preconditions;
+import java.util.Locale;
+import javax.annotation.Nullable;
+
+/** A representation of the context in which text classification would be performed. */
+public final class TextClassificationContext {
+
+ private final String packageName;
+ private final String widgetType;
+ @Nullable private final String widgetVersion;
+
+ private TextClassificationContext(
+ String packageName, String widgetType, @Nullable String widgetVersion) {
+ this.packageName = Preconditions.checkNotNull(packageName);
+ this.widgetType = Preconditions.checkNotNull(widgetType);
+ this.widgetVersion = widgetVersion;
+ }
+
+ /** Returns the package name for the calling package. */
+ public String getPackageName() {
+ return packageName;
+ }
+
+ /** Returns the widget type for this classification context. */
+ public String getWidgetType() {
+ return widgetType;
+ }
+
+ /**
+ * Returns a custom version string for the widget type.
+ *
+ * @see #getWidgetType()
+ */
+ @Nullable
+ public String getWidgetVersion() {
+ return widgetVersion;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ Locale.US,
+ "TextClassificationContext{" + "packageName=%s, widgetType=%s, widgetVersion=%s}",
+ packageName,
+ widgetType,
+ widgetVersion);
+ }
+
+ /** A builder for building a TextClassification context. */
+ public static final class Builder {
+
+ private final String packageName;
+ private final String widgetType;
+
+ @Nullable private String widgetVersion;
+
+ /**
+ * Initializes a new builder for text classification context objects.
+ *
+ * @param packageName the name of the calling package
+ * @param widgetType the type of widget e.g. {@link
+ * android.view.textclassifier.TextClassifier#WIDGET_TYPE_TEXTVIEW}
+ * @return this builder
+ */
+ public Builder(String packageName, String widgetType) {
+ this.packageName = Preconditions.checkNotNull(packageName);
+ this.widgetType = Preconditions.checkNotNull(widgetType);
+ }
+
+ /**
+ * Sets an optional custom version string for the widget type.
+ *
+ * @return this builder
+ */
+ public Builder setWidgetVersion(@Nullable String widgetVersion) {
+ this.widgetVersion = widgetVersion;
+ return this;
+ }
+
+ /**
+ * Builds the text classification context object.
+ *
+ * @return the built TextClassificationContext object
+ */
+ @NonNull
+ public TextClassificationContext build() {
+ return new TextClassificationContext(packageName, this.widgetType, widgetVersion);
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/logging/TextClassificationSessionId.java b/java/src/com/android/textclassifier/common/logging/TextClassificationSessionId.java
new file mode 100644
index 0000000..abb6f7f
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/logging/TextClassificationSessionId.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import com.google.common.base.Preconditions;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.UUID;
+
+/** This class represents the id of a text classification session. */
+public final class TextClassificationSessionId {
+ private final String value;
+
+ /** Creates a new instance. */
+ public TextClassificationSessionId() {
+ this(UUID.randomUUID().toString());
+ }
+
+ private TextClassificationSessionId(String value) {
+ this.value = Preconditions.checkNotNull(value);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.US, "TextClassificationSessionId {%s}", value);
+ }
+
+ @Override
+ public int hashCode() {
+ return value.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ TextClassificationSessionId that = (TextClassificationSessionId) o;
+ return Objects.equals(value, that.value);
+ }
+
+ /**
+ * Flattens this id to a string.
+ *
+ * @return The flattened id.
+ */
+ public String getValue() {
+ return value;
+ }
+
+ /**
+ * Recovers a TextClassificationSessionId from a string of the form returned by {@link
+ * #getValue()}.
+ */
+ public static TextClassificationSessionId unflattenFromString(String value) {
+ return new TextClassificationSessionId(value);
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java
new file mode 100644
index 0000000..f34fb3d
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/logging/TextClassifierEvent.java
@@ -0,0 +1,816 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import android.os.Bundle;
+import androidx.annotation.IntDef;
+import com.google.common.base.Preconditions;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.Arrays;
+import java.util.Locale;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+/**
+ * This class represents events that are sent by components to a TextClassifier to report something
+ * of note that relates to a feature powered by the TextClassifier. The TextClassifier may log these
+ * events or use them to improve future responses to queries.
+ *
+ * <p>Each category of events has its their own subclass. Events of each type have an associated set
+ * of related properties. You can find their specification in the subclasses.
+ */
+public abstract class TextClassifierEvent {
+
+ /** Category of the event */
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({
+ CATEGORY_SELECTION,
+ CATEGORY_LINKIFY,
+ CATEGORY_CONVERSATION_ACTIONS,
+ CATEGORY_LANGUAGE_DETECTION
+ })
+ public @interface Category {
+ // For custom event categories, use range 1000+.
+ }
+
+ /**
+ * Smart selection
+ *
+ * @see TextSelectionEvent
+ */
+ public static final int CATEGORY_SELECTION = 1;
+ /**
+ * Linkify
+ *
+ * @see TextLinkifyEvent
+ */
+ public static final int CATEGORY_LINKIFY = 2;
+ /**
+ * Conversation actions
+ *
+ * @see ConversationActionsEvent
+ */
+ public static final int CATEGORY_CONVERSATION_ACTIONS = 3;
+ /**
+ * Language detection
+ *
+ * @see LanguageDetectionEvent
+ */
+ public static final int CATEGORY_LANGUAGE_DETECTION = 4;
+
+ /** Type of the event */
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({
+ TYPE_SELECTION_STARTED,
+ TYPE_SELECTION_MODIFIED,
+ TYPE_SMART_SELECTION_SINGLE,
+ TYPE_SMART_SELECTION_MULTI,
+ TYPE_AUTO_SELECTION,
+ TYPE_ACTIONS_SHOWN,
+ TYPE_LINK_CLICKED,
+ TYPE_OVERTYPE,
+ TYPE_COPY_ACTION,
+ TYPE_PASTE_ACTION,
+ TYPE_CUT_ACTION,
+ TYPE_SHARE_ACTION,
+ TYPE_SMART_ACTION,
+ TYPE_SELECTION_DRAG,
+ TYPE_SELECTION_DESTROYED,
+ TYPE_OTHER_ACTION,
+ TYPE_SELECT_ALL,
+ TYPE_SELECTION_RESET,
+ TYPE_MANUAL_REPLY,
+ TYPE_ACTIONS_GENERATED,
+ TYPE_LINKS_GENERATED
+ })
+ public @interface Type {
+ // For custom event types, use range 1,000,000+.
+ }
+
+ // All these event type constants are required to match with those defined in
+ // textclassifier_enums.proto.
+ /** User started a new selection. */
+ public static final int TYPE_SELECTION_STARTED = 1;
+ /** User modified an existing selection. */
+ public static final int TYPE_SELECTION_MODIFIED = 2;
+ /** Smart selection triggered for a single token (word). */
+ public static final int TYPE_SMART_SELECTION_SINGLE = 3;
+ /** Smart selection triggered spanning multiple tokens (words). */
+ public static final int TYPE_SMART_SELECTION_MULTI = 4;
+ /** Something else other than user or the default TextClassifier triggered a selection. */
+ public static final int TYPE_AUTO_SELECTION = 5;
+ /** Smart actions shown to the user. */
+ public static final int TYPE_ACTIONS_SHOWN = 6;
+ /** User clicked a link. */
+ public static final int TYPE_LINK_CLICKED = 7;
+ /** User typed over the selection. */
+ public static final int TYPE_OVERTYPE = 8;
+ /** User clicked on Copy action. */
+ public static final int TYPE_COPY_ACTION = 9;
+ /** User clicked on Paste action. */
+ public static final int TYPE_PASTE_ACTION = 10;
+ /** User clicked on Cut action. */
+ public static final int TYPE_CUT_ACTION = 11;
+ /** User clicked on Share action. */
+ public static final int TYPE_SHARE_ACTION = 12;
+ /** User clicked on a Smart action. */
+ public static final int TYPE_SMART_ACTION = 13;
+ /** User dragged+dropped the selection. */
+ public static final int TYPE_SELECTION_DRAG = 14;
+ /** Selection is destroyed. */
+ public static final int TYPE_SELECTION_DESTROYED = 15;
+ /** User clicked on a custom action. */
+ public static final int TYPE_OTHER_ACTION = 16;
+ /** User clicked on Select All action */
+ public static final int TYPE_SELECT_ALL = 17;
+ /** User reset the smart selection. */
+ public static final int TYPE_SELECTION_RESET = 18;
+ /** User composed a reply. */
+ public static final int TYPE_MANUAL_REPLY = 19;
+ /** TextClassifier generated some actions */
+ public static final int TYPE_ACTIONS_GENERATED = 20;
+ /** Some text links were generated. */
+ public static final int TYPE_LINKS_GENERATED = 21;
+
+ @Category private final int eventCategory;
+ @Type private final int eventType;
+ @Nullable private final String[] entityTypes;
+ @Nullable private TextClassificationContext eventContext;
+ @Nullable private final String resultId;
+ private final int eventIndex;
+ private final float[] scores;
+ @Nullable private final String modelName;
+ private final int[] actionIndices;
+ @Nullable private final Locale locale;
+ private final Bundle extras;
+
+ private TextClassifierEvent(Builder<?> builder) {
+ eventCategory = builder.eventCategory;
+ eventType = builder.eventType;
+ entityTypes = builder.entityTypes;
+ eventContext = builder.eventContext;
+ resultId = builder.resultId;
+ eventIndex = builder.eventIndex;
+ scores = builder.scores;
+ modelName = builder.modelName;
+ actionIndices = builder.actionIndices;
+ locale = builder.locale;
+ extras = builder.extras == null ? Bundle.EMPTY : builder.extras;
+ }
+
+ /** Returns the event category. e.g. {@link #CATEGORY_SELECTION}. */
+ @Category
+ public int getEventCategory() {
+ return eventCategory;
+ }
+
+ /** Returns the event type. e.g. {@link #TYPE_SELECTION_STARTED}. */
+ @Type
+ public int getEventType() {
+ return eventType;
+ }
+
+ /**
+ * Returns an array of entity types. e.g. {@link TextClassifier#TYPE_ADDRESS}.
+ *
+ * @see Builder#setEntityTypes(String...) for supported types.
+ */
+ @Nonnull
+ public String[] getEntityTypes() {
+ return entityTypes;
+ }
+
+ /** Returns the event context. */
+ @Nullable
+ public TextClassificationContext getEventContext() {
+ return eventContext;
+ }
+
+ /**
+ * Sets the event context.
+ *
+ * <p>Package-private for SystemTextClassifier's use.
+ */
+ void setEventContext(@Nullable TextClassificationContext eventContext) {
+ this.eventContext = eventContext;
+ }
+
+ /** Returns the id of the text classifier result related to this event. */
+ @Nullable
+ public String getResultId() {
+ return resultId;
+ }
+
+ /** Returns the index of this event in the series of event it belongs to. */
+ public int getEventIndex() {
+ return eventIndex;
+ }
+
+ /** Returns the scores of the suggestions. */
+ public float[] getScores() {
+ return scores;
+ }
+
+ /** Returns the model name. */
+ @Nullable
+ public String getModelName() {
+ return modelName;
+ }
+
+ /**
+ * Returns the indices of the actions relating to this event. Actions are usually returned by the
+ * text classifier in priority order with the most preferred action at index 0. This list gives an
+ * indication of the position of the actions that are being reported.
+ *
+ * @see Builder#setActionIndices(int...)
+ */
+ public int[] getActionIndices() {
+ return actionIndices;
+ }
+
+ /** Returns the detected locale. */
+ @Nullable
+ public Locale getLocale() {
+ return locale;
+ }
+
+ /**
+ * Returns a bundle containing non-structured extra information about this event.
+ *
+ * <p><b>NOTE: </b>Do not modify this bundle.
+ */
+ @Nonnull
+ public Bundle getExtras() {
+ return extras;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder out = new StringBuilder(128);
+ out.append(this.getClass().getSimpleName());
+ out.append("{");
+ out.append("mEventCategory=").append(eventCategory);
+ out.append(", mEventType=").append(eventType);
+ out.append(", mEntityTypes=").append(Arrays.toString(entityTypes));
+ out.append(", mEventContext=").append(eventContext);
+ out.append(", mResultId=").append(resultId);
+ out.append(", mEventIndex=").append(eventIndex);
+ out.append(", mExtras=").append(extras);
+ out.append(", mScores=").append(Arrays.toString(scores));
+ out.append(", mModelName=").append(modelName);
+ out.append(", mActionIndices=").append(Arrays.toString(actionIndices));
+ toString(out);
+ out.append("}");
+ return out.toString();
+ }
+
+ /**
+ * Overrides this to append extra fields to the output of {@link #toString()}.
+ *
+ * <p>Extra fields should be formatted like this: ", {field_name}={field_value}".
+ */
+ void toString(StringBuilder out) {}
+
+ /**
+ * Builder to build a text classifier event.
+ *
+ * @param <T> The subclass to be built.
+ */
+ public abstract static class Builder<T extends Builder<T>> {
+
+ private final int eventCategory;
+ private final int eventType;
+ private String[] entityTypes = new String[0];
+ @Nullable private TextClassificationContext eventContext;
+ @Nullable private String resultId;
+ private int eventIndex;
+ private float[] scores = new float[0];
+ @Nullable private String modelName;
+ private int[] actionIndices = new int[0];
+ @Nullable private Locale locale;
+ @Nullable private Bundle extras;
+
+ /**
+ * Creates a builder for building {@link TextClassifierEvent}s.
+ *
+ * @param eventCategory The event category. e.g. {@link #CATEGORY_SELECTION}
+ * @param eventType The event type. e.g. {@link #TYPE_SELECTION_STARTED}
+ */
+ private Builder(@Category int eventCategory, @Type int eventType) {
+ this.eventCategory = eventCategory;
+ this.eventType = eventType;
+ }
+
+ /**
+ * Sets the entity types. e.g. {@link android.view.textclassifier.TextClassifier#TYPE_ADDRESS}.
+ *
+ * <p>Supported types:
+ *
+ * <p>See {@link android.view.textclassifier.TextClassifier.EntityType}
+ *
+ * <p>See {@link android.view.textclassifier.ConversationAction.ActionType}
+ *
+ * <p>See {@link Locale#toLanguageTag()}
+ */
+ public T setEntityTypes(String... entityTypes) {
+ Preconditions.checkNotNull(entityTypes);
+ this.entityTypes = new String[entityTypes.length];
+ System.arraycopy(entityTypes, 0, this.entityTypes, 0, entityTypes.length);
+ return self();
+ }
+
+ /** Sets the event context. */
+ public T setEventContext(@Nullable TextClassificationContext eventContext) {
+ this.eventContext = eventContext;
+ return self();
+ }
+
+ /** Sets the id of the text classifier result related to this event. */
+ @Nonnull
+ public T setResultId(@Nullable String resultId) {
+ this.resultId = resultId;
+ return self();
+ }
+
+ /** Sets the index of this event in the series of events it belongs to. */
+ @Nonnull
+ public T setEventIndex(int eventIndex) {
+ this.eventIndex = eventIndex;
+ return self();
+ }
+
+ /** Sets the scores of the suggestions. */
+ @Nonnull
+ public T setScores(@Nonnull float... scores) {
+ Preconditions.checkNotNull(scores);
+ this.scores = new float[scores.length];
+ System.arraycopy(scores, 0, this.scores, 0, scores.length);
+ return self();
+ }
+
+ /** Sets the model name string. */
+ @Nonnull
+ public T setModelName(@Nullable String modelVersion) {
+ modelName = modelVersion;
+ return self();
+ }
+
+ /**
+ * Sets the indices of the actions involved in this event. Actions are usually returned by the
+ * text classifier in priority order with the most preferred action at index 0. These indices
+ * give an indication of the position of the actions that are being reported.
+ *
+ * <p>E.g.
+ *
+ * <pre>
+ * // 3 smart actions are shown at index 0, 1, 2 respectively in response to a link click.
+ * new TextClassifierEvent.Builder(CATEGORY_LINKIFY, TYPE_ACTIONS_SHOWN)
+ * .setEventIndex(0, 1, 2)
+ * ...
+ * .build();
+ *
+ * ...
+ *
+ * // Smart action at index 1 is activated.
+ * new TextClassifierEvent.Builder(CATEGORY_LINKIFY, TYPE_SMART_ACTION)
+ * .setEventIndex(1)
+ * ...
+ * .build();
+ * </pre>
+ *
+ * @see android.view.textclassifier.TextClassification#getActions()
+ */
+ @Nonnull
+ public T setActionIndices(@Nonnull int... actionIndices) {
+ this.actionIndices = new int[actionIndices.length];
+ System.arraycopy(actionIndices, 0, this.actionIndices, 0, actionIndices.length);
+ return self();
+ }
+
+ /** Sets the detected locale. */
+ @Nonnull
+ public T setLocale(@Nullable Locale locale) {
+ this.locale = locale;
+ return self();
+ }
+
+ /**
+ * Sets a bundle containing non-structured extra information about the event.
+ *
+ * <p><b>NOTE: </b>Prefer to set only immutable values on the bundle otherwise, avoid updating
+ * the internals of this bundle as it may have unexpected consequences on the clients of the
+ * built event object. For similar reasons, avoid depending on mutable objects in this bundle.
+ */
+ @Nonnull
+ public T setExtras(@Nonnull Bundle extras) {
+ this.extras = Preconditions.checkNotNull(extras);
+ return self();
+ }
+
+ abstract T self();
+ }
+
+ /**
+ * This class represents events that are related to the smart text selection feature.
+ *
+ * <p>
+ *
+ * <pre>
+ * // User started a selection. e.g. "York" in text "New York City, NY".
+ * new TextSelectionEvent.Builder(TYPE_SELECTION_STARTED)
+ * .setEventContext(classificationContext)
+ * .setEventIndex(0)
+ * .build();
+ *
+ * // System smart-selects a recognized entity. e.g. "New York City".
+ * new TextSelectionEvent.Builder(TYPE_SMART_SELECTION_MULTI)
+ * .setEventContext(classificationContext)
+ * .setResultId(textSelection.getId())
+ * .setRelativeWordStartIndex(-1) // Goes back one word to "New" from "York".
+ * .setRelativeWordEndIndex(2) // Goes forward 2 words from "York" to start of ",".
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(1)
+ * .build();
+ *
+ * // User resets the selection to the original selection. i.e. "York".
+ * new TextSelectionEvent.Builder(TYPE_SELECTION_RESET)
+ * .setEventContext(classificationContext)
+ * .setResultId(textSelection.getId())
+ * .setRelativeSuggestedWordStartIndex(-1) // Repeated from above.
+ * .setRelativeSuggestedWordEndIndex(2) // Repeated from above.
+ * .setRelativeWordStartIndex(0) // Original selection is always at (0, 1].
+ * .setRelativeWordEndIndex(1)
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(2)
+ * .build();
+ *
+ * // User modified the selection. e.g. "New".
+ * new TextSelectionEvent.Builder(TYPE_SELECTION_MODIFIED)
+ * .setEventContext(classificationContext)
+ * .setResultId(textSelection.getId())
+ * .setRelativeSuggestedWordStartIndex(-1) // Repeated from above.
+ * .setRelativeSuggestedWordEndIndex(2) // Repeated from above.
+ * .setRelativeWordStartIndex(-1) // Goes backward one word from "York" to
+ * "New".
+ * .setRelativeWordEndIndex(0) // Goes backward one word to exclude "York".
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(3)
+ * .build();
+ *
+ * // Smart (contextual) actions (at indices, 0, 1, 2) presented to the user.
+ * // e.g. "Map", "Ride share", "Explore".
+ * new TextSelectionEvent.Builder(TYPE_ACTIONS_SHOWN)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setActionIndices(0, 1, 2)
+ * .setEventIndex(4)
+ * .build();
+ *
+ * // User chooses the "Copy" action.
+ * new TextSelectionEvent.Builder(TYPE_COPY_ACTION)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(5)
+ * .build();
+ *
+ * // User chooses smart action at index 1. i.e. "Ride share".
+ * new TextSelectionEvent.Builder(TYPE_SMART_ACTION)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setActionIndices(1)
+ * .setEventIndex(5)
+ * .build();
+ *
+ * // Selection dismissed.
+ * new TextSelectionEvent.Builder(TYPE_SELECTION_DESTROYED)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(6)
+ * .build();
+ * </pre>
+ *
+ * <p>
+ */
+ public static final class TextSelectionEvent extends TextClassifierEvent {
+
+ final int relativeWordStartIndex;
+ final int relativeWordEndIndex;
+ final int relativeSuggestedWordStartIndex;
+ final int relativeSuggestedWordEndIndex;
+
+ private TextSelectionEvent(TextSelectionEvent.Builder builder) {
+ super(builder);
+ relativeWordStartIndex = builder.relativeWordStartIndex;
+ relativeWordEndIndex = builder.relativeWordEndIndex;
+ relativeSuggestedWordStartIndex = builder.relativeSuggestedWordStartIndex;
+ relativeSuggestedWordEndIndex = builder.relativeSuggestedWordEndIndex;
+ }
+
+ /** Returns the relative word index of the start of the selection. */
+ public int getRelativeWordStartIndex() {
+ return relativeWordStartIndex;
+ }
+
+ /** Returns the relative word (exclusive) index of the end of the selection. */
+ public int getRelativeWordEndIndex() {
+ return relativeWordEndIndex;
+ }
+
+ /** Returns the relative word index of the start of the smart selection. */
+ public int getRelativeSuggestedWordStartIndex() {
+ return relativeSuggestedWordStartIndex;
+ }
+
+ /** Returns the relative word (exclusive) index of the end of the smart selection. */
+ public int getRelativeSuggestedWordEndIndex() {
+ return relativeSuggestedWordEndIndex;
+ }
+
+ @Override
+ void toString(StringBuilder out) {
+ out.append(", getRelativeWordStartIndex=").append(relativeWordStartIndex);
+ out.append(", getRelativeWordEndIndex=").append(relativeWordEndIndex);
+ out.append(", getRelativeSuggestedWordStartIndex=").append(relativeSuggestedWordStartIndex);
+ out.append(", getRelativeSuggestedWordEndIndex=").append(relativeSuggestedWordEndIndex);
+ }
+
+ /** Builder class for {@link TextSelectionEvent}. */
+ public static final class Builder
+ extends TextClassifierEvent.Builder<TextSelectionEvent.Builder> {
+ int relativeWordStartIndex;
+ int relativeWordEndIndex;
+ int relativeSuggestedWordStartIndex;
+ int relativeSuggestedWordEndIndex;
+
+ /**
+ * Creates a builder for building {@link TextSelectionEvent}s.
+ *
+ * @param eventType The event type. e.g. {@link #TYPE_SELECTION_STARTED}
+ */
+ public Builder(@Type int eventType) {
+ super(CATEGORY_SELECTION, eventType);
+ }
+
+ /** Sets the relative word index of the start of the selection. */
+ @Nonnull
+ public Builder setRelativeWordStartIndex(int relativeWordStartIndex) {
+ this.relativeWordStartIndex = relativeWordStartIndex;
+ return this;
+ }
+
+ /** Sets the relative word (exclusive) index of the end of the selection. */
+ @Nonnull
+ public Builder setRelativeWordEndIndex(int relativeWordEndIndex) {
+ this.relativeWordEndIndex = relativeWordEndIndex;
+ return this;
+ }
+
+ /** Sets the relative word index of the start of the smart selection. */
+ @Nonnull
+ public Builder setRelativeSuggestedWordStartIndex(int relativeSuggestedWordStartIndex) {
+ this.relativeSuggestedWordStartIndex = relativeSuggestedWordStartIndex;
+ return this;
+ }
+
+ /** Sets the relative word (exclusive) index of the end of the smart selection. */
+ @Nonnull
+ public Builder setRelativeSuggestedWordEndIndex(int relativeSuggestedWordEndIndex) {
+ this.relativeSuggestedWordEndIndex = relativeSuggestedWordEndIndex;
+ return this;
+ }
+
+ @Override
+ TextSelectionEvent.Builder self() {
+ return this;
+ }
+
+ /** Builds and returns a {@link TextSelectionEvent}. */
+ @Nonnull
+ public TextSelectionEvent build() {
+ return new TextSelectionEvent(this);
+ }
+ }
+ }
+
+ /**
+ * This class represents events that are related to the smart linkify feature.
+ *
+ * <p>
+ *
+ * <pre>
+ * // User clicked on a link.
+ * new TextLinkifyEvent.Builder(TYPE_LINK_CLICKED)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setEventIndex(0)
+ * .build();
+ *
+ * // Smart (contextual) actions presented to the user in response to a link click.
+ * new TextLinkifyEvent.Builder(TYPE_ACTIONS_SHOWN)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setActionIndices(range(textClassification.getActions().size()))
+ * .setEventIndex(1)
+ * .build();
+ *
+ * // User chooses smart action at index 0.
+ * new TextLinkifyEvent.Builder(TYPE_SMART_ACTION)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(textClassification.getEntity(0))
+ * .setScore(textClassification.getConfidenceScore(entityType))
+ * .setActionIndices(0)
+ * .setEventIndex(2)
+ * .build();
+ * </pre>
+ */
+ public static final class TextLinkifyEvent extends TextClassifierEvent {
+
+ private TextLinkifyEvent(TextLinkifyEvent.Builder builder) {
+ super(builder);
+ }
+
+ /** Builder class for {@link TextLinkifyEvent}. */
+ public static final class Builder
+ extends TextClassifierEvent.Builder<TextLinkifyEvent.Builder> {
+ /**
+ * Creates a builder for building {@link TextLinkifyEvent}s.
+ *
+ * @param eventType The event type. e.g. {@link #TYPE_SMART_ACTION}
+ */
+ public Builder(@Type int eventType) {
+ super(TextClassifierEvent.CATEGORY_LINKIFY, eventType);
+ }
+
+ @Override
+ Builder self() {
+ return this;
+ }
+
+ /** Builds and returns a {@link TextLinkifyEvent}. */
+ @Nonnull
+ public TextLinkifyEvent build() {
+ return new TextLinkifyEvent(this);
+ }
+ }
+ }
+
+ /**
+ * This class represents events that are related to the language detection feature.
+ * <p>
+ * <pre>
+ * // Translate action shown for foreign text.
+ * new LanguageDetectionEvent.Builder(TYPE_ACTIONS_SHOWN)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(language)
+ * .setScore(score)
+ * .setActionIndices(textClassification.getActions().indexOf(translateAction))
+ * .setEventIndex(0)
+ * .build();
+ *
+ * // Translate action selected.
+ * new LanguageDetectionEvent.Builder(TYPE_SMART_ACTION)
+ * .setEventContext(classificationContext)
+ * .setResultId(textClassification.getId())
+ * .setEntityTypes(language)
+ * .setScore(score)
+ * .setActionIndices(textClassification.getActions().indexOf(translateAction))
+ * .setEventIndex(1)
+ * .build();
+ */
+ public static final class LanguageDetectionEvent extends TextClassifierEvent {
+
+ private LanguageDetectionEvent(LanguageDetectionEvent.Builder builder) {
+ super(builder);
+ }
+
+ /** Builder class for {@link LanguageDetectionEvent}. */
+ public static final class Builder
+ extends TextClassifierEvent.Builder<LanguageDetectionEvent.Builder> {
+
+ /**
+ * Creates a builder for building {@link TextSelectionEvent}s.
+ *
+ * @param eventType The event type. e.g. {@link #TYPE_SMART_ACTION}
+ */
+ public Builder(@Type int eventType) {
+ super(TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION, eventType);
+ }
+
+ @Override
+ Builder self() {
+ return this;
+ }
+
+ /** Builds and returns a {@link LanguageDetectionEvent}. */
+ @Nonnull
+ public LanguageDetectionEvent build() {
+ return new LanguageDetectionEvent(this);
+ }
+ }
+ }
+
+ /**
+ * This class represents events that are related to the conversation actions feature.
+ *
+ * <p>
+ *
+ * <pre>
+ * // Conversation (contextual) actions/replies generated.
+ * new ConversationActionsEvent.Builder(TYPE_ACTIONS_GENERATED)
+ * .setEventContext(classificationContext)
+ * .setResultId(conversationActions.getId())
+ * .setEntityTypes(getTypes(conversationActions))
+ * .setActionIndices(range(conversationActions.getActions().size()))
+ * .setEventIndex(0)
+ * .build();
+ *
+ * // Conversation actions/replies presented to user.
+ * new ConversationActionsEvent.Builder(TYPE_ACTIONS_SHOWN)
+ * .setEventContext(classificationContext)
+ * .setResultId(conversationActions.getId())
+ * .setEntityTypes(getTypes(conversationActions))
+ * .setActionIndices(range(conversationActions.getActions().size()))
+ * .setEventIndex(1)
+ * .build();
+ *
+ * // User clicked the "Reply" button to compose their custom reply.
+ * new ConversationActionsEvent.Builder(TYPE_MANUAL_REPLY)
+ * .setEventContext(classificationContext)
+ * .setResultId(conversationActions.getId())
+ * .setEventIndex(2)
+ * .build();
+ *
+ * // User selected a smart (contextual) action/reply.
+ * new ConversationActionsEvent.Builder(TYPE_SMART_ACTION)
+ * .setEventContext(classificationContext)
+ * .setResultId(conversationActions.getId())
+ * .setEntityTypes(conversationActions.get(1).getType())
+ * .setScore(conversationAction.get(1).getConfidenceScore())
+ * .setActionIndices(1)
+ * .setEventIndex(2)
+ * .build();
+ * </pre>
+ */
+ public static final class ConversationActionsEvent extends TextClassifierEvent {
+
+ private ConversationActionsEvent(ConversationActionsEvent.Builder builder) {
+ super(builder);
+ }
+
+ /** Builder class for {@link ConversationActionsEvent}. */
+ public static final class Builder
+ extends TextClassifierEvent.Builder<ConversationActionsEvent.Builder> {
+ /**
+ * Creates a builder for building {@link TextSelectionEvent}s.
+ *
+ * @param eventType The event type. e.g. {@link #TYPE_SMART_ACTION}
+ */
+ public Builder(@Type int eventType) {
+ super(TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS, eventType);
+ }
+
+ @Override
+ Builder self() {
+ return this;
+ }
+
+ /** Builds and returns a {@link ConversationActionsEvent}. */
+ @Nonnull
+ public ConversationActionsEvent build() {
+ return new ConversationActionsEvent(this);
+ }
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
new file mode 100644
index 0000000..c132749
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/GenerateLinksLogger.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import android.util.StatsEvent;
+import android.util.StatsLog;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLinks;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.common.logging.TextClassifierEvent;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Optional;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+
+/** A helper for logging calls to generateLinks. */
+public final class GenerateLinksLogger {
+
+ private static final String LOG_TAG = "GenerateLinksLogger";
+
+ private final Random random;
+ private final int sampleRate;
+ private final Supplier<String> randomUuidSupplier;
+
+ /**
+ * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
+ * chance that a call to logGenerateLinks results in an event being written). To write all
+ * events, pass 1.
+ */
+ public GenerateLinksLogger(int sampleRate) {
+ this(sampleRate, () -> UUID.randomUUID().toString());
+ }
+
+ /**
+ * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
+ * chance that a call to logGenerateLinks results in an event being written). To write all
+ * events, pass 1.
+ * @param randomUuidSupplier supplies random UUIDs.
+ */
+ @VisibleForTesting
+ GenerateLinksLogger(int sampleRate, Supplier<String> randomUuidSupplier) {
+ this.sampleRate = sampleRate;
+ random = new Random();
+ this.randomUuidSupplier = Preconditions.checkNotNull(randomUuidSupplier);
+ }
+
+ /** Logs statistics about a call to generateLinks. */
+ public void logGenerateLinks(
+ CharSequence text,
+ TextLinks links,
+ String callingPackageName,
+ long latencyMs,
+ Optional<ModelInfo> annotatorModel,
+ Optional<ModelInfo> langIdModel) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(links);
+ Preconditions.checkNotNull(callingPackageName);
+ if (!shouldLog()) {
+ return;
+ }
+
+ // Always populate the total stats, and per-entity stats for each entity type detected.
+ final LinkifyStats totalStats = new LinkifyStats();
+ final Map<String, LinkifyStats> perEntityTypeStats = new ArrayMap<>();
+ for (TextLinks.TextLink link : links.getLinks()) {
+ if (link.getEntityCount() == 0) {
+ continue;
+ }
+ final String entityType = link.getEntity(0);
+ if (entityType == null
+ || TextClassifier.TYPE_OTHER.equals(entityType)
+ || TextClassifier.TYPE_UNKNOWN.equals(entityType)) {
+ continue;
+ }
+ totalStats.countLink(link);
+ perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
+ }
+
+ final String callId = randomUuidSupplier.get();
+ writeStats(
+ callId, callingPackageName, null, totalStats, text, latencyMs, annotatorModel, langIdModel);
+ // Sort the entity types to ensure the logging order is deterministic.
+ ImmutableList<String> sortedEntityTypes =
+ ImmutableList.sortedCopyOf(perEntityTypeStats.keySet());
+ for (String entityType : sortedEntityTypes) {
+ writeStats(
+ callId,
+ callingPackageName,
+ entityType,
+ perEntityTypeStats.get(entityType),
+ text,
+ latencyMs,
+ annotatorModel,
+ langIdModel);
+ }
+ }
+
+ /**
+ * Returns whether this particular event should be logged.
+ *
+ * <p>Sampling is used to reduce the amount of logging data generated.
+ */
+ private boolean shouldLog() {
+ if (sampleRate <= 1) {
+ return true;
+ } else {
+ return random.nextInt(sampleRate) == 0;
+ }
+ }
+
+ /** Writes a log event for the given stats. */
+ private static void writeStats(
+ String callId,
+ String callingPackageName,
+ @Nullable String entityType,
+ LinkifyStats stats,
+ CharSequence text,
+ long latencyMs,
+ Optional<ModelInfo> annotatorModel,
+ Optional<ModelInfo> langIdModel) {
+ String annotatorModelName = annotatorModel.transform(ModelInfo::toModelName).or("");
+ String langIdModelName = langIdModel.transform(ModelInfo::toModelName).or("");
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TextClassifierEventLogger.TEXT_LINKIFY_EVENT_ATOM_ID)
+ .writeString(callId)
+ .writeInt(TextClassifierEvent.TYPE_LINKS_GENERATED)
+ .writeString(annotatorModelName)
+ .writeInt(TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN)
+ .writeInt(/* eventIndex */ 0)
+ .writeString(entityType)
+ .writeInt(stats.numLinks)
+ .writeInt(stats.numLinksTextLength)
+ .writeInt(text.length())
+ .writeLong(latencyMs)
+ .writeString(callingPackageName)
+ .writeString(langIdModelName)
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ LOG_TAG,
+ String.format(
+ Locale.US,
+ "%s:%s %d links (%d/%d chars) %dms %s annotator=%s langid=%s",
+ callId,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName,
+ annotatorModelName,
+ langIdModelName));
+ }
+ }
+
+ /** Helper class for storing per-entity type statistics. */
+ private static final class LinkifyStats {
+ int numLinks;
+ int numLinksTextLength;
+
+ void countLink(TextLinks.TextLink link) {
+ numLinks += 1;
+ numLinksTextLength += link.getEnd() - link.getStart();
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/SelectionEventConverter.java b/java/src/com/android/textclassifier/common/statsd/SelectionEventConverter.java
new file mode 100644
index 0000000..72ed968
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/SelectionEventConverter.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassifierEvent;
+import javax.annotation.Nullable;
+
+/** Helper class to convert a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
+public final class SelectionEventConverter {
+
+ /** Converts a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
+ @Nullable
+ public static TextClassifierEvent toTextClassifierEvent(SelectionEvent selectionEvent) {
+ TextClassificationContext textClassificationContext = null;
+ if (selectionEvent.getPackageName() != null && selectionEvent.getWidgetType() != null) {
+ textClassificationContext =
+ new TextClassificationContext.Builder(
+ selectionEvent.getPackageName(), selectionEvent.getWidgetType())
+ .setWidgetVersion(selectionEvent.getWidgetVersion())
+ .build();
+ }
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_LINK) {
+ return new TextClassifierEvent.TextLinkifyEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .build();
+ }
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_MANUAL) {
+ return new TextClassifierEvent.TextSelectionEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .setRelativeWordStartIndex(selectionEvent.getStart())
+ .setRelativeWordEndIndex(selectionEvent.getEnd())
+ .setRelativeSuggestedWordStartIndex(selectionEvent.getSmartStart())
+ .setRelativeSuggestedWordEndIndex(selectionEvent.getSmartEnd())
+ .build();
+ }
+ return null;
+ }
+
+ private static int convertEventType(int eventType) {
+ switch (eventType) {
+ case SelectionEvent.EVENT_SELECTION_STARTED:
+ return TextClassifierEvent.TYPE_SELECTION_STARTED;
+ case SelectionEvent.EVENT_SELECTION_MODIFIED:
+ return TextClassifierEvent.TYPE_SELECTION_MODIFIED;
+ case SelectionEvent.EVENT_SMART_SELECTION_SINGLE:
+ return SelectionEvent.EVENT_SMART_SELECTION_SINGLE;
+ case SelectionEvent.EVENT_SMART_SELECTION_MULTI:
+ return SelectionEvent.EVENT_SMART_SELECTION_MULTI;
+ case SelectionEvent.EVENT_AUTO_SELECTION:
+ return SelectionEvent.EVENT_AUTO_SELECTION;
+ case SelectionEvent.ACTION_OVERTYPE:
+ return TextClassifierEvent.TYPE_OVERTYPE;
+ case SelectionEvent.ACTION_COPY:
+ return TextClassifierEvent.TYPE_COPY_ACTION;
+ case SelectionEvent.ACTION_PASTE:
+ return TextClassifierEvent.TYPE_PASTE_ACTION;
+ case SelectionEvent.ACTION_CUT:
+ return TextClassifierEvent.TYPE_CUT_ACTION;
+ case SelectionEvent.ACTION_SHARE:
+ return TextClassifierEvent.TYPE_SHARE_ACTION;
+ case SelectionEvent.ACTION_SMART_SHARE:
+ return TextClassifierEvent.TYPE_SMART_ACTION;
+ case SelectionEvent.ACTION_DRAG:
+ return TextClassifierEvent.TYPE_SELECTION_DRAG;
+ case SelectionEvent.ACTION_ABANDON:
+ return TextClassifierEvent.TYPE_SELECTION_DESTROYED;
+ case SelectionEvent.ACTION_OTHER:
+ return TextClassifierEvent.TYPE_OTHER_ACTION;
+ case SelectionEvent.ACTION_SELECT_ALL:
+ return TextClassifierEvent.TYPE_SELECT_ALL;
+ case SelectionEvent.ACTION_RESET:
+ return TextClassifierEvent.TYPE_SELECTION_RESET;
+ default:
+ return 0;
+ }
+ }
+
+ private SelectionEventConverter() {}
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverter.java b/java/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverter.java
new file mode 100644
index 0000000..4364c0a
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverter.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import com.android.textclassifier.common.logging.TextClassificationSessionId;
+import javax.annotation.Nullable;
+
+/**
+ * Converts between {@link TextClassificationSessionId} and {@link
+ * android.view.textclassifier.TextClassificationSessionId}.
+ */
+public final class TextClassificationSessionIdConverter {
+
+ private TextClassificationSessionIdConverter() {}
+
+ @Nullable
+ public static TextClassificationSessionId fromPlatform(
+ @Nullable android.view.textclassifier.TextClassificationSessionId sessionId) {
+ if (sessionId == null) {
+ return null;
+ }
+ return TextClassificationSessionId.unflattenFromString(sessionId.getValue());
+ }
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventConverter.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventConverter.java
new file mode 100644
index 0000000..bcfb012
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventConverter.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.logging.TextClassificationContext;
+import com.android.textclassifier.common.logging.TextClassifierEvent;
+import com.android.textclassifier.common.logging.TextClassifierEvent.ConversationActionsEvent;
+import com.android.textclassifier.common.logging.TextClassifierEvent.LanguageDetectionEvent;
+import com.android.textclassifier.common.logging.TextClassifierEvent.TextLinkifyEvent;
+import com.android.textclassifier.common.logging.TextClassifierEvent.TextSelectionEvent;
+import javax.annotation.Nullable;
+
+/**
+ * Converts between {@link TextClassifierEvent} and {@link
+ * android.view.textclassifier.TextClassifierEvent}.
+ */
+public final class TextClassifierEventConverter {
+ private static final String TAG = "TextClassifierEventConv";
+
+ /**
+ * Converts a {@link android.view.textclassifier.TextClassifierEvent} object to a {@link
+ * TextClassifierEvent}. Returns {@code null} if conversion fails.
+ */
+ @Nullable
+ public static TextClassifierEvent fromPlatform(
+ @Nullable android.view.textclassifier.TextClassifierEvent textClassifierEvent) {
+ if (textClassifierEvent == null) {
+ return null;
+ }
+ if (textClassifierEvent
+ instanceof android.view.textclassifier.TextClassifierEvent.TextSelectionEvent) {
+ return fromPlatform(
+ (android.view.textclassifier.TextClassifierEvent.TextSelectionEvent) textClassifierEvent);
+ } else if (textClassifierEvent
+ instanceof android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent) {
+ return fromPlatform(
+ (android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent) textClassifierEvent);
+ } else if (textClassifierEvent
+ instanceof android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent) {
+ return fromPlatform(
+ (android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent)
+ textClassifierEvent);
+ } else if (textClassifierEvent
+ instanceof android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent) {
+ return fromPlatform(
+ (android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent)
+ textClassifierEvent);
+ }
+ TcLog.w(TAG, "Unexpected event: " + textClassifierEvent);
+ return null;
+ }
+
+ private static TextSelectionEvent fromPlatform(
+ android.view.textclassifier.TextClassifierEvent.TextSelectionEvent textSelectionEvent) {
+ TextSelectionEvent.Builder builder =
+ new TextSelectionEvent.Builder(textSelectionEvent.getEventType());
+ copyCommonFields(textSelectionEvent, builder);
+ return builder
+ .setRelativeWordStartIndex(textSelectionEvent.getRelativeWordStartIndex())
+ .setRelativeWordEndIndex(textSelectionEvent.getRelativeWordEndIndex())
+ .setRelativeSuggestedWordStartIndex(textSelectionEvent.getRelativeSuggestedWordStartIndex())
+ .setRelativeSuggestedWordEndIndex(textSelectionEvent.getRelativeSuggestedWordEndIndex())
+ .build();
+ }
+
+ private static TextLinkifyEvent fromPlatform(
+ android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent textLinkifyEvent) {
+ TextLinkifyEvent.Builder builder =
+ new TextLinkifyEvent.Builder(textLinkifyEvent.getEventType());
+ copyCommonFields(textLinkifyEvent, builder);
+ return builder.build();
+ }
+
+ private static ConversationActionsEvent fromPlatform(
+ android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent
+ conversationActionsEvent) {
+ ConversationActionsEvent.Builder builder =
+ new ConversationActionsEvent.Builder(conversationActionsEvent.getEventType());
+ copyCommonFields(conversationActionsEvent, builder);
+ return builder.build();
+ }
+
+ private static LanguageDetectionEvent fromPlatform(
+ android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent
+ languageDetectionEvent) {
+ LanguageDetectionEvent.Builder builder =
+ new LanguageDetectionEvent.Builder(languageDetectionEvent.getEventType());
+ copyCommonFields(languageDetectionEvent, builder);
+ return builder.build();
+ }
+
+ @Nullable
+ private static TextClassificationContext fromPlatform(
+ @Nullable android.view.textclassifier.TextClassificationContext textClassificationContext) {
+ if (textClassificationContext == null) {
+ return null;
+ }
+ return new TextClassificationContext.Builder(
+ textClassificationContext.getPackageName(), textClassificationContext.getWidgetType())
+ .setWidgetVersion(textClassificationContext.getWidgetVersion())
+ .build();
+ }
+
+ private static void copyCommonFields(
+ android.view.textclassifier.TextClassifierEvent sourceEvent,
+ TextClassifierEvent.Builder<?> destBuilder) {
+ destBuilder
+ .setActionIndices(sourceEvent.getActionIndices())
+ .setEventContext(fromPlatform(sourceEvent.getEventContext()))
+ .setEntityTypes(sourceEvent.getEntityTypes())
+ .setEventIndex(sourceEvent.getEventIndex())
+ .setExtras(sourceEvent.getExtras())
+ .setLocale(sourceEvent.getLocale() == null ? null : sourceEvent.getLocale().toLocale())
+ .setModelName(sourceEvent.getModelName())
+ .setResultId(sourceEvent.getResultId())
+ .setScores(sourceEvent.getScores());
+ }
+
+ private TextClassifierEventConverter() {}
+}
diff --git a/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
new file mode 100644
index 0000000..41f546c
--- /dev/null
+++ b/java/src/com/android/textclassifier/common/statsd/TextClassifierEventLogger.java
@@ -0,0 +1,278 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.base.Charsets.UTF_8;
+import static com.google.common.base.Strings.nullToEmpty;
+
+import android.util.StatsEvent;
+import android.util.StatsLog;
+import android.view.textclassifier.TextClassifier;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.common.logging.ResultIdUtils;
+import com.android.textclassifier.common.logging.TextClassificationContext;
+import com.android.textclassifier.common.logging.TextClassificationSessionId;
+import com.android.textclassifier.common.logging.TextClassifierEvent;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.hash.Hashing;
+import java.util.List;
+import javax.annotation.Nullable;
+
+/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
+public final class TextClassifierEventLogger {
+ private static final String TAG = "TCEventLogger";
+ // These constants are defined in atoms.proto.
+ private static final int TEXT_SELECTION_EVENT_ATOM_ID = 219;
+ static final int TEXT_LINKIFY_EVENT_ATOM_ID = 220;
+ private static final int CONVERSATION_ACTIONS_EVENT_ATOM_ID = 221;
+ private static final int LANGUAGE_DETECTION_EVENT_ATOM_ID = 222;
+
+ /** Emits a text classifier event to the logs. */
+ public void writeEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ Preconditions.checkNotNull(event);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ TAG,
+ String.format(
+ "TextClassifierEventLogger.writeEvent: sessionId=%s,event=%s", sessionId, event));
+ }
+ if (event instanceof TextClassifierEvent.TextSelectionEvent) {
+ logTextSelectionEvent(sessionId, (TextClassifierEvent.TextSelectionEvent) event);
+ } else if (event instanceof TextClassifierEvent.TextLinkifyEvent) {
+ logTextLinkifyEvent(sessionId, (TextClassifierEvent.TextLinkifyEvent) event);
+ } else if (event instanceof TextClassifierEvent.ConversationActionsEvent) {
+ logConversationActionsEvent(sessionId, (TextClassifierEvent.ConversationActionsEvent) event);
+ } else if (event instanceof TextClassifierEvent.LanguageDetectionEvent) {
+ logLanguageDetectionEvent(sessionId, (TextClassifierEvent.LanguageDetectionEvent) event);
+ } else {
+ TcLog.w(TAG, "Unexpected events, category=" + event.getEventCategory());
+ }
+ }
+
+ private static void logTextSelectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.TextSelectionEvent event) {
+ ImmutableList<String> modelNames = getModelNames(event);
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TEXT_SELECTION_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.getValue())
+ .writeInt(getEventType(event))
+ .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
+ .writeInt(getWidgetType(event))
+ .writeInt(event.getEventIndex())
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeInt(event.getRelativeWordStartIndex())
+ .writeInt(event.getRelativeWordEndIndex())
+ .writeInt(event.getRelativeSuggestedWordStartIndex())
+ .writeInt(event.getRelativeSuggestedWordEndIndex())
+ .writeString(getPackageName(event))
+ .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+ }
+
+ private static int getEventType(TextClassifierEvent.TextSelectionEvent event) {
+ if (event.getEventType() == TextClassifierEvent.TYPE_AUTO_SELECTION) {
+ if (ResultIdUtils.isFromDefaultTextClassifier(event.getResultId())) {
+ return event.getRelativeWordEndIndex() - event.getRelativeWordStartIndex() > 1
+ ? TextClassifierEvent.TYPE_SMART_SELECTION_MULTI
+ : TextClassifierEvent.TYPE_SMART_SELECTION_SINGLE;
+ }
+ }
+ return event.getEventType();
+ }
+
+ private static void logTextLinkifyEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
+ ImmutableList<String> modelNames = getModelNames(event);
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(TEXT_LINKIFY_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.getValue())
+ .writeInt(event.getEventType())
+ .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
+ .writeInt(getWidgetType(event))
+ .writeInt(event.getEventIndex())
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeInt(/* numOfLinks */ 0)
+ .writeInt(/* linkedTextLength */ 0)
+ .writeInt(/* textLength */ 0)
+ .writeLong(/* latencyInMillis */ 0L)
+ .writeString(getPackageName(event))
+ .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+ }
+
+ private static void logConversationActionsEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.ConversationActionsEvent event) {
+ String resultId = nullToEmpty(event.getResultId());
+ ImmutableList<String> modelNames = ResultIdUtils.getModelNames(resultId);
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(CONVERSATION_ACTIONS_EVENT_ATOM_ID)
+ // TODO: Update ExtServices to set the session id.
+ .writeString(
+ sessionId == null
+ ? Hashing.goodFastHash(64).hashString(resultId, UTF_8).toString()
+ : sessionId.getValue())
+ .writeInt(event.getEventType())
+ .writeString(getItemAt(modelNames, /* index= */ 0, /* defaultValue= */ null))
+ .writeInt(getWidgetType(event))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 1))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 2))
+ .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
+ .writeString(getPackageName(event))
+ .writeString(getItemAt(modelNames, /* index= */ 1, /* defaultValue= */ null))
+ .writeString(getItemAt(modelNames, /* index= */ 2, /* defaultValue= */ null))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+ }
+
+ private static void logLanguageDetectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.LanguageDetectionEvent event) {
+ StatsEvent statsEvent =
+ StatsEvent.newBuilder()
+ .setAtomId(LANGUAGE_DETECTION_EVENT_ATOM_ID)
+ .writeString(sessionId == null ? null : sessionId.getValue())
+ .writeInt(event.getEventType())
+ .writeString(getItemAt(getModelNames(event), /* index= */ 0, /* defaultValue= */ null))
+ .writeInt(getWidgetType(event))
+ .writeString(getItemAt(event.getEntityTypes(), /* index= */ 0))
+ .writeFloat(getFloatAt(event.getScores(), /* index= */ 0))
+ .writeInt(getIntAt(event.getActionIndices(), /* index= */ 0))
+ .writeString(getPackageName(event))
+ .usePooledBuffer()
+ .build();
+ StatsLog.write(statsEvent);
+ }
+
+ @Nullable
+ private static <T> T getItemAt(List<T> list, int index, T defaultValue) {
+ if (list == null) {
+ return defaultValue;
+ }
+ if (index >= list.size()) {
+ return defaultValue;
+ }
+ return list.get(index);
+ }
+
+ @Nullable
+ private static <T> T getItemAt(@Nullable T[] array, int index) {
+ if (array == null) {
+ return null;
+ }
+ if (index >= array.length) {
+ return null;
+ }
+ return array[index];
+ }
+
+ private static float getFloatAt(@Nullable float[] array, int index) {
+ if (array == null) {
+ return 0f;
+ }
+ if (index >= array.length) {
+ return 0f;
+ }
+ return array[index];
+ }
+
+ private static int getIntAt(@Nullable int[] array, int index) {
+ if (array == null) {
+ return 0;
+ }
+ if (index >= array.length) {
+ return 0;
+ }
+ return array[index];
+ }
+
+ private static ImmutableList<String> getModelNames(TextClassifierEvent event) {
+ if (event.getModelName() != null) {
+ return ImmutableList.of(event.getModelName());
+ }
+ return ResultIdUtils.getModelNames(event.getResultId());
+ }
+
+ @Nullable
+ private static String getPackageName(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return null;
+ }
+ return eventContext.getPackageName();
+ }
+
+ private static int getWidgetType(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
+ switch (eventContext.getWidgetType()) {
+ case TextClassifier.WIDGET_TYPE_UNKNOWN:
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ case TextClassifier.WIDGET_TYPE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_NOTIFICATION:
+ return WidgetType.WIDGET_TYPE_NOTIFICATION;
+ default: // fall out
+ }
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
+
+ /** Widget type constants for logging. */
+ public static final class WidgetType {
+ // Sync these constants with textclassifier_enums.proto.
+ public static final int WIDGET_TYPE_UNKNOWN = 0;
+ public static final int WIDGET_TYPE_TEXTVIEW = 1;
+ public static final int WIDGET_TYPE_EDITTEXT = 2;
+ public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
+ public static final int WIDGET_TYPE_WEBVIEW = 4;
+ public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5;
+ public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
+ public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
+ public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
+ public static final int WIDGET_TYPE_NOTIFICATION = 9;
+
+ private WidgetType() {}
+ }
+}
diff --git a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
new file mode 100644
index 0000000..bd48c22
--- /dev/null
+++ b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.utils;
+
+import com.google.common.base.Preconditions;
+import java.io.PrintWriter;
+
+/**
+ * A print writer that supports indentation.
+ *
+ * @see PrintWriter
+ */
+public final class IndentingPrintWriter {
+ static final String SINGLE_INDENT = " ";
+
+ private final PrintWriter writer;
+ private final StringBuilder indentBuilder = new StringBuilder();
+ private String currentIndent = "";
+
+ public IndentingPrintWriter(PrintWriter writer) {
+ this.writer = Preconditions.checkNotNull(writer);
+ }
+
+ /** Prints a string. */
+ public IndentingPrintWriter println(String string) {
+ writer.print(currentIndent);
+ writer.print(string);
+ writer.println();
+ return this;
+ }
+
+ /** Prints a empty line */
+ public IndentingPrintWriter println() {
+ writer.println();
+ return this;
+ }
+
+ /** Increases indents for subsequent texts. */
+ public IndentingPrintWriter increaseIndent() {
+ indentBuilder.append(SINGLE_INDENT);
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
+
+ /** Decreases indents for subsequent texts. */
+ public IndentingPrintWriter decreaseIndent() {
+ indentBuilder.delete(0, SINGLE_INDENT.length());
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
+
+ /** Prints a key-valued pair. */
+ public IndentingPrintWriter printPair(String key, Object value) {
+ println(String.format("%s=%s", key, String.valueOf(value)));
+ return this;
+ }
+
+ /** Flushes the stream. */
+ public void flush() {
+ writer.flush();
+ }
+}
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
new file mode 100644
index 0000000..15ec570
--- /dev/null
+++ b/java/tests/instrumentation/Android.bp
@@ -0,0 +1,56 @@
+//
+// Copyright (C) 2019 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+android_test {
+ name: "TextClassifierServiceTest",
+
+ manifest: "AndroidManifest.xml",
+
+ srcs: [
+ "src/**/*.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ "androidx.test.espresso.core",
+ "androidx.test.ext.truth",
+ "mockito-target-minus-junit4",
+ "ub-uiautomator",
+ "testng",
+ "compatibility-device-util-axt",
+ "androidx.room_room-runtime",
+ "TextClassifierServiceLib",
+ "statsdprotolite",
+ "textclassifierprotoslite",
+ ],
+
+ jni_libs: [
+ "libtextclassifier",
+ "libdexmakerjvmtiagent"
+ ],
+
+ test_suites: [
+ "device-tests", "mts"
+ ],
+
+ plugins: ["androidx.room_room-compiler-plugin",],
+ platform_apis: true,
+ use_embedded_native_libs: true,
+ compile_multilib: "both",
+
+ instrumentation_for: "TextClassifierService",
+}
\ No newline at end of file
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
new file mode 100644
index 0000000..4964caf
--- /dev/null
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.tests">
+
+ <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="30"/>
+ <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.textclassifier.tests"/>
+</manifest>
diff --git a/java/tests/instrumentation/AndroidTest.xml b/java/tests/instrumentation/AndroidTest.xml
new file mode 100644
index 0000000..e02a338
--- /dev/null
+++ b/java/tests/instrumentation/AndroidTest.xml
@@ -0,0 +1,33 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<!-- This test config file is auto-generated. -->
+<configuration description="Runs TextClassifierServiceTest.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TextClassifierServiceTest.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.tests" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+
+ <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+ <option name="mainline-module-package-name" value="com.google.android.extservices" />
+ </object>
+</configuration>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
new file mode 100644
index 0000000..59dc41a
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
@@ -0,0 +1,306 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_OTHERS;
+import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_SELF;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.PendingIntent;
+import android.app.Person;
+import android.app.RemoteAction;
+import android.content.ComponentName;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.net.Uri;
+import android.os.Bundle;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.intent.LabeledIntent;
+import com.android.textclassifier.common.intent.TemplateIntentFactory;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import com.google.common.collect.ImmutableList;
+import java.time.Instant;
+import java.time.ZoneId;
+import java.time.ZonedDateTime;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.function.Function;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ActionsSuggestionsHelperTest {
+ private static final String LOCALE_TAG = Locale.US.toLanguageTag();
+ private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR =
+ charSequence -> Collections.singletonList(LOCALE_TAG);
+
+ @Test
+ public void testToNativeMessages_emptyInput() {
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(ImmutableList.of(), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).isEmpty();
+ }
+
+ @Test
+ public void testToNativeMessages_noTextMessages() {
+ ConversationActions.Message messageWithoutText =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS).build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ ImmutableList.of(messageWithoutText), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).isEmpty();
+ }
+
+ @Test
+ public void testToNativeMessages_userIdEncoding() {
+ Person.Builder userA = new Person.Builder().setName("userA").setKey("A");
+ Person.Builder userB = new Person.Builder().setName("userB").setKey("B");
+
+ ConversationActions.Message firstMessage =
+ new ConversationActions.Message.Builder(userB.build()).setText("first").build();
+ ConversationActions.Message secondMessage =
+ new ConversationActions.Message.Builder(userA.build()).setText("second").build();
+ ConversationActions.Message thirdMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_SELF).setText("third").build();
+ ConversationActions.Message fourthMessage =
+ new ConversationActions.Message.Builder(userA.build()).setText("fourth").build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ Arrays.asList(firstMessage, secondMessage, thirdMessage, fourthMessage),
+ LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).hasLength(4);
+ assertNativeMessage(conversationMessages[0], firstMessage.getText(), 2, 0);
+ assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
+ assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 0, 0);
+ assertNativeMessage(conversationMessages[3], fourthMessage.getText(), 1, 0);
+ }
+
+ @Test
+ public void testToNativeMessages_referenceTime() {
+ ConversationActions.Message firstMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
+ .setText("first")
+ .setReferenceTime(createZonedDateTimeFromMsUtc(1000))
+ .build();
+ ConversationActions.Message secondMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS).setText("second").build();
+ ConversationActions.Message thirdMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
+ .setText("third")
+ .setReferenceTime(createZonedDateTimeFromMsUtc(2000))
+ .build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ Arrays.asList(firstMessage, secondMessage, thirdMessage), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).hasLength(3);
+ assertNativeMessage(conversationMessages[0], firstMessage.getText(), 1, 1000);
+ assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
+ assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 1, 2000);
+ }
+
+ @Test
+ public void testDeduplicateActions() {
+ Bundle phoneExtras = new Bundle();
+ Intent phoneIntent = new Intent();
+ phoneIntent.setComponent(new ComponentName("phone", "intent"));
+ ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
+
+ Bundle anotherPhoneExtras = new Bundle();
+ Intent anotherPhoneIntent = new Intent();
+ anotherPhoneIntent.setComponent(new ComponentName("phone", "another.intent"));
+ ExtrasUtils.putActionIntent(anotherPhoneExtras, anotherPhoneIntent);
+
+ Bundle urlExtras = new Bundle();
+ Intent urlIntent = new Intent();
+ urlIntent.setComponent(new ComponentName("url", "intent"));
+ ExtrasUtils.putActionIntent(urlExtras, urlIntent);
+
+ PendingIntent pendingIntent =
+ PendingIntent.getActivity(ApplicationProvider.getApplicationContext(), 0, phoneIntent, 0);
+ Icon icon = Icon.createWithData(new byte[0], 0, 0);
+ ConversationAction action =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSameLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSamePackageButDifferentClass =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "3", pendingIntent))
+ .setExtras(anotherPhoneExtras)
+ .build();
+ ConversationAction actionWithDifferentLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "another_label", "4", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithDifferentPackage =
+ new ConversationAction.Builder(ConversationAction.TYPE_OPEN_URL)
+ .setAction(new RemoteAction(icon, "label", "5", pendingIntent))
+ .setExtras(urlExtras)
+ .build();
+ ConversationAction actionWithoutRemoteAction =
+ new ConversationAction.Builder(ConversationAction.TYPE_CREATE_REMINDER).build();
+
+ List<ConversationAction> conversationActions =
+ ActionsSuggestionsHelper.removeActionsWithDuplicates(
+ Arrays.asList(
+ action,
+ actionWithSameLabel,
+ actionWithSamePackageButDifferentClass,
+ actionWithDifferentLabel,
+ actionWithDifferentPackage,
+ actionWithoutRemoteAction));
+
+ assertThat(conversationActions).hasSize(3);
+ assertThat(conversationActions.get(0).getAction().getContentDescription().toString())
+ .isEqualTo("4");
+ assertThat(conversationActions.get(1).getAction().getContentDescription().toString())
+ .isEqualTo("5");
+ assertThat(conversationActions.get(2).getAction()).isNull();
+ }
+
+ @Test
+ public void testDeduplicateActions_nullComponent() {
+ Bundle phoneExtras = new Bundle();
+ Intent phoneIntent = new Intent(Intent.ACTION_DIAL);
+ ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
+ PendingIntent pendingIntent =
+ PendingIntent.getActivity(ApplicationProvider.getApplicationContext(), 0, phoneIntent, 0);
+ Icon icon = Icon.createWithData(new byte[0], 0, 0);
+ ConversationAction action =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSameLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+
+ List<ConversationAction> conversationActions =
+ ActionsSuggestionsHelper.removeActionsWithDuplicates(
+ Arrays.asList(action, actionWithSameLabel));
+
+ assertThat(conversationActions).isEmpty();
+ }
+
+ @Test
+ public void createLabeledIntentResult_null() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text", ConversationAction.TYPE_OPEN_URL, 1.0f, null, null, null);
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult).isNull();
+ }
+
+ @Test
+ public void createLabeledIntentResult_emptyList() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text",
+ ConversationAction.TYPE_OPEN_URL,
+ 1.0f,
+ null,
+ null,
+ new RemoteActionTemplate[0]);
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult).isNull();
+ }
+
+ @Test
+ public void createLabeledIntentResult() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text",
+ ConversationAction.TYPE_OPEN_URL,
+ 1.0f,
+ null,
+ null,
+ new RemoteActionTemplate[] {
+ new RemoteActionTemplate(
+ "title",
+ null,
+ "description",
+ null,
+ Intent.ACTION_VIEW,
+ Uri.parse("http://www.android.com").toString(),
+ null,
+ 0,
+ null,
+ null,
+ null,
+ 0)
+ });
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult.remoteAction.getTitle().toString()).isEqualTo("title");
+ assertThat(labeledIntentResult.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ }
+
+ private static ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) {
+ return ZonedDateTime.ofInstant(Instant.ofEpochMilli(msUtc), ZoneId.of("UTC"));
+ }
+
+ private static void assertNativeMessage(
+ ActionsSuggestionsModel.ConversationMessage nativeMessage,
+ CharSequence text,
+ int userId,
+ long referenceTimeInMsUtc) {
+ assertThat(nativeMessage.getText()).isEqualTo(text.toString());
+ assertThat(nativeMessage.getUserId()).isEqualTo(userId);
+ assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
+ assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
new file mode 100644
index 0000000..06d47d6
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -0,0 +1,385 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ModelFileManagerTest {
+ private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+ @Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
+ private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
+ private ModelFileManager modelFileManager;
+ private File rootTestDir;
+ private File factoryModelDir;
+ private File updatedModelFile;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ modelFileManager = new ModelFileManager(modelFileSupplier);
+ rootTestDir = ApplicationProvider.getApplicationContext().getCacheDir();
+ factoryModelDir = new File(rootTestDir, "factory");
+ updatedModelFile = new File(rootTestDir, "updated.model");
+
+ modelFileSupplierImpl =
+ new ModelFileManager.ModelFileSupplierImpl(
+ factoryModelDir,
+ "test\\d.model",
+ updatedModelFile,
+ fd -> 1,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
+
+ rootTestDir.mkdirs();
+ factoryModelDir.mkdirs();
+
+ Locale.setDefault(DEFAULT_LOCALE);
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(rootTestDir);
+ }
+
+ @Test
+ public void get() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
+
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles();
+
+ assertThat(modelFiles).hasSize(1);
+ assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFileManager.ModelFile olderModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile newerModelFile =
+ new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(locale),
+ locale.toLanguageTag(),
+ false);
+ when(modelFileSupplier.get())
+ .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags(locale.toLanguageTag()));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, Collections.emptyList(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(locale),
+ locale.toLanguageTag(),
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(DEFAULT_LOCALE),
+ DEFAULT_LOCALE.toLanguageTag(),
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFileManager.ModelFile matchButOlderModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("fr")),
+ "fr",
+ false);
+
+ ModelFileManager.ModelFile mismatchButNewerModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("fr"));
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
+ ModelFileManager.ModelFile matchLocaleModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile languageIndependentModel =
+ new ModelFileManager.ModelFile(new File("/path/a"), 2, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get())
+ .thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("ja"));
+
+ assertThat(bestModelFile).isEqualTo(matchLocaleModel);
+ }
+
+ @Test
+ public void modelFileEquals() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_getPath() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA.getPath()).isEqualTo("/path/a");
+ }
+
+ @Test
+ public void modelFile_getName() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA.getName()).isEqualTo("a");
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(new File("/path/b"), 1, Collections.emptyList(), "", false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_toModelInfo() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+
+ ModelInfo modelInfo = modelFile.toModelInfo();
+
+ assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
+ }
+
+ @Test
+ public void modelFile_toModelInfos() {
+ ModelFile englishModelFile =
+ new ModelFile(new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
+ ModelFile japaneseModelFile =
+ new ModelFile(new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+
+ ImmutableList<Optional<ModelInfo>> modelInfos =
+ ModelFileManager.ModelFile.toModelInfos(
+ Optional.of(englishModelFile), Optional.of(japaneseModelFile));
+
+ assertThat(
+ modelInfos.stream()
+ .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
+ .collect(Collectors.toList()))
+ .containsExactly("en_v1", "ja_v2")
+ .inOrder();
+ }
+
+ @Test
+ public void testFileSupplierImpl_updatedFileOnly() throws IOException {
+ updatedModelFile.createNewFile();
+ File model1 = new File(factoryModelDir, "test1.model");
+ model1.createNewFile();
+ File model2 = new File(factoryModelDir, "test2.model");
+ model2.createNewFile();
+ new File(factoryModelDir, "not_match_regex.model").createNewFile();
+
+ List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ List<String> modelFilePaths =
+ modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
+
+ assertThat(modelFiles).hasSize(3);
+ assertThat(modelFilePaths)
+ .containsExactly(
+ updatedModelFile.getAbsolutePath(), model1.getAbsolutePath(), model2.getAbsolutePath());
+ }
+
+ @Test
+ public void testFileSupplierImpl_empty() {
+ factoryModelDir.delete();
+ List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+
+ assertThat(modelFiles).hasSize(0);
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
new file mode 100644
index 0000000..6d80673
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -0,0 +1,653 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.hamcrest.CoreMatchers.not;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.testng.Assert.assertThrows;
+
+import android.app.RemoteAction;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Bundle;
+import android.os.LocaleList;
+import android.text.Spannable;
+import android.text.SpannableString;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextSelection;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.testing.FakeContextBuilder;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierImplTest {
+
+ private static final String TYPE_COPY = "copy";
+ private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
+ private static final String NO_TYPE = null;
+
+ private TextClassifierImpl classifier;
+
+ @Before
+ public void setup() {
+ Context context =
+ new FakeContextBuilder()
+ .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
+ .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
+ .build();
+ classifier = new TextClassifierImpl(context, new TextClassifierSettings());
+ }
+
+ @Test
+ public void testSuggestSelection() {
+ String text = "Contact me at droid@android.com";
+ String selected = "droid";
+ String suggested = "droid@android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(
+ selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
+ }
+
+ @Test
+ public void testSuggestSelection_url() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
+ }
+
+ @Test
+ public void testSmartSelection_withEmoji() {
+ String text = "\uD83D\uDE02 Hello.";
+ String selected = "Hello";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
+ }
+
+ @Test
+ public void testClassifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
+ }
+
+ @Test
+ public void testClassifyText_url() {
+ String text = "Visit www.android.com for more information";
+ String classifiedText = "www.android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
+ assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
+ }
+
+ @Test
+ public void testClassifyText_address() {
+ String text = "Brandschenkestrasse 110, Zürich, Switzerland";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, 0, text.length())
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
+ }
+
+ @Test
+ public void testClassifyText_url_inCaps() {
+ String text = "Visit HTTP://ANDROID.COM for more information";
+ String classifiedText = "HTTP://ANDROID.COM";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
+ assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
+ }
+
+ @Test
+ public void testClassifyText_date() {
+ String text = "Let's meet on January 9, 2018.";
+ String classifiedText = "January 9, 2018";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
+ Bundle extras = classification.getExtras();
+ List<Bundle> entities = ExtrasUtils.getEntities(extras);
+ assertThat(entities).hasSize(1);
+ assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE);
+ ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification);
+ actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras);
+ }
+
+ @Test
+ public void testClassifyText_datetime() {
+ String text = "Let's meet 2018/01/01 10:30:20.";
+ String classifiedText = "2018/01/01 10:30:20";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
+ }
+
+ // TODO(tonymak): Enable it once we drop the v8 image to Android. I have already run this test
+ // after pushing a test model to a device manually.
+ @Ignore
+ @Test
+ public void testClassifyText_foreignText() {
+ LocaleList originalLocales = LocaleList.getDefault();
+ LocaleList.setDefault(LocaleList.forLanguageTags("en"));
+ String japaneseText = "これは日本語のテキストです";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ RemoteAction translateAction = classification.getActions().get(0);
+ assertEquals(1, classification.getActions().size());
+ assertEquals("Translate", translateAction.getTitle().toString());
+
+ assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
+ Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
+ assertNoPackageInfoInExtras(intent);
+ assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
+ Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
+ assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
+ assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
+ assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
+ assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
+ assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first);
+
+ LocaleList.setDefault(originalLocales);
+ }
+
+ @Test
+ public void testGenerateLinks_phone() {
+ String text = "The number is +12122537077. See you tonight!";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+ assertThat(
+ classifier.generateLinks(request),
+ isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
+ }
+
+ @Test
+ public void testGenerateLinks_exclude() {
+ String text = "You want apple@banana.com. See you tonight!";
+ List<String> hints = ImmutableList.of();
+ List<String> included = ImmutableList.of();
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ }
+
+ @Test
+ public void testGenerateLinks_explicit_address() {
+ String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
+ List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ isTextLinksContaining(
+ text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
+ }
+
+ @Test
+ public void testGenerateLinks_exclude_override() {
+ String text = "You want apple@banana.com. See you tonight!";
+ List<String> hints = ImmutableList.of();
+ List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ }
+
+ @Test
+ public void testGenerateLinks_maxLength() {
+ char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
+ Arrays.fill(manySpaces, ' ');
+ TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
+ TextLinks links = classifier.generateLinks(request);
+ assertTrue(links.getLinks().isEmpty());
+ }
+
+ @Test
+ public void testApplyLinks_unsupportedCharacter() {
+ Spannable url = new SpannableString("\u202Emoc.diordna.com");
+ TextLinks.Request request = new TextLinks.Request.Builder(url).build();
+ assertEquals(
+ TextLinks.STATUS_UNSUPPORTED_CHARACTER,
+ classifier.generateLinks(request).apply(url, 0, null));
+ }
+
+ @Test
+ public void testGenerateLinks_tooLong() {
+ char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
+ Arrays.fill(manySpaces, ' ');
+ TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
+ assertThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
+ }
+
+ @Test
+ public void testGenerateLinks_entityData() {
+ String text = "The number is +12122537077.";
+ Bundle extras = new Bundle();
+ ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
+ TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
+
+ TextLinks textLinks = classifier.generateLinks(request);
+
+ assertThat(textLinks.getLinks()).hasSize(1);
+ TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
+ List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
+ assertThat(entities).hasSize(1);
+ Bundle entity = entities.get(0);
+ assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
+ }
+
+ @Test
+ public void testGenerateLinks_entityData_disabled() {
+ String text = "The number is +12122537077.";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = classifier.generateLinks(request);
+
+ assertThat(textLinks.getLinks()).hasSize(1);
+ TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
+ List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
+ assertThat(entities).isNull();
+ }
+
+ @Test
+ public void testDetectLanguage() {
+ String text = "This is English text";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = classifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("en"));
+ }
+
+ @Test
+ public void testDetectLanguage_japanese() {
+ String text = "これは日本語のテキストです";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = classifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("ja"));
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Where are you?")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+ assertThat(conversationAction.getTextReply()).isNotNull();
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_textReplyOnly_noMax() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Where are you?")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ assertTrue(conversationActions.getConversationActions().size() > 1);
+ for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
+ assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
+ }
+ }
+
+ @Test
+ public void testSuggestConversationActions_openUrl() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
+ assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
+ assertNoPackageInfoInExtras(actionIntent);
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_copy() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Authentication code: 12345")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(TYPE_COPY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
+ assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
+ assertThat(conversationAction.getAction()).isNull();
+ String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
+ assertThat(code).isEqualTo("12345");
+ assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
+ }
+
+ @Test
+ public void testSuggestConversationActions_deduplicate() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("a@android.com b@android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(3)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+
+ assertThat(conversationActions.getConversationActions()).isEmpty();
+ }
+
+ private static void assertNoPackageInfoInExtras(Intent intent) {
+ assertThat(intent.getComponent()).isNull();
+ assertThat(intent.getPackage()).isNull();
+ }
+
+ private static Matcher<TextSelection> isTextSelection(
+ final int startIndex, final int endIndex, final String type) {
+ return new BaseMatcher<TextSelection>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextSelection) {
+ TextSelection selection = (TextSelection) o;
+ return startIndex == selection.getSelectionStartIndex()
+ && endIndex == selection.getSelectionEndIndex()
+ && typeMatches(selection, type);
+ }
+ return false;
+ }
+
+ private boolean typeMatches(TextSelection selection, String type) {
+ return type == null
+ || (selection.getEntityCount() > 0
+ && type.trim().equalsIgnoreCase(selection.getEntity(0)));
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
+ }
+ };
+ }
+
+ private static Matcher<TextLinks> isTextLinksContaining(
+ final String text, final String substring, final String type) {
+ return new BaseMatcher<TextLinks>() {
+
+ @Override
+ public void describeTo(Description description) {
+ description
+ .appendText("text=")
+ .appendValue(text)
+ .appendText(", substring=")
+ .appendValue(substring)
+ .appendText(", type=")
+ .appendValue(type);
+ }
+
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextLinks) {
+ for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
+ if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
+ return type.equals(link.getEntity(0));
+ }
+ }
+ }
+ return false;
+ }
+ };
+ }
+
+ private static Matcher<TextClassification> isTextClassification(
+ final String text, final String type) {
+ return new BaseMatcher<TextClassification>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextClassification) {
+ TextClassification result = (TextClassification) o;
+ return text.equals(result.getText())
+ && result.getEntityCount() > 0
+ && type.equals(result.getEntity(0));
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
+ }
+ };
+ }
+
+ private static Matcher<TextClassification> containsIntentWithAction(final String action) {
+ return new BaseMatcher<TextClassification>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextClassification) {
+ TextClassification result = (TextClassification) o;
+ return ExtrasUtils.findAction(result, action) != null;
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("intent action=").appendValue(action);
+ }
+ };
+ }
+
+ private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
+ return new BaseMatcher<TextLanguage>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextLanguage) {
+ TextLanguage result = (TextLanguage) o;
+ return result.getLocaleHypothesisCount() > 0
+ && languageTag.equals(result.getLocale(0).toLanguageTag());
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("locale=").appendValue(languageTag);
+ }
+ };
+ }
+
+ private static Matcher<ConversationAction> isConversationAction(String actionType) {
+ return new BaseMatcher<ConversationAction>() {
+ @Override
+ public boolean matches(Object o) {
+ if (!(o instanceof ConversationAction)) {
+ return false;
+ }
+ ConversationAction conversationAction = (ConversationAction) o;
+ if (!actionType.equals(conversationAction.getType())) {
+ return false;
+ }
+ if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
+ if (conversationAction.getTextReply() == null) {
+ return false;
+ }
+ }
+ if (conversationAction.getConfidenceScore() < 0
+ || conversationAction.getConfidenceScore() > 1) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("actionType=").appendValue(actionType);
+ }
+ };
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
new file mode 100644
index 0000000..21ed0b6
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.provider.DeviceConfig;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import java.util.function.Consumer;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierSettingsTest {
+ private static final String WRITE_DEVICE_CONFIG_PERMISSION =
+ "android.permission.WRITE_DEVICE_CONFIG";
+ private static final float EPSILON = 0.0001f;
+
+ @Before
+ public void setup() {
+ InstrumentationRegistry.getInstrumentation()
+ .getUiAutomation()
+ .adoptShellPermissionIdentity(WRITE_DEVICE_CONFIG_PERMISSION);
+ }
+
+ @After
+ public void tearDown() {
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+ }
+
+ @Test
+ public void booleanSetting() {
+ assertSettings(
+ TextClassifierSettings.TEMPLATE_INTENT_FACTORY_ENABLED,
+ "false",
+ settings -> assertThat(settings.isTemplateIntentFactoryEnabled()).isFalse());
+ }
+
+ @Test
+ public void intSetting() {
+ assertSettings(
+ TextClassifierSettings.SUGGEST_SELECTION_MAX_RANGE_LENGTH,
+ "8",
+ settings -> assertThat(settings.getSuggestSelectionMaxRangeLength()).isEqualTo(8));
+ }
+
+ @Test
+ public void floatSetting() {
+ assertSettings(
+ TextClassifierSettings.LANG_ID_THRESHOLD_OVERRIDE,
+ "3.14",
+ settings -> assertThat(settings.getLangIdThresholdOverride()).isWithin(EPSILON).of(3.14f));
+ }
+
+ @Test
+ public void stringListSetting() {
+ assertSettings(
+ TextClassifierSettings.ENTITY_LIST_DEFAULT,
+ "email:url",
+ settings ->
+ assertThat(settings.getEntityListDefault()).containsExactly("email", "url").inOrder());
+ }
+
+ @Test
+ public void floatListSetting() {
+ assertSettings(
+ TextClassifierSettings.LANG_ID_CONTEXT_SETTINGS,
+ "30:0.5:0.3",
+ settings ->
+ assertThat(settings.getLangIdContextSettings())
+ .usingTolerance(EPSILON)
+ .containsExactly(30f, 0.5f, 0.3f)
+ .inOrder());
+ }
+
+ private static void assertSettings(
+ String key, String value, Consumer<TextClassifierSettings> settingsConsumer) {
+ final String originalValue =
+ DeviceConfig.getProperty(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key);
+ TextClassifierSettings settings = new TextClassifierSettings();
+ try {
+ setDeviceConfig(key, value);
+ settingsConsumer.accept(settings);
+ } finally {
+ setDeviceConfig(key, originalValue);
+ }
+ }
+
+ private static void setDeviceConfig(String key, String value) {
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, value, /* makeDefault */ false);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/base/LocaleCompatTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/base/LocaleCompatTest.java
new file mode 100644
index 0000000..9e1f5a8
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/base/LocaleCompatTest.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.base;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SdkSuppress;
+import androidx.test.filters.SmallTest;
+import java.util.Locale;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class LocaleCompatTest {
+ @SdkSuppress(minSdkVersion = 24)
+ @Test
+ public void toLanguageTag_minApi24() {
+ Locale locale = Locale.TRADITIONAL_CHINESE;
+
+ String languageTags = LocaleCompat.toLanguageTag(locale);
+
+ assertThat(languageTags).isEqualTo("zh-TW");
+ }
+
+ @SdkSuppress(maxSdkVersion = 23)
+ @Test
+ public void toLanguageTag_base() {
+ Locale locale = Locale.TRADITIONAL_CHINESE;
+
+ String languageTags = LocaleCompat.toLanguageTag(locale);
+
+ assertThat(languageTags).isEqualTo("zh");
+ }
+
+ @SdkSuppress(minSdkVersion = 24)
+ @Test
+ public void getResourceLanguageTags_minApi24() {
+ ApplicationProvider.getApplicationContext()
+ .getResources()
+ .getConfiguration()
+ .setLocales(LocaleList.forLanguageTags("zh-TW"));
+
+ String resourceLanguageTags =
+ LocaleCompat.getResourceLanguageTags(ApplicationProvider.getApplicationContext());
+
+ assertThat(resourceLanguageTags).isEqualTo("zh-TW");
+ }
+
+ @SdkSuppress(minSdkVersion = 21, maxSdkVersion = 23)
+ @Test
+ public void getResourceLanguageTags_minApi21() {
+ ApplicationProvider.getApplicationContext()
+ .getResources()
+ .getConfiguration()
+ .setLocale(Locale.TAIWAN);
+
+ String resourceLanguageTags =
+ LocaleCompat.getResourceLanguageTags(ApplicationProvider.getApplicationContext());
+
+ assertThat(resourceLanguageTags).isEqualTo("zh-TW");
+ }
+
+ @SdkSuppress(maxSdkVersion = 20)
+ @Test
+ public void getResourceLanguageTags_base() {
+ ApplicationProvider.getApplicationContext().getResources().getConfiguration().locale =
+ Locale.TAIWAN;
+
+ String resourceLanguageTags =
+ LocaleCompat.getResourceLanguageTags(ApplicationProvider.getApplicationContext());
+
+ assertThat(resourceLanguageTags).isEqualTo("zh");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
new file mode 100644
index 0000000..a1d9dcf
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
@@ -0,0 +1,157 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.assertThrows;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.testing.FakeContextBuilder;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class LabeledIntentTest {
+ private static final String TITLE_WITHOUT_ENTITY = "Map";
+ private static final String TITLE_WITH_ENTITY = "Map NW14D1";
+ private static final String DESCRIPTION = "Check the map";
+ private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
+ private static final Intent INTENT =
+ new Intent(Intent.ACTION_VIEW).setDataAndNormalize(Uri.parse("http://www.android.com"));
+ private static final int REQUEST_CODE = 42;
+ private static final String APP_LABEL = "fake";
+
+ private Context context;
+
+ @Before
+ public void setup() {
+ final ComponentName component = FakeContextBuilder.DEFAULT_COMPONENT;
+ context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, component)
+ .setAppLabel(component.getPackageName(), APP_LABEL)
+ .build();
+ }
+
+ @Test
+ public void resolve_preferTitleWithEntity() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY, TITLE_WITH_ENTITY, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_useAvailableTitle() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_titleChooser() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(context, (labeledIntent1, resolveInfo) -> "chooser");
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo("chooser");
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_titleChooserReturnsNull() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(context, (labeledIntent1, resolveInfo) -> null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_missingTitle() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> new LabeledIntent(null, null, DESCRIPTION, null, INTENT, REQUEST_CODE));
+ }
+
+ @Test
+ public void resolve_noIntentHandler() {
+ // See setup(). context can only resolve Intent.ACTION_VIEW.
+ Intent unresolvableIntent = new Intent(Intent.ACTION_TRANSLATE);
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, unresolvableIntent, REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, null);
+
+ assertThat(result).isNull();
+ }
+
+ @Test
+ public void resolve_descriptionWithAppName() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ INTENT,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, /*titleChooser*/ null);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getContentDescription().toString())
+ .isEqualTo("Use fake to open map");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
new file mode 100644
index 0000000..ab241c5
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -0,0 +1,280 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Intent;
+import android.net.Uri;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.android.textclassifier.NamedVariant;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TemplateIntentFactoryTest {
+
+ private static final String TITLE_WITHOUT_ENTITY = "Map";
+ private static final String TITLE_WITH_ENTITY = "Map NW14D1";
+ private static final String DESCRIPTION = "Check the map";
+ private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
+ private static final String ACTION = Intent.ACTION_VIEW;
+ private static final String DATA = Uri.parse("http://www.android.com").toString();
+ private static final String TYPE = "text/html";
+ private static final Integer FLAG = Intent.FLAG_ACTIVITY_NEW_TASK;
+ private static final String[] CATEGORY =
+ new String[] {Intent.CATEGORY_DEFAULT, Intent.CATEGORY_APP_BROWSER};
+ private static final String PACKAGE_NAME = "pkg.name";
+ private static final String KEY_ONE = "key1";
+ private static final String VALUE_ONE = "value1";
+ private static final String KEY_TWO = "key2";
+ private static final int VALUE_TWO = 42;
+ private static final String KEY_STRING_ARRAY = "string_array_key";
+ private static final String[] VALUE_STRING_ARRAY = new String[] {"a", "b"};
+ private static final String KEY_FLOAT_ARRAY = "float_array_key";
+ private static final float[] VALUE_FLOAT_ARRAY = new float[] {3.14f, 2.718f};
+ private static final String KEY_INT_ARRAY = "int_array_key";
+ private static final int[] VALUE_INT_ARRAY = new int[] {7, 2, 1};
+
+ private static final NamedVariant[] NAMED_VARIANTS =
+ new NamedVariant[] {
+ new NamedVariant(KEY_ONE, VALUE_ONE),
+ new NamedVariant(KEY_TWO, VALUE_TWO),
+ new NamedVariant(KEY_STRING_ARRAY, VALUE_STRING_ARRAY),
+ new NamedVariant(KEY_FLOAT_ARRAY, VALUE_FLOAT_ARRAY),
+ new NamedVariant(KEY_INT_ARRAY, VALUE_INT_ARRAY)
+ };
+ private static final Integer REQUEST_CODE = 10;
+
+ private TemplateIntentFactory templateIntentFactory;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ templateIntentFactory = new TemplateIntentFactory();
+ }
+
+ @Test
+ public void create_full() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ DATA,
+ TYPE,
+ FLAG,
+ CATEGORY,
+ /* packageName */ null,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(labeledIntent.titleWithEntity).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
+ assertThat(labeledIntent.descriptionWithAppName).isEqualTo(DESCRIPTION_WITH_APP_NAME);
+ assertThat(labeledIntent.requestCode).isEqualTo(REQUEST_CODE);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+ assertThat(intent.getData().toString()).isEqualTo(DATA);
+ assertThat(intent.getType()).isEqualTo(TYPE);
+ assertThat(intent.getFlags()).isEqualTo(FLAG);
+ assertThat(intent.getCategories()).containsExactly((Object[]) CATEGORY);
+ assertThat(intent.getPackage()).isNull();
+ assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
+ assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
+ assertThat(intent.getStringArrayExtra(KEY_STRING_ARRAY)).isEqualTo(VALUE_STRING_ARRAY);
+ assertThat(intent.getFloatArrayExtra(KEY_FLOAT_ARRAY)).isEqualTo(VALUE_FLOAT_ARRAY);
+ assertThat(intent.getIntArrayExtra(KEY_INT_ARRAY)).isEqualTo(VALUE_INT_ARRAY);
+ assertThat(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)).isTrue();
+ }
+
+ @Test
+ public void normalizesScheme() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ "HTTp://www.android.com",
+ TYPE,
+ FLAG,
+ CATEGORY,
+ /* packageName */ null,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ String data = intents.get(0).intent.getData().toString();
+ assertThat(data).isEqualTo("http://www.android.com");
+ }
+
+ @Test
+ public void create_minimal() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ null,
+ DESCRIPTION,
+ null,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(labeledIntent.titleWithEntity).isNull();
+ assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
+ assertThat(labeledIntent.requestCode).isEqualTo(LabeledIntent.DEFAULT_REQUEST_CODE);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+ assertThat(intent.getData()).isNull();
+ assertThat(intent.getType()).isNull();
+ assertThat(intent.getFlags()).isEqualTo(0);
+ assertThat(intent.getCategories()).isNull();
+ assertThat(intent.getPackage()).isNull();
+ }
+
+ @Test
+ public void invalidTemplate_nullTemplate() {
+ RemoteActionTemplate remoteActionTemplate = null;
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_nonEmptyPackageName() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ DATA,
+ TYPE,
+ FLAG,
+ CATEGORY,
+ PACKAGE_NAME,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyTitle() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ null,
+ null,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyDescription() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ null,
+ null,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyIntentAction() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/logging/ResultIdUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/ResultIdUtilsTest.java
new file mode 100644
index 0000000..3a85061
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/ResultIdUtilsTest.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import java.util.Locale;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ResultIdUtilsTest {
+ private static final int MODEL_VERSION = 703;
+ private static final int HASH = 12345;
+
+ @Test
+ public void createId_customHash() {
+ ImmutableList<Optional<ModelInfo>> modelInfos =
+ ImmutableList.of(
+ Optional.absent(),
+ Optional.of(
+ new ModelInfo(/* version= */ 1, ImmutableList.of(Locale.ENGLISH, Locale.FRENCH))),
+ Optional.absent(),
+ Optional.of(new ModelInfo(/* version= */ 2, ImmutableList.of(Locale.CHINESE))),
+ Optional.absent());
+
+ String resultId = ResultIdUtils.createId(HASH, modelInfos);
+
+ assertThat(resultId).isEqualTo("androidtc|;en,fr_v1;;zh_v2;|12345");
+ }
+
+ @Test
+ public void createId_selection() {
+ String resultId =
+ ResultIdUtils.createId(
+ ApplicationProvider.getApplicationContext(),
+ "text",
+ 1,
+ 2,
+ ImmutableList.of(
+ Optional.of(new ModelInfo(MODEL_VERSION, ImmutableList.of(Locale.ENGLISH)))));
+
+ assertThat(resultId).matches("androidtc\\|en_v703\\|-?\\d+");
+ }
+
+ @Test
+ public void getModelName_invalid() {
+ assertThat(ResultIdUtils.getModelNames("a|b")).isEmpty();
+ }
+
+ @Test
+ public void getModelNames() {
+ assertThat(ResultIdUtils.getModelNames("androidtc|;en_v703;;zh_v101;|12344"))
+ .containsExactly("", "en_v703", "", "zh_v101", "")
+ .inOrder();
+ }
+
+ @Test
+ public void getModelNames_invalid() {
+ assertThat(ResultIdUtils.getModelNames("a|b")).isEmpty();
+ assertThat(ResultIdUtils.getModelNames("a|b|c|d")).isEmpty();
+ }
+
+ @Test
+ public void modelInfo_toModelName() {
+ ModelInfo modelInfo = new ModelInfo(700, ImmutableList.of(Locale.ENGLISH));
+
+ assertThat(modelInfo.toModelName()).isEqualTo("en_v700");
+ }
+
+ @Test
+ public void modelInfo_toModelName_supportedLanguageTags() {
+ ModelInfo modelInfo = new ModelInfo(700, "en,fr");
+
+ assertThat(modelInfo.toModelName()).isEqualTo("en,fr_v700");
+ }
+
+ @Test
+ public void isFromDefaultTextClassifier_true() {
+ assertThat(ResultIdUtils.isFromDefaultTextClassifier("androidtc|en_v703|12344")).isTrue();
+ }
+
+ @Test
+ public void isFromDefaultTextClassifier_false() {
+ assertThat(ResultIdUtils.isFromDefaultTextClassifier("aiai|en_v703|12344")).isFalse();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationContextTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationContextTest.java
new file mode 100644
index 0000000..37a0a83
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationContextTest.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.view.textclassifier.TextClassifier;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassificationContextTest {
+
+ @Test
+ public void minimumObject() {
+ TextClassificationContext textClassificationContext =
+ new TextClassificationContext.Builder("pkg", TextClassifier.WIDGET_TYPE_EDITTEXT).build();
+
+ assertThat(textClassificationContext.getPackageName()).isEqualTo("pkg");
+ assertThat(textClassificationContext.getWidgetType())
+ .isEqualTo(TextClassifier.WIDGET_TYPE_EDITTEXT);
+ }
+
+ @Test
+ public void fullObject() {
+ TextClassificationContext textClassificationContext =
+ new TextClassificationContext.Builder("pkg", TextClassifier.WIDGET_TYPE_EDITTEXT)
+ .setWidgetVersion("v1")
+ .build();
+
+ assertThat(textClassificationContext.getPackageName()).isEqualTo("pkg");
+ assertThat(textClassificationContext.getWidgetType())
+ .isEqualTo(TextClassifier.WIDGET_TYPE_EDITTEXT);
+ assertThat(textClassificationContext.getWidgetVersion()).isEqualTo("v1");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationEventTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationEventTest.java
new file mode 100644
index 0000000..3656512
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationEventTest.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.Bundle;
+import android.view.textclassifier.TextClassifier;
+import android.widget.TextView;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.util.Locale;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassificationEventTest {
+ @Test
+ public void testTextSelectionEvent_minimal() {
+ final TextClassifierEvent.TextSelectionEvent event =
+ new TextClassifierEvent.TextSelectionEvent.Builder(TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_SELECTION);
+ assertMinimumCommonFields(event);
+ assertThat(event.getRelativeWordStartIndex()).isEqualTo(0);
+ assertThat(event.getRelativeWordEndIndex()).isEqualTo(0);
+ assertThat(event.getRelativeSuggestedWordStartIndex()).isEqualTo(0);
+ assertThat(event.getRelativeSuggestedWordEndIndex()).isEqualTo(0);
+ }
+
+ @Test
+ public void testTextSelectionEvent_full() {
+ final TextClassifierEvent.TextSelectionEvent.Builder builder =
+ new TextClassifierEvent.TextSelectionEvent.Builder(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ TextClassifierEvent.TextSelectionEvent event =
+ builder
+ .setRelativeWordStartIndex(1)
+ .setRelativeWordEndIndex(2)
+ .setRelativeSuggestedWordStartIndex(-1)
+ .setRelativeSuggestedWordEndIndex(3)
+ .build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_SELECTION);
+ assertFullCommonFields(event);
+ assertThat(event.getRelativeWordStartIndex()).isEqualTo(1);
+ assertThat(event.getRelativeWordEndIndex()).isEqualTo(2);
+ assertThat(event.getRelativeSuggestedWordStartIndex()).isEqualTo(-1);
+ assertThat(event.getRelativeSuggestedWordEndIndex()).isEqualTo(3);
+ }
+
+ @Test
+ public void testTextLinkifyEvent_minimal() {
+ TextClassifierEvent.TextLinkifyEvent event =
+ new TextClassifierEvent.TextLinkifyEvent.Builder(TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LINKIFY);
+ assertMinimumCommonFields(event);
+ }
+
+ @Test
+ public void testTextLinkifyEvent_full() {
+ TextClassifierEvent.TextLinkifyEvent.Builder builder =
+ new TextClassifierEvent.TextLinkifyEvent.Builder(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ TextClassifierEvent.TextLinkifyEvent event = builder.build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LINKIFY);
+ assertFullCommonFields(event);
+ }
+
+ @Test
+ public void testConversationActionsEvent_minimal() {
+ TextClassifierEvent.ConversationActionsEvent event =
+ new TextClassifierEvent.ConversationActionsEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ assertThat(event.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS);
+ assertMinimumCommonFields(event);
+ }
+
+ @Test
+ public void testConversationActionsEvent_full() {
+ TextClassifierEvent.ConversationActionsEvent.Builder builder =
+ new TextClassifierEvent.ConversationActionsEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ TextClassifierEvent.ConversationActionsEvent event = builder.build();
+
+ assertThat(event.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS);
+ assertFullCommonFields(event);
+ }
+
+ @Test
+ public void testLanguageDetectionEventEvent_minimal() {
+ TextClassifierEvent.LanguageDetectionEvent event =
+ new TextClassifierEvent.LanguageDetectionEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION);
+ assertMinimumCommonFields(event);
+ }
+
+ @Test
+ public void testLanguageDetectionEvent_full() {
+ TextClassifierEvent.LanguageDetectionEvent.Builder builder =
+ new TextClassifierEvent.LanguageDetectionEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ TextClassifierEvent.LanguageDetectionEvent event = builder.build();
+
+ assertThat(event.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION);
+ assertFullCommonFields(event);
+ }
+
+ private static void setFullCommonFields(TextClassifierEvent.Builder<?> builder) {
+ Bundle extra = new Bundle();
+ extra.putString("key", "value");
+ builder
+ .setEventIndex(2)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS)
+ .setResultId("androidtc-en-v606-1234")
+ .setActionIndices(1, 2, 5)
+ .setExtras(extra)
+ .setEventContext(
+ new TextClassificationContext.Builder("pkg", TextClassifier.WIDGET_TYPE_TEXTVIEW)
+ .setWidgetVersion(TextView.class.getName())
+ .build())
+ .setScores(0.5f)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS, TextClassifier.TYPE_DATE)
+ .setLocale(Locale.US);
+ }
+
+ private static void assertFullCommonFields(TextClassifierEvent event) {
+ assertThat(event.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ assertThat(event.getEventIndex()).isEqualTo(2);
+ assertThat(event.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS, TextClassifier.TYPE_DATE);
+ assertThat(event.getResultId()).isEqualTo("androidtc-en-v606-1234");
+ assertThat(event.getActionIndices()).asList().containsExactly(1, 2, 5);
+ assertThat(event.getExtras().get("key")).isEqualTo("value");
+ assertThat(event.getEventContext().getPackageName()).isEqualTo("pkg");
+ assertThat(event.getEventContext().getWidgetType())
+ .isEqualTo(TextClassifier.WIDGET_TYPE_TEXTVIEW);
+ assertThat(event.getEventContext().getWidgetVersion()).isEqualTo(TextView.class.getName());
+ assertThat(event.getScores()).hasLength(1);
+ assertThat(event.getScores()[0]).isEqualTo(0.5f);
+ assertThat(event.getLocale().getLanguage()).isEqualTo("en");
+ assertThat(event.getLocale().getCountry()).isEqualTo("US");
+ }
+
+ private static void assertMinimumCommonFields(TextClassifierEvent event) {
+ assertThat(event.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ assertThat(event.getEventIndex()).isEqualTo(0);
+ assertThat(event.getEntityTypes()).isEmpty();
+ assertThat(event.getResultId()).isNull();
+ assertThat(event.getActionIndices()).isEmpty();
+ assertThat(event.getExtras().size()).isEqualTo(0);
+ assertThat(event.getEventContext()).isNull();
+ assertThat(event.getEntityTypes()).isEmpty();
+ assertThat(event.getLocale()).isNull();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationSessionIdTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationSessionIdTest.java
new file mode 100644
index 0000000..6dc45c8
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/logging/TextClassificationSessionIdTest.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassificationSessionIdTest {
+
+ @Test
+ public void getValue() {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+
+ assertThat(sessionId.getValue()).isNotEmpty();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
new file mode 100644
index 0000000..c2a911a
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/GenerateLinksLoggerTest.java
@@ -0,0 +1,183 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.stats.textclassifier.EventType;
+import android.stats.textclassifier.WidgetType;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLinks;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.AtomsProto.TextLinkifyEvent;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+@LargeTest
+public class GenerateLinksLoggerTest {
+ private static final String PACKAGE_NAME = "package.name";
+ private static final int LATENCY_MS = 123;
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private static final ModelInfo ANNOTATOR_MODEL =
+ new ModelInfo(1, ImmutableList.of(Locale.ENGLISH));
+ private static final ModelInfo LANGID_MODEL =
+ new ModelInfo(2, ImmutableList.of(Locale.forLanguageTag("*")));
+
+ @Before
+ public void setup() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_LINKIFY_EVENT_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void logGenerateLinks_allFieldsAreSet() throws Exception {
+ String phoneText = "+12122537077";
+ String testText = "The number is " + phoneText;
+ int phoneOffset = testText.indexOf(phoneText);
+ Map<String, Float> phoneEntityScores = ImmutableMap.of(TextClassifier.TYPE_PHONE, 1.0f);
+ TextLinks links =
+ new TextLinks.Builder(testText)
+ .addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
+ .build();
+ String uuid = "uuid";
+
+ GenerateLinksLogger generateLinksLogger =
+ new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ generateLinksLogger.logGenerateLinks(
+ testText,
+ links,
+ PACKAGE_NAME,
+ LATENCY_MS,
+ Optional.of(ANNOTATOR_MODEL),
+ Optional.of(LANGID_MODEL));
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+
+ ImmutableList<TextLinkifyEvent> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream().map(Atom::getTextLinkifyEvent).collect(Collectors.toList()));
+
+ assertThat(loggedEvents).hasSize(2);
+ TextLinkifyEvent summaryEvent =
+ AtomsProto.TextLinkifyEvent.newBuilder()
+ .setSessionId(uuid)
+ .setEventIndex(0)
+ .setModelName("en_v1")
+ .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setEventType(EventType.LINKS_GENERATED)
+ .setPackageName(PACKAGE_NAME)
+ .setEntityType("")
+ .setNumLinks(1)
+ .setTextLength(testText.length())
+ .setLinkedTextLength(phoneText.length())
+ .setLatencyMillis(LATENCY_MS)
+ .setLangidModelName("und_v2")
+ .build();
+ TextLinkifyEvent phoneEvent =
+ AtomsProto.TextLinkifyEvent.newBuilder()
+ .setSessionId(uuid)
+ .setEventIndex(0)
+ .setModelName("en_v1")
+ .setWidgetType(WidgetType.WIDGET_TYPE_UNKNOWN)
+ .setEventType(EventType.LINKS_GENERATED)
+ .setPackageName(PACKAGE_NAME)
+ .setEntityType(TextClassifier.TYPE_PHONE)
+ .setNumLinks(1)
+ .setTextLength(testText.length())
+ .setLinkedTextLength(phoneText.length())
+ .setLatencyMillis(LATENCY_MS)
+ .setLangidModelName("und_v2")
+ .build();
+ assertThat(loggedEvents).containsExactly(summaryEvent, phoneEvent).inOrder();
+ }
+
+ @Test
+ public void logGenerateLinks_multipleLinks() throws Exception {
+ String phoneText = "+12122537077";
+ String addressText = "1600 Amphitheater Parkway, Mountain View, CA";
+ String testText = "The number is " + phoneText + ", the address is " + addressText;
+ int phoneOffset = testText.indexOf(phoneText);
+ int addressOffset = testText.indexOf(addressText);
+ Map<String, Float> phoneEntityScores = ImmutableMap.of(TextClassifier.TYPE_PHONE, 1.0f);
+ Map<String, Float> addressEntityScores = ImmutableMap.of(TextClassifier.TYPE_ADDRESS, 1.0f);
+ TextLinks links =
+ new TextLinks.Builder(testText)
+ .addLink(phoneOffset, phoneOffset + phoneText.length(), phoneEntityScores)
+ .addLink(addressOffset, addressOffset + addressText.length(), addressEntityScores)
+ .build();
+ String uuid = "uuid";
+
+ GenerateLinksLogger generateLinksLogger =
+ new GenerateLinksLogger(/* sampleRate= */ 1, () -> uuid);
+ generateLinksLogger.logGenerateLinks(
+ testText,
+ links,
+ PACKAGE_NAME,
+ LATENCY_MS,
+ Optional.of(ANNOTATOR_MODEL),
+ Optional.of(LANGID_MODEL));
+ ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+
+ ImmutableList<TextLinkifyEvent> loggedEvents =
+ ImmutableList.copyOf(
+ loggedAtoms.stream().map(Atom::getTextLinkifyEvent).collect(Collectors.toList()));
+ assertThat(loggedEvents).hasSize(3);
+
+ TextLinkifyEvent summaryEvent = loggedEvents.get(0);
+ assertThat(summaryEvent.getEntityType()).isEmpty();
+ assertThat(summaryEvent.getNumLinks()).isEqualTo(2);
+ assertThat(summaryEvent.getLinkedTextLength())
+ .isEqualTo(phoneText.length() + addressText.length());
+
+ TextLinkifyEvent addressEvent = loggedEvents.get(1);
+ assertThat(addressEvent.getEntityType()).isEqualTo(TextClassifier.TYPE_ADDRESS);
+ assertThat(addressEvent.getNumLinks()).isEqualTo(1);
+ assertThat(addressEvent.getLinkedTextLength()).isEqualTo(addressText.length());
+
+ TextLinkifyEvent phoneEvent = loggedEvents.get(2);
+ assertThat(phoneEvent.getEntityType()).isEqualTo(TextClassifier.TYPE_PHONE);
+ assertThat(phoneEvent.getNumLinks()).isEqualTo(1);
+ assertThat(phoneEvent.getLinkedTextLength()).isEqualTo(phoneText.length());
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/SelectionEventConverterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/SelectionEventConverterTest.java
new file mode 100644
index 0000000..ecdc1f4
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/SelectionEventConverterTest.java
@@ -0,0 +1,198 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.logging.ResultIdUtils;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class SelectionEventConverterTest {
+ private static final String PKG_NAME = "com.pkg";
+ private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_EDITTEXT;
+ private static final int START = 2;
+ private static final int SMART_START = 1;
+ private static final int SMART_END = 3;
+ private TestTextClassifier testTextClassifier;
+ private TextClassifier session;
+
+ @Before
+ public void setup() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ testTextClassifier = new TestTextClassifier();
+ textClassificationManager.setTextClassifier(testTextClassifier);
+ session = textClassificationManager.createTextClassificationSession(createEventContext());
+ }
+
+ @Test
+ public void convert_started() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent textClassifierEvent =
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textClassifierEvent.getEventContext());
+ assertThat(textClassifierEvent.getEventIndex()).isEqualTo(0);
+ assertThat(textClassifierEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SELECTION_STARTED);
+ }
+
+ @Test
+ public void convert_smartSelection() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+ String resultId = createResultId();
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionActionEvent(
+ SMART_START,
+ SMART_END,
+ SelectionEvent.ACTION_SMART_SHARE,
+ new TextClassification.Builder()
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ (TextClassifierEvent.TextSelectionEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textSelectionEvent.getEventContext());
+ assertThat(textSelectionEvent.getRelativeWordStartIndex()).isEqualTo(-1);
+ assertThat(textSelectionEvent.getRelativeWordEndIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
+ assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ @Test
+ public void convert_smartShare() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+ String resultId = createResultId();
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionModifiedEvent(
+ SMART_START,
+ SMART_END,
+ new TextSelection.Builder(SMART_START, SMART_END)
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ (TextClassifierEvent.TextSelectionEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textSelectionEvent.getEventContext());
+ assertThat(textSelectionEvent.getRelativeSuggestedWordStartIndex()).isEqualTo(-1);
+ assertThat(textSelectionEvent.getRelativeSuggestedWordEndIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
+ assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ @Test
+ public void convert_smartLinkify() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_LINK, START));
+ String resultId = createResultId();
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionModifiedEvent(
+ SMART_START,
+ SMART_END,
+ new TextSelection.Builder(SMART_START, SMART_END)
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextLinkifyEvent textLinkifyEvent =
+ (TextClassifierEvent.TextLinkifyEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textLinkifyEvent.getEventContext());
+ assertThat(textLinkifyEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
+ assertThat(textLinkifyEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textLinkifyEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textLinkifyEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ private static TextClassificationContext createEventContext() {
+ return new TextClassificationContext.Builder(PKG_NAME, TextClassifier.WIDGET_TYPE_EDITTEXT)
+ .build();
+ }
+
+ private static void assertEventContext(TextClassificationContext eventContext) {
+ assertThat(eventContext.getPackageName()).isEqualTo(PKG_NAME);
+ assertThat(eventContext.getWidgetType()).isEqualTo(WIDGET_TYPE);
+ }
+
+ private static String createResultId() {
+ return ResultIdUtils.createId(
+ /*hash=*/ 12345,
+ ImmutableList.of(
+ Optional.of(new ModelInfo(/* version= */ 702, ImmutableList.of(Locale.ENGLISH)))));
+ }
+
+ private static class TestTextClassifier implements TextClassifier {
+ private final Deque<SelectionEvent> selectionEvents = new ArrayDeque<>();
+
+ @Override
+ public void onSelectionEvent(SelectionEvent event) {
+ selectionEvents.push(event);
+ }
+
+ SelectionEvent popLastSelectionEvent() {
+ return selectionEvents.pop();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
new file mode 100644
index 0000000..f2b8223
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Instrumentation;
+import android.app.UiAutomation;
+import android.os.ParcelFileDescriptor;
+import android.util.Log;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.android.internal.os.StatsdConfigProto.AtomMatcher;
+import com.android.internal.os.StatsdConfigProto.EventMetric;
+import com.android.internal.os.StatsdConfigProto.SimpleAtomMatcher;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto.Atom;
+import com.android.os.StatsLog.ConfigMetricsReport;
+import com.android.os.StatsLog.ConfigMetricsReportList;
+import com.android.os.StatsLog.EventMetricData;
+import com.android.os.StatsLog.StatsLogReport;
+import com.google.common.collect.ImmutableList;
+import com.google.common.io.ByteStreams;
+import java.io.ByteArrayInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.lang.reflect.Method;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** Util functions to make statsd testing easier by using adb shell cmd stats commands. */
+public class StatsdTestUtils {
+ private static final String TAG = "StatsdTestUtils";
+ private static final long SHORT_WAIT_MS = 1000;
+
+ private StatsdTestUtils() {}
+
+ /** Push a config which specifies what loggings we are interested in. */
+ public static void pushConfig(StatsdConfig config) throws Exception {
+ String command = String.format("cmd stats config update %s", config.getId());
+ Log.v(TAG, "pushConfig: " + config);
+ String output = new String(runShellCommand(command, config.toByteArray()));
+ assertThat(output).isEmpty();
+ }
+
+ /** Adds a atom matcher to capture logs with given atom tag. */
+ public static void addAtomMatcher(StatsdConfig.Builder builder, int atomTag) {
+ final String atomName = "Atom" + atomTag;
+ final String eventName = "Event" + atomTag;
+ SimpleAtomMatcher simpleAtomMatcher = SimpleAtomMatcher.newBuilder().setAtomId(atomTag).build();
+ builder.addAtomMatcher(
+ AtomMatcher.newBuilder()
+ .setId(atomName.hashCode())
+ .setSimpleAtomMatcher(simpleAtomMatcher));
+ builder.addEventMetric(
+ EventMetric.newBuilder().setId(eventName.hashCode()).setWhat(atomName.hashCode()));
+ }
+
+ /**
+ * Extracts logged atoms from the report, sorted by logging time, and deletes the saved report.
+ */
+ public static ImmutableList<Atom> getLoggedAtoms(long configId) throws Exception {
+ // There is no callback to notify us the log is collected. So we do a short wait here.
+ Thread.sleep(SHORT_WAIT_MS);
+
+ ConfigMetricsReportList reportList = getAndRemoveReportList(configId);
+ assertThat(reportList.getReportsCount()).isEqualTo(1);
+ ConfigMetricsReport report = reportList.getReports(0);
+ List<StatsLogReport> metricsList = report.getMetricsList();
+ return ImmutableList.copyOf(
+ metricsList.stream()
+ .flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream())
+ .sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos))
+ .map(EventMetricData::getAtom)
+ .collect(Collectors.toList()));
+ }
+
+ /** Removes the pushed config file and existing reports. */
+ public static void cleanup(long configId) throws Exception {
+ runShellCommand(String.format("cmd stats config remove %d", configId), /* input= */ null);
+ // Remove existing reports.
+ getAndRemoveReportList(configId);
+ }
+
+ /**
+ * Runs an adb shell command with the provided input and returns the command line output.
+ *
+ * @param cmd the shell command
+ * @param input the content that will be piped to the command stdin.
+ * @return the command output
+ */
+ private static byte[] runShellCommand(String cmd, @Nullable byte[] input) throws Exception {
+ Log.v(TAG, "run shell command: " + cmd);
+ Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation();
+ UiAutomation uiAutomation = instrumentation.getUiAutomation();
+ Method method =
+ uiAutomation.getClass().getDeclaredMethod("executeShellCommandRw", String.class);
+ ParcelFileDescriptor[] pipes = (ParcelFileDescriptor[]) method.invoke(uiAutomation, cmd);
+ // Write to the input pipe.
+ try (FileOutputStream fos = new ParcelFileDescriptor.AutoCloseOutputStream(pipes[1])) {
+ if (input != null) {
+ fos.write(input);
+ }
+ }
+ // Read from the output pipe.
+ try (FileInputStream inputStream = new ParcelFileDescriptor.AutoCloseInputStream(pipes[0])) {
+ return ByteStreams.toByteArray(inputStream);
+ }
+ }
+
+ /** Gets the statsd report. Note that this also deletes that report from statsd. */
+ private static ConfigMetricsReportList getAndRemoveReportList(long configId) throws Exception {
+ byte[] output =
+ runShellCommand(
+ String.format("cmd stats dump-report %d --include_current_bucket --proto", configId),
+ /*input=*/ null);
+ return ConfigMetricsReportList.parser().parseFrom(new ByteArrayInputStream(output));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverterTest.java
new file mode 100644
index 0000000..32a0591
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassificationSessionIdConverterTest.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.logging.TextClassificationSessionId;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassificationSessionIdConverterTest {
+
+ @Test
+ public void testTextSelectionEvent_minimal() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ textClassificationManager.setTextClassifier(TextClassifier.NO_OP);
+ TextClassifier textClassifier =
+ textClassificationManager.createTextClassificationSession(
+ new TextClassificationContext.Builder("com.pkg", TextClassifier.WIDGET_TYPE_TEXTVIEW)
+ .build());
+ SelectionEvent startedEvent =
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_LINK, /* start= */ 10);
+
+ textClassifier.onSelectionEvent(startedEvent);
+ android.view.textclassifier.TextClassificationSessionId platformSessionId =
+ startedEvent.getSessionId();
+ TextClassificationSessionId textClassificationSessionId =
+ TextClassificationSessionIdConverter.fromPlatform(platformSessionId);
+
+ assertThat(textClassificationSessionId).isNotNull();
+ assertThat(textClassificationSessionId.getValue()).isEqualTo(platformSessionId.getValue());
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventConverterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventConverterTest.java
new file mode 100644
index 0000000..87bb2ad
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventConverterTest.java
@@ -0,0 +1,225 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.icu.util.ULocale;
+import android.os.Bundle;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassifier;
+import android.widget.TextView;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.common.logging.TextClassifierEvent;
+import com.android.textclassifier.common.logging.TextClassifierEvent.TextSelectionEvent;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierEventConverterTest {
+ private static final float TOLERANCE = 0.000001f;
+
+ @Test
+ public void testTextSelectionEvent_minimal() {
+ final android.view.textclassifier.TextClassifierEvent.TextSelectionEvent event =
+ new android.view.textclassifier.TextClassifierEvent.TextSelectionEvent.Builder(
+ android.view.textclassifier.TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ TextSelectionEvent result =
+ (TextSelectionEvent) TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_SELECTION);
+ assertMinimumCommonFields(result);
+ assertThat(result.getRelativeWordStartIndex()).isEqualTo(0);
+ assertThat(result.getRelativeWordEndIndex()).isEqualTo(0);
+ assertThat(result.getRelativeSuggestedWordStartIndex()).isEqualTo(0);
+ assertThat(result.getRelativeSuggestedWordEndIndex()).isEqualTo(0);
+ }
+
+ @Test
+ public void testTextSelectionEvent_full() {
+ final android.view.textclassifier.TextClassifierEvent.TextSelectionEvent.Builder builder =
+ new android.view.textclassifier.TextClassifierEvent.TextSelectionEvent.Builder(
+ android.view.textclassifier.TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ android.view.textclassifier.TextClassifierEvent.TextSelectionEvent event =
+ builder
+ .setRelativeWordStartIndex(1)
+ .setRelativeWordEndIndex(2)
+ .setRelativeSuggestedWordStartIndex(-1)
+ .setRelativeSuggestedWordEndIndex(3)
+ .build();
+
+ TextClassifierEvent.TextSelectionEvent result =
+ (TextClassifierEvent.TextSelectionEvent) TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_SELECTION);
+ assertFullCommonFields(result);
+ assertThat(result.getRelativeWordStartIndex()).isEqualTo(1);
+ assertThat(result.getRelativeWordEndIndex()).isEqualTo(2);
+ assertThat(result.getRelativeSuggestedWordStartIndex()).isEqualTo(-1);
+ assertThat(result.getRelativeSuggestedWordEndIndex()).isEqualTo(3);
+ }
+
+ @Test
+ public void testTextLinkifyEvent_minimal() {
+ android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent event =
+ new android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ TextClassifierEvent.TextLinkifyEvent result =
+ (TextClassifierEvent.TextLinkifyEvent) TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LINKIFY);
+ assertMinimumCommonFields(result);
+ }
+
+ @Test
+ public void testTextLinkifyEvent_full() {
+ android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent.Builder builder =
+ new android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent.Builder(
+ android.view.textclassifier.TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ android.view.textclassifier.TextClassifierEvent.TextLinkifyEvent event = builder.build();
+
+ TextClassifierEvent.TextLinkifyEvent result =
+ (TextClassifierEvent.TextLinkifyEvent) TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory()).isEqualTo(TextClassifierEvent.CATEGORY_LINKIFY);
+ assertFullCommonFields(result);
+ }
+
+ @Test
+ public void testConversationActionsEvent_minimal() {
+ android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent event =
+ new android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ TextClassifierEvent.ConversationActionsEvent result =
+ (TextClassifierEvent.ConversationActionsEvent)
+ TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS);
+ assertMinimumCommonFields(result);
+ }
+
+ @Test
+ public void testConversationActionsEvent_full() {
+ android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent.Builder builder =
+ new android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent.Builder(
+ android.view.textclassifier.TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ android.view.textclassifier.TextClassifierEvent.ConversationActionsEvent event =
+ builder.build();
+
+ TextClassifierEvent.ConversationActionsEvent result =
+ (TextClassifierEvent.ConversationActionsEvent)
+ TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_CONVERSATION_ACTIONS);
+ assertFullCommonFields(result);
+ }
+
+ @Test
+ public void testLanguageDetectionEventEvent_minimal() {
+ android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent event =
+ new android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent.Builder(
+ TextClassifierEvent.TYPE_ACTIONS_SHOWN)
+ .build();
+
+ TextClassifierEvent.LanguageDetectionEvent result =
+ (TextClassifierEvent.LanguageDetectionEvent)
+ TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION);
+ assertMinimumCommonFields(result);
+ }
+
+ @Test
+ public void testLanguageDetectionEvent_full() {
+ android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent.Builder builder =
+ new android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent.Builder(
+ android.view.textclassifier.TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ setFullCommonFields(builder);
+ android.view.textclassifier.TextClassifierEvent.LanguageDetectionEvent event = builder.build();
+
+ TextClassifierEvent.LanguageDetectionEvent result =
+ (TextClassifierEvent.LanguageDetectionEvent)
+ TextClassifierEventConverter.fromPlatform(event);
+
+ assertThat(result.getEventCategory())
+ .isEqualTo(TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION);
+ assertFullCommonFields(result);
+ }
+
+ private static void setFullCommonFields(
+ android.view.textclassifier.TextClassifierEvent.Builder<?> builder) {
+ Bundle extra = new Bundle();
+ extra.putString("key", "value");
+ builder
+ .setEventIndex(2)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS)
+ .setResultId("androidtc-en-v606-1234")
+ .setActionIndices(1, 2, 5)
+ .setExtras(extra)
+ .setEventContext(
+ new TextClassificationContext.Builder("pkg", TextClassifier.WIDGET_TYPE_TEXTVIEW)
+ .setWidgetVersion(TextView.class.getName())
+ .build())
+ .setScores(0.5f)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS, TextClassifier.TYPE_DATE)
+ .setLocale(ULocale.US);
+ }
+
+ private static void assertFullCommonFields(TextClassifierEvent event) {
+ assertThat(event.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ assertThat(event.getEventIndex()).isEqualTo(2);
+ assertThat(event.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS, TextClassifier.TYPE_DATE);
+ assertThat(event.getResultId()).isEqualTo("androidtc-en-v606-1234");
+ assertThat(event.getActionIndices()).asList().containsExactly(1, 2, 5);
+ assertThat(event.getExtras().get("key")).isEqualTo("value");
+ assertThat(event.getEventContext().getPackageName()).isEqualTo("pkg");
+ assertThat(event.getEventContext().getWidgetType())
+ .isEqualTo(TextClassifier.WIDGET_TYPE_TEXTVIEW);
+ assertThat(event.getEventContext().getWidgetVersion()).isEqualTo(TextView.class.getName());
+ assertThat(event.getScores()).hasLength(1);
+ assertThat(event.getScores()[0]).isWithin(TOLERANCE).of(0.5f);
+ assertThat(event.getLocale().toLanguageTag()).isEqualTo("en-US");
+ }
+
+ private static void assertMinimumCommonFields(TextClassifierEvent event) {
+ assertThat(event.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ assertThat(event.getEventIndex()).isEqualTo(0);
+ assertThat(event.getEntityTypes()).isEmpty();
+ assertThat(event.getResultId()).isNull();
+ assertThat(event.getActionIndices()).isEmpty();
+ assertThat(event.getExtras().size()).isEqualTo(0);
+ assertThat(event.getEventContext()).isNull();
+ assertThat(event.getEntityTypes()).isEmpty();
+ assertThat(event.getLocale()).isNull();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
new file mode 100644
index 0000000..719fc31
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierEventLoggerTest.java
@@ -0,0 +1,265 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.common.statsd;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.stats.textclassifier.EventType;
+import android.stats.textclassifier.WidgetType;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.android.internal.os.StatsdConfigProto.StatsdConfig;
+import com.android.os.AtomsProto;
+import com.android.os.AtomsProto.Atom;
+import com.android.textclassifier.common.logging.TextClassificationContext;
+import com.android.textclassifier.common.logging.TextClassificationSessionId;
+import com.android.textclassifier.common.logging.TextClassifierEvent;
+import com.google.common.collect.ImmutableList;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+@LargeTest
+public class TextClassifierEventLoggerTest {
+ private static final String PKG_NAME = "pkg.name";
+ private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_WEBVIEW;
+ private static final String MODEL_NAME = "model_name";
+ /** A statsd config ID, which is arbitrary. */
+ private static final long CONFIG_ID = 689777;
+
+ private TextClassifierEventLogger textClassifierEventLogger;
+
+ @Before
+ public void setup() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+
+ StatsdConfig.Builder builder =
+ StatsdConfig.newBuilder()
+ .setId(CONFIG_ID)
+ .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_SELECTION_EVENT_FIELD_NUMBER);
+ StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_LINKIFY_EVENT_FIELD_NUMBER);
+ StatsdTestUtils.addAtomMatcher(builder, Atom.CONVERSATION_ACTIONS_EVENT_FIELD_NUMBER);
+ StatsdTestUtils.addAtomMatcher(builder, Atom.LANGUAGE_DETECTION_EVENT_FIELD_NUMBER);
+ StatsdTestUtils.pushConfig(builder.build());
+
+ textClassifierEventLogger = new TextClassifierEventLogger();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ StatsdTestUtils.cleanup(CONFIG_ID);
+ }
+
+ @Test
+ public void writeEvent_textSelectionEvent() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ new TextClassifierEvent.TextSelectionEvent.Builder(
+ TextClassifierEvent.TYPE_SELECTION_STARTED)
+ .setEventContext(createTextClassificationContext())
+ .setResultId("androidtc|en_v705;und_v1|12345")
+ .setEventIndex(1)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS)
+ .setRelativeWordStartIndex(2)
+ .setRelativeWordEndIndex(3)
+ .setRelativeSuggestedWordStartIndex(1)
+ .setRelativeSuggestedWordEndIndex(4)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
+
+ AtomsProto.TextSelectionEvent event =
+ AtomsProto.TextSelectionEvent.newBuilder()
+ .setSessionId(sessionId.getValue())
+ .setEventType(EventType.SELECTION_STARTED)
+ .setModelName("en_v705")
+ .setWidgetType(WidgetType.WIDGET_TYPE_WEBVIEW)
+ .setEventIndex(1)
+ .setEntityType(TextClassifier.TYPE_ADDRESS)
+ .setRelativeWordStartIndex(2)
+ .setRelativeWordEndIndex(3)
+ .setRelativeSuggestedWordStartIndex(1)
+ .setRelativeSuggestedWordEndIndex(4)
+ .setPackageName(PKG_NAME)
+ .setLangidModelName("und_v1")
+ .build();
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getTextSelectionEvent()).isEqualTo(event);
+ }
+
+ @Test
+ public void writeEvent_textSelectionEvent_autoToSingle() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ new TextClassifierEvent.TextSelectionEvent.Builder(TextClassifierEvent.TYPE_AUTO_SELECTION)
+ .setResultId("androidtc|en_v705;und_v1|12345")
+ .setRelativeWordStartIndex(2)
+ .setRelativeWordEndIndex(3)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
+
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
+ .isEqualTo(EventType.SMART_SELECTION_SINGLE);
+ }
+
+ @Test
+ public void writeEvent_textSelectionEvent_autoToMulti() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ new TextClassifierEvent.TextSelectionEvent.Builder(TextClassifierEvent.TYPE_AUTO_SELECTION)
+ .setResultId("androidtc|en_v705;und_v1|12345")
+ .setRelativeWordStartIndex(2)
+ .setRelativeWordEndIndex(4)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
+
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
+ .isEqualTo(EventType.SMART_SELECTION_MULTI);
+ }
+
+ @Test
+ public void writeEvent_textSelectionEvent_keepAuto() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ new TextClassifierEvent.TextSelectionEvent.Builder(TextClassifierEvent.TYPE_AUTO_SELECTION)
+ .setResultId("aiai|en_v705;und_v1|12345")
+ .setRelativeWordStartIndex(2)
+ .setRelativeWordEndIndex(4)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, textSelectionEvent);
+
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getTextSelectionEvent().getEventType())
+ .isEqualTo(EventType.AUTO_SELECTION);
+ }
+
+ @Test
+ public void writeEvent_textLinkifyEvent() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.TextLinkifyEvent textLinkifyEvent =
+ new TextClassifierEvent.TextLinkifyEvent.Builder(TextClassifierEvent.TYPE_SELECTION_STARTED)
+ .setEventContext(createTextClassificationContext())
+ .setResultId("androidtc|en_v705;und_v1|12345")
+ .setEventIndex(1)
+ .setEntityTypes(TextClassifier.TYPE_ADDRESS)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, textLinkifyEvent);
+
+ AtomsProto.TextLinkifyEvent event =
+ AtomsProto.TextLinkifyEvent.newBuilder()
+ .setSessionId(sessionId.getValue())
+ .setEventType(EventType.SELECTION_STARTED)
+ .setModelName("en_v705")
+ .setWidgetType(WidgetType.WIDGET_TYPE_WEBVIEW)
+ .setEventIndex(1)
+ .setEntityType(TextClassifier.TYPE_ADDRESS)
+ .setNumLinks(0)
+ .setLinkedTextLength(0)
+ .setTextLength(0)
+ .setLatencyMillis(0)
+ .setPackageName(PKG_NAME)
+ .setLangidModelName("und_v1")
+ .build();
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getTextLinkifyEvent()).isEqualTo(event);
+ }
+
+ @Test
+ public void writeEvent_textConversationActionEvent() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.ConversationActionsEvent conversationActionsEvent =
+ new TextClassifierEvent.ConversationActionsEvent.Builder(
+ TextClassifierEvent.TYPE_SELECTION_STARTED)
+ .setEventContext(createTextClassificationContext())
+ .setResultId("android_tc|en_v1;zh_v2;und_v3|12345")
+ .setEventIndex(1)
+ .setEntityTypes("first", "second", "third", "fourth")
+ .setScores(0.5f)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, conversationActionsEvent);
+
+ AtomsProto.ConversationActionsEvent event =
+ AtomsProto.ConversationActionsEvent.newBuilder()
+ .setSessionId(sessionId.getValue())
+ .setEventType(EventType.SELECTION_STARTED)
+ .setModelName("en_v1")
+ .setWidgetType(WidgetType.WIDGET_TYPE_WEBVIEW)
+ .setFirstEntityType("first")
+ .setSecondEntityType("second")
+ .setThirdEntityType("third")
+ .setScore(0.5f)
+ .setPackageName(PKG_NAME)
+ .setAnnotatorModelName("zh_v2")
+ .setLangidModelName("und_v3")
+ .build();
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getConversationActionsEvent()).isEqualTo(event);
+ }
+
+ @Test
+ public void writeEvent_languageDetectionEvent() throws Exception {
+ TextClassificationSessionId sessionId = new TextClassificationSessionId();
+ TextClassifierEvent.LanguageDetectionEvent languageDetectionEvent =
+ new TextClassifierEvent.LanguageDetectionEvent.Builder(
+ TextClassifierEvent.TYPE_SELECTION_STARTED)
+ .setEventContext(createTextClassificationContext())
+ .setModelName(MODEL_NAME)
+ .setEventIndex(1)
+ .setEntityTypes("en")
+ .setScores(0.5f)
+ .setActionIndices(1)
+ .build();
+
+ textClassifierEventLogger.writeEvent(sessionId, languageDetectionEvent);
+ AtomsProto.LanguageDetectionEvent event =
+ AtomsProto.LanguageDetectionEvent.newBuilder()
+ .setSessionId(sessionId.getValue())
+ .setEventType(EventType.SELECTION_STARTED)
+ .setModelName(MODEL_NAME)
+ .setWidgetType(WidgetType.WIDGET_TYPE_WEBVIEW)
+ .setLanguageTag("en")
+ .setScore(0.5f)
+ .setActionIndex(1)
+ .setPackageName(PKG_NAME)
+ .build();
+ ImmutableList<Atom> atoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID);
+ assertThat(atoms).hasSize(1);
+ assertThat(atoms.get(0).getLanguageDetectionEvent()).isEqualTo(event);
+ }
+
+ private static TextClassificationContext createTextClassificationContext() {
+ return new TextClassificationContext.Builder(PKG_NAME, WIDGET_TYPE).build();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
new file mode 100644
index 0000000..6d01a64
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.subjects;
+
+import static com.google.common.truth.Truth.assertAbout;
+
+import com.android.textclassifier.Entity;
+import com.google.common.truth.FailureMetadata;
+import com.google.common.truth.MathUtil;
+import com.google.common.truth.Subject;
+import javax.annotation.Nullable;
+
+/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
+public final class EntitySubject extends Subject<EntitySubject, Entity> {
+
+ private static final float TOLERANCE = 0.0001f;
+
+ private final Entity entity;
+
+ public static EntitySubject assertThat(@Nullable Entity entity) {
+ return assertAbout(EntitySubject::new).that(entity);
+ }
+
+ private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
+ super(failureMetadata, entity);
+ this.entity = entity;
+ }
+
+ public void isMatchWithinTolerance(@Nullable Entity entity) {
+ if (!entity.getEntityType().equals(this.entity.getEntityType())) {
+ failWithActual("expected to have type", entity.getEntityType());
+ }
+ if (!MathUtil.equalWithinTolerance(entity.getScore(), this.entity.getScore(), TOLERANCE)) {
+ failWithActual("expected to have confidence score", entity.getScore());
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
new file mode 100644
index 0000000..f3ad833
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/FakeContextBuilder.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.ContextWrapper;
+import android.content.Intent;
+import android.content.pm.ActivityInfo;
+import android.content.pm.ApplicationInfo;
+import android.content.pm.PackageManager;
+import android.content.pm.ResolveInfo;
+import androidx.test.core.app.ApplicationProvider;
+import com.google.common.base.Preconditions;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import org.mockito.stubbing.Answer;
+
+/** A builder used to build a fake context for testing. */
+public final class FakeContextBuilder {
+
+ /** A component name that can be used for tests. */
+ public static final ComponentName DEFAULT_COMPONENT = new ComponentName("pkg", "cls");
+
+ private final PackageManager packageManager;
+ private final ContextWrapper context;
+ private final Map<String, ComponentName> components = new HashMap<>();
+ private final Map<String, CharSequence> appLabels = new HashMap<>();
+ @Nullable private ComponentName allIntentComponent;
+
+ public FakeContextBuilder() {
+ packageManager = mock(PackageManager.class);
+ when(packageManager.resolveActivity(any(Intent.class), anyInt())).thenReturn(null);
+ context =
+ new ContextWrapper(ApplicationProvider.getApplicationContext()) {
+ @Override
+ public PackageManager getPackageManager() {
+ return packageManager;
+ }
+ };
+ }
+
+ /**
+ * Sets the component name of an activity to handle the specified intent action.
+ *
+ * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
+ */
+ public FakeContextBuilder setIntentComponent(
+ String intentAction, @Nullable ComponentName component) {
+ Preconditions.checkNotNull(intentAction);
+ components.put(intentAction, component);
+ return this;
+ }
+
+ /** Sets the app label res for a specified package. */
+ public FakeContextBuilder setAppLabel(String packageName, @Nullable CharSequence appLabel) {
+ Preconditions.checkNotNull(packageName);
+ appLabels.put(packageName, appLabel);
+ return this;
+ }
+
+ /**
+ * Sets the component name of an activity to handle all intents.
+ *
+ * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
+ */
+ public FakeContextBuilder setAllIntentComponent(@Nullable ComponentName component) {
+ allIntentComponent = component;
+ return this;
+ }
+
+ /** Builds and returns a fake context. */
+ public Context build() {
+ when(packageManager.resolveActivity(any(Intent.class), anyInt()))
+ .thenAnswer(
+ (Answer<ResolveInfo>)
+ invocation -> {
+ final String action = ((Intent) invocation.getArgument(0)).getAction();
+ final ComponentName component =
+ components.containsKey(action) ? components.get(action) : allIntentComponent;
+ return getResolveInfo(component);
+ });
+ when(packageManager.getApplicationLabel(any(ApplicationInfo.class)))
+ .thenAnswer(
+ (Answer<CharSequence>)
+ invocation -> {
+ ApplicationInfo applicationInfo = invocation.getArgument(0);
+ return appLabels.get(applicationInfo.packageName);
+ });
+ return context;
+ }
+
+ /** Returns a component name with random package and class names. */
+ public static ComponentName newComponent() {
+ return new ComponentName(UUID.randomUUID().toString(), UUID.randomUUID().toString());
+ }
+
+ private static ResolveInfo getResolveInfo(ComponentName component) {
+ final ResolveInfo info;
+ if (component == null) {
+ info = null;
+ } else {
+ // NOTE: If something breaks in TextClassifier because we expect more fields to be set
+ // in here, just add them.
+ info = new ResolveInfo();
+ info.activityInfo = new ActivityInfo();
+ info.activityInfo.packageName = component.getPackageName();
+ info.activityInfo.name = component.getClassName();
+ info.activityInfo.exported = true;
+ info.activityInfo.applicationInfo = new ApplicationInfo();
+ info.activityInfo.applicationInfo.packageName = component.getPackageName();
+ info.activityInfo.applicationInfo.icon = 0;
+ }
+ return info;
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/utils/IndentingPrintWriterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/utils/IndentingPrintWriterTest.java
new file mode 100644
index 0000000..c60942e
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/utils/IndentingPrintWriterTest.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.utils;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class IndentingPrintWriterTest {
+
+ private static final String TEST_STRING = "sense";
+ private static final String TEST_KEY = "key";
+ private static final String TEST_VALUE = "value";
+
+ private StringWriter stringWriter;
+ private IndentingPrintWriter indentingPrintWriter;
+
+ @Before
+ public void setUp() {
+ stringWriter = new StringWriter();
+ indentingPrintWriter =
+ new IndentingPrintWriter(new PrintWriter(stringWriter, /* autoFlush= */ true));
+ }
+
+ @Test
+ public void println_printString_noIndent() throws Exception {
+ indentingPrintWriter.println(TEST_STRING);
+
+ assertThat(stringWriter.toString()).isEqualTo(TEST_STRING + "\n");
+ }
+
+ @Test
+ public void println_printString_withIndent() throws Exception {
+ indentingPrintWriter.increaseIndent().println(TEST_STRING);
+
+ assertThat(stringWriter.toString())
+ .isEqualTo(IndentingPrintWriter.SINGLE_INDENT + TEST_STRING + "\n");
+ }
+
+ @Test
+ public void decreaseIndent_noIndent() throws Exception {
+ indentingPrintWriter.decreaseIndent().println(TEST_STRING);
+
+ assertThat(stringWriter.toString()).isEqualTo(TEST_STRING + "\n");
+ }
+
+ @Test
+ public void decreaseIndent_withIndent() throws Exception {
+ indentingPrintWriter.increaseIndent().decreaseIndent().println(TEST_STRING);
+
+ assertThat(stringWriter.toString()).isEqualTo(TEST_STRING + "\n");
+ }
+
+ @Test
+ public void printPair_singlePair() throws Exception {
+ indentingPrintWriter.printPair(TEST_KEY, TEST_VALUE);
+
+ assertThat(stringWriter.toString()).isEqualTo(TEST_KEY + "=" + TEST_VALUE + "\n");
+ }
+}
diff --git a/jni/Android.bp b/jni/Android.bp
new file mode 100644
index 0000000..569368e
--- /dev/null
+++ b/jni/Android.bp
@@ -0,0 +1,19 @@
+// Copyright (C) 2019 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+java_library_static {
+ name: "libtextclassifier-java",
+ sdk_version: "core_current",
+ srcs: ["**/*.java"],
+}
\ No newline at end of file
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
new file mode 100644
index 0000000..3af04e8
--- /dev/null
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -0,0 +1,274 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
+ * actions and replies in a given conversation.
+ *
+ * @hide
+ */
+public final class ActionsSuggestionsModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long actionsModelPtr;
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * descriptor.
+ */
+ public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) {
+ actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
+ }
+ }
+
+ public ActionsSuggestionsModel(int fileDescriptor) {
+ this(fileDescriptor, /* serializedPreconditions= */ null);
+ }
+
+ /**
+ * Creates a new instance of Actions predictor, using the provided model image, given as a file
+ * path.
+ */
+ public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) {
+ actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
+ if (actionsModelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
+ }
+ }
+
+ public ActionsSuggestionsModel(String path) {
+ this(path, /* serializedPreconditions= */ null);
+ }
+
+ /** Suggests actions / replies to the given conversation. */
+ public ActionSuggestion[] suggestActions(
+ Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
+ return nativeSuggestActions(
+ actionsModelPtr,
+ conversation,
+ options,
+ (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
+ /* appContext= */ null,
+ /* deviceLocales= */ null,
+ /* generateAndroidIntents= */ false);
+ }
+
+ public ActionSuggestion[] suggestActionsWithIntents(
+ Conversation conversation,
+ ActionSuggestionOptions options,
+ Object appContext,
+ String deviceLocales,
+ AnnotatorModel annotator) {
+ return nativeSuggestActions(
+ actionsModelPtr,
+ conversation,
+ options,
+ (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
+ appContext,
+ deviceLocales,
+ /* generateAndroidIntents= */ true);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseActionsModel(actionsModelPtr);
+ actionsModelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(int fd) {
+ return nativeGetLocales(fd);
+ }
+
+ /** Returns the version of the model. */
+ public static int getVersion(int fd) {
+ return nativeGetVersion(fd);
+ }
+
+ /** Returns the name of the model. */
+ public static String getName(int fd) {
+ return nativeGetName(fd);
+ }
+
+ /** Action suggestion that contains a response text and the type of the response. */
+ public static final class ActionSuggestion {
+ private final String responseText;
+ private final String actionType;
+ private final float score;
+ private final NamedVariant[] entityData;
+ private final byte[] serializedEntityData;
+ private final RemoteActionTemplate[] remoteActionTemplates;
+
+ public ActionSuggestion(
+ String responseText,
+ String actionType,
+ float score,
+ NamedVariant[] entityData,
+ byte[] serializedEntityData,
+ RemoteActionTemplate[] remoteActionTemplates) {
+ this.responseText = responseText;
+ this.actionType = actionType;
+ this.score = score;
+ this.entityData = entityData;
+ this.serializedEntityData = serializedEntityData;
+ this.remoteActionTemplates = remoteActionTemplates;
+ }
+
+ public String getResponseText() {
+ return responseText;
+ }
+
+ public String getActionType() {
+ return actionType;
+ }
+
+ /** Confidence score between 0 and 1 */
+ public float getScore() {
+ return score;
+ }
+
+ public NamedVariant[] getEntityData() {
+ return entityData;
+ }
+
+ public byte[] getSerializedEntityData() {
+ return serializedEntityData;
+ }
+
+ public RemoteActionTemplate[] getRemoteActionTemplates() {
+ return remoteActionTemplates;
+ }
+ }
+
+ /** Represents a single message in the conversation. */
+ public static final class ConversationMessage {
+ private final int userId;
+ private final String text;
+ private final long referenceTimeMsUtc;
+ private final String referenceTimezone;
+ private final String detectedTextLanguageTags;
+
+ public ConversationMessage(
+ int userId,
+ String text,
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String detectedTextLanguageTags) {
+ this.userId = userId;
+ this.text = text;
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.referenceTimezone = referenceTimezone;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ }
+
+ /** The identifier of the sender */
+ public int getUserId() {
+ return userId;
+ }
+
+ public String getText() {
+ return text;
+ }
+
+ /**
+ * Return the reference time of the message, for example, it could be compose time or send time.
+ * {@code 0} means unspecified.
+ */
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
+ }
+
+ public String getReferenceTimezone() {
+ return referenceTimezone;
+ }
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+ }
+
+ /** Represents conversation between multiple users. */
+ public static final class Conversation {
+ public final ConversationMessage[] conversationMessages;
+
+ public Conversation(ConversationMessage[] conversationMessages) {
+ this.conversationMessages = conversationMessages;
+ }
+
+ public ConversationMessage[] getConversationMessages() {
+ return conversationMessages;
+ }
+ }
+
+ /** Represents options for the SuggestActions call. */
+ public static final class ActionSuggestionOptions {
+ public ActionSuggestionOptions() {}
+ }
+
+ private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
+
+ private static native long nativeNewActionsModelFromPath(
+ String path, byte[] preconditionsOverwrite);
+
+ private static native long nativeNewActionsModelWithOffset(
+ int fd, long offset, long size, byte[] preconditionsOverwrite);
+
+ private static native String nativeGetLocales(int fd);
+
+ private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
+
+ private static native int nativeGetVersion(int fd);
+
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
+
+ private static native String nativeGetName(int fd);
+
+ private static native String nativeGetNameWithOffset(int fd, long offset, long size);
+
+ private native ActionSuggestion[] nativeSuggestActions(
+ long context,
+ Conversation conversation,
+ ActionSuggestionOptions options,
+ long annotatorPtr,
+ Object appContext,
+ String deviceLocales,
+ boolean generateAndroidIntents);
+
+ private native void nativeCloseActionsModel(long ptr);
+}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
new file mode 100644
index 0000000..7658bf5
--- /dev/null
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -0,0 +1,824 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.Collection;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for Annotator native library interface. This library is used for detecting entities
+ * in text.
+ *
+ * @hide
+ */
+public final class AnnotatorModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ // Keep these in sync with the constants defined in AOSP.
+ static final String TYPE_UNKNOWN = "";
+ static final String TYPE_OTHER = "other";
+ static final String TYPE_EMAIL = "email";
+ static final String TYPE_PHONE = "phone";
+ static final String TYPE_ADDRESS = "address";
+ static final String TYPE_URL = "url";
+ static final String TYPE_DATE = "date";
+ static final String TYPE_DATE_TIME = "datetime";
+ static final String TYPE_FLIGHT_NUMBER = "flight";
+
+ public static final double INVALID_LATITUDE = 180;
+ public static final double INVALID_LONGITUDE = 360;
+ public static final float INVALID_LOCATION_ACCURACY_METERS = 0;
+
+ private long annotatorPtr;
+ // To tell GC to keep the LangID model alive at least as long as this object.
+ private LangIdModel langIdModel;
+
+ /** Enumeration for specifying the usecase of the annotations. */
+ public static enum AnnotationUsecase {
+ /** Results are optimized for Smart{Select,Share,Linkify}. */
+ SMART(0),
+
+ /**
+ * Results are optimized for using TextClassifier as an infrastructure that annotates as much as
+ * possible.
+ */
+ RAW(1);
+
+ private final int value;
+
+ AnnotationUsecase(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ };
+
+ /**
+ * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
+ * file descriptor.
+ */
+ public AnnotatorModel(int fileDescriptor) {
+ annotatorPtr = nativeNewAnnotator(fileDescriptor);
+ if (annotatorPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
+ }
+ }
+
+ /**
+ * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
+ * file path.
+ */
+ public AnnotatorModel(String path) {
+ annotatorPtr = nativeNewAnnotatorFromPath(path);
+ if (annotatorPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize TC from given file.");
+ }
+ }
+
+ /** Initializes the knowledge engine, passing the given serialized config to it. */
+ public void initializeKnowledgeEngine(byte[] serializedConfig) {
+ if (!nativeInitializeKnowledgeEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the KG engine");
+ }
+ }
+
+ /** Initializes the contact engine, passing the given serialized config to it. */
+ public void initializeContactEngine(byte[] serializedConfig) {
+ if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the contact engine");
+ }
+ }
+
+ /** Initializes the installed app engine, passing the given serialized config to it. */
+ public void initializeInstalledAppEngine(byte[] serializedConfig) {
+ if (!nativeInitializeInstalledAppEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the installed app engine");
+ }
+ }
+
+ /**
+ * Sets the LangId model to the annotator. Do not call close on the given LangIdModel object
+ * before this object is closed. Also, this object does not take the memory ownership of the given
+ * LangIdModel object.
+ */
+ public void setLangIdModel(LangIdModel langIdModel) {
+ this.langIdModel = langIdModel;
+ nativeSetLangId(annotatorPtr, langIdModel == null ? 0 : langIdModel.getNativePointer());
+ }
+
+ /**
+ * Given a string context and current selection, computes the selection suggestion.
+ *
+ * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is the
+ * character index where the selection begins, and selectionEnd is the index of one character past
+ * the selection span.
+ *
+ * <p>The return value is an array of two ints: suggested selection beginning and end, with the
+ * same semantics as the input selectionBeginning and selectionEnd.
+ */
+ public int[] suggestSelection(
+ String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
+ return nativeSuggestSelection(annotatorPtr, context, selectionBegin, selectionEnd, options);
+ }
+
+ /**
+ * Given a string context and current selection, classifies the type of the selected text.
+ *
+ * <p>The begin and end params are character indices in the context string.
+ *
+ * <p>Returns an array of ClassificationResult objects with the probability scores for different
+ * collections.
+ */
+ public ClassificationResult[] classifyText(
+ String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
+ return classifyText(
+ context,
+ selectionBegin,
+ selectionEnd,
+ options,
+ /*appContext=*/ null,
+ /*resourcesLocale=*/ null);
+ }
+
+ public ClassificationResult[] classifyText(
+ String context,
+ int selectionBegin,
+ int selectionEnd,
+ ClassificationOptions options,
+ Object appContext,
+ String resourcesLocale) {
+ return nativeClassifyText(
+ annotatorPtr, context, selectionBegin, selectionEnd, options, appContext, resourcesLocale);
+ }
+
+ /**
+ * Annotates given input text. The annotations should cover the whole input context except for
+ * whitespaces, and are sorted by their position in the context string.
+ */
+ public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
+ return nativeAnnotate(annotatorPtr, text, options);
+ }
+
+ /**
+ * Annotates multiple fragments of text at once. There will be one AnnotatedSpan array for each
+ * input fragment to annotate.
+ */
+ public AnnotatedSpan[][] annotateStructuredInput(
+ InputFragment[] fragments, AnnotationOptions options) {
+ return nativeAnnotateStructuredInput(annotatorPtr, fragments, options);
+ }
+
+ /**
+ * Looks up a knowledge entity by its identifier. Returns null if the entity is not found or on
+ * error.
+ */
+ public byte[] lookUpKnowledgeEntity(String id) {
+ return nativeLookUpKnowledgeEntity(annotatorPtr, id);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseAnnotator(annotatorPtr);
+ annotatorPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(int fd) {
+ return nativeGetLocales(fd);
+ }
+
+ /** Returns the version of the model. */
+ public static int getVersion(int fd) {
+ return nativeGetVersion(fd);
+ }
+
+ /** Returns the name of the model. */
+ public static String getName(int fd) {
+ return nativeGetName(fd);
+ }
+
+ /** Information about a parsed time/date. */
+ public static final class DatetimeResult {
+
+ public static final int GRANULARITY_YEAR = 0;
+ public static final int GRANULARITY_MONTH = 1;
+ public static final int GRANULARITY_WEEK = 2;
+ public static final int GRANULARITY_DAY = 3;
+ public static final int GRANULARITY_HOUR = 4;
+ public static final int GRANULARITY_MINUTE = 5;
+ public static final int GRANULARITY_SECOND = 6;
+
+ private final long timeMsUtc;
+ private final int granularity;
+
+ public DatetimeResult(long timeMsUtc, int granularity) {
+ this.timeMsUtc = timeMsUtc;
+ this.granularity = granularity;
+ }
+
+ public long getTimeMsUtc() {
+ return timeMsUtc;
+ }
+
+ public int getGranularity() {
+ return granularity;
+ }
+ }
+
+ /** Classification result for classifyText method. */
+ public static final class ClassificationResult {
+ private final String collection;
+ private final float score;
+ private final DatetimeResult datetimeResult;
+ private final byte[] serializedKnowledgeResult;
+ private final String contactName;
+ private final String contactGivenName;
+ private final String contactFamilyName;
+ private final String contactNickname;
+ private final String contactEmailAddress;
+ private final String contactPhoneNumber;
+ private final String contactId;
+ private final String appName;
+ private final String appPackageName;
+ private final NamedVariant[] entityData;
+ private final byte[] serializedEntityData;
+ private final RemoteActionTemplate[] remoteActionTemplates;
+ private final long durationMs;
+ private final long numericValue;
+ private final double numericDoubleValue;
+
+ public ClassificationResult(
+ String collection,
+ float score,
+ DatetimeResult datetimeResult,
+ byte[] serializedKnowledgeResult,
+ String contactName,
+ String contactGivenName,
+ String contactFamilyName,
+ String contactNickname,
+ String contactEmailAddress,
+ String contactPhoneNumber,
+ String contactId,
+ String appName,
+ String appPackageName,
+ NamedVariant[] entityData,
+ byte[] serializedEntityData,
+ RemoteActionTemplate[] remoteActionTemplates,
+ long durationMs,
+ long numericValue,
+ double numericDoubleValue) {
+ this.collection = collection;
+ this.score = score;
+ this.datetimeResult = datetimeResult;
+ this.serializedKnowledgeResult = serializedKnowledgeResult;
+ this.contactName = contactName;
+ this.contactGivenName = contactGivenName;
+ this.contactFamilyName = contactFamilyName;
+ this.contactNickname = contactNickname;
+ this.contactEmailAddress = contactEmailAddress;
+ this.contactPhoneNumber = contactPhoneNumber;
+ this.contactId = contactId;
+ this.appName = appName;
+ this.appPackageName = appPackageName;
+ this.entityData = entityData;
+ this.serializedEntityData = serializedEntityData;
+ this.remoteActionTemplates = remoteActionTemplates;
+ this.durationMs = durationMs;
+ this.numericValue = numericValue;
+ this.numericDoubleValue = numericDoubleValue;
+ }
+
+ /** Returns the classified entity type. */
+ public String getCollection() {
+ return collection;
+ }
+
+ /** Confidence score between 0 and 1. */
+ public float getScore() {
+ return score;
+ }
+
+ public DatetimeResult getDatetimeResult() {
+ return datetimeResult;
+ }
+
+ public byte[] getSerializedKnowledgeResult() {
+ return serializedKnowledgeResult;
+ }
+
+ public String getContactName() {
+ return contactName;
+ }
+
+ public String getContactGivenName() {
+ return contactGivenName;
+ }
+
+ public String getContactFamilyName() {
+ return contactFamilyName;
+ }
+
+ public String getContactNickname() {
+ return contactNickname;
+ }
+
+ public String getContactEmailAddress() {
+ return contactEmailAddress;
+ }
+
+ public String getContactPhoneNumber() {
+ return contactPhoneNumber;
+ }
+
+ public String getContactId() {
+ return contactId;
+ }
+
+ public String getAppName() {
+ return appName;
+ }
+
+ public String getAppPackageName() {
+ return appPackageName;
+ }
+
+ public NamedVariant[] getEntityData() {
+ return entityData;
+ }
+
+ public byte[] getSerializedEntityData() {
+ return serializedEntityData;
+ }
+
+ public RemoteActionTemplate[] getRemoteActionTemplates() {
+ return remoteActionTemplates;
+ }
+
+ public long getDurationMs() {
+ return durationMs;
+ }
+
+ public long getNumericValue() {
+ return numericValue;
+ }
+
+ public double getNumericDoubleValue() {
+ return numericDoubleValue;
+ }
+ }
+
+ /** Represents a result of Annotate call. */
+ public static final class AnnotatedSpan {
+ private final int startIndex;
+ private final int endIndex;
+ private final ClassificationResult[] classification;
+
+ AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) {
+ this.startIndex = startIndex;
+ this.endIndex = endIndex;
+ this.classification = classification;
+ }
+
+ public int getStartIndex() {
+ return startIndex;
+ }
+
+ public int getEndIndex() {
+ return endIndex;
+ }
+
+ public ClassificationResult[] getClassification() {
+ return classification;
+ }
+ }
+
+ /** Represents a fragment of text to the AnnotateStructuredInput call. */
+ public static final class InputFragment {
+
+ /** Encapsulates the data required to set the relative time of an InputFragment. */
+ public static final class DatetimeOptions {
+ private final String referenceTimezone;
+ private final Long referenceTimeMsUtc;
+
+ public DatetimeOptions(String referenceTimezone, Long referenceTimeMsUtc) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.referenceTimezone = referenceTimezone;
+ }
+ }
+
+ public InputFragment(String text) {
+ this.text = text;
+ this.datetimeOptionsNullable = null;
+ }
+
+ public InputFragment(String text, DatetimeOptions datetimeOptions) {
+ this.text = text;
+ this.datetimeOptionsNullable = datetimeOptions;
+ }
+
+ private final String text;
+ // The DatetimeOptions can't be Optional because the _api16 build of the TCLib SDK does not
+ // support java.util.Optional.
+ private final DatetimeOptions datetimeOptionsNullable;
+
+ public String getText() {
+ return text;
+ }
+
+ public boolean hasDatetimeOptions() {
+ return datetimeOptionsNullable != null;
+ }
+
+ public long getReferenceTimeMsUtc() {
+ return datetimeOptionsNullable.referenceTimeMsUtc;
+ }
+
+ public String getReferenceTimezone() {
+ return datetimeOptionsNullable.referenceTimezone;
+ }
+ }
+
+ /**
+ * Represents options for the suggestSelection call. TODO(b/63427420): Use location with Selection
+ * options.
+ */
+ public static final class SelectionOptions {
+ private final String locales;
+ private final String detectedTextLanguageTags;
+ private final int annotationUsecase;
+ private final double userLocationLat;
+ private final double userLocationLng;
+ private final float userLocationAccuracyMeters;
+
+ public SelectionOptions(
+ String locales, String detectedTextLanguageTags, int annotationUsecase) {
+ this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.annotationUsecase = annotationUsecase;
+ this.userLocationLat = INVALID_LATITUDE;
+ this.userLocationLng = INVALID_LONGITUDE;
+ this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ }
+
+ public SelectionOptions(String locales, String detectedTextLanguageTags) {
+ this(locales, detectedTextLanguageTags, AnnotationUsecase.SMART.getValue());
+ }
+
+ public String getLocales() {
+ return locales;
+ }
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
+
+ public double getUserLocationLat() {
+ return userLocationLat;
+ }
+
+ public double getUserLocationLng() {
+ return userLocationLng;
+ }
+
+ public float getUserLocationAccuracyMeters() {
+ return userLocationAccuracyMeters;
+ }
+ }
+
+ /**
+ * Represents options for the classifyText call. TODO(b/63427420): Use location with
+ * Classification options.
+ */
+ public static final class ClassificationOptions {
+ private final long referenceTimeMsUtc;
+ private final String referenceTimezone;
+ private final String locales;
+ private final String detectedTextLanguageTags;
+ private final int annotationUsecase;
+ private final double userLocationLat;
+ private final double userLocationLng;
+ private final float userLocationAccuracyMeters;
+ private final String userFamiliarLanguageTags;
+
+ public ClassificationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ int annotationUsecase,
+ String userFamiliarLanguageTags) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.referenceTimezone = referenceTimezone;
+ this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.annotationUsecase = annotationUsecase;
+ this.userLocationLat = INVALID_LATITUDE;
+ this.userLocationLng = INVALID_LONGITUDE;
+ this.userLocationAccuracyMeters = INVALID_LOCATION_ACCURACY_METERS;
+ this.userFamiliarLanguageTags = userFamiliarLanguageTags;
+ }
+
+ public ClassificationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ AnnotationUsecase.SMART.getValue(),
+ "");
+ }
+
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
+ }
+
+ public String getReferenceTimezone() {
+ return referenceTimezone;
+ }
+
+ public String getLocale() {
+ return locales;
+ }
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
+
+ public double getUserLocationLat() {
+ return userLocationLat;
+ }
+
+ public double getUserLocationLng() {
+ return userLocationLng;
+ }
+
+ public float getUserLocationAccuracyMeters() {
+ return userLocationAccuracyMeters;
+ }
+
+ public String getUserFamiliarLanguageTags() {
+ return userFamiliarLanguageTags;
+ }
+ }
+
+ /** Represents options for the annotate call. */
+ public static final class AnnotationOptions {
+ private final long referenceTimeMsUtc;
+ private final String referenceTimezone;
+ private final String locales;
+ private final String detectedTextLanguageTags;
+ private final String[] entityTypes;
+ private final int annotationUsecase;
+ private final boolean hasLocationPermission;
+ private final boolean hasPersonalizationPermission;
+ private final boolean isSerializedEntityDataEnabled;
+ private final double userLocationLat;
+ private final double userLocationLng;
+ private final float userLocationAccuracyMeters;
+
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ Collection<String> entityTypes,
+ int annotationUsecase,
+ boolean hasLocationPermission,
+ boolean hasPersonalizationPermission,
+ boolean isSerializedEntityDataEnabled,
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters) {
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
+ this.referenceTimezone = referenceTimezone;
+ this.locales = locales;
+ this.detectedTextLanguageTags = detectedTextLanguageTags;
+ this.entityTypes = entityTypes == null ? new String[0] : entityTypes.toArray(new String[0]);
+ this.annotationUsecase = annotationUsecase;
+ this.isSerializedEntityDataEnabled = isSerializedEntityDataEnabled;
+ this.userLocationLat = userLocationLat;
+ this.userLocationLng = userLocationLng;
+ this.userLocationAccuracyMeters = userLocationAccuracyMeters;
+ this.hasLocationPermission = hasLocationPermission;
+ this.hasPersonalizationPermission = hasPersonalizationPermission;
+ }
+
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ Collection<String> entityTypes,
+ int annotationUsecase,
+ boolean isSerializedEntityDataEnabled,
+ double userLocationLat,
+ double userLocationLng,
+ float userLocationAccuracyMeters) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ entityTypes,
+ annotationUsecase,
+ /* hasLocationPermission */ true,
+ /* hasPersonalizationPermission */ true,
+ isSerializedEntityDataEnabled,
+ userLocationLat,
+ userLocationLng,
+ userLocationAccuracyMeters);
+ }
+
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags,
+ Collection<String> entityTypes,
+ int annotationUsecase,
+ boolean isSerializedEntityDataEnabled) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ entityTypes,
+ annotationUsecase,
+ isSerializedEntityDataEnabled,
+ INVALID_LATITUDE,
+ INVALID_LONGITUDE,
+ INVALID_LOCATION_ACCURACY_METERS);
+ }
+
+ public AnnotationOptions(
+ long referenceTimeMsUtc,
+ String referenceTimezone,
+ String locales,
+ String detectedTextLanguageTags) {
+ this(
+ referenceTimeMsUtc,
+ referenceTimezone,
+ locales,
+ detectedTextLanguageTags,
+ null,
+ AnnotationUsecase.SMART.getValue(),
+ /* isSerializedEntityDataEnabled */ false);
+ }
+
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
+ }
+
+ public String getReferenceTimezone() {
+ return referenceTimezone;
+ }
+
+ public String getLocale() {
+ return locales;
+ }
+
+ /** Returns a comma separated list of BCP 47 language tags. */
+ public String getDetectedTextLanguageTags() {
+ return detectedTextLanguageTags;
+ }
+
+ public String[] getEntityTypes() {
+ return entityTypes;
+ }
+
+ public int getAnnotationUsecase() {
+ return annotationUsecase;
+ }
+
+ public boolean isSerializedEntityDataEnabled() {
+ return isSerializedEntityDataEnabled;
+ }
+
+ public double getUserLocationLat() {
+ return userLocationLat;
+ }
+
+ public double getUserLocationLng() {
+ return userLocationLng;
+ }
+
+ public float getUserLocationAccuracyMeters() {
+ return userLocationAccuracyMeters;
+ }
+
+ public boolean hasLocationPermission() {
+ return hasLocationPermission;
+ }
+
+ public boolean hasPersonalizationPermission() {
+ return hasPersonalizationPermission;
+ }
+ }
+
+ /**
+ * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long
+ * as the pointer is used.
+ */
+ long getNativeAnnotatorPointer() {
+ return nativeGetNativeModelPtr(annotatorPtr);
+ }
+
+ private static native long nativeNewAnnotator(int fd);
+
+ private static native long nativeNewAnnotatorFromPath(String path);
+
+ private static native long nativeNewAnnotatorWithOffset(int fd, long offset, long size);
+
+ private static native String nativeGetLocales(int fd);
+
+ private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
+
+ private static native int nativeGetVersion(int fd);
+
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
+
+ private static native String nativeGetName(int fd);
+
+ private static native String nativeGetNameWithOffset(int fd, long offset, long size);
+
+ private native long nativeGetNativeModelPtr(long context);
+
+ private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
+
+ private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig);
+
+ private native boolean nativeInitializeInstalledAppEngine(long context, byte[] serializedConfig);
+
+ private native boolean nativeInitializePersonNameEngine(
+ long context, int fd, long offset, long size);
+
+ private native void nativeSetLangId(long annotatorPtr, long langIdPtr);
+
+ private native int[] nativeSuggestSelection(
+ long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
+
+ private native ClassificationResult[] nativeClassifyText(
+ long context,
+ String text,
+ int selectionBegin,
+ int selectionEnd,
+ ClassificationOptions options,
+ Object appContext,
+ String resourceLocales);
+
+ private native AnnotatedSpan[] nativeAnnotate(
+ long context, String text, AnnotationOptions options);
+
+ private native AnnotatedSpan[][] nativeAnnotateStructuredInput(
+ long context, InputFragment[] inputFragments, AnnotationOptions options);
+
+ private native byte[] nativeLookUpKnowledgeEntity(long context, String id);
+
+ private native void nativeCloseAnnotator(long context);
+}
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
new file mode 100644
index 0000000..0015826
--- /dev/null
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Java wrapper for LangId native library interface. This class is used to detect languages in text.
+ *
+ * @hide
+ */
+public final class LangIdModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
+
+ static {
+ System.loadLibrary("textclassifier");
+ }
+
+ private long modelPtr;
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(int fd) {
+ modelPtr = nativeNew(fd);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
+ }
+ }
+
+ /** Creates a new instance of LangId predictor, using the provided model image. */
+ public LangIdModel(String modelPath) {
+ modelPtr = nativeNewFromPath(modelPath);
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from given file.");
+ }
+ }
+
+ /** Detects the languages for given text. */
+ public LanguageResult[] detectLanguages(String text) {
+ return nativeDetectLanguages(modelPtr, text);
+ }
+
+ /** Frees up the allocated memory. */
+ @Override
+ public void close() {
+ if (isClosed.compareAndSet(false, true)) {
+ nativeClose(modelPtr);
+ modelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Result for detectLanguages method. */
+ public static final class LanguageResult {
+ final String mLanguage;
+ final float mScore;
+
+ LanguageResult(String language, float score) {
+ mLanguage = language;
+ mScore = score;
+ }
+
+ public final String getLanguage() {
+ return mLanguage;
+ }
+
+ public final float getScore() {
+ return mScore;
+ }
+ }
+
+ /** Returns the version of the LangId model used. */
+ public int getVersion() {
+ return nativeGetVersion(modelPtr);
+ }
+
+ public float getLangIdThreshold() {
+ return nativeGetLangIdThreshold(modelPtr);
+ }
+
+ public static int getVersion(int fd) {
+ return nativeGetVersionFromFd(fd);
+ }
+
+ /** Retrieves the pointer to the native object. */
+ long getNativePointer() {
+ return modelPtr;
+ }
+
+ // Visible for testing.
+ float getLangIdNoiseThreshold() {
+ return nativeGetLangIdNoiseThreshold(modelPtr);
+ }
+
+ // Visible for testing.
+ int getMinTextSizeInBytes() {
+ return nativeGetMinTextSizeInBytes(modelPtr);
+ }
+
+ /**
+ * Returns the pointer to the native object. Note: Need to keep the LangIdModel alive as long as
+ * the pointer is used.
+ */
+ long getNativeLangIdPointer() {
+ return modelPtr;
+ }
+
+ private static native long nativeNew(int fd);
+
+ private static native long nativeNewFromPath(String path);
+
+ private native LanguageResult[] nativeDetectLanguages(long nativePtr, String text);
+
+ private native void nativeClose(long nativePtr);
+
+ private native int nativeGetVersion(long nativePtr);
+
+ private static native int nativeGetVersionFromFd(int fd);
+
+ private native float nativeGetLangIdThreshold(long nativePtr);
+
+ private native float nativeGetLangIdNoiseThreshold(long nativePtr);
+
+ private native int nativeGetMinTextSizeInBytes(long nativePtr);
+}
diff --git a/jni/com/google/android/textclassifier/NamedVariant.java b/jni/com/google/android/textclassifier/NamedVariant.java
new file mode 100644
index 0000000..5d3bd7b
--- /dev/null
+++ b/jni/com/google/android/textclassifier/NamedVariant.java
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+/**
+ * Represents a union of different basic types.
+ *
+ * @hide
+ */
+public final class NamedVariant {
+ public static final int TYPE_EMPTY = 0;
+ public static final int TYPE_INT = 1;
+ public static final int TYPE_LONG = 2;
+ public static final int TYPE_FLOAT = 3;
+ public static final int TYPE_DOUBLE = 4;
+ public static final int TYPE_BOOL = 5;
+ public static final int TYPE_STRING = 6;
+ public static final int TYPE_STRING_ARRAY = 7;
+ public static final int TYPE_FLOAT_ARRAY = 8;
+ public static final int TYPE_INT_ARRAY = 9;
+ public static final int TYPE_NAMED_VARIANT_ARRAY = 10;
+
+ private final String name;
+ private final int type;
+ private int intValue;
+ private long longValue;
+ private float floatValue;
+ private double doubleValue;
+ private boolean boolValue;
+ private String stringValue;
+ private String[] stringArrValue;
+ private float[] floatArrValue;
+ private int[] intArrValue;
+ private NamedVariant[] namedVariantArray;
+
+ public NamedVariant(String name, int value) {
+ this.name = name;
+ this.intValue = value;
+ this.type = TYPE_INT;
+ }
+
+ public NamedVariant(String name, long value) {
+ this.name = name;
+ this.longValue = value;
+ this.type = TYPE_LONG;
+ }
+
+ public NamedVariant(String name, float value) {
+ this.name = name;
+ this.floatValue = value;
+ this.type = TYPE_FLOAT;
+ }
+
+ public NamedVariant(String name, double value) {
+ this.name = name;
+ this.doubleValue = value;
+ this.type = TYPE_DOUBLE;
+ }
+
+ public NamedVariant(String name, boolean value) {
+ this.name = name;
+ this.boolValue = value;
+ this.type = TYPE_BOOL;
+ }
+
+ public NamedVariant(String name, String value) {
+ this.name = name;
+ this.stringValue = value;
+ this.type = TYPE_STRING;
+ }
+
+ public NamedVariant(String name, String[] value) {
+ this.name = name;
+ this.stringArrValue = value;
+ this.type = TYPE_STRING_ARRAY;
+ }
+
+ public NamedVariant(String name, float[] value) {
+ this.name = name;
+ this.floatArrValue = value;
+ this.type = TYPE_FLOAT_ARRAY;
+ }
+
+ public NamedVariant(String name, int[] value) {
+ this.name = name;
+ this.intArrValue = value;
+ this.type = TYPE_INT_ARRAY;
+ }
+
+ public NamedVariant(String name, NamedVariant[] value) {
+ this.name = name;
+ this.namedVariantArray = value;
+ this.type = TYPE_NAMED_VARIANT_ARRAY;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public int getType() {
+ return type;
+ }
+
+ public int getInt() {
+ assert (type == TYPE_INT);
+ return intValue;
+ }
+
+ public long getLong() {
+ assert (type == TYPE_LONG);
+ return longValue;
+ }
+
+ public float getFloat() {
+ assert (type == TYPE_FLOAT);
+ return floatValue;
+ }
+
+ public double getDouble() {
+ assert (type == TYPE_DOUBLE);
+ return doubleValue;
+ }
+
+ public boolean getBool() {
+ assert (type == TYPE_BOOL);
+ return boolValue;
+ }
+
+ public String getString() {
+ assert (type == TYPE_STRING);
+ return stringValue;
+ }
+
+ public String[] getStringArray() {
+ assert (type == TYPE_STRING_ARRAY);
+ return stringArrValue;
+ }
+
+ public float[] getFloatArray() {
+ assert (type == TYPE_FLOAT_ARRAY);
+ return floatArrValue;
+ }
+
+ public int[] getIntArray() {
+ assert (type == TYPE_INT_ARRAY);
+ return intArrValue;
+ }
+
+ public NamedVariant[] getNamedVariantArray() {
+ assert (type == TYPE_NAMED_VARIANT_ARRAY);
+ return namedVariantArray;
+ }
+}
diff --git a/java/com/google/android/textclassifier/RemoteActionTemplate.java b/jni/com/google/android/textclassifier/RemoteActionTemplate.java
similarity index 100%
rename from java/com/google/android/textclassifier/RemoteActionTemplate.java
rename to jni/com/google/android/textclassifier/RemoteActionTemplate.java
diff --git a/lang_id/common/embedding-feature-extractor.cc b/lang_id/common/embedding-feature-extractor.cc
deleted file mode 100644
index 6235f89..0000000
--- a/lang_id/common/embedding-feature-extractor.cc
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/embedding-feature-extractor.h"
-
-#include <stddef.h>
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/fel/feature-extractor.h"
-#include "lang_id/common/fel/feature-types.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/numbers.h"
-#include "lang_id/common/lite_strings/str-split.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-bool GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
- // Don't use version to determine how to get feature FML.
- const string features = context->Get(GetParamName("features"), "");
- const string embedding_names =
- context->Get(GetParamName("embedding_names"), "");
- const string embedding_dims =
- context->Get(GetParamName("embedding_dims"), "");
-
- // NOTE: unfortunately, LiteStrSplit returns a vector of StringPieces pointing
- // to the original string, in this case |features|, which is local to this
- // method. We need to explicitly create new strings.
- for (StringPiece sp : LiteStrSplit(features, ';')) {
- embedding_fml_.emplace_back(sp);
- }
-
- // Same here.
- for (StringPiece sp : LiteStrSplit(embedding_names, ';')) {
- embedding_names_.emplace_back(sp);
- }
-
- std::vector<StringPiece> dim_strs = LiteStrSplit(embedding_dims, ';');
- for (const auto &dim_str : dim_strs) {
- int dim = 0;
- if (!LiteAtoi(dim_str, &dim)) {
- SAFTM_LOG(ERROR) << "Unable to parse " << dim_str;
- return false;
- }
- embedding_dims_.push_back(dim);
- }
- return true;
-}
-
-bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
- return true;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/embedding-feature-extractor.h b/lang_id/common/embedding-feature-extractor.h
deleted file mode 100644
index f51b6e5..0000000
--- a/lang_id/common/embedding-feature-extractor.h
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "lang_id/common/fel/feature-extractor.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/fel/workspace.h"
-#include "lang_id/common/lite_base/attributes.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// An EmbeddingFeatureExtractor manages the extraction of features for
-// embedding-based models. It wraps a sequence of underlying classes of feature
-// extractors, along with associated predicate maps. Each class of feature
-// extractors is associated with a name, e.g., "words", "labels", "tags".
-//
-// The class is split between a generic abstract version,
-// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
-// signature of the ExtractFeatures method) and a typed version.
-//
-// The predicate maps must be initialized before use: they can be loaded using
-// Read() or updated via UpdateMapsForExample.
-class GenericEmbeddingFeatureExtractor {
- public:
- // Constructs this GenericEmbeddingFeatureExtractor.
- //
- // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
- // avoid name clashes. See GetParamName().
- explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
- : arg_prefix_(arg_prefix) {}
-
- virtual ~GenericEmbeddingFeatureExtractor() {}
-
- // Sets/inits up predicate maps and embedding space names that are common for
- // all embedding based feature extractors.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
- SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
-
- // Requests workspace for the underlying feature extractors. This is
- // implemented in the typed class.
- virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
-
- // Returns number of embedding spaces.
- int NumEmbeddings() const { return embedding_dims_.size(); }
-
- const std::vector<string> &embedding_fml() const { return embedding_fml_; }
-
- // Get parameter name by concatenating the prefix and the original name.
- string GetParamName(const string ¶m_name) const {
- string full_name = arg_prefix_;
- full_name.push_back('_');
- full_name.append(param_name);
- return full_name;
- }
-
- private:
- // Prefix for TaskContext parameters.
- const string arg_prefix_;
-
- // Embedding space names for parameter sharing.
- std::vector<string> embedding_names_;
-
- // FML strings for each feature extractor.
- std::vector<string> embedding_fml_;
-
- // Size of each of the embedding spaces (maximum predicate id).
- std::vector<int> embedding_sizes_;
-
- // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
- std::vector<int> embedding_dims_;
-};
-
-// Templated, object-specific implementation of the
-// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
-// ARGS...> class that has the appropriate FeatureTraits() to ensure that
-// locator type features work.
-//
-// Note: for backwards compatibility purposes, this always reads the FML spec
-// from "<prefix>_features".
-template <class EXTRACTOR, class OBJ, class... ARGS>
-class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
- public:
- // Constructs this EmbeddingFeatureExtractor.
- //
- // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
- // avoid name clashes. See GetParamName().
- explicit EmbeddingFeatureExtractor(const string &arg_prefix)
- : GenericEmbeddingFeatureExtractor(arg_prefix) {}
-
- // Sets up all predicate maps, feature extractors, and flags.
- SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
- if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
- return false;
- }
- feature_extractors_.resize(embedding_fml().size());
- for (int i = 0; i < embedding_fml().size(); ++i) {
- feature_extractors_[i].reset(new EXTRACTOR());
- if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
- if (!feature_extractors_[i]->Setup(context)) return false;
- }
- return true;
- }
-
- // Initializes resources needed by the feature extractors.
- SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
- if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
- for (auto &feature_extractor : feature_extractors_) {
- if (!feature_extractor->Init(context)) return false;
- }
- return true;
- }
-
- // Requests workspaces from the registry. Must be called after Init(), and
- // before Preprocess().
- void RequestWorkspaces(WorkspaceRegistry *registry) override {
- for (auto &feature_extractor : feature_extractors_) {
- feature_extractor->RequestWorkspaces(registry);
- }
- }
-
- // Must be called on the object one state for each sentence, before any
- // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
- void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
- for (auto &feature_extractor : feature_extractors_) {
- feature_extractor->Preprocess(workspaces, obj);
- }
- }
-
- // Extracts features using the extractors. Note that features must already
- // be initialized to the correct number of feature extractors. No predicate
- // mapping is applied.
- void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
- ARGS... args,
- std::vector<FeatureVector> *features) const {
- // DCHECK(features != nullptr);
- // DCHECK_EQ(features->size(), feature_extractors_.size());
- for (int i = 0; i < feature_extractors_.size(); ++i) {
- (*features)[i].clear();
- feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
- &(*features)[i]);
- }
- }
-
- private:
- // Templated feature extractor class.
- std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/embedding-feature-interface.h b/lang_id/common/embedding-feature-interface.h
deleted file mode 100644
index 87576c6..0000000
--- a/lang_id/common/embedding-feature-interface.h
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/embedding-feature-extractor.h"
-#include "lang_id/common/fel/feature-extractor.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/fel/workspace.h"
-#include "lang_id/common/lite_base/attributes.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-template <class EXTRACTOR, class OBJ, class... ARGS>
-class EmbeddingFeatureInterface {
- public:
- // Constructs this EmbeddingFeatureInterface.
- //
- // |arg_prefix| is a string prefix for the TaskContext parameters, passed to
- // |the underlying EmbeddingFeatureExtractor.
- explicit EmbeddingFeatureInterface(const string &arg_prefix)
- : feature_extractor_(arg_prefix) {}
-
- // Sets up feature extractors and flags for processing (inference).
- SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
- return feature_extractor_.Setup(context);
- }
-
- // Initializes feature extractor resources for processing (inference)
- // including requesting a workspace for caching extracted features.
- SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
- if (!feature_extractor_.Init(context)) return false;
- feature_extractor_.RequestWorkspaces(&workspace_registry_);
- return true;
- }
-
- // Preprocesses *obj using the internal workspace registry.
- void Preprocess(WorkspaceSet *workspace, OBJ *obj) const {
- workspace->Reset(workspace_registry_);
- feature_extractor_.Preprocess(workspace, obj);
- }
-
- // Extract features from |obj|. On return, FeatureVector features[i]
- // contains the features for the embedding space #i.
- //
- // This function uses the precomputed info from |workspace|. Usage pattern:
- //
- // EmbeddingFeatureInterface<...> feature_interface;
- // ...
- // OBJ obj;
- // WorkspaceSet workspace;
- // feature_interface.Preprocess(&workspace, &obj);
- //
- // // For the same obj, but with different args:
- // std::vector<FeatureVector> features;
- // feature_interface.GetFeatures(obj, args, workspace, &features);
- //
- // This pattern is useful (more efficient) if you can pre-compute some info
- // for the entire |obj|, which is reused by the feature extraction performed
- // for different args. If that is not the case, you can use the simpler
- // version GetFeaturesNoCaching below.
- void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace,
- std::vector<FeatureVector> *features) const {
- feature_extractor_.ExtractFeatures(workspace, obj, args..., features);
- }
-
- // Simpler version of GetFeatures(), for cases when there is no opportunity to
- // reuse computation between feature extractions for the same |obj|, but with
- // different |args|. Returns the extracted features. For more info, see the
- // doc for GetFeatures().
- std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj,
- ARGS... args) const {
- // Technically, we still use a workspace, because
- // feature_extractor_.ExtractFeatures requires one. But there is no real
- // caching here, as we start from scratch for each call to ExtractFeatures.
- WorkspaceSet workspace;
- Preprocess(&workspace, obj);
- std::vector<FeatureVector> features(NumEmbeddings());
- GetFeatures(*obj, args..., workspace, &features);
- return features;
- }
-
- // Returns number of embedding spaces.
- int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); }
-
- private:
- // Typed feature extractor for embeddings.
- EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_;
-
- // The registry of shared workspaces in the feature extractor.
- WorkspaceRegistry workspace_registry_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
diff --git a/lang_id/common/embedding-network-params.cc b/lang_id/common/embedding-network-params.cc
deleted file mode 100644
index be7c80e..0000000
--- a/lang_id/common/embedding-network-params.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/embedding-network-params.h"
-
-#include "lang_id/common/lite_base/logging.h"
-
-namespace libtextclassifier3 {
-
-QuantizationType ParseQuantizationType(const string &s) {
- if (s == "NONE") {
- return QuantizationType::NONE;
- }
- if (s == "UINT8") {
- return QuantizationType::UINT8;
- }
- if (s == "UINT4") {
- return QuantizationType::UINT4;
- }
- if (s == "FLOAT16") {
- return QuantizationType::FLOAT16;
- }
- SAFTM_LOG(FATAL) << "Unsupported quantization type: " << s;
-
- // Execution should never reach this point; just to keep the compiler happy.
- // TODO(salcianu): implement SAFTM_LOG(FATAL) in a way that doesn't require
- // this trick.
- return QuantizationType::NONE;
-}
-
-} // namespace nlp_saft
diff --git a/lang_id/common/embedding-network-params.h b/lang_id/common/embedding-network-params.h
deleted file mode 100755
index f43c653..0000000
--- a/lang_id/common/embedding-network-params.h
+++ /dev/null
@@ -1,316 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
-
-#include <string>
-
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/lite_base/float16.h"
-#include "lang_id/common/lite_base/logging.h"
-
-namespace libtextclassifier3 {
-
-enum class QuantizationType {
- NONE = 0,
-
- // Quantization to 8 bit unsigned ints.
- UINT8,
-
- // Quantization to 4 bit unsigned ints.
- UINT4,
-
- // Quantization to 16 bit floats, the type defined in
- // lang_id/common/float16.h
- FLOAT16,
-
- // NOTE: for backward compatibility, if you add a new value to this enum, add
- // it *at the end*, such that you do not change the integer values of the
- // existing enum values.
-};
-
-// Converts "UINT8" -> QuantizationType::UINT8, and so on.
-QuantizationType ParseQuantizationType(const string &s);
-
-// API for accessing parameters for a feed-forward neural network with
-// embeddings.
-//
-//
-// In fact, we provide two APIs: a high-level (and highly-recommented) API, with
-// methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
-// low-level API, using C-style names (e.g., softmax_num_cols()).
-//
-// Note: the API below is meant to allow the inference code (the class
-// libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
-// for transposing any matrix (which would require extra overhead on mobile
-// devices). Hence, as indicated by the comments for the API methods, some of
-// the matrices below are the transposes of the corresponding matrices from the
-// original proto.
-class EmbeddingNetworkParams {
- public:
- virtual ~EmbeddingNetworkParams() {}
-
- // Returns true if these params are valid. False otherwise (e.g., if the
- // underlying data is corrupted). If is_valid() returns false, clients should
- // not call any other method on that instance of EmbeddingNetworkParams. If
- // is_valid() returns true, then calls to the API methods below should not
- // crash *if they are called with index parameters in bounds*. E.g., if
- // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
- // should not crash.
- virtual bool is_valid() const = 0;
-
- // **** High-level API.
-
- // Simple representation of a matrix. This small struct that doesn't own any
- // resource intentionally supports copy / assign, to simplify our APIs.
- struct Matrix {
- // Number of rows.
- int rows = 0;
-
- // Number of columns.
- int cols = 0;
-
- QuantizationType quant_type = QuantizationType::NONE;
-
- // Pointer to matrix elements, in row-major order
- // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
- const void *elements = nullptr;
-
- // Quantization scales: one scale for each row.
- const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
- };
-
- // Returns i-th embedding matrix. Crashes on out of bounds indices.
- //
- // This is the transpose of the corresponding matrix from the original proto.
- Matrix GetEmbeddingMatrix(int i) const {
- CheckIndex(i, embeddings_size(), "embedding matrix");
- Matrix matrix;
- matrix.rows = embeddings_num_rows(i);
- matrix.cols = embeddings_num_cols(i);
- matrix.elements = embeddings_weights(i);
- matrix.quant_type = embeddings_quant_type(i);
- matrix.quant_scales = embeddings_quant_scales(i);
- return matrix;
- }
-
- // Returns weight matrix for i-th hidden layer. Crashes on out of bounds
- // indices.
- //
- // This is the transpose of the corresponding matrix from the original proto.
- Matrix GetHiddenLayerMatrix(int i) const {
- CheckIndex(i, hidden_size(), "hidden layer");
- Matrix matrix;
- matrix.rows = hidden_num_rows(i);
- matrix.cols = hidden_num_cols(i);
-
- // Quantization not supported here.
- matrix.quant_type = hidden_weights_quant_type(i);
- matrix.elements = hidden_weights(i);
- return matrix;
- }
-
- // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it
- // to be a row/column vector (i.e., num rows or num cols is 1). However, we
- // don't CHECK for that: we just provide access to underlying data. Crashes
- // on out of bounds indices.
- Matrix GetHiddenLayerBias(int i) const {
- CheckIndex(i, hidden_bias_size(), "hidden layer bias");
- Matrix matrix;
- matrix.rows = hidden_bias_num_rows(i);
- matrix.cols = hidden_bias_num_cols(i);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = hidden_bias_weights(i);
- return matrix;
- }
-
- // Returns true if a softmax layer exists.
- bool HasSoftmax() const {
- return softmax_size() == 1;
- }
-
- // Returns weight matrix for the softmax layer. Note: should be called only
- // if HasSoftmax() is true.
- //
- // This is the transpose of the corresponding matrix from the original proto.
- Matrix GetSoftmaxMatrix() const {
- SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
- Matrix matrix;
- matrix.rows = softmax_num_rows(0);
- matrix.cols = softmax_num_cols(0);
-
- // Quantization not supported here.
- matrix.quant_type = softmax_weights_quant_type(0);
- matrix.elements = softmax_weights(0);
- return matrix;
- }
-
- // Returns bias for the softmax layer. Technically a Matrix, but we expect it
- // to be a row/column vector (i.e., num rows or num cols is 1). However, we
- // don't CHECK for that: we just provide access to underlying data.
- Matrix GetSoftmaxBias() const {
- SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
- Matrix matrix;
- matrix.rows = softmax_bias_num_rows(0);
- matrix.cols = softmax_bias_num_cols(0);
-
- // Quantization not supported here.
- matrix.quant_type = QuantizationType::NONE;
- matrix.elements = softmax_bias_weights(0);
- return matrix;
- }
-
- // Updates the EmbeddingNetwork-related parameters from task_context. Returns
- // true on success, false on error.
- virtual bool UpdateTaskContextParameters(
- mobile::TaskContext *task_context) = 0;
-
- // **** Low-level API.
- //
- // * Most low-level API methods are documented by giving an equivalent
- // function call on proto, the original proto (of type
- // EmbeddingNetworkProto) which was used to generate the C++ code.
- //
- // * To simplify our generation code, optional proto fields of message type
- // are treated as repeated fields with 0 or 1 instances. As such, we have
- // *_size() methods for such optional fields: they return 0 or 1.
- //
- // * "transpose(M)" denotes the transpose of a matrix M.
-
- // ** Access methods for repeated MatrixParams embeddings.
- //
- // Returns proto.embeddings_size().
- virtual int embeddings_size() const = 0;
-
- // Returns number of rows of transpose(proto.embeddings(i)).
- virtual int embeddings_num_rows(int i) const = 0;
-
- // Returns number of columns of transpose(proto.embeddings(i)).
- virtual int embeddings_num_cols(int i) const = 0;
-
- // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
- // order. NOTE: for unquantized embeddings, this returns a pointer to float;
- // for quantized embeddings, this returns a pointer to uint8.
- virtual const void *embeddings_weights(int i) const = 0;
-
- virtual QuantizationType embeddings_quant_type(int i) const {
- return QuantizationType::NONE;
- }
-
- virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
- int i) const {
- return nullptr;
- }
-
- // ** Access methods for repeated MatrixParams hidden.
- //
- // Returns embedding_network_proto.hidden_size().
- virtual int hidden_size() const = 0;
-
- // Returns embedding_network_proto.hidden(i).rows().
- virtual int hidden_num_rows(int i) const = 0;
-
- // Returns embedding_network_proto.hidden(i).rows().
- virtual int hidden_num_cols(int i) const = 0;
-
- // Returns quantization mode for the weights of the i-th hidden layer.
- virtual QuantizationType hidden_weights_quant_type(int i) const {
- return QuantizationType::NONE;
- }
-
- // Returns pointer to beginning of array of floats with all values from
- // embedding_network_proto.hidden(i).
- virtual const void *hidden_weights(int i) const = 0;
-
- // ** Access methods for repeated MatrixParams hidden_bias.
- //
- // Returns proto.hidden_bias_size().
- virtual int hidden_bias_size() const = 0;
-
- // Returns number of rows of proto.hidden_bias(i).
- virtual int hidden_bias_num_rows(int i) const = 0;
-
- // Returns number of columns of proto.hidden_bias(i).
- virtual int hidden_bias_num_cols(int i) const = 0;
-
- // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
- virtual const void *hidden_bias_weights(int i) const = 0;
-
- // ** Access methods for optional MatrixParams softmax.
- //
- // Returns 1 if proto has optional field softmax, 0 otherwise.
- virtual int softmax_size() const = 0;
-
- // Returns number of rows of transpose(proto.softmax()).
- virtual int softmax_num_rows(int i) const = 0;
-
- // Returns number of columns of transpose(proto.softmax()).
- virtual int softmax_num_cols(int i) const = 0;
-
- // Returns quantization mode for the softmax weights.
- virtual QuantizationType softmax_weights_quant_type(int i) const {
- return QuantizationType::NONE;
- }
-
- // Returns pointer to elements of transpose(proto.softmax()), in row-major
- // order.
- virtual const void *softmax_weights(int i) const = 0;
-
- // ** Access methods for optional MatrixParams softmax_bias.
- //
- // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
- virtual int softmax_bias_size() const = 0;
-
- // Returns number of rows of proto.softmax_bias().
- virtual int softmax_bias_num_rows(int i) const = 0;
-
- // Returns number of columns of proto.softmax_bias().
- virtual int softmax_bias_num_cols(int i) const = 0;
-
- // Returns pointer to elements of proto.softmax_bias(), in row-major order.
- virtual const void *softmax_bias_weights(int i) const = 0;
-
- // ** Access methods for repeated int32 embedding_num_features.
- //
- // Returns proto.embedding_num_features_size().
- virtual int embedding_num_features_size() const = 0;
-
- // Returns proto.embedding_num_features(i).
- virtual int embedding_num_features(int i) const = 0;
-
- // ** Access methods for is_precomputed
- //
- // Returns proto.has_is_precomputed().
- virtual bool has_is_precomputed() const = 0;
-
- // Returns proto.is_precomputed().
- virtual bool is_precomputed() const = 0;
-
- protected:
- void CheckIndex(int index, int size, const string &description) const {
- SAFTM_CHECK_GE(index, 0)
- << "Out-of-range index for " << description << ": " << index;
- SAFTM_CHECK_LT(index, size)
- << "Out-of-range index for " << description << ": " << index;
- }
-}; // class EmbeddingNetworkParams
-
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
diff --git a/lang_id/common/fel/feature-descriptors.cc b/lang_id/common/fel/feature-descriptors.cc
deleted file mode 100644
index bf03dd5..0000000
--- a/lang_id/common/fel/feature-descriptors.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/fel/feature-descriptors.h"
-
-#include "lang_id/common/lite_strings/str-cat.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-void ToFELFunction(const FeatureFunctionDescriptor &function, string *output) {
- LiteStrAppend(output, function.type());
- if (function.argument() != 0 || function.parameter_size() > 0) {
- LiteStrAppend(output, "(");
- bool first = true;
- if (function.argument() != 0) {
- LiteStrAppend(output, function.argument());
- first = false;
- }
- for (int i = 0; i < function.parameter_size(); ++i) {
- if (!first) LiteStrAppend(output, ",");
- LiteStrAppend(output, function.parameter(i).name(), "=\"",
- function.parameter(i).value(), "\"");
- first = false;
- }
- LiteStrAppend(output, ")");
- }
-}
-
-void ToFEL(const FeatureFunctionDescriptor &function, string *output) {
- ToFELFunction(function, output);
- if (function.feature_size() == 1) {
- LiteStrAppend(output, ".");
- ToFEL(function.feature(0), output);
- } else if (function.feature_size() > 1) {
- LiteStrAppend(output, " { ");
- for (int i = 0; i < function.feature_size(); ++i) {
- if (i > 0) LiteStrAppend(output, " ");
- ToFEL(function.feature(i), output);
- }
- LiteStrAppend(output, " } ");
- }
-}
-
-void ToFEL(const FeatureExtractorDescriptor &extractor, string *output) {
- for (int i = 0; i < extractor.feature_size(); ++i) {
- ToFEL(extractor.feature(i), output);
- LiteStrAppend(output, "\n");
- }
-}
-
-string FeatureFunctionDescriptor::DebugString() const {
- string str;
- ToFEL(*this, &str);
- return str;
-}
-
-string FeatureExtractorDescriptor::DebugString() const {
- string str;
- ToFEL(*this, &str);
- return str;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-descriptors.h b/lang_id/common/fel/feature-descriptors.h
deleted file mode 100644
index a9408c9..0000000
--- a/lang_id/common/fel/feature-descriptors.h
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_base/macros.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Named feature parameter.
-class Parameter {
- public:
- Parameter() {}
-
- void set_name(const string &name) { name_ = name; }
- const string &name() const { return name_; }
-
- void set_value(const string &value) { value_ = value; }
- const string &value() const { return value_; }
-
- private:
- string name_;
- string value_;
-};
-
-// Descriptor for a feature function. Used to store the results of parsing one
-// feature function.
-class FeatureFunctionDescriptor {
- public:
- FeatureFunctionDescriptor() {}
-
- // Accessors for the feature function type. The function type is the string
- // that the feature extractor code is registered under.
- void set_type(const string &type) { type_ = type; }
- const string &type() const { return type_; }
-
- // Accessors for the feature function name. The function name (if available)
- // is used for some log messages. Otherwise, a more precise, but also more
- // verbose name based on the feature specification is used.
- void set_name(const string &name) { name_ = name; }
- const string &name() const { return name_; }
-
- // Accessors for the default (name-less) parameter.
- void set_argument(int32 argument) { argument_ = argument; }
- bool has_argument() const {
- // If argument has not been specified, clients should treat it as 0. This
- // makes the test below correct, without having a separate has_argument_
- // bool field.
- return argument_ != 0;
- }
- int32 argument() const { return argument_; }
-
- // Accessors for the named parameters.
- Parameter *add_parameter() {
- parameters_.emplace_back();
- return &(parameters_.back());
- }
- int parameter_size() const { return parameters_.size(); }
- const Parameter ¶meter(int i) const {
- SAFTM_DCHECK((i >= 0) && (i < parameter_size()));
- return parameters_[i];
- }
-
- // Accessors for the sub (i.e., nested) features. Nested features: as in
- // offset(1).label.
- FeatureFunctionDescriptor *add_feature() {
- sub_features_.emplace_back(new FeatureFunctionDescriptor());
- return sub_features_.back().get();
- }
- int feature_size() const { return sub_features_.size(); }
- const FeatureFunctionDescriptor &feature(int i) const {
- SAFTM_DCHECK((i >= 0) && (i < feature_size()));
- return *(sub_features_[i].get());
- }
-
- // Returns human-readable representation of this FeatureFunctionDescriptor.
- string DebugString() const;
-
- private:
- // See comments for set_type().
- string type_;
-
- // See comments for set_name().
- string name_;
-
- // See comments for set_argument().
- int32 argument_ = 0;
-
- // See comemnts for add_parameter().
- std::vector<Parameter> parameters_;
-
- // See comments for add_feature().
- std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor);
-};
-
-// List of FeatureFunctionDescriptors. Used to store the result of parsing the
-// spec for several feature functions.
-class FeatureExtractorDescriptor {
- public:
- FeatureExtractorDescriptor() {}
-
- int feature_size() const { return features_.size(); }
-
- FeatureFunctionDescriptor *add_feature() {
- features_.emplace_back(new FeatureFunctionDescriptor());
- return features_.back().get();
- }
-
- const FeatureFunctionDescriptor &feature(int i) const {
- SAFTM_DCHECK((i >= 0) && (i < feature_size()));
- return *(features_[i].get());
- }
-
- // Returns human-readable representation of this FeatureExtractorDescriptor.
- string DebugString() const;
-
- private:
- std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor);
-};
-
-// Appends to |*output| the FEL representation of the top-level feature from
-// |function|, without diving into the nested features.
-void ToFELFunction(const FeatureFunctionDescriptor &function, string *output);
-
-// Appends to |*output| the FEL representation of |function|.
-void ToFEL(const FeatureFunctionDescriptor &function, string *output);
-
-// Appends to |*output| the FEL representation of |extractor|.
-void ToFEL(const FeatureExtractorDescriptor &extractor, string *output);
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
diff --git a/lang_id/common/fel/feature-extractor.cc b/lang_id/common/fel/feature-extractor.cc
deleted file mode 100644
index c256257..0000000
--- a/lang_id/common/fel/feature-extractor.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/fel/feature-extractor.h"
-
-#include "lang_id/common/fel/feature-types.h"
-#include "lang_id/common/fel/fel-parser.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/numbers.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-constexpr FeatureValue GenericFeatureFunction::kNone;
-
-GenericFeatureExtractor::GenericFeatureExtractor() {}
-
-GenericFeatureExtractor::~GenericFeatureExtractor() {}
-
-bool GenericFeatureExtractor::Parse(const string &source) {
- // Parse feature specification into descriptor.
- FELParser parser;
-
- if (!parser.Parse(source, mutable_descriptor())) {
- SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
- return false;
- }
-
- // Initialize feature extractor from descriptor.
- return InitializeFeatureFunctions();
-}
-
-bool GenericFeatureExtractor::InitializeFeatureTypes() {
- // Register all feature types.
- GetFeatureTypes(&feature_types_);
- for (size_t i = 0; i < feature_types_.size(); ++i) {
- FeatureType *ft = feature_types_[i];
- ft->set_base(i);
-
- // Check for feature space overflow.
- double domain_size = ft->GetDomainSize();
- if (domain_size < 0) {
- SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
- << ": " << domain_size;
- return false;
- }
- }
- return true;
-}
-
-string GenericFeatureFunction::GetParameter(const string &name,
- const string &default_value) const {
- // Find named parameter in feature descriptor.
- for (int i = 0; i < descriptor_->parameter_size(); ++i) {
- if (name == descriptor_->parameter(i).name()) {
- return descriptor_->parameter(i).value();
- }
- }
- return default_value;
-}
-
-GenericFeatureFunction::GenericFeatureFunction() {}
-
-GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
-
-int GenericFeatureFunction::GetIntParameter(const string &name,
- int default_value) const {
- string value_str = GetParameter(name, "");
- if (value_str.empty()) {
- // Parameter not specified, use default value for it.
- return default_value;
- }
- int value = 0;
- if (!LiteAtoi(value_str, &value)) {
- SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
- << "' as int for parameter " << name;
- return default_value;
- }
- return value;
-}
-
-bool GenericFeatureFunction::GetBoolParameter(const string &name,
- bool default_value) const {
- string value = GetParameter(name, "");
- if (value.empty()) return default_value;
- if (value == "true") return true;
- if (value == "false") return false;
- SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
- << name;
- return default_value;
-}
-
-void GenericFeatureFunction::GetFeatureTypes(
- std::vector<FeatureType *> *types) const {
- if (feature_type_ != nullptr) types->push_back(feature_type_);
-}
-
-FeatureType *GenericFeatureFunction::GetFeatureType() const {
- // If a single feature type has been registered return it.
- if (feature_type_ != nullptr) return feature_type_;
-
- // Get feature types for function.
- std::vector<FeatureType *> types;
- GetFeatureTypes(&types);
-
- // If there is exactly one feature type return this, else return null.
- if (types.size() == 1) return types[0];
- return nullptr;
-}
-
-string GenericFeatureFunction::name() const {
- string output;
- if (descriptor_->name().empty()) {
- if (!prefix_.empty()) {
- output.append(prefix_);
- output.append(".");
- }
- ToFEL(*descriptor_, &output);
- } else {
- output = descriptor_->name();
- }
- return output;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/fel/feature-extractor.h b/lang_id/common/fel/feature-extractor.h
deleted file mode 100644
index 8763852..0000000
--- a/lang_id/common/fel/feature-extractor.h
+++ /dev/null
@@ -1,651 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Generic feature extractor for extracting features from objects. The feature
-// extractor can be used for extracting features from any object. The feature
-// extractor and feature function classes are template classes that have to
-// be instantiated for extracting feature from a specific object type.
-//
-// A feature extractor consists of a hierarchy of feature functions. Each
-// feature function extracts one or more feature type and value pairs from the
-// object.
-//
-// The feature extractor has a modular design where new feature functions can be
-// registered as components. The feature extractor is initialized from a
-// descriptor represented by a protocol buffer. The feature extractor can also
-// be initialized from a text-based source specification of the feature
-// extractor. Feature specification parsers can be added as components. By
-// default the feature extractor can be read from an ASCII protocol buffer or in
-// a simple feature modeling language (fml).
-
-// A feature function is invoked with a focus. Nested feature function can be
-// invoked with another focus determined by the parent feature function.
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
-
-#include <stddef.h>
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/fel/feature-descriptors.h"
-#include "lang_id/common/fel/feature-types.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/fel/workspace.h"
-#include "lang_id/common/lite_base/attributes.h"
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_base/macros.h"
-#include "lang_id/common/registry.h"
-#include "lang_id/common/stl-util.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// TODO(djweiss) Clean this up as well.
-// Use the same type for feature values as is used for predicated.
-typedef int64 Predicate;
-typedef Predicate FeatureValue;
-
-// A union used to represent discrete and continuous feature values.
-union FloatFeatureValue {
- public:
- explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
- FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
- FeatureValue discrete_value;
- struct {
- uint32 id;
- float weight;
- };
-};
-
-// A feature vector contains feature type and value pairs.
-class FeatureVector {
- public:
- FeatureVector() {}
-
- // Adds feature type and value pair to feature vector.
- void add(FeatureType *type, FeatureValue value) {
- features_.emplace_back(type, value);
- }
-
- // Removes all elements from the feature vector.
- void clear() { features_.clear(); }
-
- // Returns the number of elements in the feature vector.
- int size() const { return features_.size(); }
-
- // Reserves space in the underlying feature vector.
- void reserve(int n) { features_.reserve(n); }
-
- // Returns feature type for an element in the feature vector.
- FeatureType *type(int index) const { return features_[index].type; }
-
- // Returns feature value for an element in the feature vector.
- FeatureValue value(int index) const { return features_[index].value; }
-
- private:
- // Structure for holding feature type and value pairs.
- struct Element {
- Element() : type(nullptr), value(-1) {}
- Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
-
- FeatureType *type;
- FeatureValue value;
- };
-
- // Array for storing feature vector elements.
- std::vector<Element> features_;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
-};
-
-// The generic feature extractor is the type-independent part of a feature
-// extractor. This holds the descriptor for the feature extractor and the
-// collection of feature types used in the feature extractor. The feature
-// types are not available until FeatureExtractor<>::Init() has been called.
-class GenericFeatureExtractor {
- public:
- GenericFeatureExtractor();
- virtual ~GenericFeatureExtractor();
-
- // Initializes the feature extractor from the FEL specification |source|.
- //
- // Returns true on success, false otherwise (e.g., FEL syntax error).
- SAFTM_MUST_USE_RESULT bool Parse(const string &source);
-
- // Returns the feature extractor descriptor.
- const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
- FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
-
- // Returns the number of feature types in the feature extractor. Invalid
- // before Init() has been called.
- int feature_types() const { return feature_types_.size(); }
-
- protected:
- // Initializes the feature types used by the extractor. Called from
- // FeatureExtractor<>::Init().
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes();
-
- private:
- // Initializes the top-level feature functions.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0;
-
- // Returns all feature types used by the extractor. The feature types are
- // added to the result array.
- virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
-
- // Descriptor for the feature extractor. This is a protocol buffer that
- // contains all the information about the feature extractor. The feature
- // functions are initialized from the information in the descriptor.
- FeatureExtractorDescriptor descriptor_;
-
- // All feature types used by the feature extractor. The collection of all the
- // feature types describes the feature space of the feature set produced by
- // the feature extractor. Not owned.
- std::vector<FeatureType *> feature_types_;
-};
-
-// The generic feature function is the type-independent part of a feature
-// function. Each feature function is associated with the descriptor that it is
-// instantiated from. The feature types associated with this feature function
-// will be established by the time FeatureExtractor<>::Init() completes.
-class GenericFeatureFunction {
- public:
- // A feature value that represents the absence of a value.
- static constexpr FeatureValue kNone = -1;
-
- GenericFeatureFunction();
- virtual ~GenericFeatureFunction();
-
- // Sets up the feature function. NB: FeatureTypes of nested functions are not
- // guaranteed to be available until Init().
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) {
- return true;
- }
-
- // Initializes the feature function. NB: The FeatureType of this function must
- // be established when this method completes.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; }
-
- // Requests workspaces from a registry to obtain indices into a WorkspaceSet
- // for any Workspace objects used by this feature function. NB: This will be
- // called after Init(), so it can depend on resources and arguments.
- virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
-
- // Appends the feature types produced by the feature function to types. The
- // default implementation appends feature_type(), if non-null. Invalid
- // before Init() has been called.
- virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
-
- // Returns the feature type for feature produced by this feature function. If
- // the feature function produces features of different types this returns
- // null. Invalid before Init() has been called.
- virtual FeatureType *GetFeatureType() const;
-
- // Returns value of parameter |name| from the feature function descriptor.
- // If the parameter is not present, returns the indicated |default_value|.
- string GetParameter(const string &name, const string &default_value) const;
-
- // Returns value of int parameter |name| from feature function descriptor.
- // If the parameter is not present, or its value can't be parsed as an int,
- // returns |default_value|.
- int GetIntParameter(const string &name, int default_value) const;
-
- // Returns value of bool parameter |name| from feature function descriptor.
- // If the parameter is not present, or its value is not "true" or "false",
- // returns |default_value|. NOTE: this method is case sensitive, it doesn't
- // do any lower-casing.
- bool GetBoolParameter(const string &name, bool default_value) const;
-
- // Returns the FEL function description for the feature function, i.e. the
- // name and parameters without the nested features.
- string FunctionName() const {
- string output;
- ToFELFunction(*descriptor_, &output);
- return output;
- }
-
- // Returns the prefix for nested feature functions. This is the prefix of this
- // feature function concatenated with the feature function name.
- string SubPrefix() const {
- return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
- }
-
- // Returns/sets the feature extractor this function belongs to.
- const GenericFeatureExtractor *extractor() const { return extractor_; }
- void set_extractor(const GenericFeatureExtractor *extractor) {
- extractor_ = extractor;
- }
-
- // Returns/sets the feature function descriptor.
- const FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
- void set_descriptor(const FeatureFunctionDescriptor *descriptor) {
- descriptor_ = descriptor;
- }
-
- // Returns a descriptive name for the feature function. The name is taken from
- // the descriptor for the feature function. If the name is empty or the
- // feature function is a variable the name is the FEL representation of the
- // feature, including the prefix.
- string name() const;
-
- // Returns the argument from the feature function descriptor. It defaults to
- // 0 if the argument has not been specified.
- int argument() const {
- return descriptor_->has_argument() ? descriptor_->argument() : 0;
- }
-
- // Returns/sets/clears function name prefix.
- const string &prefix() const { return prefix_; }
- void set_prefix(const string &prefix) { prefix_ = prefix; }
-
- protected:
- // Returns the feature type for single-type feature functions.
- FeatureType *feature_type() const { return feature_type_; }
-
- // Sets the feature type for single-type feature functions. This takes
- // ownership of feature_type. Can only be called once.
- void set_feature_type(FeatureType *feature_type) {
- SAFTM_CHECK_EQ(feature_type_, nullptr);
- feature_type_ = feature_type;
- }
-
- private:
- // Feature extractor this feature function belongs to. Not owned. Set to a
- // pointer != nullptr as soon as this object is created by Instantiate().
- // Normal methods can safely assume this is != nullptr.
- const GenericFeatureExtractor *extractor_ = nullptr;
-
- // Descriptor for feature function. Not owned. Set to a pointer != nullptr
- // as soon as this object is created by Instantiate(). Normal methods can
- // safely assume this is != nullptr.
- const FeatureFunctionDescriptor *descriptor_ = nullptr;
-
- // Feature type for features produced by this feature function. If the
- // feature function produces features of multiple feature types this is null
- // and the feature function must return it's feature types in
- // GetFeatureTypes(). Owned.
- FeatureType *feature_type_ = nullptr;
-
- // Prefix used for sub-feature types of this function.
- string prefix_;
-};
-
-// Feature function that can extract features from an object. Templated on
-// two type arguments:
-//
-// OBJ: The "object" from which features are extracted; e.g., a sentence. This
-// should be a plain type, rather than a reference or pointer.
-//
-// ARGS: A set of 0 or more types that are used to "index" into some part of the
-// object that should be extracted, e.g. an int token index for a sentence
-// object. This should not be a reference type.
-template <class OBJ, class... ARGS>
-class FeatureFunction
- : public GenericFeatureFunction,
- public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
- public:
- using Self = FeatureFunction<OBJ, ARGS...>;
-
- // Preprocesses the object. This will be called prior to calling Evaluate()
- // or Compute() on that object.
- virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {}
-
- // Appends features computed from the object and focus to the result. The
- // default implementation delegates to Compute(), adding a single value if
- // available. Multi-valued feature functions must override this method.
- virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args, FeatureVector *result) const {
- FeatureValue value = Compute(workspaces, object, args...);
- if (value != kNone) result->add(feature_type(), value);
- }
-
- // Returns a feature value computed from the object and focus, or kNone if no
- // value is computed. Single-valued feature functions only need to override
- // this method.
- virtual FeatureValue Compute(const WorkspaceSet &workspaces,
- const OBJ &object, ARGS... args) const {
- return kNone;
- }
-
- // Instantiates a new feature function in a feature extractor from a feature
- // descriptor.
- //
- // Returns a pointer to the newly-created object if everything goes well.
- // Returns nullptr if the feature function could not be instantiated (e.g., if
- // the function with that name is not registered; this usually happens because
- // the relevant cc_library was not linked-in).
- static Self *Instantiate(const GenericFeatureExtractor *extractor,
- const FeatureFunctionDescriptor *fd,
- const string &prefix) {
- Self *f = Self::Create(fd->type());
- if (f != nullptr) {
- f->set_extractor(extractor);
- f->set_descriptor(fd);
- f->set_prefix(prefix);
- }
- return f;
- }
-
- private:
- // Special feature function class for resolving variable references. The type
- // of the feature function is used for resolving the variable reference. When
- // evaluated it will either get the feature value(s) from the variable portion
- // of the feature vector, if present, or otherwise it will call the referenced
- // feature extractor function directly to extract the feature(s).
- class Reference;
-};
-
-// Base class for features with nested feature functions. The nested functions
-// are of type NES, which may be different from the type of the parent function.
-// NB: NestedFeatureFunction will ensure that all initialization of nested
-// functions takes place during Setup() and Init() -- after the nested features
-// are initialized, the parent feature is initialized via SetupNested() and
-// InitNested(). Alternatively, a derived classes that overrides Setup() and
-// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
-//
-// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
-// Compute, since the nested functions may be of a different type.
-template <class NES, class OBJ, class... ARGS>
-class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
- public:
- using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
-
- // Clean up nested functions.
- ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
-
- // By default, just appends the nested feature types.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- SAFTM_CHECK(!this->nested().empty())
- << "Nested features require nested features to be defined.";
- for (auto *function : nested_) function->GetFeatureTypes(types);
- }
-
- // Sets up the nested features.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
- bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
- this->SubPrefix());
- if (!success) return false;
- for (auto *function : nested_) {
- if (!function->Setup(context)) return false;
- }
- if (!SetupNested(context)) return false;
- return true;
- }
-
- // Sets up this NestedFeatureFunction specifically.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) {
- return true;
- }
-
- // Initializes the nested features.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
- for (auto *function : nested_) {
- if (!function->Init(context)) return false;
- }
- if (!InitNested(context)) return false;
- return true;
- }
-
- // Initializes this NestedFeatureFunction specifically.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) {
- return true;
- }
-
- // Gets all the workspaces needed for the nested functions.
- void RequestWorkspaces(WorkspaceRegistry *registry) override {
- for (auto *function : nested_) function->RequestWorkspaces(registry);
- }
-
- // Returns the list of nested feature functions.
- const std::vector<NES *> &nested() const { return nested_; }
-
- // Instantiates nested feature functions for a feature function. Creates and
- // initializes one feature function for each sub-descriptor in the feature
- // descriptor.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT static bool CreateNested(
- const GenericFeatureExtractor *extractor,
- const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
- const string &prefix) {
- for (int i = 0; i < fd->feature_size(); ++i) {
- const FeatureFunctionDescriptor &sub = fd->feature(i);
- NES *f = NES::Instantiate(extractor, &sub, prefix);
- if (f == nullptr) return false;
- functions->push_back(f);
- }
- return true;
- }
-
- protected:
- // The nested feature functions, if any, in order of declaration in the
- // feature descriptor. Owned.
- std::vector<NES *> nested_;
-};
-
-// Base class for a nested feature function that takes nested features with the
-// same signature as these features, i.e. a meta feature. For this class, we can
-// provide preprocessing of the nested features.
-template <class OBJ, class... ARGS>
-class MetaFeatureFunction
- : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
- ARGS...> {
- public:
- // Preprocesses using the nested features.
- void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
- for (auto *function : this->nested_) {
- function->Preprocess(workspaces, object);
- }
- }
-};
-
-// Template for a special type of locator: The locator of type
-// FeatureFunction<OBJ, ARGS...> calls nested functions of type
-// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
-// responsible for translating by providing the following:
-//
-// // Gets the new additional focus.
-// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
-//
-// This is useful to e.g. add a token focus to a parser state based on some
-// desired property of that state.
-template <class DER, class OBJ, class IDX, class... ARGS>
-class FeatureAddFocusLocator
- : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
- ARGS...> {
- public:
- void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
- for (auto *function : this->nested_) {
- function->Preprocess(workspaces, object);
- }
- }
-
- void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
- FeatureVector *result) const override {
- IDX focus =
- static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
- for (auto *function : this->nested()) {
- function->Evaluate(workspaces, object, focus, args..., result);
- }
- }
-
- // Returns the first nested feature's computed value.
- FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args) const override {
- IDX focus =
- static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
- return this->nested()[0]->Compute(workspaces, object, focus, args...);
- }
-};
-
-// CRTP feature locator class. This is a meta feature that modifies ARGS and
-// then calls the nested feature functions with the modified ARGS. Note that in
-// order for this template to work correctly, all of ARGS must be types for
-// which the reference operator & can be interpreted as a pointer to the
-// argument. The derived class DER must implement the UpdateFocus method which
-// takes pointers to the ARGS arguments:
-//
-// // Updates the current arguments.
-// void UpdateArgs(const OBJ &object, ARGS *...args) const;
-template <class DER, class OBJ, class... ARGS>
-class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
- public:
- // Feature locators have an additional check that there is no intrinsic type.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- SAFTM_CHECK_EQ(this->feature_type(), nullptr)
- << "FeatureLocators should not have an intrinsic type.";
- MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
- }
-
- // Evaluates the locator.
- void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
- FeatureVector *result) const override {
- static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
- for (auto *function : this->nested()) {
- function->Evaluate(workspaces, object, args..., result);
- }
- }
-
- // Returns the first nested feature's computed value.
- FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args) const override {
- static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
- return this->nested()[0]->Compute(workspaces, object, args...);
- }
-};
-
-// Feature extractor for extracting features from objects of a certain class.
-// Template type parameters are as defined for FeatureFunction.
-template <class OBJ, class... ARGS>
-class FeatureExtractor : public GenericFeatureExtractor {
- public:
- // Feature function type for top-level functions in the feature extractor.
- typedef FeatureFunction<OBJ, ARGS...> Function;
- typedef FeatureExtractor<OBJ, ARGS...> Self;
-
- // Feature locator type for the feature extractor.
- template <class DER>
- using Locator = FeatureLocator<DER, OBJ, ARGS...>;
-
- // Initializes feature extractor.
- FeatureExtractor() {}
-
- ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
-
- // Sets up the feature extractor. Note that only top-level functions exist
- // until Setup() is called. This does not take ownership over the context,
- // which must outlive this.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) {
- for (Function *function : functions_) {
- if (!function->Setup(context)) return false;
- }
- return true;
- }
-
- // Initializes the feature extractor. Must be called after Setup(). This
- // does not take ownership over the context, which must outlive this.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) {
- for (Function *function : functions_) {
- if (!function->Init(context)) return false;
- }
- if (!this->InitializeFeatureTypes()) return false;
- return true;
- }
-
- // Requests workspaces from the registry. Must be called after Init(), and
- // before Preprocess(). Does not take ownership over registry. This should be
- // the same registry used to initialize the WorkspaceSet used in Preprocess()
- // and ExtractFeatures(). NB: This is a different ordering from that used in
- // SentenceFeatureRepresentation style feature computation.
- void RequestWorkspaces(WorkspaceRegistry *registry) {
- for (auto *function : functions_) function->RequestWorkspaces(registry);
- }
-
- // Preprocesses the object using feature functions for the phase. Must be
- // called before any calls to ExtractFeatures() on that object and phase.
- void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {
- for (Function *function : functions_) {
- function->Preprocess(workspaces, object);
- }
- }
-
- // Extracts features from an object with a focus. This invokes all the
- // top-level feature functions in the feature extractor. Only feature
- // functions belonging to the specified phase are invoked.
- void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
- ARGS... args, FeatureVector *result) const {
- result->reserve(this->feature_types());
-
- // Extract features.
- for (int i = 0; i < functions_.size(); ++i) {
- functions_[i]->Evaluate(workspaces, object, args..., result);
- }
- }
-
- private:
- // Creates and initializes all feature functions in the feature extractor.
- //
- // Returns true on success, false otherwise.
- SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override {
- // Create all top-level feature functions.
- for (int i = 0; i < descriptor().feature_size(); ++i) {
- const FeatureFunctionDescriptor &fd = descriptor().feature(i);
- Function *function = Function::Instantiate(this, &fd, "");
- if (function == nullptr) return false;
- functions_.push_back(function);
- }
- return true;
- }
-
- // Collect all feature types used in the feature extractor.
- void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
- for (int i = 0; i < functions_.size(); ++i) {
- functions_[i]->GetFeatureTypes(types);
- }
- }
-
- // Top-level feature functions (and variables) in the feature extractor.
- // Owned.
- std::vector<Function *> functions_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
diff --git a/lang_id/common/fel/feature-types.h b/lang_id/common/fel/feature-types.h
deleted file mode 100644
index 18cf69a..0000000
--- a/lang_id/common/fel/feature-types.h
+++ /dev/null
@@ -1,189 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Common feature types for parser components.
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
-
-#include <algorithm>
-#include <map>
-#include <string>
-#include <utility>
-
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/str-cat.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// TODO(djweiss) Clean this up as well.
-// Use the same type for feature values as is used for predicated.
-typedef int64 Predicate;
-typedef Predicate FeatureValue;
-
-// Each feature value in a feature vector has a feature type. The feature type
-// is used for converting feature type and value pairs to predicate values. The
-// feature type can also return names for feature values and calculate the size
-// of the feature value domain. The FeatureType class is abstract and must be
-// specialized for the concrete feature types.
-class FeatureType {
- public:
- // Initializes a feature type.
- explicit FeatureType(const string &name)
- : name_(name), base_(0),
- is_continuous_(name.find("continuous") != string::npos) {
- }
-
- virtual ~FeatureType() {}
-
- // Converts a feature value to a name.
- virtual string GetFeatureValueName(FeatureValue value) const = 0;
-
- // Returns the size of the feature values domain.
- virtual int64 GetDomainSize() const = 0;
-
- // Returns the feature type name.
- const string &name() const { return name_; }
-
- Predicate base() const { return base_; }
- void set_base(Predicate base) { base_ = base; }
-
- // Returns true iff this feature is continuous; see FloatFeatureValue.
- bool is_continuous() const { return is_continuous_; }
-
- private:
- // Feature type name.
- string name_;
-
- // "Base" feature value: i.e. a "slot" in a global ordering of features.
- Predicate base_;
-
- // See doc for is_continuous().
- bool is_continuous_;
-};
-
-// Feature type that is defined using an explicit map from FeatureValue to
-// string values. This can reduce some of the boilerplate when defining
-// features that generate enum values. Example usage:
-//
-// class BeverageSizeFeature : public FeatureFunction<Beverage>
-// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
-// void Init(TaskContext *context) override {
-// set_feature_type(new EnumFeatureType("beverage_size",
-// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
-// }
-// [...]
-// };
-class EnumFeatureType : public FeatureType {
- public:
- EnumFeatureType(const string &name,
- const std::map<FeatureValue, string> &value_names)
- : FeatureType(name), value_names_(value_names) {
- for (const auto &pair : value_names) {
- SAFTM_CHECK_GE(pair.first, 0)
- << "Invalid feature value: " << pair.first << ", " << pair.second;
- domain_size_ = std::max(domain_size_, pair.first + 1);
- }
- }
-
- // Returns the feature name for a given feature value.
- string GetFeatureValueName(FeatureValue value) const override {
- auto it = value_names_.find(value);
- if (it == value_names_.end()) {
- SAFTM_LOG(ERROR) << "Invalid feature value " << value << " for "
- << name();
- return "<INVALID>";
- }
- return it->second;
- }
-
- // Returns the number of possible values for this feature type. This is one
- // greater than the largest value in the value_names map.
- FeatureValue GetDomainSize() const override { return domain_size_; }
-
- protected:
- // Maximum possible value this feature could take.
- FeatureValue domain_size_ = 0;
-
- // Names of feature values.
- std::map<FeatureValue, string> value_names_;
-};
-
-// Feature type for binary features.
-class BinaryFeatureType : public FeatureType {
- public:
- BinaryFeatureType(const string &name, const string &off, const string &on)
- : FeatureType(name), off_(off), on_(on) {}
-
- // Returns the feature name for a given feature value.
- string GetFeatureValueName(FeatureValue value) const override {
- if (value == 0) return off_;
- if (value == 1) return on_;
- return "";
- }
-
- // Binary features always have two feature values.
- FeatureValue GetDomainSize() const override { return 2; }
-
- private:
- // Feature value names for on and off.
- string off_;
- string on_;
-};
-
-// Feature type for numeric features.
-class NumericFeatureType : public FeatureType {
- public:
- // Initializes numeric feature.
- NumericFeatureType(const string &name, FeatureValue size)
- : FeatureType(name), size_(size) {}
-
- // Returns numeric feature value.
- string GetFeatureValueName(FeatureValue value) const override {
- if (value < 0) return "";
- return LiteStrCat(value);
- }
-
- // Returns the number of feature values.
- FeatureValue GetDomainSize() const override { return size_; }
-
- private:
- // The underlying size of the numeric feature.
- FeatureValue size_;
-};
-
-// Feature type for byte features, including an "outside" value.
-class ByteFeatureType : public NumericFeatureType {
- public:
- explicit ByteFeatureType(const string &name)
- : NumericFeatureType(name, 257) {}
-
- string GetFeatureValueName(FeatureValue value) const override {
- if (value == 256) {
- return "<NULL>";
- }
- string result;
- result += static_cast<char>(value);
- return result;
- }
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
diff --git a/lang_id/common/fel/fel-parser.cc b/lang_id/common/fel/fel-parser.cc
deleted file mode 100644
index 4346fb7..0000000
--- a/lang_id/common/fel/fel-parser.cc
+++ /dev/null
@@ -1,289 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/fel/fel-parser.h"
-
-#include <ctype.h>
-#include <string>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/numbers.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-namespace {
-inline bool IsValidCharAtStartOfIdentifier(char c) {
- return isalpha(c) || (c == '_') || (c == '/');
-}
-
-// Returns true iff character c can appear inside an identifier.
-inline bool IsValidCharInsideIdentifier(char c) {
- return isalnum(c) || (c == '_') || (c == '-') || (c == '/');
-}
-
-// Returns true iff character c can appear at the beginning of a number.
-inline bool IsValidCharAtStartOfNumber(char c) {
- return isdigit(c) || (c == '+') || (c == '-');
-}
-
-// Returns true iff character c can appear inside a number.
-inline bool IsValidCharInsideNumber(char c) {
- return isdigit(c) || (c == '.');
-}
-} // namespace
-
-bool FELParser::Initialize(const string &source) {
- // Initialize parser state.
- source_ = source;
- current_ = source_.begin();
- item_start_ = line_start_ = current_;
- line_number_ = item_line_number_ = 1;
-
- // Read first input item.
- return NextItem();
-}
-
-void FELParser::ReportError(const string &error_message) {
- const int position = item_start_ - line_start_ + 1;
- const string line(line_start_, current_);
-
- SAFTM_LOG(ERROR) << "Error in feature model, line " << item_line_number_
- << ", position " << position << ": " << error_message
- << "\n " << line << " <--HERE";
-}
-
-void FELParser::Next() {
- // Move to the next input character. If we are at a line break update line
- // number and line start position.
- if (CurrentChar() == '\n') {
- ++line_number_;
- ++current_;
- line_start_ = current_;
- } else {
- ++current_;
- }
-}
-
-bool FELParser::NextItem() {
- // Skip white space and comments.
- while (!eos()) {
- if (CurrentChar() == '#') {
- // Skip comment.
- while (!eos() && CurrentChar() != '\n') Next();
- } else if (isspace(CurrentChar())) {
- // Skip whitespace.
- while (!eos() && isspace(CurrentChar())) Next();
- } else {
- break;
- }
- }
-
- // Record start position for next item.
- item_start_ = current_;
- item_line_number_ = line_number_;
-
- // Check for end of input.
- if (eos()) {
- item_type_ = END;
- return true;
- }
-
- // Parse number.
- if (IsValidCharAtStartOfNumber(CurrentChar())) {
- string::iterator start = current_;
- Next();
- while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
- item_text_.assign(start, current_);
- item_type_ = NUMBER;
- return true;
- }
-
- // Parse string.
- if (CurrentChar() == '"') {
- Next();
- string::iterator start = current_;
- while (CurrentChar() != '"') {
- if (eos()) {
- ReportError("Unterminated string");
- return false;
- }
- Next();
- }
- item_text_.assign(start, current_);
- item_type_ = STRING;
- Next();
- return true;
- }
-
- // Parse identifier name.
- if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
- string::iterator start = current_;
- while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
- Next();
- }
- item_text_.assign(start, current_);
- item_type_ = NAME;
- return true;
- }
-
- // Single character item.
- item_type_ = CurrentChar();
- Next();
- return true;
-}
-
-bool FELParser::Parse(const string &source,
- FeatureExtractorDescriptor *result) {
- // Initialize parser.
- if (!Initialize(source)) {
- return false;
- }
-
- while (item_type_ != END) {
- // Current item should be a feature name.
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- string name = item_text_;
- if (!NextItem()) {
- return false;
- }
-
- if (item_type_ == '=') {
- ReportError("Invalid syntax: feature expected");
- return false;
- } else {
- // Parse feature.
- FeatureFunctionDescriptor *descriptor = result->add_feature();
- descriptor->set_type(name);
- if (!ParseFeature(descriptor)) {
- return false;
- }
- }
- }
-
- return true;
-}
-
-bool FELParser::ParseFeature(FeatureFunctionDescriptor *result) {
- // Parse argument and parameters.
- if (item_type_ == '(') {
- if (!NextItem()) return false;
- if (!ParseParameter(result)) return false;
- while (item_type_ == ',') {
- if (!NextItem()) return false;
- if (!ParseParameter(result)) return false;
- }
-
- if (item_type_ != ')') {
- ReportError(") expected");
- return false;
- }
- if (!NextItem()) return false;
- }
-
- // Parse feature name.
- if (item_type_ == ':') {
- if (!NextItem()) return false;
- if (item_type_ != NAME && item_type_ != STRING) {
- ReportError("Feature name expected");
- return false;
- }
- string name = item_text_;
- if (!NextItem()) return false;
-
- // Set feature name.
- result->set_name(name);
- }
-
- // Parse sub-features.
- if (item_type_ == '.') {
- // Parse dotted sub-feature.
- if (!NextItem()) return false;
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- string type = item_text_;
- if (!NextItem()) return false;
-
- // Parse sub-feature.
- FeatureFunctionDescriptor *subfeature = result->add_feature();
- subfeature->set_type(type);
- if (!ParseFeature(subfeature)) return false;
- } else if (item_type_ == '{') {
- // Parse sub-feature block.
- if (!NextItem()) return false;
- while (item_type_ != '}') {
- if (item_type_ != NAME) {
- ReportError("Feature type name expected");
- return false;
- }
- string type = item_text_;
- if (!NextItem()) return false;
-
- // Parse sub-feature.
- FeatureFunctionDescriptor *subfeature = result->add_feature();
- subfeature->set_type(type);
- if (!ParseFeature(subfeature)) return false;
- }
- if (!NextItem()) return false;
- }
- return true;
-}
-
-bool FELParser::ParseParameter(FeatureFunctionDescriptor *result) {
- if (item_type_ == NUMBER) {
- int argument;
- if (!LiteAtoi(item_text_, &argument)) {
- ReportError("Unable to parse number");
- return false;
- }
- if (!NextItem()) return false;
-
- // Set default argument for feature.
- result->set_argument(argument);
- } else if (item_type_ == NAME) {
- string name = item_text_;
- if (!NextItem()) return false;
- if (item_type_ != '=') {
- ReportError("= expected");
- return false;
- }
- if (!NextItem()) return false;
- if (item_type_ >= END) {
- ReportError("Parameter value expected");
- return false;
- }
- string value = item_text_;
- if (!NextItem()) return false;
-
- // Add parameter to feature.
- Parameter *parameter;
- parameter = result->add_parameter();
- parameter->set_name(name);
- parameter->set_value(value);
- } else {
- ReportError("Syntax error in parameter list");
- return false;
- }
- return true;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/fel/fel-parser.h b/lang_id/common/fel/fel-parser.h
deleted file mode 100644
index eacb442..0000000
--- a/lang_id/common/fel/fel-parser.h
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Feature extraction language (FEL) parser.
-//
-// BNF grammar for FEL:
-//
-// <feature model> ::= { <feature extractor> }
-//
-// <feature extractor> ::= <extractor spec> |
-// <extractor spec> '.' <feature extractor> |
-// <extractor spec> '{' { <feature extractor> } '}'
-//
-// <extractor spec> ::= <extractor type>
-// [ '(' <parameter list> ')' ]
-// [ ':' <extractor name> ]
-//
-// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
-//
-// <parameter> ::= <parameter name> '=' <parameter value>
-//
-// <extractor type> ::= NAME
-// <extractor name> ::= NAME | STRING
-// <argument> ::= NUMBER
-// <parameter name> ::= NAME
-// <parameter value> ::= NUMBER | STRING | NAME
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
-
-#include <string>
-
-#include "lang_id/common/fel/feature-descriptors.h"
-#include "lang_id/common/lite_base/logging.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-class FELParser {
- public:
- // Parses fml specification into feature extractor descriptor.
- // Returns true on success, false on error (e.g., syntax errors).
- bool Parse(const string &source, FeatureExtractorDescriptor *result);
-
- private:
- // Initializes the parser with the source text.
- // Returns true on success, false on syntax error.
- bool Initialize(const string &source);
-
- // Outputs an error message, with context info.
- void ReportError(const string &error_message);
-
- // Moves to the next input character.
- void Next();
-
- // Moves to the next input item. Sets item_text_ and item_type_ accordingly.
- // Returns true on success, false on syntax error.
- bool NextItem();
-
- // Parses a feature descriptor.
- // Returns true on success, false on syntax error.
- bool ParseFeature(FeatureFunctionDescriptor *result);
-
- // Parses a parameter specification.
- // Returns true on success, false on syntax error.
- bool ParseParameter(FeatureFunctionDescriptor *result);
-
- // Returns true if end of source input has been reached.
- bool eos() const { return current_ >= source_.end(); }
-
- // Returns current character. Other methods should access the current
- // character through this method (instead of using *current_ directly): this
- // method performs extra safety checks.
- //
- // In case of an unsafe access, returns '\0'.
- char CurrentChar() const {
- if ((current_ >= source_.begin()) && (current_ < source_.end())) {
- return *current_;
- } else {
- SAFTM_LOG(ERROR) << "Unsafe char read";
- return '\0';
- }
- }
-
- // Item types.
- enum ItemTypes {
- END = 0,
- NAME = -1,
- NUMBER = -2,
- STRING = -3,
- };
-
- // Source text.
- string source_;
-
- // Current input position.
- string::iterator current_;
-
- // Line number for current input position.
- int line_number_;
-
- // Start position for current item.
- string::iterator item_start_;
-
- // Start position for current line.
- string::iterator line_start_;
-
- // Line number for current item.
- int item_line_number_;
-
- // Item type for current item. If this is positive it is interpreted as a
- // character. If it is negative it is interpreted as an item type.
- int item_type_;
-
- // Text for current item.
- string item_text_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
diff --git a/lang_id/common/fel/task-context.cc b/lang_id/common/fel/task-context.cc
deleted file mode 100644
index f8b0701..0000000
--- a/lang_id/common/fel/task-context.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/fel/task-context.h"
-
-#include <string>
-
-#include "lang_id/common/lite_strings/numbers.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-string TaskContext::GetInputPath(const string &name) const {
- auto it = inputs_.find(name);
- if (it != inputs_.end()) {
- return it->second;
- }
- return "";
-}
-
-void TaskContext::SetInputPath(const string &name, const string &path) {
- inputs_[name] = path;
-}
-
-string TaskContext::Get(const string &name, const char *defval) const {
- auto it = parameters_.find(name);
- if (it != parameters_.end()) {
- return it->second;
- }
- return defval;
-}
-
-int TaskContext::Get(const string &name, int defval) const {
- const string s = Get(name, "");
- int value = defval;
- if (LiteAtoi(s, &value)) {
- return value;
- }
- return defval;
-}
-
-float TaskContext::Get(const string &name, float defval) const {
- const string s = Get(name, "");
- float value = defval;
- if (LiteAtof(s, &value)) {
- return value;
- }
- return defval;
-}
-
-bool TaskContext::Get(const string &name, bool defval) const {
- string value = Get(name, "");
- return value.empty() ? defval : value == "true";
-}
-
-void TaskContext::SetParameter(const string &name, const string &value) {
- parameters_[name] = value;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/fel/task-context.h b/lang_id/common/fel/task-context.h
deleted file mode 100644
index ddc8cfe..0000000
--- a/lang_id/common/fel/task-context.h
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef TC3_STD_STRING_IMPORT
-#define TC3_STD_STRING_IMPORT
-#include <string>
-
-namespace libtextclassifier3 {
-using string = std::string;
-template <class CharT, class Traits = std::char_traits<CharT>,
- class Allocator = std::allocator<CharT> >
-using basic_string = std::basic_string<CharT, Traits, Allocator>;
-} // namespace libtextclassifier3
-#endif
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
-
-#include <map>
-#include <string>
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Class that provides access to model parameter and inputs.
-//
-// Note: This class is related to the servers-side nlp_saft::TaskContext, but it
-// has been simplified to reduce code dependencies.
-class TaskContext {
- public:
- // Returns path for the input named |name|. Returns empty string ("") if
- // there is no input with that name. Note: this can be a standard file path,
- // or a path in a more special file system.
- string GetInputPath(const string &name) const;
-
- // Sets path for input |name|. Previous path, if any, is overwritten.
- void SetInputPath(const string &name, const string &path);
-
- // Returns parameter value. If the parameter is not specified in this
- // context, the default value is returned.
- string Get(const string &name, const char *defval) const;
- int Get(const string &name, int defval) const;
- float Get(const string &name, float defval) const;
- bool Get(const string &name, bool defval) const;
-
- // Sets value of parameter |name| to |value|.
- void SetParameter(const string &name, const string &value);
-
- private:
- // Maps input name -> path.
- std::map<string, string> inputs_;
-
- // Maps parameter name -> value.
- std::map<string, string> parameters_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
diff --git a/lang_id/common/fel/workspace.cc b/lang_id/common/fel/workspace.cc
deleted file mode 100644
index 8cab281..0000000
--- a/lang_id/common/fel/workspace.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/fel/workspace.h"
-
-#include <atomic>
-#include <string>
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// static
-int GetFreshTypeId() {
- // Static local below is initialized the first time this method is run.
- static std::atomic<int> counter(0);
- return counter++;
-}
-
-string WorkspaceRegistry::DebugString() const {
- string str;
- for (auto &it : workspace_names_) {
- const string &type_name = workspace_types_.at(it.first);
- for (size_t index = 0; index < it.second.size(); ++index) {
- const string &workspace_name = it.second[index];
- str.append("\n ");
- str.append(type_name);
- str.append(" :: ");
- str.append(workspace_name);
- }
- }
- return str;
-}
-
-VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
-
-VectorIntWorkspace::VectorIntWorkspace(int size, int value)
- : elements_(size, value) {}
-
-VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
- : elements_(elements) {}
-
-string VectorIntWorkspace::TypeName() { return "Vector"; }
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/fel/workspace.h b/lang_id/common/fel/workspace.h
deleted file mode 100644
index 09095e4..0000000
--- a/lang_id/common/fel/workspace.h
+++ /dev/null
@@ -1,204 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Notes on thread-safety: All of the classes here are thread-compatible. More
-// specifically, the registry machinery is thread-safe, as long as each thread
-// performs feature extraction on a different Sentence object.
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
-
-#include <stddef.h>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_base/macros.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// A base class for shared workspaces. Derived classes implement a static member
-// function TypeName() which returns a human readable string name for the class.
-class Workspace {
- public:
- // Polymorphic destructor.
- virtual ~Workspace() {}
-
- protected:
- // Create an empty workspace.
- Workspace() {}
-
- private:
- SAFTM_DISALLOW_COPY_AND_ASSIGN(Workspace);
-};
-
-// Returns a new, strictly increasing int every time it is invoked.
-int GetFreshTypeId();
-
-// Struct to simulate typeid, but without RTTI.
-template <typename T>
-struct TypeId {
- static int type_id;
-};
-
-template <typename T>
-int TypeId<T>::type_id = GetFreshTypeId();
-
-// A registry that keeps track of workspaces.
-class WorkspaceRegistry {
- public:
- // Create an empty registry.
- WorkspaceRegistry() {}
-
- // Returns the index of a named workspace, adding it to the registry first
- // if necessary.
- template <class W>
- int Request(const string &name) {
- const int id = TypeId<W>::type_id;
- max_workspace_id_ = std::max(id, max_workspace_id_);
- workspace_types_[id] = W::TypeName();
- std::vector<string> &names = workspace_names_[id];
- for (int i = 0; i < names.size(); ++i) {
- if (names[i] == name) return i;
- }
- names.push_back(name);
- return names.size() - 1;
- }
-
- // Returns the maximum workspace id that has been registered.
- int MaxId() const {
- return max_workspace_id_;
- }
-
- const std::unordered_map<int, std::vector<string> > &WorkspaceNames()
- const {
- return workspace_names_;
- }
-
- // Returns a string describing the registered workspaces.
- string DebugString() const;
-
- private:
- // Workspace type names, indexed as workspace_types_[typeid].
- std::unordered_map<int, string> workspace_types_;
-
- // Workspace names, indexed as workspace_names_[typeid][workspace].
- std::unordered_map<int, std::vector<string> > workspace_names_;
-
- // The maximum workspace id that has been registered.
- int max_workspace_id_ = 0;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
-};
-
-// A typed collected of workspaces. The workspaces are indexed according to an
-// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
-// also immutable.
-class WorkspaceSet {
- public:
- ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
-
- // Returns true if a workspace has been set.
- template <class W>
- bool Has(int index) const {
- const int id = TypeId<W>::type_id;
- SAFTM_DCHECK_GE(id, 0);
- SAFTM_DCHECK_LT(id, workspaces_.size());
- SAFTM_DCHECK_GE(index, 0);
- SAFTM_DCHECK_LT(index, workspaces_[id].size());
- if (id >= workspaces_.size()) return false;
- return workspaces_[id][index] != nullptr;
- }
-
- // Returns an indexed workspace; the workspace must have been set.
- template <class W>
- const W &Get(int index) const {
- SAFTM_DCHECK(Has<W>(index));
- const int id = TypeId<W>::type_id;
- const Workspace *w = workspaces_[id][index];
- return reinterpret_cast<const W &>(*w);
- }
-
- // Sets an indexed workspace; this takes ownership of the workspace, which
- // must have been new-allocated. It is an error to set a workspace twice.
- template <class W>
- void Set(int index, W *workspace) {
- const int id = TypeId<W>::type_id;
- SAFTM_DCHECK_GE(id, 0);
- SAFTM_DCHECK_LT(id, workspaces_.size());
- SAFTM_DCHECK_GE(index, 0);
- SAFTM_DCHECK_LT(index, workspaces_[id].size());
- SAFTM_DCHECK(workspaces_[id][index] == nullptr);
- SAFTM_DCHECK(workspace != nullptr);
- workspaces_[id][index] = workspace;
- }
-
- void Reset(const WorkspaceRegistry ®istry) {
- // Deallocate current workspaces.
- for (auto &it : workspaces_) {
- for (size_t index = 0; index < it.size(); ++index) {
- delete it[index];
- }
- }
- workspaces_.clear();
- workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
- for (auto &it : registry.WorkspaceNames()) {
- workspaces_[it.first].resize(it.second.size());
- }
- }
-
- private:
- // The set of workspaces, indexed as workspaces_[typeid][index].
- std::vector<std::vector<Workspace *> > workspaces_;
-};
-
-// A workspace that wraps around a vector of int.
-class VectorIntWorkspace : public Workspace {
- public:
- // Creates a vector of the given size.
- explicit VectorIntWorkspace(int size);
-
- // Creates a vector initialized with the given array.
- explicit VectorIntWorkspace(const std::vector<int> &elements);
-
- // Creates a vector of the given size, with each element initialized to the
- // given value.
- VectorIntWorkspace(int size, int value);
-
- // Returns the name of this type of workspace.
- static string TypeName();
-
- // Returns the i'th element.
- int element(int i) const { return elements_[i]; }
-
- // Sets the i'th element.
- void set_element(int i, int value) { elements_[i] = value; }
-
- // Returns the size of the underlying vector.
- int size() const { return elements_.size(); }
-
- private:
- // The enclosed vector.
- std::vector<int> elements_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
diff --git a/lang_id/common/file/file-utils.cc b/lang_id/common/file/file-utils.cc
deleted file mode 100644
index 108c7d5..0000000
--- a/lang_id/common/file/file-utils.cc
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/file/file-utils.h"
-
-#include <fcntl.h>
-#include <stdio.h>
-#include <sys/stat.h>
-#include <sys/types.h>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-namespace file_utils {
-
-bool GetFileContent(const string &filename, string *content) {
- ScopedMmap scoped_mmap(filename);
- const MmapHandle &handle = scoped_mmap.handle();
- if (!handle.ok()) {
- SAFTM_LOG(ERROR) << "Error opening " << filename;
- return false;
- }
- StringPiece sp = handle.to_stringpiece();
- content->assign(sp.data(), sp.size());
- return true;
-}
-
-bool FileExists(const string &filename) {
- struct stat s = {0};
- if (!stat(filename.c_str(), &s)) {
- return s.st_mode & S_IFREG;
- } else {
- return false;
- }
-}
-
-bool DirectoryExists(const string &dirpath) {
- struct stat s = {0};
- if (!stat(dirpath.c_str(), &s)) {
- return s.st_mode & S_IFDIR;
- } else {
- return false;
- }
-}
-
-} // namespace file_utils
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/file/file-utils.h b/lang_id/common/file/file-utils.h
deleted file mode 100644
index 6377d7a..0000000
--- a/lang_id/common/file/file-utils.h
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
-
-#include <stddef.h>
-#include <string>
-
-#include "lang_id/common/file/mmap.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-namespace file_utils {
-
-// Reads the entire content of a file into a string. Returns true on success,
-// false on error.
-bool GetFileContent(const string &filename, string *content);
-
-// Parses a proto from its serialized representation in memory. That
-// representation starts at address |data| and should contain exactly
-// |num_bytes| bytes. Returns true on success, false otherwise.
-template <class Proto>
-bool ParseProtoFromMemory(const char *data, size_t num_bytes, Proto *proto) {
- if (data == nullptr) {
- // Avoid passing a nullptr to ParseFromArray below.
- return false;
- }
- return proto->ParseFromArray(data, num_bytes);
-}
-
-// Convenience StringPiece-based version of ParseProtoFromMemory.
-template <class Proto>
-inline bool ParseProtoFromMemory(StringPiece sp, Proto *proto) {
- return ParseProtoFromMemory(sp.data(), sp.size(), proto);
-}
-
-// Parses a proto from a file. Returns true on success, false otherwise.
-//
-// Note: the entire content of the file should be the binary (not
-// human-readable) serialization of a protocol buffer.
-//
-// Note: when we compile for Android, the proto parsing methods need to know the
-// type of the message they are parsing. We use template polymorphism for that.
-template<class Proto>
-bool ReadProtoFromFile(const string &filename, Proto *proto) {
- ScopedMmap scoped_mmap(filename);
- const MmapHandle &handle = scoped_mmap.handle();
- if (!handle.ok()) {
- return false;
- }
- return ParseProtoFromMemory(handle.to_stringpiece(), proto);
-}
-
-// Returns true if filename is the name of an existing file, and false
-// otherwise.
-bool FileExists(const string &filename);
-
-// Returns true if dirpath is the path to an existing directory, and false
-// otherwise.
-bool DirectoryExists(const string &dirpath);
-
-} // namespace file_utils
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
diff --git a/lang_id/common/file/mmap.cc b/lang_id/common/file/mmap.cc
deleted file mode 100644
index 89efa99..0000000
--- a/lang_id/common/file/mmap.cc
+++ /dev/null
@@ -1,133 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/file/mmap.h"
-
-#include <errno.h>
-#include <fcntl.h>
-#include <stdint.h>
-#include <string.h>
-#include <sys/mman.h>
-#include <sys/stat.h>
-#include <unistd.h>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_base/macros.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-namespace {
-inline string GetLastSystemError() {
- return string(strerror(errno));
-}
-
-inline MmapHandle GetErrorMmapHandle() {
- return MmapHandle(nullptr, 0);
-}
-
-class FileCloser {
- public:
- explicit FileCloser(int fd) : fd_(fd) {}
- ~FileCloser() {
- int result = close(fd_);
- if (result != 0) {
- const string last_error = GetLastSystemError();
- SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
- }
- }
- private:
- const int fd_;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(FileCloser);
-};
-} // namespace
-
-MmapHandle MmapFile(const string &filename) {
- int fd = open(filename.c_str(), O_RDONLY);
-
- if (fd < 0) {
- const string last_error = GetLastSystemError();
- SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
- return GetErrorMmapHandle();
- }
-
- // Make sure we close fd no matter how we exit this function. As the man page
- // for mmap clearly states: "closing the file descriptor does not unmap the
- // region." Hence, we can close fd as soon as we return from here.
- FileCloser file_closer(fd);
-
- return MmapFile(fd);
-}
-
-MmapHandle MmapFile(int fd) {
- // Get file stats to obtain file size.
- struct stat sb;
- if (fstat(fd, &sb) != 0) {
- const string last_error = GetLastSystemError();
- SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
- return GetErrorMmapHandle();
- }
- size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
-
- // Perform actual mmap.
- void *mmap_addr = mmap(
-
- // Let system pick address for mmapp-ed data.
- nullptr,
-
- // Mmap all bytes from the file.
- file_size_in_bytes,
-
- // One can read / write the mapped data (but see MAP_PRIVATE below).
- // Normally, we expect only to read it, but in the future, we may want to
- // write it, to fix e.g., endianness differences.
- PROT_READ | PROT_WRITE,
-
- // Updates to mmaped data are *not* propagated to actual file.
- // AFAIK(salcianu) that's anyway not possible on Android.
- MAP_PRIVATE,
-
- // Descriptor of file to mmap.
- fd,
-
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
- if (mmap_addr == MAP_FAILED) {
- const string last_error = GetLastSystemError();
- SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
- return GetErrorMmapHandle();
- }
-
- return MmapHandle(mmap_addr, file_size_in_bytes);
-}
-
-bool Unmap(MmapHandle mmap_handle) {
- if (!mmap_handle.ok()) {
- // Unmapping something that hasn't been mapped is trivially successful.
- return true;
- }
- if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
- const string last_error = GetLastSystemError();
- SAFTM_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
- return false;
- }
- return true;
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/file/mmap.h b/lang_id/common/file/mmap.h
deleted file mode 100644
index 6131803..0000000
--- a/lang_id/common/file/mmap.h
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
-
-#include <stddef.h>
-
-#include <string>
-
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Handle for a memory area where a file has been mmapped.
-//
-// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
-// it using Unmap(). Just like a pointer, it is passed around by value (see
-// signature of MmapFile and Unmap; fortunately, it's a small class, so there
-// shouldn't be any significant performance penalty) and its usage is not
-// necessarily scoped (that's why the destructor is not performing the unmap).
-//
-// Note: on program termination, each still unmapped file is automatically
-// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
-// are ok keeping that file in memory the whole time).
-class MmapHandle {
- public:
- MmapHandle(void *start, size_t num_bytes)
- : start_(start), num_bytes_(num_bytes) {}
-
- // Returns start address for the memory area where a file has been mmapped.
- void *start() const { return start_; }
-
- // Returns number of bytes of the memory area from start().
- size_t num_bytes() const { return num_bytes_; }
-
- // Shortcut to simplify checking success of MmapFile(). See usage example
- // from the doc of that function.
- bool ok() const { return start() != nullptr; }
-
- // Returns a StringPiece pointing to the same underlying bytes.
- StringPiece to_stringpiece() const {
- return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
- }
-
- private:
- // See doc for start(). Not owned.
- void *const start_;
-
- // See doc for num_bytes().
- const size_t num_bytes_;
-};
-
-// Maps the full content of a file in memory (using mmap).
-//
-// When done using the file content, one can unmap using Unmap(). Otherwise,
-// all mapped files are unmapped when the program terminates.
-//
-// Sample usage:
-//
-// MmapHandle mmap_handle = MmapFile(filename);
-// CHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
-//
-// ... use data from addresses
-// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
-//
-// Unmap(mmap_handle); // Unmap logs errors internally.
-//
-// Note: one can read *and* write the num_bytes bytes from start, but those
-// writes are not propagated to the underlying file, nor to other processes that
-// may have mmapped that file (all changes are local to current process).
-MmapHandle MmapFile(const string &filename);
-
-// Like MmapFile(const string &filename), but uses a file descriptor.
-MmapHandle MmapFile(int fd);
-
-// Unmaps a file mapped using MmapFile. Returns true on success, false
-// otherwise.
-bool Unmap(MmapHandle mmap_handle);
-
-// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
-// destruction.
-class ScopedMmap {
- public:
- explicit ScopedMmap(const string &filename)
- : handle_(MmapFile(filename)) {}
-
- explicit ScopedMmap(int fd)
- : handle_(MmapFile(fd)) {}
-
- ~ScopedMmap() {
- if (handle_.ok()) {
- Unmap(handle_);
- }
- }
-
- const MmapHandle &handle() { return handle_; }
-
- private:
- MmapHandle handle_;
-};
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
diff --git a/lang_id/common/flatbuffers/model-utils.cc b/lang_id/common/flatbuffers/model-utils.cc
deleted file mode 100644
index 2c57aa2..0000000
--- a/lang_id/common/flatbuffers/model-utils.cc
+++ /dev/null
@@ -1,208 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/flatbuffers/model-utils.h"
-
-#include <string.h>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/math/checksum.h"
-
-namespace libtextclassifier3 {
-namespace saft_fbs {
-
-namespace {
-
-// Returns true if we have clear evidence that |model| fails its checksum.
-//
-// E.g., if |model| has the crc32 field, and the value of that field does not
-// match the checksum, then this function returns true. If there is no crc32
-// field, then we don't know what the original (at build time) checksum was, so
-// we don't know anything clear and this function returns false.
-bool ClearlyFailsChecksum(const Model &model) {
- if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
- SAFTM_LOG(WARNING)
- << "No CRC32, most likely an old model; skip CRC32 check";
- return false;
- }
- const mobile::uint32 expected_crc32 = model.crc32();
- const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
- if (actual_crc32 != expected_crc32) {
- SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
- << " vs " << expected_crc32;
- return true;
- }
- SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
- return false;
-}
-} // namespace
-
-const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
- if ((data == nullptr) || (num_bytes == 0)) {
- SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
- return nullptr;
- }
- const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
- flatbuffers::Verifier verifier(start, num_bytes);
- if (!VerifyModelBuffer(verifier)) {
- SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
- return nullptr;
- }
- const Model *model = GetModel(start);
- if (model == nullptr) {
- return nullptr;
- }
- if (ClearlyFailsChecksum(*model)) {
- return nullptr;
- }
- return model;
-}
-
-const ModelInput *GetInputByName(const Model *model, const string &name) {
- if (model == nullptr) {
- SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
- return nullptr;
- }
- const auto *inputs = model->inputs();
- if (inputs == nullptr) {
- // We should always have a list of inputs; maybe an empty one, if no inputs,
- // but the list should be there.
- SAFTM_LOG(ERROR) << "null inputs";
- return nullptr;
- }
- for (const ModelInput *input : *inputs) {
- if (input != nullptr) {
- const flatbuffers::String *input_name = input->name();
- if (input_name && input_name->str() == name) {
- return input;
- }
- }
- }
- return nullptr;
-}
-
-mobile::StringPiece GetInputBytes(const ModelInput *input) {
- if ((input == nullptr) || (input->data() == nullptr)) {
- SAFTM_LOG(ERROR) << "ModelInput has no content";
- return mobile::StringPiece(nullptr, 0);
- }
- const flatbuffers::Vector<uint8_t> *input_data = input->data();
- if (input_data == nullptr) {
- SAFTM_LOG(ERROR) << "null input data";
- return mobile::StringPiece(nullptr, 0);
- }
- return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
- input_data->size());
-}
-
-bool FillParameters(const Model &model, mobile::TaskContext *context) {
- if (context == nullptr) {
- SAFTM_LOG(ERROR) << "null context";
- return false;
- }
- const auto *parameters = model.parameters();
- if (parameters == nullptr) {
- // We should always have a list of parameters; maybe an empty one, if no
- // parameters, but the list should be there.
- SAFTM_LOG(ERROR) << "null list of parameters";
- return false;
- }
- for (const ModelParameter *p : *parameters) {
- if (p == nullptr) {
- SAFTM_LOG(ERROR) << "null parameter";
- return false;
- }
- if (p->name() == nullptr) {
- SAFTM_LOG(ERROR) << "null parameter name";
- return false;
- }
- const string name = p->name()->str();
- if (name.empty()) {
- SAFTM_LOG(ERROR) << "empty parameter name";
- return false;
- }
- if (p->value() == nullptr) {
- SAFTM_LOG(ERROR) << "null parameter name";
- return false;
- }
- context->SetParameter(name, p->value()->str());
- }
- return true;
-}
-
-namespace {
-// Updates |*crc| with the information from |s|. Auxiliary for
-// ComputeCrc2Checksum.
-//
-// The bytes from |info| are also used to update the CRC32 checksum. |info|
-// should be a brief tag that indicates what |s| represents. The idea is to add
-// some structure to the information that goes into the CRC32 computation.
-template <typename T>
-void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
- mobile::StringPiece info) {
- crc->Update("|");
- crc->Update(info.data(), info.size());
- crc->Update(":");
- if (s == nullptr) {
- crc->Update("empty");
- } else {
- crc->Update(reinterpret_cast<const char *>(s->data()),
- s->size() * sizeof(T));
- }
-}
-} // namespace
-
-mobile::uint32 ComputeCrc2Checksum(const Model *model) {
- // Implementation note: originally, I (salcianu@) thought we can just compute
- // a CRC32 checksum of the model bytes. Unfortunately, the expected checksum
- // is there too (and because we don't control the flatbuffer format, we can't
- // "arrange" for it to be placed at the head / tail of those bytes). Instead,
- // we traverse |model| and feed into the CRC32 computation those parts we are
- // interested in (which excludes the crc32 field).
- //
- // Note: storing the checksum outside the Model would be too disruptive for
- // the way we currently ship our models.
- mobile::Crc32 crc;
- if (model == nullptr) {
- return crc.Get();
- }
- crc.Update("|Parameters:");
- const auto *parameters = model->parameters();
- if (parameters != nullptr) {
- for (const ModelParameter *p : *parameters) {
- if (p != nullptr) {
- UpdateCrc(&crc, p->name(), "name");
- UpdateCrc(&crc, p->value(), "value");
- }
- }
- }
- crc.Update("|Inputs:");
- const auto *inputs = model->inputs();
- if (inputs != nullptr) {
- for (const ModelInput *input : *inputs) {
- if (input != nullptr) {
- UpdateCrc(&crc, input->name(), "name");
- UpdateCrc(&crc, input->type(), "type");
- UpdateCrc(&crc, input->sub_type(), "sub-type");
- UpdateCrc(&crc, input->data(), "data");
- }
- }
- }
- return crc.Get();
-}
-
-} // namespace saft_fbs
-} // namespace nlp_saft
diff --git a/lang_id/common/flatbuffers/model-utils.h b/lang_id/common/flatbuffers/model-utils.h
deleted file mode 100644
index 5427f70..0000000
--- a/lang_id/common/flatbuffers/model-utils.h
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
-
-#include <stddef.h>
-
-#include <string>
-
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/flatbuffers/model_generated.h"
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace saft_fbs {
-
-// Verifies that the |num_bytes| bytes that start at |data| represent a valid
-// Model flatbuffer. If so, returns that Model. Otherwise, returns nullptr.
-//
-// Note: if the Model has the crc32 field, this method checks that the Model
-// checksum matches that field; if they don't match, the Model is considered
-// invalid, and this function returns nullptr. The checksum test is in addition
-// to the standard flatbuffer validity checking.
-const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes);
-
-// Convenience StringPiece version of GetVerifiedModelFromBytes.
-inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) {
- return GetVerifiedModelFromBytes(bytes.data(), bytes.size());
-}
-
-// Returns the |model| input with specified |name|. Returns nullptr if no such
-// input exists. If |model| contains multiple inputs with that |name|, returns
-// the first one (model builders should avoid building such models).
-const ModelInput *GetInputByName(const Model *model, const string &name);
-
-// Returns a StringPiece pointing to the bytes for the content of |input|. In
-// case of errors, returns StringPiece(nullptr, 0).
-mobile::StringPiece GetInputBytes(const ModelInput *input);
-
-// Fills parameters from |context|, based on the parameters from |model|.
-// Returns false if any error is encountered, true otherwise. In the case of an
-// error, some parameters may have been added to |context| (e.g., if we find a
-// problem with the 3rd parameter, the first 2 have been added).
-bool FillParameters(const Model &model, mobile::TaskContext *context);
-
-// Returns the CRC32 checksum of |model|. This checksum is computed over the
-// entire information from the model (including the bytes of the inputs),
-// *except* the crc32 field. Hence, when a model is build, one can store the
-// result of this function into that field; on the user side, one can check that
-// the result of this function matches the crc32 field, to guard against model
-// corruption. GetVerifiedModelFromBytes performs this check.
-mobile::uint32 ComputeCrc2Checksum(const Model *model);
-
-} // namespace saft_fbs
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
diff --git a/lang_id/common/lite_base/compact-logging-raw.cc b/lang_id/common/lite_base/compact-logging-raw.cc
deleted file mode 100644
index 53dfc8e..0000000
--- a/lang_id/common/lite_base/compact-logging-raw.cc
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/lite_base/compact-logging-raw.h"
-
-#include <stdio.h>
-#include <string>
-
-// NOTE: this file contains two implementations: one for Android, one for all
-// other cases. We always build exactly one implementation.
-#if defined(__ANDROID__)
-
-// Compiled as part of Android.
-#include <android/log.h>
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace internal_logging {
-
-namespace {
-// Converts LogSeverity to level for __android_log_write.
-int GetAndroidLogLevel(LogSeverity severity) {
- switch (severity) {
- case FATAL:
- return ANDROID_LOG_FATAL;
- case ERROR:
- return ANDROID_LOG_ERROR;
- case WARNING:
- return ANDROID_LOG_WARN;
- case INFO:
- return ANDROID_LOG_INFO;
- default:
- return ANDROID_LOG_DEBUG;
- }
-}
-} // namespace
-
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message) {
- const int android_log_level = GetAndroidLogLevel(severity);
-#if !defined(SAFTM_DEBUG_LOGGING)
- if (android_log_level != ANDROID_LOG_ERROR &&
- android_log_level != ANDROID_LOG_FATAL) {
- return;
- }
-#endif
- __android_log_write(android_log_level, tag.c_str(), message.c_str());
-}
-
-} // namespace internal_logging
-} // namespace mobile
-} // namespace nlp_saft
-
-#else // if defined(__ANDROID__)
-
-// Not on Android: implement LowLevelLogging to print to stderr (see below).
-namespace libtextclassifier3 {
-namespace mobile {
-namespace internal_logging {
-
-namespace {
-// Converts LogSeverity to human-readable text.
-const char *LogSeverityToString(LogSeverity severity) {
- switch (severity) {
- case INFO:
- return "INFO";
- case WARNING:
- return "WARNING";
- case ERROR:
- return "ERROR";
- case FATAL:
- return "FATAL";
- default:
- return "UNKNOWN";
- }
-}
-} // namespace
-
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message) {
- fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
- message.c_str());
- fflush(stderr);
-}
-
-} // namespace internal_logging
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // if defined(__ANDROID__)
diff --git a/lang_id/common/lite_base/compact-logging-raw.h b/lang_id/common/lite_base/compact-logging-raw.h
deleted file mode 100644
index f67287c..0000000
--- a/lang_id/common/lite_base/compact-logging-raw.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
-
-#include <string>
-
-#include "lang_id/common/lite_base/compact-logging-levels.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace internal_logging {
-
-// Low-level logging primitive. Logs a message, with the indicated log
-// severity. From android/log.h: "the tag normally corresponds to the component
-// that emits the log message, and should be reasonably small".
-void LowLevelLogging(LogSeverity severity, const string &tag,
- const string &message);
-
-} // namespace internal_logging
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
diff --git a/lang_id/common/lite_base/compact-logging.h b/lang_id/common/lite_base/compact-logging.h
deleted file mode 100644
index eccb7d1..0000000
--- a/lang_id/common/lite_base/compact-logging.h
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
-
-#include <cassert>
-#include <string>
-
-#include "lang_id/common/lite_base/attributes.h"
-#include "lang_id/common/lite_base/compact-logging-levels.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace internal_logging {
-
-// A tiny code footprint string stream for assembling log messages.
-struct LoggingStringStream {
- LoggingStringStream() {}
- LoggingStringStream &stream() { return *this; }
-
- // Needed for invocation in SAFTM_CHECK macro.
- explicit operator bool() const { return true; }
-
- string message;
-};
-
-template <typename T>
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const T &entry) {
- stream.message.append(std::to_string(entry));
- return stream;
-}
-
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const char *message) {
- stream.message.append(message);
- return stream;
-}
-
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const string &message) {
- stream.message.append(message);
- return stream;
-}
-
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- StringPiece sp) {
- stream.message.append(sp.data(), sp.size());
- return stream;
-}
-
-// The class that does all the work behind our SAFTM_LOG(severity) macros. Each
-// SAFTM_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
-// LogMessage temporary object containing a stringstream. Each operator<< adds
-// info to that stringstream and the LogMessage destructor performs the actual
-// logging. The reason this works is that in C++, "all temporary objects are
-// destroyed as the last step in evaluating the full-expression that (lexically)
-// contains the point where they were created." For more info, see
-// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
-// invoked after the last << from that logging statement.
-class LogMessage {
- public:
- LogMessage(LogSeverity severity, const char *file_name,
- int line_number) SAFTM_ATTRIBUTE_NOINLINE;
-
- ~LogMessage() SAFTM_ATTRIBUTE_NOINLINE;
-
- // Returns the stream associated with the logger object.
- LoggingStringStream &stream() { return stream_; }
-
- private:
- const LogSeverity severity_;
-
- // Stream that "prints" all info into a string (not to a file). We construct
- // here the entire logging message and next print it in one operation.
- LoggingStringStream stream_;
-};
-
-// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
-// anything.
-class NullStream {
- public:
- NullStream() {}
- NullStream &stream() { return *this; }
-};
-template <typename T>
-inline NullStream &operator<<(NullStream &str, const T &) {
- return str;
-}
-
-} // namespace internal_logging
-} // namespace mobile
-} // namespace nlp_saft
-
-#define SAFTM_LOG(severity) \
- ::libtextclassifier3::mobile::internal_logging::LogMessage( \
- ::libtextclassifier3::mobile::internal_logging::severity, __FILE__, __LINE__) \
- .stream()
-
-// If condition x is true, does nothing. Otherwise, crashes the program (liek
-// LOG(FATAL)) with an informative message. Can be continued with extra
-// messages, via <<, like any logging macro, e.g.,
-//
-// SAFTM_CHECK(my_cond) << "I think we hit a problem";
-#define SAFTM_CHECK(x) \
- (x) || SAFTM_LOG(FATAL) << __FILE__ << ":" << __LINE__ \
- << ": check failed: \"" << #x
-
-#define SAFTM_CHECK_EQ(x, y) SAFTM_CHECK((x) == (y))
-#define SAFTM_CHECK_LT(x, y) SAFTM_CHECK((x) < (y))
-#define SAFTM_CHECK_GT(x, y) SAFTM_CHECK((x) > (y))
-#define SAFTM_CHECK_LE(x, y) SAFTM_CHECK((x) <= (y))
-#define SAFTM_CHECK_GE(x, y) SAFTM_CHECK((x) >= (y))
-#define SAFTM_CHECK_NE(x, y) SAFTM_CHECK((x) != (y))
-
-#define SAFTM_NULLSTREAM \
- ::libtextclassifier3::mobile::internal_logging::NullStream().stream()
-
-// Debug checks: a SAFTM_DCHECK<suffix> macro should behave like
-// SAFTM_CHECK<suffix> in debug mode an don't check / don't print anything in
-// non-debug mode.
-#ifdef NDEBUG
-
-#define SAFTM_DCHECK(x) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_EQ(x, y) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_LT(x, y) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_GT(x, y) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_LE(x, y) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_GE(x, y) SAFTM_NULLSTREAM
-#define SAFTM_DCHECK_NE(x, y) SAFTM_NULLSTREAM
-
-// In non-debug mode, SAFT_DLOG statements do not generate any logging.
-#define SAFTM_DLOG(severity) SAFTM_NULLSTREAM
-
-#else // NDEBUG
-
-// In debug mode, each SAFTM_DCHECK<suffix> is equivalent to
-// SAFTM_CHECK<suffix>, i.e., a real check that crashes when the condition is
-// not true.
-#define SAFTM_DCHECK(x) SAFTM_CHECK(x)
-#define SAFTM_DCHECK_EQ(x, y) SAFTM_CHECK_EQ(x, y)
-#define SAFTM_DCHECK_LT(x, y) SAFTM_CHECK_LT(x, y)
-#define SAFTM_DCHECK_GT(x, y) SAFTM_CHECK_GT(x, y)
-#define SAFTM_DCHECK_LE(x, y) SAFTM_CHECK_LE(x, y)
-#define SAFTM_DCHECK_GE(x, y) SAFTM_CHECK_GE(x, y)
-#define SAFTM_DCHECK_NE(x, y) SAFTM_CHECK_NE(x, y)
-
-// In debug mode, SAFT_DLOG statements are like SAFT_LOG.
-#define SAFTM_DLOG SAFTM_LOG
-
-#endif // NDEBUG
-
-#ifdef LIBTEXTCLASSIFIER_VLOG
-#define SAFTM_VLOG(severity) \
- ::libtextclassifier3::mobile::internal_logging::LogMessage( \
- ::libtextclassifier3::mobile::internal_logging::INFO, __FILE__, __LINE__) \
- .stream()
-#else
-#define SAFTM_VLOG(severity) SAFTM_NULLSTREAM
-#endif
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
diff --git a/lang_id/common/lite_strings/numbers.cc b/lang_id/common/lite_strings/numbers.cc
deleted file mode 100644
index e0c66f3..0000000
--- a/lang_id/common/lite_strings/numbers.cc
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/common/lite_strings/numbers.h"
-
-#include <ctype.h>
-#include <stdlib.h>
-#include <climits>
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Returns true if the characters that start at address ptr (inclusive) and stop
-// at the first '\0' consist of only whitespaces, as determined by isspace().
-// Note: this function returns false if ptr is nullptr.
-static bool OnlyWhitespaces(const char *ptr) {
- if (ptr == nullptr) {
- return false;
- }
- for (; *ptr != '\0'; ++ptr) {
- if (!isspace(*ptr)) {
- return false;
- }
- }
- return true;
-}
-
-bool LiteAtoi(const char *c_str, int *value) {
- if (c_str == nullptr) {
- return false;
- }
-
- // Short version of man strtol:
- //
- // strtol parses some optional whitespaces, an optional +/- sign, and next a
- // succession of digits. If it finds some digits, it sets temp to point to
- // the first character after that succession of digits and returns the parsed
- // integer.
- //
- // If there were no digits at all, strtol() sets temp to be c_str (the start
- // address) and returns 0.
- char *temp = nullptr;
- const long int parsed_value = strtol(c_str, &temp, 0); // NOLINT
-
- // Check for overflow. Note: to simplify the code, we assume that LONG_MIN /
- // LONG_MAX means that strtol encountered an overflow (normally, in that case,
- // one should also inspect errno). Hence, we maybe give up the possibility to
- // parse one extreme value on each side (min/max). That should be ok.
- if ((parsed_value == LONG_MIN) || (parsed_value == LONG_MAX) ||
- (parsed_value < INT_MIN) || (parsed_value > INT_MAX)) {
- return false;
- }
- *value = static_cast<int>(parsed_value);
-
- // First part of the expression below means that the input string contained at
- // least one digit. The other part checks that what remains after the number
- // (if anything) consists only of whitespaces.
- return (temp != c_str) && OnlyWhitespaces(temp);
-}
-
-bool LiteAtof(const char *c_str, float *value) {
- if (c_str == nullptr) {
- return false;
- }
-
- // strtof is similar to strtol, see more detailed comments inside LiteAtoi.
- char *temp = nullptr;
- *value = strtof(c_str, &temp);
- return (temp != c_str) && OnlyWhitespaces(temp);
-}
-
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/numbers.h b/lang_id/common/lite_strings/numbers.h
deleted file mode 100644
index 4b3c93c..0000000
--- a/lang_id/common/lite_strings/numbers.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
-
-#include <string>
-
-#include "lang_id/common/lite_strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Parses an int from a C-style string; similar to absl::SimpleAtoi.
-//
-// c_str should point to a zero-terminated array of chars that contains the
-// number representation as (a) "<radix-10-number>" (e.g., "721"), (b)
-// "0x<radix-16-number>" (e.g., "0xa1"), or (c) "0<radix-8-number>" (e.g.,
-// "017201"). Whitespaces (as determined by isspace()) are allowed before and
-// after the number representation (but obviously not in the middle).
-//
-// Stores parsed number into *value. Returns true on success, false on error.
-// Note: presence of extra non-whitespace characters after the number counts as
-// an error: e.g., parsing "123a" will return false due to the extra "a" (which
-// is not a valid radix-10 digit). This function also returns false for strings
-// that do not contain any digit (e.g., ""), as well as for overflows /
-// underflows.
-bool LiteAtoi(const char *c_str, int *value);
-
-inline bool LiteAtoi(const string &s, int *value) {
- return LiteAtoi(s.c_str(), value);
-}
-
-inline bool LiteAtoi(StringPiece sp, int *value) {
- // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
- // char *) needs a zero-terminated string.
- const string temp(sp.data(), sp.size());
- return LiteAtoi(temp.c_str(), value);
-}
-
-// Like LiteAtoi, but for float; similar to absl::SimpleAtof.
-//
-// NOTE: currently, does not properly handle overflow / underflow.
-// TODO(salcianu): fix that.
-bool LiteAtof(const char *c_str, float *value);
-
-inline bool LiteAtof(const string &s, float *value) {
- return LiteAtof(s.c_str(), value);
-}
-
-inline bool LiteAtof(StringPiece sp, float *value) {
- // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
- // char *) needs a zero-terminated string.
- const string temp(sp.data(), sp.size());
- return LiteAtof(temp.c_str(), value);
-}
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
diff --git a/lang_id/common/lite_strings/str-cat.h b/lang_id/common/lite_strings/str-cat.h
deleted file mode 100644
index f0c1682..0000000
--- a/lang_id/common/lite_strings/str-cat.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
-
-// Less efficient but more compact versions of several absl string utils.
-//
-// "More compact" means "pulls in fewer code dependencies". That's useful if
-// one tries to minimize the code size.
-//
-// Note: the name and the signature of the functions from this header were
-// chosen to minimize the effort of converting code that uses absl::LiteStrCat &
-// co to our more compact functions.
-
-#include <string>
-
-#ifdef COMPILER_MSVC
-#include <sstream>
-#endif // COMPILER_MSVC
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Less efficient but more compact version of absl::LiteStrCat().
-//
-// Given a value v (see supported types below) LiteStrCat(v) returns a new
-// string that contains the representation of v. For examples, see
-// str-cat_test.cc.
-template <typename T>
-inline string LiteStrCat(T v) {
-#ifdef COMPILER_MSVC
- std::stringstream stream;
- stream << input;
- return stream.str();
-#else
- return std::to_string(v);
-#endif
-}
-
-template <>
-inline string LiteStrCat(const char *v) {
- return string(v);
-}
-
-// TODO(salcianu): use a reference type (const string &). For some reason, I
-// couldn't get that to work on a first try.
-template <>
-inline string LiteStrCat(string v) {
- return v;
-}
-
-template <>
-inline string LiteStrCat(char v) {
- return string(1, v);
-}
-
-// Less efficient but more compact version of absl::LiteStrAppend().
-template <typename T>
-inline void LiteStrAppend(string *dest, T v) {
- dest->append(LiteStrCat(v)); // NOLINT
-}
-
-template <typename T1, typename T2>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2) {
- dest->append(LiteStrCat(v1)); // NOLINT
- dest->append(LiteStrCat(v2)); // NOLINT
-}
-
-template <typename T1, typename T2, typename T3>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3) {
- LiteStrAppend(dest, v1, v2);
- dest->append(LiteStrCat(v3)); // NOLINT
-}
-
-template <typename T1, typename T2, typename T3, typename T4>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
- LiteStrAppend(dest, v1, v2, v3);
- dest->append(LiteStrCat(v4)); // NOLINT
-}
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
diff --git a/lang_id/common/lite_strings/stringpiece.h b/lang_id/common/lite_strings/stringpiece.h
deleted file mode 100644
index 59a2176..0000000
--- a/lang_id/common/lite_strings/stringpiece.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef TC3_STD_STRING_IMPORT
-#define TC3_STD_STRING_IMPORT
-#include <string>
-
-namespace libtextclassifier3 {
-using string = std::string;
-template <class CharT, class Traits = std::char_traits<CharT>,
- class Allocator = std::allocator<CharT> >
-using basic_string = std::basic_string<CharT, Traits, Allocator>;
-} // namespace libtextclassifier3
-#endif
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
-
-#include <stddef.h>
-#include <string.h>
-
-#include <ostream>
-#include <string>
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Read-only "view" of a piece of data. Does not own the underlying data.
-class StringPiece {
- public:
- StringPiece() : StringPiece(nullptr, 0) {}
-
- StringPiece(const char *str) // NOLINT
- : start_(str), size_(strlen(str)) {}
-
- StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
-
- // Intentionally no "explicit" keyword: in function calls, we want strings to
- // be converted to StringPiece implicitly.
- StringPiece(const string &s) // NOLINT
- : StringPiece(s.data(), s.size()) {}
-
- StringPiece(const string &s, int offset, int len)
- : StringPiece(s.data() + offset, len) {}
-
- char operator[](size_t i) const { return start_[i]; }
-
- // Returns start address of underlying data.
- const char *data() const { return start_; }
-
- // Returns number of bytes of underlying data.
- size_t size() const { return size_; }
-
- // Returns true if this StringPiece does not refer to any characters.
- bool empty() const { return size() == 0; }
-
- template <typename A>
- explicit operator basic_string<char, std::char_traits<char>, A>() const {
- if (!data()) return {};
- return basic_string<char, std::char_traits<char>, A>(data(), size());
- }
-
- private:
- const char *start_; // Not owned.
- size_t size_;
-};
-
-inline std::ostream &operator<<(std::ostream &out, StringPiece sp) {
- return out.write(sp.data(), sp.size());
-}
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
diff --git a/lang_id/common/math/algorithm.h b/lang_id/common/math/algorithm.h
deleted file mode 100644
index a963807..0000000
--- a/lang_id/common/math/algorithm.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Generic utils similar to those from the C++ header <algorithm>.
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
-
-#include <algorithm>
-#include <vector>
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-// Returns index of max element from the vector |elements|. Returns 0 if
-// |elements| is empty. T should be a type that can be compared by operator<.
-template<typename T>
-inline int GetArgMax(const std::vector<T> &elements) {
- return std::distance(
- elements.begin(),
- std::max_element(elements.begin(), elements.end()));
-}
-
-// Returns index of min element from the vector |elements|. Returns 0 if
-// |elements| is empty. T should be a type that can be compared by operator<.
-template<typename T>
-inline int GetArgMin(const std::vector<T> &elements) {
- return std::distance(
- elements.begin(),
- std::min_element(elements.begin(), elements.end()));
-}
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
diff --git a/lang_id/common/math/fastexp.h b/lang_id/common/math/fastexp.h
deleted file mode 100644
index 05b654a..0000000
--- a/lang_id/common/math/fastexp.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Fast approximation for exp.
-//
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
-
-#include <cassert>
-#include <cmath>
-#include <limits>
-
-#include "lang_id/common/lite_base/casts.h"
-#include "lang_id/common/lite_base/integral-types.h"
-#include "lang_id/common/lite_base/logging.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-class FastMathClass {
- private:
- static const int kBits = 7;
- static const int kMask1 = (1 << kBits) - 1;
- static const int kMask2 = 0xFF << kBits;
- static constexpr float kLogBase2OfE = 1.44269504088896340736f;
-
- struct Table {
- int32 exp1[1 << kBits];
- };
-
- public:
- float VeryFastExp2(float f) const {
- SAFTM_DCHECK_LE(fabs(f), 126);
- const float g = f + (127 + (1 << (23 - kBits)));
- const int32 x = bit_cast<int32>(g);
- int32 ret = ((x & kMask2) << (23 - kBits))
- | cache_.exp1[x & kMask1];
- return bit_cast<float>(ret);
- }
-
- float VeryFastExp(float f) const {
- return VeryFastExp2(f * kLogBase2OfE);
- }
-
- private:
- static const Table cache_;
-};
-
-extern FastMathClass FastMathInstance;
-
-inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
diff --git a/lang_id/common/math/hash.h b/lang_id/common/math/hash.h
deleted file mode 100644
index 08c32be..0000000
--- a/lang_id/common/math/hash.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef TC3_STD_STRING_IMPORT
-#define TC3_STD_STRING_IMPORT
-#include <string>
-
-namespace libtextclassifier3 {
-using string = std::string;
-template <class CharT, class Traits = std::char_traits<CharT>,
- class Allocator = std::allocator<CharT> >
-using basic_string = std::basic_string<CharT, Traits, Allocator>;
-} // namespace libtextclassifier3
-#endif
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
-
-#include <string>
-
-#include "lang_id/common/lite_base/integral-types.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace utils {
-
-// Returns a 32 bit hash of the |n| bytes that start at |data|, using |seed| for
-// internal initialization. By changing the seed, one effectively gets
-// different hash functions.
-//
-// NOTE: this function is guaranteed not to change in the future.
-//
-// IMPORTANT: for speed reasons, this method does not check its parameters
-// |data| and |n|. The caller should ensure that n >= 0 and that one can read
-// from the memory area [data, data + n).
-uint32 Hash32(const char *data, size_t n, uint32 seed);
-
-static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
- return Hash32(data, n, 0xBEEF);
-}
-
-static inline uint32 Hash32WithDefaultSeed(const string &input) {
- return Hash32WithDefaultSeed(input.data(), input.size());
-}
-
-} // namespace utils
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
diff --git a/lang_id/common/registry.h b/lang_id/common/registry.h
deleted file mode 100644
index d2c5271..0000000
--- a/lang_id/common/registry.h
+++ /dev/null
@@ -1,321 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Mechanism to instantiate classes by name.
-//
-// This mechanism is useful if the concrete classes to be instantiated are not
-// statically known (e.g., if their names are read from a dynamically-provided
-// config).
-//
-// In that case, the first step is to define the API implemented by the
-// instantiated classes. E.g.,
-//
-// // In a header file function.h:
-//
-// // Abstract function that takes a double and returns a double.
-// class Function : public RegisterableClass<Function> {
-// public:
-// virtual ~Function() {}
-// virtual double Evaluate(double x) = 0;
-// };
-//
-// // Should be inside namespace libtextclassifier3::mobile.
-// SAFTM_DECLARE_CLASS_REGISTRY_NAME(Function);
-//
-// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
-// is defined by this file (registry.h). Under the hood, this inheritanace
-// defines a "registry" that maps names (zero-terminated arrays of chars) to
-// factory methods that create Functions. You should give a human-readable name
-// to this registry. To do that, use the following macro in a .cc file (it has
-// to be a .cc file, as it defines some static data):
-//
-// // Inside function.cc
-// // Should be inside namespace libtextclassifier3::mobile.
-// SAFTM_DEFINE_CLASS_REGISTRY_NAME("function", Function);
-//
-// Now, let's define a few concrete Functions: e.g.,
-//
-// class Cos : public Function {
-// public:
-// double Evaluate(double x) override { return cos(x); }
-// SAFTM_DEFINE_REGISTRATION_METHOD("cos", Cos);
-// };
-//
-// class Exp : public Function {
-// public:
-// double Evaluate(double x) override { return exp(x); }
-// SAFTM_DEFINE_REGISTRATION_METHOD("sin", Sin);
-// };
-//
-// Each concrete Function implementation should have (in the public section) the
-// macro
-//
-// SAFTM_DEFINE_REGISTRATION_METHOD("name", implementation_class);
-//
-// This defines a RegisterClass static method that, when invoked, associates
-// "name" with a factory method that creates instances of implementation_class.
-//
-// Before instantiating Functions by name, we need to tell our system which
-// Functions we may be interested in. This is done by calling the
-// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
-// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
-// first call will perform something, the others will return immediately.
-//
-// Cos::RegisterClass();
-// Exp::RegisterClass();
-//
-// Now, let's instantiate a Function based on its name. This get a lot more
-// interesting if the Function name is not statically known (i.e.,
-// read from an input proto:
-//
-// std::unique_ptr<Function> f(Function::Create("cos"));
-// double result = f->Evaluate(arg);
-//
-// NOTE: the same binary can use this mechanism for different APIs. E.g., one
-// can also have (in the binary with Function, Sin, Cos, etc):
-//
-// class IntFunction : public RegisterableClass<IntFunction> {
-// public:
-// virtual ~IntFunction() {}
-// virtual int Evaluate(int k) = 0;
-// };
-//
-// SAFTM_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
-//
-// SAFTM_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
-//
-// class Inc : public IntFunction {
-// public:
-// int Evaluate(int k) override { return k + 1; }
-// SAFTM_DEFINE_REGISTRATION_METHOD("inc", Inc);
-// };
-//
-// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
-// own registries: each maps string names to implementation of the corresponding
-// API.
-//
-// NOTE: the mechanism described above requires you to explicitly call
-// RegisterClass() for all relevant classes before instantiating them. You can
-// do this in the main() function or in any other function that is guaranteed to
-// run before the code that instantiates those classes. Alternatively, you can
-// use the macro SAFTM_STATIC_REGISTRATION to perform this registration in a
-// decentralized fashion. Just use that macro in a .cc file, outside any
-// function / class, e.g.,
-//
-// SAFTM_STATIC_REGISTRATION(Cos);
-//
-// and make sure you link in all symbols from that .cc file; e.g., in bazel, use
-// alwayslink = 1 for the corresponding cc_library. Still, please be aware that
-// using alwayslink = 1 limits the ability of the linker to perform dead code
-// elimination.
-
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
-
-#include <stdlib.h>
-#include <string.h>
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_base/macros.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-
-namespace internal {
-// Registry that associates keys (zero-terminated array of chars) with values.
-// Values are pointers to type T (the template parameter). This is used to
-// store the association between component names and factory methods that
-// produce those components; the error messages are focused on that case.
-//
-// Internally, this registry uses a linked list of (key, value) pairs. We do
-// not use an STL map, list, etc because we aim for small code size.
-template <class T>
-class ComponentRegistry {
- public:
- explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
-
- // Adds a the (key, value) pair to this registry (if the key does not already
- // exists in this registry) and returns true. If the registry already has a
- // mapping for key, returns false and does not modify the registry. NOTE: the
- // error (false) case happens even if the existing value for key is equal with
- // the new one.
- //
- // This method does not take ownership of key, nor of value.
- bool Add(const char *key, T *value) {
- const Cell *old_cell = FindCell(key);
- if (old_cell != nullptr) {
- SAFTM_LOG(ERROR) << "Duplicate component: " << key;
- return false;
- }
- Cell *new_cell = new Cell(key, value, head_);
- head_ = new_cell;
- return true;
- }
-
- // Returns the value attached to a key in this registry. Returns nullptr on
- // error (e.g., unknown key).
- T *Lookup(const char *key) const {
- const Cell *cell = FindCell(key);
- if (cell == nullptr) {
- SAFTM_LOG(ERROR) << "Unknown " << name() << " component: " << key;
- }
- return (cell == nullptr) ? nullptr : cell->value();
- }
-
- T *Lookup(const string &key) const { return Lookup(key.c_str()); }
-
- // Returns name of this ComponentRegistry.
- const char *name() const { return name_; }
-
- // Fills *names with names of all components registered in this
- // ComponentRegistry. Previous content of *names is cleared out.
- void GetComponentNames(std::vector<string> *names) {
- names->clear();
- for (const Cell *c = head_; c!= nullptr; c = c->next()) {
- names->emplace_back(c->key());
- }
- }
-
- private:
- // Cell for the singly-linked list underlying this ComponentRegistry. Each
- // cell contains a key, the value for that key, as well as a pointer to the
- // next Cell from the list.
- class Cell {
- public:
- // Constructs a new Cell.
- Cell(const char *key, T *value, Cell *next)
- : key_(key), value_(value), next_(next) {}
-
- const char *key() const { return key_; }
- T *value() const { return value_; }
- Cell *next() const { return next_; }
-
- private:
- const char *const key_;
- T *const value_;
- Cell *const next_;
- };
-
- // Finds Cell for indicated key in the singly-linked list pointed to by head_.
- // Returns pointer to that first Cell with that key, or nullptr if no such
- // Cell (i.e., unknown key).
- //
- // Caller does NOT own the returned pointer.
- const Cell *FindCell(const char *key) const {
- const Cell *c = head_;
- while (c != nullptr && strcmp(key, c->key()) != 0) {
- c = c->next();
- }
- return c;
- }
-
- // Human-readable description for this ComponentRegistry. For debug purposes.
- const char *const name_;
-
- // Pointer to the first Cell from the underlying list of (key, value) pairs.
- Cell *head_;
-};
-} // namespace internal
-
-// Base class for registerable classes.
-template <class T>
-class RegisterableClass {
- public:
- // Factory function type.
- typedef T *(Factory)();
-
- // Registry type.
- typedef internal::ComponentRegistry<Factory> Registry;
-
- // Creates a new instance of T. Returns pointer to new instance or nullptr in
- // case of errors (e.g., unknown component).
- //
- // Passes ownership of the returned pointer to the caller.
- static T *Create(const string &name) { // NOLINT
- auto *factory = registry()->Lookup(name);
- if (factory == nullptr) {
- SAFTM_LOG(ERROR) << "Unknown RegisterableClass " << name;
- return nullptr;
- }
- return factory();
- }
-
- // Returns registry for class.
- static Registry *registry() {
- static Registry *registry_for_type_t = new Registry(kRegistryName);
- return registry_for_type_t;
- }
-
- protected:
- // Factory method for subclass ComponentClass. Used internally by the static
- // method RegisterClass() defined by SAFTM_DEFINE_REGISTRATION_METHOD.
- template <class ComponentClass>
- static T *_internal_component_factory() {
- return new ComponentClass();
- }
-
- private:
- // Human-readable name for the registry for this class.
- static const char kRegistryName[];
-};
-
-// Defines the static method component_class::RegisterClass() that should be
-// called before trying to instantiate component_class by name. Should be used
-// inside the public section of the declaration of component_class. See
-// comments at the top-level of this file.
-#define SAFTM_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
- static void RegisterClass() { \
- static bool once = registry()->Add( \
- component_name, &_internal_component_factory<component_class>); \
- if (!once) { \
- SAFTM_LOG(ERROR) << "Problem registering " << component_name; \
- } \
- SAFTM_DCHECK(once); \
- }
-
-// Defines the human-readable name of the registry associated with base_class.
-#define SAFTM_DECLARE_CLASS_REGISTRY_NAME(base_class) \
- template <> \
- const char ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[]
-
-// Defines the human-readable name of the registry associated with base_class.
-#define SAFTM_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
- template <> \
- const char \
- ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[] \
- = registry_name
-
-// Register component_name, by calling component_class::RegisterClass() on
-// program start-up, before main. NOTE: this macro should be used in
-// conjunction with something like alwayslink = 1 from bazel. That is
-// discouraged, as it prevents the linker from doing dead code elimination, so
-// please use this macro only in special cases. Instead, if you care about code
-// size, then you should aim to explicitly call RegisterClass from your code
-// (e.g., from the main method, or from the constructor of the class that may
-// need those registered components).
-#define SAFTM_STATIC_REGISTRATION(component_class) \
- static bool SAFTM_UNIQUE_ID(_kRegistrationDummy) = [] { \
- component_class::RegisterClass(); \
- return true; \
- }()
-
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
diff --git a/lang_id/common/utf8.h b/lang_id/common/utf8.h
deleted file mode 100644
index 2365429..0000000
--- a/lang_id/common/utf8.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef TC3_STD_STRING_IMPORT
-#define TC3_STD_STRING_IMPORT
-#include <string>
-
-namespace libtextclassifier3 {
-using string = std::string;
-template <class CharT, class Traits = std::char_traits<CharT>,
- class Allocator = std::allocator<CharT> >
-using basic_string = std::basic_string<CharT, Traits, Allocator>;
-} // namespace libtextclassifier3
-#endif
-#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
-#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
-
-#include <stddef.h>
-
-#include <string>
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace utils {
-
-// Returns the length (number of bytes) of the UTF8 code point starting at src,
-// by reading only the byte from address src.
-//
-// The result is a number from the set {1, 2, 3, 4}.
-static inline int OneCharLen(const char *src) {
- // On most platforms, char is unsigned by default, but iOS is an exception.
- // The cast below makes sure we always interpret *src as an unsigned char.
- return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
- [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
-}
-
-// Returns a pointer "end" inside [data, data + size) such that the prefix from
-// [data, end) is the largest one that does not contain '\0' and offers the
-// following guarantee: if one starts with
-//
-// curr = text.data()
-//
-// and keeps executing
-//
-// curr += OneCharLen(curr)
-//
-// one would eventually reach curr == end (the pointer returned by this
-// function) without accessing data outside the string. This guards against
-// scenarios like a broken UTF8 string which has only e.g., the first 2 bytes
-// from a 3-byte UTF8 sequence.
-//
-// Preconditions: data != nullptr.
-const char *GetSafeEndOfUtf8String(const char *data, size_t size);
-
-static inline const char *GetSafeEndOfUtf8String(const string &text) {
- return GetSafeEndOfUtf8String(text.data(), text.size());
-}
-
-} // namespace utils
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
diff --git a/lang_id/custom-tokenizer.cc b/lang_id/custom-tokenizer.cc
deleted file mode 100644
index f77ad53..0000000
--- a/lang_id/custom-tokenizer.cc
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/custom-tokenizer.h"
-
-#include <ctype.h>
-
-#include <string>
-
-#include "lang_id/common/lite_base/attributes.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/utf8.h"
-#include "utf.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-namespace {
-inline bool IsTokenSeparator(int num_bytes, const char *curr) {
- if (num_bytes != 1) {
- return false;
- }
- return !isalpha(*curr);
-}
-
-// Appends to *word the UTF8 encoding for the lowercase version of the UTF8
-// character that starts at |curr| and has |num_bytes| bytes.
-//
-// NOTE: if the current UTF8 character does not have a lowercase version, then
-// we append the original UTF8 character.
-inline SAFTM_ATTRIBUTE_ALWAYS_INLINE void AppendLowerCase(const char *curr,
- int num_bytes,
- string *word) {
- if (num_bytes == 1) {
- // Optimize the ASCII case.
- word->push_back(tolower(*curr));
- return;
- }
-
- // Harder, general case.
- //
- // NOTE: for lowercasing, we use the utils from utf.h:
- // charntorune + tolowerrune + runetochar. Unfortunately, that library does
- // not contain any fast util for determining the number of bytes for the UTF8
- // character that starts at a given address *without* converting to a full
- // codepoint (like our utils::OneCharLen, which is used intensively by the
- // rest of our code, including by the performance-critical char ngram
- // feature). Hence, the rest of our code continues to use utils::OneCharLen,
- // and here, when we append the bytes to *word, we make sure that's consistent
- // with utils::OneCharLen.
-
- // charntorune() below reads the UTF8 character that starts at curr (using at
- // most num_bytes bytes) and stores the corresponding codepoint into rune.
- Rune rune;
- charntorune(&rune, curr, num_bytes);
- if (rune != Runeerror) {
- Rune lower = tolowerrune(rune);
- char lower_buf[UTFmax];
- runetochar(lower_buf, &lower);
-
- // When appending the UTF8 bytes to word, we do not use the number of bytes
- // returned by runetochar(); instead, we use utils::OneCharLen(), the same
- // method used by the char ngram feature. We expect them to be equal, but
- // just in case.
- int lower_num_bytes = utils::OneCharLen(lower_buf);
-
- // Using lower_num_bytes below is safe, because, by definition of UTFmax,
- SAFTM_DCHECK_GE(UTFmax, 4);
-
- // And, by implementation of utils::OneCharLen():
- SAFTM_DCHECK_GT(lower_num_bytes, 0);
- SAFTM_DCHECK_LE(lower_num_bytes, 4);
- word->append(lower_buf, lower_num_bytes);
- } else {
- // There are sequences of bytes that charntorune() can't convert into a
- // valid Rune (a special case is [0xEF, 0xBF, 0xBD], the UTF8 encoding for
- // the U+FFFD special Unicode character, which is also the value of
- // Runeerror). We keep those bytes unchanged.
- word->append(curr, num_bytes);
- }
-}
-} // namespace
-
-void TokenizerForLangId::Setup(TaskContext *context) {
- lowercase_input_ = context->Get("lang_id_lowercase_input", false);
-}
-
-void TokenizerForLangId::Tokenize(StringPiece text,
- LightSentence *sentence) const {
- const char *const start = text.data();
- const char *curr = start;
- const char *end = utils::GetSafeEndOfUtf8String(start, text.size());
-
- // Corner case: the safe part of the text is empty ("").
- if (curr >= end) {
- return;
- }
-
- // Number of bytes for UTF8 character starting at *curr. Note: the loop below
- // is guaranteed to terminate because in each iteration, we move curr by at
- // least num_bytes, and num_bytes is guaranteed to be > 0.
- int num_bytes = utils::OneCharLen(curr);
- while (curr < end) {
- // Jump over consecutive token separators.
- while (IsTokenSeparator(num_bytes, curr)) {
- curr += num_bytes;
- if (curr >= end) {
- return;
- }
- num_bytes = utils::OneCharLen(curr);
- }
-
- // If control reaches this point, we are at beginning of a non-empty token.
- sentence->emplace_back();
- string *word = &(sentence->back());
-
- // Add special token-start character.
- word->push_back('^');
-
- // Add UTF8 characters to word, until we hit the end of the safe text or a
- // token separator.
- while (true) {
- if (lowercase_input_) {
- AppendLowerCase(curr, num_bytes, word);
- } else {
- word->append(curr, num_bytes);
- }
- curr += num_bytes;
- if (curr >= end) {
- break;
- }
- num_bytes = utils::OneCharLen(curr);
- if (IsTokenSeparator(num_bytes, curr)) {
- curr += num_bytes;
- if (curr >= end) {
- break;
- }
- num_bytes = utils::OneCharLen(curr);
- break;
- }
- }
- word->push_back('$');
- }
-}
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/fb_model/lang-id-from-fb.cc b/lang_id/fb_model/lang-id-from-fb.cc
deleted file mode 100644
index f8e39d7..0000000
--- a/lang_id/fb_model/lang-id-from-fb.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/fb_model/lang-id-from-fb.h"
-
-#include "lang_id/fb_model/model-provider-from-fb.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename) {
- std::unique_ptr<ModelProvider> model_provider(
- new ModelProviderFromFlatbuffer(filename));
-
- // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
- return std::unique_ptr<LangId>( // NOLINT
- new LangId(std::move(model_provider)));
-}
-
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd) {
- std::unique_ptr<ModelProvider> model_provider(
- new ModelProviderFromFlatbuffer(fd));
-
- // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
- return std::unique_ptr<LangId>( // NOLINT
- new LangId(std::move(model_provider)));
-}
-
-std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
- size_t num_bytes) {
- std::unique_ptr<ModelProvider> model_provider(
- new ModelProviderFromFlatbuffer(data, num_bytes));
-
- // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
- return std::unique_ptr<LangId>( // NOLINT
- new LangId(std::move(model_provider)));
-}
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/fb_model/lang-id-from-fb.h b/lang_id/fb_model/lang-id-from-fb.h
deleted file mode 100644
index 51bcffe..0000000
--- a/lang_id/fb_model/lang-id-from-fb.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
-
-#include <stddef.h>
-
-#include <memory>
-#include <string>
-
-#include "lang_id/lang-id.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// Returns a LangId built using the SAFT model in flatbuffer format from
-// |filename|.
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(const string &filename);
-
-// Returns a LangId built using the SAFT model in flatbuffer format from
-// given file descriptor.
-std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(int fd);
-
-// Returns a LangId built using the SAFT model in flatbuffer format from
-// the |num_bytes| bytes that start at address |data|.
-//
-// IMPORTANT: the model bytes must be alive during the lifetime of the returned
-// LangId. To avoid overhead (e.g., heap allocation), this method does not make
-// a private copy of the model bytes. Avoiding overhead is the main reason we
-// use flatbuffers.
-std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
- size_t num_bytes);
-
-// Convenience string-based version of GetLangIdFromFlatbufferBytes.
-//
-// IMPORTANT: |bytes| must be alive during the lifetime of the returned LangId.
-inline std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(
- const string &bytes) {
- return GetLangIdFromFlatbufferBytes(bytes.data(), bytes.size());
-}
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
diff --git a/lang_id/fb_model/model-provider-from-fb.cc b/lang_id/fb_model/model-provider-from-fb.cc
deleted file mode 100644
index 3357963..0000000
--- a/lang_id/fb_model/model-provider-from-fb.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/fb_model/model-provider-from-fb.h"
-
-#include "lang_id/common/file/file-utils.h"
-#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
-#include "lang_id/common/flatbuffers/model-utils.h"
-#include "lang_id/common/lite_strings/str-split.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
-
- // Using mmap as a fast way to read the model bytes. As the file is
- // unmapped only when the field scoped_mmap_ is destructed, the model bytes
- // stay alive for the entire lifetime of this object.
- : scoped_mmap_(new ScopedMmap(filename)) {
- Initialize(scoped_mmap_->handle().to_stringpiece());
-}
-
-ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
-
- // Using mmap as a fast way to read the model bytes. As the file is
- // unmapped only when the field scoped_mmap_ is destructed, the model bytes
- // stay alive for the entire lifetime of this object.
- : scoped_mmap_(new ScopedMmap(fd)) {
- Initialize(scoped_mmap_->handle().to_stringpiece());
-}
-
-void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
- // Note: valid_ was initialized to false. In the code below, we set valid_ to
- // true only if all initialization steps completed successfully. Otherwise,
- // we return early, leaving valid_ to its default value false.
- model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
- if (model_ == nullptr) {
- SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
- return;
- }
-
- // Initialize context_ parameters.
- if (!saft_fbs::FillParameters(*model_, &context_)) {
- // FillParameters already performs error logging.
- return;
- }
-
- // Init languages_.
- const string known_languages_str = context_.Get("supported_languages", "");
- for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
- languages_.emplace_back(sp);
- }
- if (languages_.empty()) {
- SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
- return;
- }
-
- // Init nn_params_.
- if (!InitNetworkParams()) {
- // InitNetworkParams already performs error logging.
- return;
- }
-
- // Everything looks fine.
- valid_ = true;
-}
-
-bool ModelProviderFromFlatbuffer::InitNetworkParams() {
- const string kInputName = "language-identifier-network";
- StringPiece bytes =
- saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
- if ((bytes.data() == nullptr) || bytes.empty()) {
- SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
- return false;
- }
- std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
- new EmbeddingNetworkParamsFromFlatbuffer(bytes));
- if (!nn_params_from_fb->is_valid()) {
- SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
- return false;
- }
- nn_params_ = std::move(nn_params_from_fb);
- return true;
-}
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/fb_model/model-provider-from-fb.h b/lang_id/fb_model/model-provider-from-fb.h
deleted file mode 100644
index d25c903..0000000
--- a/lang_id/fb_model/model-provider-from-fb.h
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
-
-#include <cstddef>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/file/mmap.h"
-#include "lang_id/common/flatbuffers/model_generated.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-#include "lang_id/model-provider.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// ModelProvider for LangId, based on a SAFT model in flatbuffer format.
-class ModelProviderFromFlatbuffer : public ModelProvider {
- public:
- // Constructs a model provider based on a flatbuffer-format SAFT model from
- // |filename|.
- explicit ModelProviderFromFlatbuffer(const string &filename);
-
- // Constructs a model provider based on a flatbuffer-format SAFT model from
- // file descriptor |fd|.
- explicit ModelProviderFromFlatbuffer(int fd);
-
- // Constructs a model provider from a flatbuffer-format SAFT model the bytes
- // of which are already in RAM (size bytes starting from address data).
- // Useful if you "transport" these bytes otherwise than via a normal file
- // (e.g., if you embed them somehow in your binary).
- //
- // IMPORTANT: |data| should be alive during the lifetime of the
- // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure
- // for data that's statically embedded in your binary, but more complex in
- // other cases. To avoid overhead (e.g., heap allocation), this method does
- // not make a private copy of the data. In general, the ownership of the
- // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a
- // LangId object (which doesn't pass it further); hence, one needs to make
- // sure |data| is alive during the lifetime of that LangId object.
- ModelProviderFromFlatbuffer(const char *data, std::size_t size) {
- StringPiece model_bytes(data, size);
- Initialize(model_bytes);
- }
-
- ~ModelProviderFromFlatbuffer() override = default;
-
- const TaskContext *GetTaskContext() const override {
- return &context_;
- }
-
- const EmbeddingNetworkParams *GetNnParams() const override {
- return nn_params_.get();
- }
-
- std::vector<string> GetLanguages() const override {
- return languages_;
- }
-
- private:
- // Initializes the fields of this class based on the flatbuffer from
- // |model_bytes|. These bytes are supposed to be the representation of a
- // Model flatbuffer and should be alive during the lifetime of this object.
- void Initialize(StringPiece model_bytes);
-
- // Initializes nn_params_ based on model_.
- bool InitNetworkParams();
-
- // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped
- // during the lifetime of this object, such that references inside the Model
- // flatbuffer from those bytes remain valid.
- const std::unique_ptr<ScopedMmap> scoped_mmap_;
-
- // Pointer to the flatbuffer from
- //
- // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_
- // (for safety considerations, see comment for that field), or
- //
- // (b) [of (data, size) constructor was used:] the bytes from [data,
- // data+size). Please read carefully the doc for that constructor.
- const saft_fbs::Model *model_;
-
- // Context returned by this model provider. We set its parameters based on
- // model_, at construction time.
- TaskContext context_;
-
- // List of supported languages, see GetLanguages(). We expect this list to be
- // specified by the ModelParameter named "supported_languages" from model_.
- std::vector<string> languages_;
-
- // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput
- // named "language-identifier-network" from model_.
- std::unique_ptr<EmbeddingNetworkParams> nn_params_;
-};
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
diff --git a/lang_id/features/char-ngram-feature.cc b/lang_id/features/char-ngram-feature.cc
deleted file mode 100644
index 83d7588..0000000
--- a/lang_id/features/char-ngram-feature.cc
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/features/char-ngram-feature.h"
-
-#include <utility>
-#include <vector>
-
-#include "lang_id/common/fel/feature-types.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/math/hash.h"
-#include "lang_id/common/utf8.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
- // Parameters in the feature function descriptor.
- bool include_terminators = GetBoolParameter("include_terminators", false);
- if (!include_terminators) {
- SAFTM_LOG(ERROR) << "No support for include_terminators=true";
- return false;
- }
-
- bool include_spaces = GetBoolParameter("include_spaces", false);
- if (include_spaces) {
- SAFTM_LOG(ERROR) << "No support for include_spaces=true";
- return false;
- }
-
- bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
- if (use_equal_ngram_weight) {
- SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
- return false;
- }
-
- ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
- ngram_size_ = GetIntParameter("size", 3);
-
- counts_.assign(ngram_id_dimension_, 0);
- return true;
-}
-
-bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
- set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
- return true;
-}
-
-int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
- const LightSentence &sentence) const {
- SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
- SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
-
- int total_count = 0;
-
- for (const string &word : sentence) {
- const char *const word_end = word.data() + word.size();
-
- // Set ngram_start at the start of the current token (word).
- const char *ngram_start = word.data();
-
- // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
- // UTF8 character contains between 1 and 4 bytes.
- const char *ngram_end = ngram_start;
- int num_utf8_chars = 0;
- do {
- ngram_end += utils::OneCharLen(ngram_end);
- num_utf8_chars++;
- } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
-
- if (num_utf8_chars < ngram_size_) {
- // Current token is so small, it does not contain a single ngram of
- // ngram_size UTF8 characters. Not much we can do in this case ...
- continue;
- }
-
- // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
- // UTF8 characters from current token.
- while (true) {
- // Compute ngram id: hash(ngram) % ngram_id_dimension
- int ngram_id = (
- utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
- % ngram_id_dimension_);
-
- // Use a reference to the actual count, such that we can both test whether
- // the count was 0 and increment it without perfoming two lookups.
- int &ref_to_count_for_ngram = counts_[ngram_id];
- if (ref_to_count_for_ngram == 0) {
- non_zero_count_indices_.push_back(ngram_id);
- }
- ref_to_count_for_ngram++;
- total_count++;
- if (ngram_end >= word_end) {
- break;
- }
-
- // Advance both ngram_start and ngram_end by one UTF8 character. This
- // way, the number of UTF8 characters between them remains constant
- // (ngram_size).
- ngram_start += utils::OneCharLen(ngram_start);
- ngram_end += utils::OneCharLen(ngram_end);
- }
- } // end of loop over tokens.
-
- return total_count;
-}
-
-void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
- const LightSentence &sentence,
- FeatureVector *result) const {
- // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
- // porting to Android and to avoid pulling in absl (which increases our code
- // size).
- std::lock_guard<std::mutex> mlock(state_mutex_);
-
- // Find the char ngram counts.
- int total_count = ComputeNgramCounts(sentence);
-
- // Populate the feature vector.
- const float norm = static_cast<float>(total_count);
-
- // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
- // elements) separately.
- for (int ngram_id : non_zero_count_indices_) {
- const float weight = counts_[ngram_id] / norm;
- FloatFeatureValue value(ngram_id, weight);
- result->add(feature_type(), value.discrete_value);
-
- // Clear up counts_, for the next invocation of Evaluate().
- counts_[ngram_id] = 0;
- }
-
- // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
- non_zero_count_indices_.clear();
-}
-
-SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/features/relevant-script-feature.cc b/lang_id/features/relevant-script-feature.cc
deleted file mode 100644
index 0fde87b..0000000
--- a/lang_id/features/relevant-script-feature.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/features/relevant-script-feature.h"
-
-#include <string>
-
-#include "lang_id/common/fel/feature-types.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/fel/workspace.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/utf8.h"
-#include "lang_id/script/script-detector.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-bool RelevantScriptFeature::Setup(TaskContext *context) {
- string script_detector_name = GetParameter(
- "script_detector_name", /* default_value = */ "tiny-script-detector");
-
- // We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
- script_detector_.reset(ScriptDetector::Create(script_detector_name));
- if (script_detector_ == nullptr) {
- // This means ScriptDetector::Create() could not find the requested
- // script_detector_name. In that case, Create() already logged an error
- // message.
- return false;
- }
-
- // We use default value 172 because this is the number of scripts supported by
- // the first model we trained with this feature. See http://b/70617713.
- // Newer models may support more scripts.
- num_supported_scripts_ = GetIntParameter("num_supported_scripts", 172);
- return true;
-}
-
-bool RelevantScriptFeature::Init(TaskContext *context) {
- set_feature_type(new NumericFeatureType(name(), num_supported_scripts_));
- return true;
-}
-
-void RelevantScriptFeature::Evaluate(
- const WorkspaceSet &workspaces, const LightSentence &sentence,
- FeatureVector *result) const {
- // counts[s] is the number of characters with script s.
- std::vector<int> counts(num_supported_scripts_);
- int total_count = 0;
- for (const string &word : sentence) {
- const char *const word_end = word.data() + word.size();
- const char *curr = word.data();
-
- // Skip over token start '^'.
- SAFTM_DCHECK_EQ(*curr, '^');
- curr += utils::OneCharLen(curr);
- while (true) {
- const int num_bytes = utils::OneCharLen(curr);
-
- int script = script_detector_->GetScript(curr, num_bytes);
-
- // We do this update and the if (...) break below *before* incrementing
- // counts[script] in order to skip the token end '$'.
- curr += num_bytes;
- if (curr >= word_end) {
- SAFTM_DCHECK_EQ(*(curr - num_bytes), '$');
- break;
- }
- SAFTM_DCHECK_GE(script, 0);
-
- if (script < num_supported_scripts_) {
- counts[script]++;
- total_count++;
- } else {
- // Unsupported script: this usually indicates a script that is
- // recognized by newer versions of the code, after the model was
- // trained. E.g., new code running with old model.
- }
- }
- }
-
- for (int script_id = 0; script_id < num_supported_scripts_; ++script_id) {
- int count = counts[script_id];
- if (count > 0) {
- const float weight = static_cast<float>(count) / total_count;
- FloatFeatureValue value(script_id, weight);
- result->add(feature_type(), value.discrete_value);
- }
- }
-}
-
-SAFTM_STATIC_REGISTRATION(RelevantScriptFeature);
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
deleted file mode 100644
index 1339223..0000000
--- a/lang_id/lang-id.cc
+++ /dev/null
@@ -1,320 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/lang-id.h"
-
-#include <stdio.h>
-
-#include <algorithm>
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "lang_id/common/embedding-feature-interface.h"
-#include "lang_id/common/embedding-network-params.h"
-#include "lang_id/common/embedding-network.h"
-#include "lang_id/common/fel/feature-extractor.h"
-#include "lang_id/common/lite_base/logging.h"
-#include "lang_id/common/lite_strings/numbers.h"
-#include "lang_id/common/lite_strings/str-split.h"
-#include "lang_id/common/lite_strings/stringpiece.h"
-#include "lang_id/common/math/algorithm.h"
-#include "lang_id/common/math/softmax.h"
-#include "lang_id/custom-tokenizer.h"
-#include "lang_id/features/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-namespace {
-// Default value for the confidence threshold. If the confidence of the top
-// prediction is below this threshold, then FindLanguage() returns
-// LangId::kUnknownLanguageCode. Note: this is just a default value; if the
-// TaskSpec from the model specifies a "reliability_thresh" parameter, then we
-// use that value instead. Note: for legacy reasons, our code and comments use
-// the terms "confidence", "probability" and "reliability" equivalently.
-static const float kDefaultConfidenceThreshold = 0.50f;
-} // namespace
-
-// Class that performs all work behind LangId.
-class LangIdImpl {
- public:
- explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
- : model_provider_(std::move(model_provider)),
- lang_id_brain_interface_("language_identifier") {
- // Note: in the code below, we set valid_ to true only if all initialization
- // steps completed successfully. Otherwise, we return early, leaving valid_
- // to its default value false.
- if (!model_provider_ || !model_provider_->is_valid()) {
- SAFTM_LOG(ERROR) << "Invalid model provider";
- return;
- }
-
- auto *nn_params = model_provider_->GetNnParams();
- if (!nn_params) {
- SAFTM_LOG(ERROR) << "No NN params";
- return;
- }
- network_.reset(new EmbeddingNetwork(nn_params));
-
- languages_ = model_provider_->GetLanguages();
- if (languages_.empty()) {
- SAFTM_LOG(ERROR) << "No known languages";
- return;
- }
-
- TaskContext context = *model_provider_->GetTaskContext();
- if (!Setup(&context)) {
- SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
- return;
- }
- if (!Init(&context)) {
- SAFTM_LOG(ERROR) << "Unable to Init() LangId";
- return;
- }
- valid_ = true;
- }
-
- string FindLanguage(StringPiece text) const {
- // NOTE: it would be wasteful to implement this method in terms of
- // FindLanguages(). We just need the most likely language and its
- // probability; no need to compute (and allocate) a vector of pairs for all
- // languages, nor to compute probabilities for all non-top languages.
- if (!is_valid()) {
- return LangId::kUnknownLanguageCode;
- }
-
- // Create a Sentence storing the input text.
- LightSentence sentence;
- tokenizer_.Tokenize(text, &sentence);
-
- // Test input size here, after pre-processing removed irrelevant chars.
- if (IsTooShort(sentence)) {
- return LangId::kUnknownLanguageCode;
- }
-
- std::vector<float> scores;
- ComputeScores(&sentence, &scores);
-
- int prediction_id = GetArgMax(scores);
- const string language = GetLanguageForSoftmaxLabel(prediction_id);
- float probability = ComputeSoftmaxProbability(scores, prediction_id);
- SAFTM_DLOG(INFO) << "Predicted " << language
- << " with prob: " << probability << " for \"" << text
- << "\"";
-
- // Find confidence threshold for language.
- float threshold = default_threshold_;
- auto it = per_lang_thresholds_.find(language);
- if (it != per_lang_thresholds_.end()) {
- threshold = it->second;
- }
- if (probability < threshold) {
- SAFTM_DLOG(INFO) << " below threshold => "
- << LangId::kUnknownLanguageCode;
- return LangId::kUnknownLanguageCode;
- }
- return language;
- }
-
- void FindLanguages(StringPiece text, LangIdResult *result) const {
- if (result == nullptr) return;
-
- result->predictions.clear();
- if (!is_valid()) {
- result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
- return;
- }
-
- // Create a Sentence storing the input text.
- LightSentence sentence;
- tokenizer_.Tokenize(text, &sentence);
-
- // Test input size here, after pre-processing removed irrelevant chars.
- if (IsTooShort(sentence)) {
- result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
- return;
- }
-
- std::vector<float> scores;
- ComputeScores(&sentence, &scores);
-
- // Compute and sort softmax in descending order by probability and convert
- // IDs to language code strings. When probabilities are equal, we sort by
- // language code string in ascending order.
- std::vector<float> softmax = ComputeSoftmax(scores);
-
- for (int i = 0; i < softmax.size(); ++i) {
- result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
- softmax[i]);
- }
-
- // Sort the resulting language predictions by probability in descending
- // order.
- std::sort(result->predictions.begin(), result->predictions.end(),
- [](const std::pair<string, float> &a,
- const std::pair<string, float> &b) {
- if (a.second == b.second) {
- return a.first.compare(b.first) < 0;
- } else {
- return a.second > b.second;
- }
- });
- }
-
- bool is_valid() const { return valid_; }
-
- int GetModelVersion() const { return model_version_; }
-
- // Returns a property stored in the model file.
- template <typename T, typename R>
- R GetProperty(const string &property, T default_value) const {
- return model_provider_->GetTaskContext()->Get(property, default_value);
- }
-
- private:
- bool Setup(TaskContext *context) {
- tokenizer_.Setup(context);
- if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
-
- min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
- default_threshold_ =
- context->Get("reliability_thresh", kDefaultConfidenceThreshold);
-
- // Parse task parameter "per_lang_reliability_thresholds", fill
- // per_lang_thresholds_.
- const string thresholds_str =
- context->Get("per_lang_reliability_thresholds", "");
- std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
- for (const auto &token : tokens) {
- if (token.empty()) continue;
- std::vector<StringPiece> parts = LiteStrSplit(token, '=');
- float threshold = 0.0f;
- if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
- per_lang_thresholds_[string(parts[0])] = threshold;
- } else {
- SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
- }
- }
- model_version_ = context->Get("model_version", model_version_);
- return true;
- }
-
- bool Init(TaskContext *context) {
- return lang_id_brain_interface_.InitForProcessing(context);
- }
-
- // Extracts features for |text|, runs them through the feed-forward neural
- // network, and computes the output scores (activations from the last layer).
- // These scores can be used to compute the softmax probabilities for our
- // labels (in this case, the languages).
- void ComputeScores(LightSentence* sentence, std::vector<float> *scores) const {
- std::vector<FeatureVector> features =
- lang_id_brain_interface_.GetFeaturesNoCaching(sentence);
-
- // Run feed-forward neural network to compute scores.
- network_->ComputeFinalScores(features, scores);
- }
-
- // Returns language code for a softmax label. See comments for languages_
- // field. If label is out of range, returns LangId::kUnknownLanguageCode.
- string GetLanguageForSoftmaxLabel(int label) const {
- if ((label >= 0) && (label < languages_.size())) {
- return languages_[label];
- } else {
- SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
- << languages_.size() << ")";
- return LangId::kUnknownLanguageCode;
- }
- }
-
- bool IsTooShort(const LightSentence &sentence) const {
- int text_size = 0;
- for (const std::string &token : sentence) {
- // Each token has the form ^...$: we subtract 2 because we want to count
- // only the real text, not the chars added by us.
- text_size += token.size() - 2;
- }
- return text_size < min_text_size_in_bytes_;
- }
-
- std::unique_ptr<ModelProvider> model_provider_;
-
- TokenizerForLangId tokenizer_;
-
- EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
- lang_id_brain_interface_;
-
- // Neural network to use for scoring.
- std::unique_ptr<EmbeddingNetwork> network_;
-
- // True if this object is ready to perform language predictions.
- bool valid_ = false;
-
- // The model returns LangId::kUnknownLanguageCode for input text that has
- // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
- // digits, and punctuation).
- int min_text_size_in_bytes_ = 0;
-
- // Only predictions with a probability (confidence) above this threshold are
- // reported. Otherwise, we report LangId::kUnknownLanguageCode.
- float default_threshold_ = kDefaultConfidenceThreshold;
-
- std::unordered_map<string, float> per_lang_thresholds_;
-
- // Recognized languages: softmax label i means languages_[i] (something like
- // "en", "fr", "ru", etc).
- std::vector<string> languages_;
-
- // Version of the model used by this LangIdImpl object. Zero means that the
- // model version could not be determined.
- int model_version_ = 0;
-};
-
-const char LangId::kUnknownLanguageCode[] = "und";
-
-LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
- : pimpl_(new LangIdImpl(std::move(model_provider))) {}
-
-LangId::~LangId() = default;
-
-string LangId::FindLanguage(const char *data, size_t num_bytes) const {
- StringPiece text(data, num_bytes);
- return pimpl_->FindLanguage(text);
-}
-
-void LangId::FindLanguages(const char *data, size_t num_bytes,
- LangIdResult *result) const {
- SAFTM_DCHECK(result) << "LangIdResult must not be null.";
- StringPiece text(data, num_bytes);
- pimpl_->FindLanguages(text, result);
-}
-
-bool LangId::is_valid() const { return pimpl_->is_valid(); }
-
-int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
-
-float LangId::GetFloatProperty(const string &property,
- float default_value) const {
- return pimpl_->GetProperty<float, float>(property, default_value);
-}
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
deleted file mode 100644
index 94af0c3..0000000
--- a/lang_id/lang-id.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
-
-
-#include <stddef.h>
-
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "lang_id/common/lite_base/macros.h"
-#include "lang_id/model-provider.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// Forward-declaration of the class that performs all underlying work.
-class LangIdImpl;
-
-struct LangIdResult {
- // An n-best list of possible language codes for a given input sorted in
- // descending order according to each code's respective probability.
- //
- // This list is guaranteed to be non-empty after calling
- // LangId::FindLanguages. The most likely language code is always the first
- // item in this array.
- //
- // If the model cannot make a prediction, this array contains a single result:
- // a language code LangId::kUnknownLanguageCode with probability 1.
- std::vector<std::pair<string, float>> predictions;
-};
-
-// Class for detecting the language of a document.
-//
-// Note: this class does not handle the details of loading the actual model.
-// Those details have been "outsourced" to the ModelProvider class.
-//
-// This class is thread safe.
-class LangId {
- public:
- // Standard BCP-47 language code for Unknown/Undetermined language.
- static const char kUnknownLanguageCode[];
-
- // Constructs a LangId object, based on |model_provider|.
- //
- // Note: we don't crash if we detect a problem at construction time (e.g., the
- // model provider can't read an underlying file). Instead, we mark the
- // newly-constructed object as invalid; clients can invoke FindLanguage() on
- // an invalid object: nothing crashes, but accuracy will be bad.
- explicit LangId(std::unique_ptr<ModelProvider> model_provider);
-
- virtual ~LangId();
-
- // Computes the an n-best list of language codes and probabilities
- // corresponding to the most likely languages the given input text is written
- // in. The list is sorted in descending order by language probability.
- //
- // The input text consists of the |num_bytes| bytes that starts at |data|.
- //
- // Note: If this LangId object is not valid (see is_valid()) or if this LangId
- // object can't make a prediction, this method sets the LangIdResult to
- // contain a single entry with kUnknownLanguageCode with probability 1.
- void FindLanguages(const char *data, size_t num_bytes,
- LangIdResult *result) const;
-
- // Convenience version of FindLanguages(const char *, size_t, LangIdResult *).
- void FindLanguages(const string &text, LangIdResult *result) const {
- FindLanguages(text.data(), text.size(), result);
- }
-
- // Returns language code for the most likely language for a piece of text.
- //
- // The input text consists of the |num_bytes| bytes that start at |data|.
- //
- // Note: this method reports the most likely (1-best) language only if its
- // probability is high enough; otherwise, it returns
- // LangId::kUnknownLanguageCode. The specific probability threshold is tuned
- // to the needs of an early client. If you need a different threshold, you
- // can use FindLanguages (plural) to get the full LangIdResult, and apply your
- // own threshold.
- //
- // Note: if this LangId object is not valid (see is_valid()) or if this LangId
- // object can't make a prediction, then this method returns
- // LangId::kUnknownLanguageCode.
- //
- string FindLanguage(const char *data, size_t num_bytes) const;
-
- // Convenience version of FindLanguage(const char *, size_t).
- string FindLanguage(const string &text) const {
- return FindLanguage(text.data(), text.size());
- }
-
- // Returns true if this object has been correctly initialized and is ready to
- // perform predictions. For more info, see doc for LangId
- // constructor above.
- bool is_valid() const;
-
- // Returns the version of the model used by this LangId object. On success,
- // the returned version number is a strictly positive integer. Returns 0 if
- // the model version can not be determined (e.g., for old models that do not
- // specify a version number).
- int GetModelVersion() const;
-
- // Returns a typed property stored in the model file.
- float GetFloatProperty(const string &property, float default_value) const;
-
- private:
- // Pimpl ("pointer to implementation") pattern, to hide all internals from our
- // clients.
- std::unique_ptr<LangIdImpl> pimpl_;
-
- SAFTM_DISALLOW_COPY_AND_ASSIGN(LangId);
-};
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc
deleted file mode 100644
index 61547e5..0000000
--- a/lang_id/lang-id_jni.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "lang_id/lang-id_jni.h"
-
-#include <jni.h>
-#include <type_traits>
-#include <vector>
-
-#include "utils/base/logging.h"
-#include "utils/java/scoped_local_ref.h"
-#include "lang_id/fb_model/lang-id-from-fb.h"
-#include "lang_id/lang-id.h"
-
-using libtextclassifier3::ScopedLocalRef;
-using libtextclassifier3::ToStlString;
-using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
-using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
-using libtextclassifier3::mobile::lang_id::LangId;
-using libtextclassifier3::mobile::lang_id::LangIdResult;
-
-namespace {
-jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
- const LangIdResult& lang_id_result) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
- "$LanguageResult"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
- return nullptr;
- }
-
- std::vector<std::pair<std::string, float>> predictions;
- std::copy_if(lang_id_result.predictions.begin(),
- lang_id_result.predictions.end(),
- std::back_inserter(predictions),
- [](std::pair<std::string, float> pair) {
- return pair.first != "und";
- });
-
- const jmethodID result_class_constructor =
- env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
- const jobjectArray results =
- env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
- for (int i = 0; i < predictions.size(); i++) {
- ScopedLocalRef<jobject> result(
- env->NewObject(result_class.get(), result_class_constructor,
- env->NewStringUTF(predictions[i].first.c_str()),
- static_cast<jfloat>(predictions[i].second)));
- env->SetObjectArrayElement(results, i, result.get());
- }
- return results;
-}
-} // namespace
-
-TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
-(JNIEnv* env, jobject thiz, jint fd) {
- std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
- if (!lang_id->is_valid()) {
- return reinterpret_cast<jlong>(nullptr);
- }
- return reinterpret_cast<jlong>(lang_id.release());
-}
-
-TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
-(JNIEnv* env, jobject thiz, jstring path) {
- const std::string path_str = ToStlString(env, path);
- std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
- if (!lang_id->is_valid()) {
- return reinterpret_cast<jlong>(nullptr);
- }
- return reinterpret_cast<jlong>(lang_id.release());
-}
-
-TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
- LangId* model = reinterpret_cast<LangId*>(ptr);
- if (!model) {
- return nullptr;
- }
-
- const std::string text_str = ToStlString(env, text);
- LangIdResult result;
- model->FindLanguages(text_str, &result);
-
- return LangIdResultToJObjectArray(env, result);
-}
-
-TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr) {
- if (!ptr) {
- TC3_LOG(ERROR) << "Trying to close null LangId.";
- return;
- }
- LangId* model = reinterpret_cast<LangId*>(ptr);
- delete model;
-}
-
-TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr) {
- if (!ptr) {
- return -1;
- }
- LangId* model = reinterpret_cast<LangId*>(ptr);
- return model->GetModelVersion();
-}
-
-TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
-(JNIEnv* env, jobject clazz, jint fd) {
- std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
- if (!lang_id->is_valid()) {
- return -1;
- }
- return lang_id->GetModelVersion();
-}
-
-TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
-(JNIEnv* env, jobject thizz, jlong ptr) {
- if (!ptr) {
- return -1.0;
- }
- LangId* model = reinterpret_cast<LangId*>(ptr);
- return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
-}
diff --git a/lang_id/lang-id_jni.h b/lang_id/lang-id_jni.h
deleted file mode 100644
index cd67a4c..0000000
--- a/lang_id/lang-id_jni.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// JNI wrapper for LangId.
-
-#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
-#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
-
-#include <jni.h>
-#include <string>
-#include "utils/java/jni-base.h"
-
-#ifndef TC3_LANG_ID_CLASS_NAME
-#define TC3_LANG_ID_CLASS_NAME LangIdModel
-#endif
-
-#define TC3_LANG_ID_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_LANG_ID_CLASS_NAME)
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
-(JNIEnv* env, jobject clazz, jstring path);
-
-TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
-
-TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
-(JNIEnv* env, jobject clazz, jlong ptr);
-
-TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
-(JNIEnv* env, jobject clazz, jlong ptr);
-
-TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
-(JNIEnv* env, jobject clazz, jint fd);
-
-TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
-(JNIEnv* env, jobject thizz, jlong ptr);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
diff --git a/lang_id/light-sentence.h b/lang_id/light-sentence.h
deleted file mode 100644
index 2937549..0000000
--- a/lang_id/light-sentence.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
-
-#include <string>
-#include <vector>
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// Very simplified alternative to heavy sentence.proto, for the purpose of
-// LangId. It turns out that in this case, all we need is a vector of strings,
-// which uses a lot less code size than a Sentence proto.
-using LightSentence = std::vector<string>;
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
diff --git a/lang_id/model-provider.h b/lang_id/model-provider.h
deleted file mode 100644
index a076871..0000000
--- a/lang_id/model-provider.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/embedding-network-params.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// Interface for accessing parameters for the LangId model.
-//
-// Note: some clients prefer to include the model parameters in the binary,
-// others prefer loading them from a separate file. This file provides a common
-// interface for these alternative mechanisms.
-class ModelProvider {
- public:
- virtual ~ModelProvider() = default;
-
- // Returns true if this ModelProvider has been succesfully constructed (e.g.,
- // can return false if an underlying model file could not be read). Clients
- // should not use invalid ModelProviders.
- bool is_valid() { return valid_; }
-
- // Returns the TaskContext with parameters for the LangId model. E.g., one
- // important parameter specifies the features to use.
- virtual const TaskContext *GetTaskContext() const = 0;
-
- // Returns parameters for the underlying Neurosis feed-forward neural network.
- virtual const EmbeddingNetworkParams *GetNnParams() const = 0;
-
- // Returns list of languages recognized by the model. Each element of the
- // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
- // Language at index i from the returned vector corresponds to softmax label
- // i.
- virtual std::vector<string> GetLanguages() const = 0;
-
- protected:
- bool valid_ = false;
-};
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
diff --git a/lang_id/script/approx-script-data.cc b/lang_id/script/approx-script-data.cc
deleted file mode 100755
index e11d7b7..0000000
--- a/lang_id/script/approx-script-data.cc
+++ /dev/null
@@ -1,1146 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Internal data for approx-script.cc; see approx-script-data.h
-//
-// DO NOT EDIT BY HAND
-//
-// Generated by
-// lang_id/script/update-script-data.sh
-
-#include "lang_id/script/approx-script-data.h"
-
-namespace libtextclassifier3 {
-namespace mobile {
-namespace approx_script_internal {
-
-const int kNumRanges = 367;
-
-const uint32 kRangeFirst[] = {
- 65, // Range #0: [65, 90, Latin]
- 97, // Range #1: [97, 122, Latin]
- 170, // Range #2: [170, 170, Latin]
- 186, // Range #3: [186, 186, Latin]
- 192, // Range #4: [192, 214, Latin]
- 216, // Range #5: [216, 246, Latin]
- 248, // Range #6: [248, 696, Latin]
- 736, // Range #7: [736, 740, Latin]
- 746, // Range #8: [746, 747, Bopomofo]
- 880, // Range #9: [880, 883, Greek]
- 885, // Range #10: [885, 893, Greek]
- 895, // Range #11: [895, 900, Greek]
- 902, // Range #12: [902, 902, Greek]
- 904, // Range #13: [904, 993, Greek]
- 994, // Range #14: [994, 1007, Coptic]
- 1008, // Range #15: [1008, 1023, Greek]
- 1024, // Range #16: [1024, 1156, Cyrillic]
- 1159, // Range #17: [1159, 1327, Cyrillic]
- 1329, // Range #18: [1329, 1416, Armenian]
- 1418, // Range #19: [1418, 1423, Armenian]
- 1425, // Range #20: [1425, 1479, Hebrew]
- 1488, // Range #21: [1488, 1524, Hebrew]
- 1536, // Range #22: [1536, 1540, Arabic]
- 1542, // Range #23: [1542, 1547, Arabic]
- 1549, // Range #24: [1549, 1562, Arabic]
- 1564, // Range #25: [1564, 1566, Arabic]
- 1568, // Range #26: [1568, 1599, Arabic]
- 1601, // Range #27: [1601, 1610, Arabic]
- 1622, // Range #28: [1622, 1647, Arabic]
- 1649, // Range #29: [1649, 1756, Arabic]
- 1758, // Range #30: [1758, 1791, Arabic]
- 1792, // Range #31: [1792, 1871, Syriac]
- 1872, // Range #32: [1872, 1919, Arabic]
- 1920, // Range #33: [1920, 1969, Thaana]
- 1984, // Range #34: [1984, 2047, Nko]
- 2048, // Range #35: [2048, 2110, Samaritan]
- 2112, // Range #36: [2112, 2142, Mandaic]
- 2144, // Range #37: [2144, 2154, Syriac]
- 2208, // Range #38: [2208, 2237, Arabic]
- 2259, // Range #39: [2259, 2273, Arabic]
- 2275, // Range #40: [2275, 2303, Arabic]
- 2304, // Range #41: [2304, 2384, Devanagari]
- 2389, // Range #42: [2389, 2403, Devanagari]
- 2406, // Range #43: [2406, 2431, Devanagari]
- 2432, // Range #44: [2432, 2510, Bengali]
- 2519, // Range #45: [2519, 2558, Bengali]
- 2561, // Range #46: [2561, 2641, Gurmukhi]
- 2649, // Range #47: [2649, 2654, Gurmukhi]
- 2662, // Range #48: [2662, 2678, Gurmukhi]
- 2689, // Range #49: [2689, 2768, Gujarati]
- 2784, // Range #50: [2784, 2801, Gujarati]
- 2809, // Range #51: [2809, 2815, Gujarati]
- 2817, // Range #52: [2817, 2893, Oriya]
- 2902, // Range #53: [2902, 2935, Oriya]
- 2946, // Range #54: [2946, 3024, Tamil]
- 3031, // Range #55: [3031, 3031, Tamil]
- 3046, // Range #56: [3046, 3066, Tamil]
- 3072, // Range #57: [3072, 3149, Telugu]
- 3157, // Range #58: [3157, 3162, Telugu]
- 3168, // Range #59: [3168, 3183, Telugu]
- 3191, // Range #60: [3191, 3199, Telugu]
- 3200, // Range #61: [3200, 3277, Kannada]
- 3285, // Range #62: [3285, 3286, Kannada]
- 3294, // Range #63: [3294, 3314, Kannada]
- 3328, // Range #64: [3328, 3455, Malayalam]
- 3458, // Range #65: [3458, 3551, Sinhala]
- 3558, // Range #66: [3558, 3572, Sinhala]
- 3585, // Range #67: [3585, 3642, Thai]
- 3648, // Range #68: [3648, 3675, Thai]
- 3713, // Range #69: [3713, 3807, Lao]
- 3840, // Range #70: [3840, 4052, Tibetan]
- 4057, // Range #71: [4057, 4058, Tibetan]
- 4096, // Range #72: [4096, 4255, Myanmar]
- 4256, // Range #73: [4256, 4295, Georgian]
- 4301, // Range #74: [4301, 4346, Georgian]
- 4348, // Range #75: [4348, 4351, Georgian]
- 4352, // Range #76: [4352, 4607, Hangul]
- 4608, // Range #77: [4608, 5017, Ethiopic]
- 5024, // Range #78: [5024, 5117, Cherokee]
- 5120, // Range #79: [5120, 5759, Canadian_Aboriginal]
- 5760, // Range #80: [5760, 5788, Ogham]
- 5792, // Range #81: [5792, 5866, Runic]
- 5870, // Range #82: [5870, 5880, Runic]
- 5888, // Range #83: [5888, 5908, Tagalog]
- 5920, // Range #84: [5920, 5940, Hanunoo]
- 5952, // Range #85: [5952, 5971, Buhid]
- 5984, // Range #86: [5984, 6003, Tagbanwa]
- 6016, // Range #87: [6016, 6121, Khmer]
- 6128, // Range #88: [6128, 6137, Khmer]
- 6144, // Range #89: [6144, 6145, Mongolian]
- 6148, // Range #90: [6148, 6148, Mongolian]
- 6150, // Range #91: [6150, 6169, Mongolian]
- 6176, // Range #92: [6176, 6264, Mongolian]
- 6272, // Range #93: [6272, 6314, Mongolian]
- 6320, // Range #94: [6320, 6389, Canadian_Aboriginal]
- 6400, // Range #95: [6400, 6479, Limbu]
- 6480, // Range #96: [6480, 6516, Tai_Le]
- 6528, // Range #97: [6528, 6601, New_Tai_Lue]
- 6608, // Range #98: [6608, 6623, New_Tai_Lue]
- 6624, // Range #99: [6624, 6655, Khmer]
- 6656, // Range #100: [6656, 6687, Buginese]
- 6688, // Range #101: [6688, 6793, Tai_Tham]
- 6800, // Range #102: [6800, 6809, Tai_Tham]
- 6816, // Range #103: [6816, 6829, Tai_Tham]
- 6912, // Range #104: [6912, 7036, Balinese]
- 7040, // Range #105: [7040, 7103, Sundanese]
- 7104, // Range #106: [7104, 7155, Batak]
- 7164, // Range #107: [7164, 7167, Batak]
- 7168, // Range #108: [7168, 7247, Lepcha]
- 7248, // Range #109: [7248, 7295, Ol_Chiki]
- 7296, // Range #110: [7296, 7304, Cyrillic]
- 7312, // Range #111: [7312, 7359, Georgian]
- 7360, // Range #112: [7360, 7367, Sundanese]
- 7424, // Range #113: [7424, 7461, Latin]
- 7462, // Range #114: [7462, 7466, Greek]
- 7467, // Range #115: [7467, 7467, Cyrillic]
- 7468, // Range #116: [7468, 7516, Latin]
- 7517, // Range #117: [7517, 7521, Greek]
- 7522, // Range #118: [7522, 7525, Latin]
- 7526, // Range #119: [7526, 7530, Greek]
- 7531, // Range #120: [7531, 7543, Latin]
- 7544, // Range #121: [7544, 7544, Cyrillic]
- 7545, // Range #122: [7545, 7614, Latin]
- 7615, // Range #123: [7615, 7615, Greek]
- 7680, // Range #124: [7680, 7935, Latin]
- 7936, // Range #125: [7936, 8190, Greek]
- 8305, // Range #126: [8305, 8305, Latin]
- 8319, // Range #127: [8319, 8319, Latin]
- 8336, // Range #128: [8336, 8348, Latin]
- 8486, // Range #129: [8486, 8486, Greek]
- 8490, // Range #130: [8490, 8491, Latin]
- 8498, // Range #131: [8498, 8498, Latin]
- 8526, // Range #132: [8526, 8526, Latin]
- 8544, // Range #133: [8544, 8584, Latin]
- 10240, // Range #134: [10240, 10495, Braille]
- 11264, // Range #135: [11264, 11358, Glagolitic]
- 11360, // Range #136: [11360, 11391, Latin]
- 11392, // Range #137: [11392, 11507, Coptic]
- 11513, // Range #138: [11513, 11519, Coptic]
- 11520, // Range #139: [11520, 11559, Georgian]
- 11565, // Range #140: [11565, 11565, Georgian]
- 11568, // Range #141: [11568, 11623, Tifinagh]
- 11631, // Range #142: [11631, 11632, Tifinagh]
- 11647, // Range #143: [11647, 11647, Tifinagh]
- 11648, // Range #144: [11648, 11670, Ethiopic]
- 11680, // Range #145: [11680, 11742, Ethiopic]
- 11744, // Range #146: [11744, 11775, Cyrillic]
- 11904, // Range #147: [11904, 12019, Han]
- 12032, // Range #148: [12032, 12245, Han]
- 12293, // Range #149: [12293, 12293, Han]
- 12295, // Range #150: [12295, 12295, Han]
- 12321, // Range #151: [12321, 12329, Han]
- 12334, // Range #152: [12334, 12335, Hangul]
- 12344, // Range #153: [12344, 12347, Han]
- 12353, // Range #154: [12353, 12438, Hiragana]
- 12445, // Range #155: [12445, 12447, Hiragana]
- 12449, // Range #156: [12449, 12538, Katakana]
- 12541, // Range #157: [12541, 12543, Katakana]
- 12549, // Range #158: [12549, 12591, Bopomofo]
- 12593, // Range #159: [12593, 12686, Hangul]
- 12704, // Range #160: [12704, 12730, Bopomofo]
- 12784, // Range #161: [12784, 12799, Katakana]
- 12800, // Range #162: [12800, 12830, Hangul]
- 12896, // Range #163: [12896, 12926, Hangul]
- 13008, // Range #164: [13008, 13054, Katakana]
- 13056, // Range #165: [13056, 13143, Katakana]
- 13312, // Range #166: [13312, 19893, Han]
- 19968, // Range #167: [19968, 40943, Han]
- 40960, // Range #168: [40960, 42182, Yi]
- 42192, // Range #169: [42192, 42239, Lisu]
- 42240, // Range #170: [42240, 42539, Vai]
- 42560, // Range #171: [42560, 42655, Cyrillic]
- 42656, // Range #172: [42656, 42743, Bamum]
- 42786, // Range #173: [42786, 42887, Latin]
- 42891, // Range #174: [42891, 42950, Latin]
- 42999, // Range #175: [42999, 43007, Latin]
- 43008, // Range #176: [43008, 43051, Syloti_Nagri]
- 43072, // Range #177: [43072, 43127, Phags_Pa]
- 43136, // Range #178: [43136, 43205, Saurashtra]
- 43214, // Range #179: [43214, 43225, Saurashtra]
- 43232, // Range #180: [43232, 43263, Devanagari]
- 43264, // Range #181: [43264, 43309, Kayah_Li]
- 43311, // Range #182: [43311, 43311, Kayah_Li]
- 43312, // Range #183: [43312, 43347, Rejang]
- 43359, // Range #184: [43359, 43359, Rejang]
- 43360, // Range #185: [43360, 43388, Hangul]
- 43392, // Range #186: [43392, 43469, Javanese]
- 43472, // Range #187: [43472, 43487, Javanese]
- 43488, // Range #188: [43488, 43518, Myanmar]
- 43520, // Range #189: [43520, 43574, Cham]
- 43584, // Range #190: [43584, 43615, Cham]
- 43616, // Range #191: [43616, 43647, Myanmar]
- 43648, // Range #192: [43648, 43714, Tai_Viet]
- 43739, // Range #193: [43739, 43743, Tai_Viet]
- 43744, // Range #194: [43744, 43766, Meetei_Mayek]
- 43777, // Range #195: [43777, 43798, Ethiopic]
- 43808, // Range #196: [43808, 43822, Ethiopic]
- 43824, // Range #197: [43824, 43866, Latin]
- 43868, // Range #198: [43868, 43876, Latin]
- 43877, // Range #199: [43877, 43877, Greek]
- 43878, // Range #200: [43878, 43879, Latin]
- 43888, // Range #201: [43888, 43967, Cherokee]
- 43968, // Range #202: [43968, 44025, Meetei_Mayek]
- 44032, // Range #203: [44032, 55203, Hangul]
- 55216, // Range #204: [55216, 55291, Hangul]
- 63744, // Range #205: [63744, 64217, Han]
- 64256, // Range #206: [64256, 64262, Latin]
- 64275, // Range #207: [64275, 64279, Armenian]
- 64285, // Range #208: [64285, 64335, Hebrew]
- 64336, // Range #209: [64336, 64449, Arabic]
- 64467, // Range #210: [64467, 64829, Arabic]
- 64848, // Range #211: [64848, 64967, Arabic]
- 65008, // Range #212: [65008, 65021, Arabic]
- 65070, // Range #213: [65070, 65071, Cyrillic]
- 65136, // Range #214: [65136, 65276, Arabic]
- 65313, // Range #215: [65313, 65338, Latin]
- 65345, // Range #216: [65345, 65370, Latin]
- 65382, // Range #217: [65382, 65391, Katakana]
- 65393, // Range #218: [65393, 65437, Katakana]
- 65440, // Range #219: [65440, 65500, Hangul]
- 65536, // Range #220: [65536, 65629, Linear_B]
- 65664, // Range #221: [65664, 65786, Linear_B]
- 65856, // Range #222: [65856, 65934, Greek]
- 65952, // Range #223: [65952, 65952, Greek]
- 66176, // Range #224: [66176, 66204, Lycian]
- 66208, // Range #225: [66208, 66256, Carian]
- 66304, // Range #226: [66304, 66339, Old_Italic]
- 66349, // Range #227: [66349, 66351, Old_Italic]
- 66352, // Range #228: [66352, 66378, Gothic]
- 66384, // Range #229: [66384, 66426, Old_Permic]
- 66432, // Range #230: [66432, 66463, Ugaritic]
- 66464, // Range #231: [66464, 66517, Old_Persian]
- 66560, // Range #232: [66560, 66639, Deseret]
- 66640, // Range #233: [66640, 66687, Shavian]
- 66688, // Range #234: [66688, 66729, Osmanya]
- 66736, // Range #235: [66736, 66811, Osage]
- 66816, // Range #236: [66816, 66855, Elbasan]
- 66864, // Range #237: [66864, 66915, Caucasian_Albanian]
- 66927, // Range #238: [66927, 66927, Caucasian_Albanian]
- 67072, // Range #239: [67072, 67382, Linear_A]
- 67392, // Range #240: [67392, 67413, Linear_A]
- 67424, // Range #241: [67424, 67431, Linear_A]
- 67584, // Range #242: [67584, 67647, Cypriot]
- 67648, // Range #243: [67648, 67679, Imperial_Aramaic]
- 67680, // Range #244: [67680, 67711, Palmyrene]
- 67712, // Range #245: [67712, 67742, Nabataean]
- 67751, // Range #246: [67751, 67759, Nabataean]
- 67808, // Range #247: [67808, 67829, Hatran]
- 67835, // Range #248: [67835, 67839, Hatran]
- 67840, // Range #249: [67840, 67871, Phoenician]
- 67872, // Range #250: [67872, 67897, Lydian]
- 67903, // Range #251: [67903, 67903, Lydian]
- 67968, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
- 68000, // Range #253: [68000, 68095, Meroitic_Cursive]
- 68096, // Range #254: [68096, 68102, Kharoshthi]
- 68108, // Range #255: [68108, 68168, Kharoshthi]
- 68176, // Range #256: [68176, 68184, Kharoshthi]
- 68192, // Range #257: [68192, 68223, Old_South_Arabian]
- 68224, // Range #258: [68224, 68255, Old_North_Arabian]
- 68288, // Range #259: [68288, 68342, Manichaean]
- 68352, // Range #260: [68352, 68415, Avestan]
- 68416, // Range #261: [68416, 68447, Inscriptional_Parthian]
- 68448, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
- 68472, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
- 68480, // Range #264: [68480, 68497, Psalter_Pahlavi]
- 68505, // Range #265: [68505, 68508, Psalter_Pahlavi]
- 68521, // Range #266: [68521, 68527, Psalter_Pahlavi]
- 68608, // Range #267: [68608, 68680, Old_Turkic]
- 68736, // Range #268: [68736, 68786, Old_Hungarian]
- 68800, // Range #269: [68800, 68850, Old_Hungarian]
- 68858, // Range #270: [68858, 68863, Old_Hungarian]
- 68864, // Range #271: [68864, 68903, Hanifi_Rohingya]
- 68912, // Range #272: [68912, 68921, Hanifi_Rohingya]
- 69216, // Range #273: [69216, 69246, Arabic]
- 69376, // Range #274: [69376, 69415, Old_Sogdian]
- 69424, // Range #275: [69424, 69465, Sogdian]
- 69600, // Range #276: [69600, 69622, Elymaic]
- 69632, // Range #277: [69632, 69743, Brahmi]
- 69759, // Range #278: [69759, 69759, Brahmi]
- 69760, // Range #279: [69760, 69825, Kaithi]
- 69837, // Range #280: [69837, 69837, Kaithi]
- 69840, // Range #281: [69840, 69864, Sora_Sompeng]
- 69872, // Range #282: [69872, 69881, Sora_Sompeng]
- 69888, // Range #283: [69888, 69958, Chakma]
- 69968, // Range #284: [69968, 70006, Mahajani]
- 70016, // Range #285: [70016, 70111, Sharada]
- 70113, // Range #286: [70113, 70132, Sinhala]
- 70144, // Range #287: [70144, 70206, Khojki]
- 70272, // Range #288: [70272, 70313, Multani]
- 70320, // Range #289: [70320, 70378, Khudawadi]
- 70384, // Range #290: [70384, 70393, Khudawadi]
- 70400, // Range #291: [70400, 70457, Grantha]
- 70460, // Range #292: [70460, 70480, Grantha]
- 70487, // Range #293: [70487, 70487, Grantha]
- 70493, // Range #294: [70493, 70516, Grantha]
- 70656, // Range #295: [70656, 70751, Newa]
- 70784, // Range #296: [70784, 70855, Tirhuta]
- 70864, // Range #297: [70864, 70873, Tirhuta]
- 71040, // Range #298: [71040, 71133, Siddham]
- 71168, // Range #299: [71168, 71236, Modi]
- 71248, // Range #300: [71248, 71257, Modi]
- 71264, // Range #301: [71264, 71276, Mongolian]
- 71296, // Range #302: [71296, 71352, Takri]
- 71360, // Range #303: [71360, 71369, Takri]
- 71424, // Range #304: [71424, 71487, Ahom]
- 71680, // Range #305: [71680, 71739, Dogra]
- 71840, // Range #306: [71840, 71922, Warang_Citi]
- 71935, // Range #307: [71935, 71935, Warang_Citi]
- 72096, // Range #308: [72096, 72164, Nandinagari]
- 72192, // Range #309: [72192, 72263, Zanabazar_Square]
- 72272, // Range #310: [72272, 72354, Soyombo]
- 72384, // Range #311: [72384, 72440, Pau_Cin_Hau]
- 72704, // Range #312: [72704, 72773, Bhaiksuki]
- 72784, // Range #313: [72784, 72812, Bhaiksuki]
- 72816, // Range #314: [72816, 72886, Marchen]
- 72960, // Range #315: [72960, 73031, Masaram_Gondi]
- 73040, // Range #316: [73040, 73049, Masaram_Gondi]
- 73056, // Range #317: [73056, 73112, Gunjala_Gondi]
- 73120, // Range #318: [73120, 73129, Gunjala_Gondi]
- 73440, // Range #319: [73440, 73464, Makasar]
- 73664, // Range #320: [73664, 73713, Tamil]
- 73727, // Range #321: [73727, 73727, Tamil]
- 73728, // Range #322: [73728, 74649, Cuneiform]
- 74752, // Range #323: [74752, 74868, Cuneiform]
- 74880, // Range #324: [74880, 75075, Cuneiform]
- 77824, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
- 82944, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
- 92160, // Range #327: [92160, 92728, Bamum]
- 92736, // Range #328: [92736, 92783, Mro]
- 92880, // Range #329: [92880, 92917, Bassa_Vah]
- 92928, // Range #330: [92928, 92997, Pahawh_Hmong]
- 93008, // Range #331: [93008, 93047, Pahawh_Hmong]
- 93053, // Range #332: [93053, 93071, Pahawh_Hmong]
- 93760, // Range #333: [93760, 93850, Medefaidrin]
- 93952, // Range #334: [93952, 94087, Miao]
- 94095, // Range #335: [94095, 94111, Miao]
- 94176, // Range #336: [94176, 94176, Tangut]
- 94177, // Range #337: [94177, 94177, Nushu]
- 94208, // Range #338: [94208, 100343, Tangut]
- 100352, // Range #339: [100352, 101106, Tangut]
- 110592, // Range #340: [110592, 110592, Katakana]
- 110593, // Range #341: [110593, 110878, Hiragana]
- 110928, // Range #342: [110928, 110930, Hiragana]
- 110948, // Range #343: [110948, 110951, Katakana]
- 110960, // Range #344: [110960, 111355, Nushu]
- 113664, // Range #345: [113664, 113770, Duployan]
- 113776, // Range #346: [113776, 113800, Duployan]
- 113808, // Range #347: [113808, 113823, Duployan]
- 119296, // Range #348: [119296, 119365, Greek]
- 120832, // Range #349: [120832, 121483, SignWriting]
- 121499, // Range #350: [121499, 121519, SignWriting]
- 122880, // Range #351: [122880, 122922, Glagolitic]
- 123136, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
- 123584, // Range #353: [123584, 123641, Wancho]
- 123647, // Range #354: [123647, 123647, Wancho]
- 124928, // Range #355: [124928, 125142, Mende_Kikakui]
- 125184, // Range #356: [125184, 125279, Adlam]
- 126464, // Range #357: [126464, 126523, Arabic]
- 126530, // Range #358: [126530, 126619, Arabic]
- 126625, // Range #359: [126625, 126651, Arabic]
- 126704, // Range #360: [126704, 126705, Arabic]
- 127488, // Range #361: [127488, 127488, Hiragana]
- 131072, // Range #362: [131072, 173782, Han]
- 173824, // Range #363: [173824, 177972, Han]
- 177984, // Range #364: [177984, 183969, Han]
- 183984, // Range #365: [183984, 191456, Han]
- 194560, // Range #366: [194560, 195101, Han]
-};
-
-const uint16 kRangeSizeMinusOne[] = {
- 25, // Range #0: [65, 90, Latin]
- 25, // Range #1: [97, 122, Latin]
- 0, // Range #2: [170, 170, Latin]
- 0, // Range #3: [186, 186, Latin]
- 22, // Range #4: [192, 214, Latin]
- 30, // Range #5: [216, 246, Latin]
- 448, // Range #6: [248, 696, Latin]
- 4, // Range #7: [736, 740, Latin]
- 1, // Range #8: [746, 747, Bopomofo]
- 3, // Range #9: [880, 883, Greek]
- 8, // Range #10: [885, 893, Greek]
- 5, // Range #11: [895, 900, Greek]
- 0, // Range #12: [902, 902, Greek]
- 89, // Range #13: [904, 993, Greek]
- 13, // Range #14: [994, 1007, Coptic]
- 15, // Range #15: [1008, 1023, Greek]
- 132, // Range #16: [1024, 1156, Cyrillic]
- 168, // Range #17: [1159, 1327, Cyrillic]
- 87, // Range #18: [1329, 1416, Armenian]
- 5, // Range #19: [1418, 1423, Armenian]
- 54, // Range #20: [1425, 1479, Hebrew]
- 36, // Range #21: [1488, 1524, Hebrew]
- 4, // Range #22: [1536, 1540, Arabic]
- 5, // Range #23: [1542, 1547, Arabic]
- 13, // Range #24: [1549, 1562, Arabic]
- 2, // Range #25: [1564, 1566, Arabic]
- 31, // Range #26: [1568, 1599, Arabic]
- 9, // Range #27: [1601, 1610, Arabic]
- 25, // Range #28: [1622, 1647, Arabic]
- 107, // Range #29: [1649, 1756, Arabic]
- 33, // Range #30: [1758, 1791, Arabic]
- 79, // Range #31: [1792, 1871, Syriac]
- 47, // Range #32: [1872, 1919, Arabic]
- 49, // Range #33: [1920, 1969, Thaana]
- 63, // Range #34: [1984, 2047, Nko]
- 62, // Range #35: [2048, 2110, Samaritan]
- 30, // Range #36: [2112, 2142, Mandaic]
- 10, // Range #37: [2144, 2154, Syriac]
- 29, // Range #38: [2208, 2237, Arabic]
- 14, // Range #39: [2259, 2273, Arabic]
- 28, // Range #40: [2275, 2303, Arabic]
- 80, // Range #41: [2304, 2384, Devanagari]
- 14, // Range #42: [2389, 2403, Devanagari]
- 25, // Range #43: [2406, 2431, Devanagari]
- 78, // Range #44: [2432, 2510, Bengali]
- 39, // Range #45: [2519, 2558, Bengali]
- 80, // Range #46: [2561, 2641, Gurmukhi]
- 5, // Range #47: [2649, 2654, Gurmukhi]
- 16, // Range #48: [2662, 2678, Gurmukhi]
- 79, // Range #49: [2689, 2768, Gujarati]
- 17, // Range #50: [2784, 2801, Gujarati]
- 6, // Range #51: [2809, 2815, Gujarati]
- 76, // Range #52: [2817, 2893, Oriya]
- 33, // Range #53: [2902, 2935, Oriya]
- 78, // Range #54: [2946, 3024, Tamil]
- 0, // Range #55: [3031, 3031, Tamil]
- 20, // Range #56: [3046, 3066, Tamil]
- 77, // Range #57: [3072, 3149, Telugu]
- 5, // Range #58: [3157, 3162, Telugu]
- 15, // Range #59: [3168, 3183, Telugu]
- 8, // Range #60: [3191, 3199, Telugu]
- 77, // Range #61: [3200, 3277, Kannada]
- 1, // Range #62: [3285, 3286, Kannada]
- 20, // Range #63: [3294, 3314, Kannada]
- 127, // Range #64: [3328, 3455, Malayalam]
- 93, // Range #65: [3458, 3551, Sinhala]
- 14, // Range #66: [3558, 3572, Sinhala]
- 57, // Range #67: [3585, 3642, Thai]
- 27, // Range #68: [3648, 3675, Thai]
- 94, // Range #69: [3713, 3807, Lao]
- 212, // Range #70: [3840, 4052, Tibetan]
- 1, // Range #71: [4057, 4058, Tibetan]
- 159, // Range #72: [4096, 4255, Myanmar]
- 39, // Range #73: [4256, 4295, Georgian]
- 45, // Range #74: [4301, 4346, Georgian]
- 3, // Range #75: [4348, 4351, Georgian]
- 255, // Range #76: [4352, 4607, Hangul]
- 409, // Range #77: [4608, 5017, Ethiopic]
- 93, // Range #78: [5024, 5117, Cherokee]
- 639, // Range #79: [5120, 5759, Canadian_Aboriginal]
- 28, // Range #80: [5760, 5788, Ogham]
- 74, // Range #81: [5792, 5866, Runic]
- 10, // Range #82: [5870, 5880, Runic]
- 20, // Range #83: [5888, 5908, Tagalog]
- 20, // Range #84: [5920, 5940, Hanunoo]
- 19, // Range #85: [5952, 5971, Buhid]
- 19, // Range #86: [5984, 6003, Tagbanwa]
- 105, // Range #87: [6016, 6121, Khmer]
- 9, // Range #88: [6128, 6137, Khmer]
- 1, // Range #89: [6144, 6145, Mongolian]
- 0, // Range #90: [6148, 6148, Mongolian]
- 19, // Range #91: [6150, 6169, Mongolian]
- 88, // Range #92: [6176, 6264, Mongolian]
- 42, // Range #93: [6272, 6314, Mongolian]
- 69, // Range #94: [6320, 6389, Canadian_Aboriginal]
- 79, // Range #95: [6400, 6479, Limbu]
- 36, // Range #96: [6480, 6516, Tai_Le]
- 73, // Range #97: [6528, 6601, New_Tai_Lue]
- 15, // Range #98: [6608, 6623, New_Tai_Lue]
- 31, // Range #99: [6624, 6655, Khmer]
- 31, // Range #100: [6656, 6687, Buginese]
- 105, // Range #101: [6688, 6793, Tai_Tham]
- 9, // Range #102: [6800, 6809, Tai_Tham]
- 13, // Range #103: [6816, 6829, Tai_Tham]
- 124, // Range #104: [6912, 7036, Balinese]
- 63, // Range #105: [7040, 7103, Sundanese]
- 51, // Range #106: [7104, 7155, Batak]
- 3, // Range #107: [7164, 7167, Batak]
- 79, // Range #108: [7168, 7247, Lepcha]
- 47, // Range #109: [7248, 7295, Ol_Chiki]
- 8, // Range #110: [7296, 7304, Cyrillic]
- 47, // Range #111: [7312, 7359, Georgian]
- 7, // Range #112: [7360, 7367, Sundanese]
- 37, // Range #113: [7424, 7461, Latin]
- 4, // Range #114: [7462, 7466, Greek]
- 0, // Range #115: [7467, 7467, Cyrillic]
- 48, // Range #116: [7468, 7516, Latin]
- 4, // Range #117: [7517, 7521, Greek]
- 3, // Range #118: [7522, 7525, Latin]
- 4, // Range #119: [7526, 7530, Greek]
- 12, // Range #120: [7531, 7543, Latin]
- 0, // Range #121: [7544, 7544, Cyrillic]
- 69, // Range #122: [7545, 7614, Latin]
- 0, // Range #123: [7615, 7615, Greek]
- 255, // Range #124: [7680, 7935, Latin]
- 254, // Range #125: [7936, 8190, Greek]
- 0, // Range #126: [8305, 8305, Latin]
- 0, // Range #127: [8319, 8319, Latin]
- 12, // Range #128: [8336, 8348, Latin]
- 0, // Range #129: [8486, 8486, Greek]
- 1, // Range #130: [8490, 8491, Latin]
- 0, // Range #131: [8498, 8498, Latin]
- 0, // Range #132: [8526, 8526, Latin]
- 40, // Range #133: [8544, 8584, Latin]
- 255, // Range #134: [10240, 10495, Braille]
- 94, // Range #135: [11264, 11358, Glagolitic]
- 31, // Range #136: [11360, 11391, Latin]
- 115, // Range #137: [11392, 11507, Coptic]
- 6, // Range #138: [11513, 11519, Coptic]
- 39, // Range #139: [11520, 11559, Georgian]
- 0, // Range #140: [11565, 11565, Georgian]
- 55, // Range #141: [11568, 11623, Tifinagh]
- 1, // Range #142: [11631, 11632, Tifinagh]
- 0, // Range #143: [11647, 11647, Tifinagh]
- 22, // Range #144: [11648, 11670, Ethiopic]
- 62, // Range #145: [11680, 11742, Ethiopic]
- 31, // Range #146: [11744, 11775, Cyrillic]
- 115, // Range #147: [11904, 12019, Han]
- 213, // Range #148: [12032, 12245, Han]
- 0, // Range #149: [12293, 12293, Han]
- 0, // Range #150: [12295, 12295, Han]
- 8, // Range #151: [12321, 12329, Han]
- 1, // Range #152: [12334, 12335, Hangul]
- 3, // Range #153: [12344, 12347, Han]
- 85, // Range #154: [12353, 12438, Hiragana]
- 2, // Range #155: [12445, 12447, Hiragana]
- 89, // Range #156: [12449, 12538, Katakana]
- 2, // Range #157: [12541, 12543, Katakana]
- 42, // Range #158: [12549, 12591, Bopomofo]
- 93, // Range #159: [12593, 12686, Hangul]
- 26, // Range #160: [12704, 12730, Bopomofo]
- 15, // Range #161: [12784, 12799, Katakana]
- 30, // Range #162: [12800, 12830, Hangul]
- 30, // Range #163: [12896, 12926, Hangul]
- 46, // Range #164: [13008, 13054, Katakana]
- 87, // Range #165: [13056, 13143, Katakana]
- 6581, // Range #166: [13312, 19893, Han]
- 20975, // Range #167: [19968, 40943, Han]
- 1222, // Range #168: [40960, 42182, Yi]
- 47, // Range #169: [42192, 42239, Lisu]
- 299, // Range #170: [42240, 42539, Vai]
- 95, // Range #171: [42560, 42655, Cyrillic]
- 87, // Range #172: [42656, 42743, Bamum]
- 101, // Range #173: [42786, 42887, Latin]
- 59, // Range #174: [42891, 42950, Latin]
- 8, // Range #175: [42999, 43007, Latin]
- 43, // Range #176: [43008, 43051, Syloti_Nagri]
- 55, // Range #177: [43072, 43127, Phags_Pa]
- 69, // Range #178: [43136, 43205, Saurashtra]
- 11, // Range #179: [43214, 43225, Saurashtra]
- 31, // Range #180: [43232, 43263, Devanagari]
- 45, // Range #181: [43264, 43309, Kayah_Li]
- 0, // Range #182: [43311, 43311, Kayah_Li]
- 35, // Range #183: [43312, 43347, Rejang]
- 0, // Range #184: [43359, 43359, Rejang]
- 28, // Range #185: [43360, 43388, Hangul]
- 77, // Range #186: [43392, 43469, Javanese]
- 15, // Range #187: [43472, 43487, Javanese]
- 30, // Range #188: [43488, 43518, Myanmar]
- 54, // Range #189: [43520, 43574, Cham]
- 31, // Range #190: [43584, 43615, Cham]
- 31, // Range #191: [43616, 43647, Myanmar]
- 66, // Range #192: [43648, 43714, Tai_Viet]
- 4, // Range #193: [43739, 43743, Tai_Viet]
- 22, // Range #194: [43744, 43766, Meetei_Mayek]
- 21, // Range #195: [43777, 43798, Ethiopic]
- 14, // Range #196: [43808, 43822, Ethiopic]
- 42, // Range #197: [43824, 43866, Latin]
- 8, // Range #198: [43868, 43876, Latin]
- 0, // Range #199: [43877, 43877, Greek]
- 1, // Range #200: [43878, 43879, Latin]
- 79, // Range #201: [43888, 43967, Cherokee]
- 57, // Range #202: [43968, 44025, Meetei_Mayek]
- 11171, // Range #203: [44032, 55203, Hangul]
- 75, // Range #204: [55216, 55291, Hangul]
- 473, // Range #205: [63744, 64217, Han]
- 6, // Range #206: [64256, 64262, Latin]
- 4, // Range #207: [64275, 64279, Armenian]
- 50, // Range #208: [64285, 64335, Hebrew]
- 113, // Range #209: [64336, 64449, Arabic]
- 362, // Range #210: [64467, 64829, Arabic]
- 119, // Range #211: [64848, 64967, Arabic]
- 13, // Range #212: [65008, 65021, Arabic]
- 1, // Range #213: [65070, 65071, Cyrillic]
- 140, // Range #214: [65136, 65276, Arabic]
- 25, // Range #215: [65313, 65338, Latin]
- 25, // Range #216: [65345, 65370, Latin]
- 9, // Range #217: [65382, 65391, Katakana]
- 44, // Range #218: [65393, 65437, Katakana]
- 60, // Range #219: [65440, 65500, Hangul]
- 93, // Range #220: [65536, 65629, Linear_B]
- 122, // Range #221: [65664, 65786, Linear_B]
- 78, // Range #222: [65856, 65934, Greek]
- 0, // Range #223: [65952, 65952, Greek]
- 28, // Range #224: [66176, 66204, Lycian]
- 48, // Range #225: [66208, 66256, Carian]
- 35, // Range #226: [66304, 66339, Old_Italic]
- 2, // Range #227: [66349, 66351, Old_Italic]
- 26, // Range #228: [66352, 66378, Gothic]
- 42, // Range #229: [66384, 66426, Old_Permic]
- 31, // Range #230: [66432, 66463, Ugaritic]
- 53, // Range #231: [66464, 66517, Old_Persian]
- 79, // Range #232: [66560, 66639, Deseret]
- 47, // Range #233: [66640, 66687, Shavian]
- 41, // Range #234: [66688, 66729, Osmanya]
- 75, // Range #235: [66736, 66811, Osage]
- 39, // Range #236: [66816, 66855, Elbasan]
- 51, // Range #237: [66864, 66915, Caucasian_Albanian]
- 0, // Range #238: [66927, 66927, Caucasian_Albanian]
- 310, // Range #239: [67072, 67382, Linear_A]
- 21, // Range #240: [67392, 67413, Linear_A]
- 7, // Range #241: [67424, 67431, Linear_A]
- 63, // Range #242: [67584, 67647, Cypriot]
- 31, // Range #243: [67648, 67679, Imperial_Aramaic]
- 31, // Range #244: [67680, 67711, Palmyrene]
- 30, // Range #245: [67712, 67742, Nabataean]
- 8, // Range #246: [67751, 67759, Nabataean]
- 21, // Range #247: [67808, 67829, Hatran]
- 4, // Range #248: [67835, 67839, Hatran]
- 31, // Range #249: [67840, 67871, Phoenician]
- 25, // Range #250: [67872, 67897, Lydian]
- 0, // Range #251: [67903, 67903, Lydian]
- 31, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
- 95, // Range #253: [68000, 68095, Meroitic_Cursive]
- 6, // Range #254: [68096, 68102, Kharoshthi]
- 60, // Range #255: [68108, 68168, Kharoshthi]
- 8, // Range #256: [68176, 68184, Kharoshthi]
- 31, // Range #257: [68192, 68223, Old_South_Arabian]
- 31, // Range #258: [68224, 68255, Old_North_Arabian]
- 54, // Range #259: [68288, 68342, Manichaean]
- 63, // Range #260: [68352, 68415, Avestan]
- 31, // Range #261: [68416, 68447, Inscriptional_Parthian]
- 18, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
- 7, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
- 17, // Range #264: [68480, 68497, Psalter_Pahlavi]
- 3, // Range #265: [68505, 68508, Psalter_Pahlavi]
- 6, // Range #266: [68521, 68527, Psalter_Pahlavi]
- 72, // Range #267: [68608, 68680, Old_Turkic]
- 50, // Range #268: [68736, 68786, Old_Hungarian]
- 50, // Range #269: [68800, 68850, Old_Hungarian]
- 5, // Range #270: [68858, 68863, Old_Hungarian]
- 39, // Range #271: [68864, 68903, Hanifi_Rohingya]
- 9, // Range #272: [68912, 68921, Hanifi_Rohingya]
- 30, // Range #273: [69216, 69246, Arabic]
- 39, // Range #274: [69376, 69415, Old_Sogdian]
- 41, // Range #275: [69424, 69465, Sogdian]
- 22, // Range #276: [69600, 69622, Elymaic]
- 111, // Range #277: [69632, 69743, Brahmi]
- 0, // Range #278: [69759, 69759, Brahmi]
- 65, // Range #279: [69760, 69825, Kaithi]
- 0, // Range #280: [69837, 69837, Kaithi]
- 24, // Range #281: [69840, 69864, Sora_Sompeng]
- 9, // Range #282: [69872, 69881, Sora_Sompeng]
- 70, // Range #283: [69888, 69958, Chakma]
- 38, // Range #284: [69968, 70006, Mahajani]
- 95, // Range #285: [70016, 70111, Sharada]
- 19, // Range #286: [70113, 70132, Sinhala]
- 62, // Range #287: [70144, 70206, Khojki]
- 41, // Range #288: [70272, 70313, Multani]
- 58, // Range #289: [70320, 70378, Khudawadi]
- 9, // Range #290: [70384, 70393, Khudawadi]
- 57, // Range #291: [70400, 70457, Grantha]
- 20, // Range #292: [70460, 70480, Grantha]
- 0, // Range #293: [70487, 70487, Grantha]
- 23, // Range #294: [70493, 70516, Grantha]
- 95, // Range #295: [70656, 70751, Newa]
- 71, // Range #296: [70784, 70855, Tirhuta]
- 9, // Range #297: [70864, 70873, Tirhuta]
- 93, // Range #298: [71040, 71133, Siddham]
- 68, // Range #299: [71168, 71236, Modi]
- 9, // Range #300: [71248, 71257, Modi]
- 12, // Range #301: [71264, 71276, Mongolian]
- 56, // Range #302: [71296, 71352, Takri]
- 9, // Range #303: [71360, 71369, Takri]
- 63, // Range #304: [71424, 71487, Ahom]
- 59, // Range #305: [71680, 71739, Dogra]
- 82, // Range #306: [71840, 71922, Warang_Citi]
- 0, // Range #307: [71935, 71935, Warang_Citi]
- 68, // Range #308: [72096, 72164, Nandinagari]
- 71, // Range #309: [72192, 72263, Zanabazar_Square]
- 82, // Range #310: [72272, 72354, Soyombo]
- 56, // Range #311: [72384, 72440, Pau_Cin_Hau]
- 69, // Range #312: [72704, 72773, Bhaiksuki]
- 28, // Range #313: [72784, 72812, Bhaiksuki]
- 70, // Range #314: [72816, 72886, Marchen]
- 71, // Range #315: [72960, 73031, Masaram_Gondi]
- 9, // Range #316: [73040, 73049, Masaram_Gondi]
- 56, // Range #317: [73056, 73112, Gunjala_Gondi]
- 9, // Range #318: [73120, 73129, Gunjala_Gondi]
- 24, // Range #319: [73440, 73464, Makasar]
- 49, // Range #320: [73664, 73713, Tamil]
- 0, // Range #321: [73727, 73727, Tamil]
- 921, // Range #322: [73728, 74649, Cuneiform]
- 116, // Range #323: [74752, 74868, Cuneiform]
- 195, // Range #324: [74880, 75075, Cuneiform]
- 1080, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
- 582, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
- 568, // Range #327: [92160, 92728, Bamum]
- 47, // Range #328: [92736, 92783, Mro]
- 37, // Range #329: [92880, 92917, Bassa_Vah]
- 69, // Range #330: [92928, 92997, Pahawh_Hmong]
- 39, // Range #331: [93008, 93047, Pahawh_Hmong]
- 18, // Range #332: [93053, 93071, Pahawh_Hmong]
- 90, // Range #333: [93760, 93850, Medefaidrin]
- 135, // Range #334: [93952, 94087, Miao]
- 16, // Range #335: [94095, 94111, Miao]
- 0, // Range #336: [94176, 94176, Tangut]
- 0, // Range #337: [94177, 94177, Nushu]
- 6135, // Range #338: [94208, 100343, Tangut]
- 754, // Range #339: [100352, 101106, Tangut]
- 0, // Range #340: [110592, 110592, Katakana]
- 285, // Range #341: [110593, 110878, Hiragana]
- 2, // Range #342: [110928, 110930, Hiragana]
- 3, // Range #343: [110948, 110951, Katakana]
- 395, // Range #344: [110960, 111355, Nushu]
- 106, // Range #345: [113664, 113770, Duployan]
- 24, // Range #346: [113776, 113800, Duployan]
- 15, // Range #347: [113808, 113823, Duployan]
- 69, // Range #348: [119296, 119365, Greek]
- 651, // Range #349: [120832, 121483, SignWriting]
- 20, // Range #350: [121499, 121519, SignWriting]
- 42, // Range #351: [122880, 122922, Glagolitic]
- 79, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
- 57, // Range #353: [123584, 123641, Wancho]
- 0, // Range #354: [123647, 123647, Wancho]
- 214, // Range #355: [124928, 125142, Mende_Kikakui]
- 95, // Range #356: [125184, 125279, Adlam]
- 59, // Range #357: [126464, 126523, Arabic]
- 89, // Range #358: [126530, 126619, Arabic]
- 26, // Range #359: [126625, 126651, Arabic]
- 1, // Range #360: [126704, 126705, Arabic]
- 0, // Range #361: [127488, 127488, Hiragana]
- 42710, // Range #362: [131072, 173782, Han]
- 4148, // Range #363: [173824, 177972, Han]
- 5985, // Range #364: [177984, 183969, Han]
- 7472, // Range #365: [183984, 191456, Han]
- 541, // Range #366: [194560, 195101, Han]
-};
-
-const uint8 kRangeScript[] = {
- 25, // Range #0: [65, 90, Latin]
- 25, // Range #1: [97, 122, Latin]
- 25, // Range #2: [170, 170, Latin]
- 25, // Range #3: [186, 186, Latin]
- 25, // Range #4: [192, 214, Latin]
- 25, // Range #5: [216, 246, Latin]
- 25, // Range #6: [248, 696, Latin]
- 25, // Range #7: [736, 740, Latin]
- 5, // Range #8: [746, 747, Bopomofo]
- 14, // Range #9: [880, 883, Greek]
- 14, // Range #10: [885, 893, Greek]
- 14, // Range #11: [895, 900, Greek]
- 14, // Range #12: [902, 902, Greek]
- 14, // Range #13: [904, 993, Greek]
- 7, // Range #14: [994, 1007, Coptic]
- 14, // Range #15: [1008, 1023, Greek]
- 8, // Range #16: [1024, 1156, Cyrillic]
- 8, // Range #17: [1159, 1327, Cyrillic]
- 3, // Range #18: [1329, 1416, Armenian]
- 3, // Range #19: [1418, 1423, Armenian]
- 19, // Range #20: [1425, 1479, Hebrew]
- 19, // Range #21: [1488, 1524, Hebrew]
- 2, // Range #22: [1536, 1540, Arabic]
- 2, // Range #23: [1542, 1547, Arabic]
- 2, // Range #24: [1549, 1562, Arabic]
- 2, // Range #25: [1564, 1566, Arabic]
- 2, // Range #26: [1568, 1599, Arabic]
- 2, // Range #27: [1601, 1610, Arabic]
- 2, // Range #28: [1622, 1647, Arabic]
- 2, // Range #29: [1649, 1756, Arabic]
- 2, // Range #30: [1758, 1791, Arabic]
- 34, // Range #31: [1792, 1871, Syriac]
- 2, // Range #32: [1872, 1919, Arabic]
- 37, // Range #33: [1920, 1969, Thaana]
- 87, // Range #34: [1984, 2047, Nko]
- 126, // Range #35: [2048, 2110, Samaritan]
- 84, // Range #36: [2112, 2142, Mandaic]
- 34, // Range #37: [2144, 2154, Syriac]
- 2, // Range #38: [2208, 2237, Arabic]
- 2, // Range #39: [2259, 2273, Arabic]
- 2, // Range #40: [2275, 2303, Arabic]
- 10, // Range #41: [2304, 2384, Devanagari]
- 10, // Range #42: [2389, 2403, Devanagari]
- 10, // Range #43: [2406, 2431, Devanagari]
- 4, // Range #44: [2432, 2510, Bengali]
- 4, // Range #45: [2519, 2558, Bengali]
- 16, // Range #46: [2561, 2641, Gurmukhi]
- 16, // Range #47: [2649, 2654, Gurmukhi]
- 16, // Range #48: [2662, 2678, Gurmukhi]
- 15, // Range #49: [2689, 2768, Gujarati]
- 15, // Range #50: [2784, 2801, Gujarati]
- 15, // Range #51: [2809, 2815, Gujarati]
- 31, // Range #52: [2817, 2893, Oriya]
- 31, // Range #53: [2902, 2935, Oriya]
- 35, // Range #54: [2946, 3024, Tamil]
- 35, // Range #55: [3031, 3031, Tamil]
- 35, // Range #56: [3046, 3066, Tamil]
- 36, // Range #57: [3072, 3149, Telugu]
- 36, // Range #58: [3157, 3162, Telugu]
- 36, // Range #59: [3168, 3183, Telugu]
- 36, // Range #60: [3191, 3199, Telugu]
- 21, // Range #61: [3200, 3277, Kannada]
- 21, // Range #62: [3285, 3286, Kannada]
- 21, // Range #63: [3294, 3314, Kannada]
- 26, // Range #64: [3328, 3455, Malayalam]
- 33, // Range #65: [3458, 3551, Sinhala]
- 33, // Range #66: [3558, 3572, Sinhala]
- 38, // Range #67: [3585, 3642, Thai]
- 38, // Range #68: [3648, 3675, Thai]
- 24, // Range #69: [3713, 3807, Lao]
- 39, // Range #70: [3840, 4052, Tibetan]
- 39, // Range #71: [4057, 4058, Tibetan]
- 28, // Range #72: [4096, 4255, Myanmar]
- 12, // Range #73: [4256, 4295, Georgian]
- 12, // Range #74: [4301, 4346, Georgian]
- 12, // Range #75: [4348, 4351, Georgian]
- 18, // Range #76: [4352, 4607, Hangul]
- 11, // Range #77: [4608, 5017, Ethiopic]
- 6, // Range #78: [5024, 5117, Cherokee]
- 40, // Range #79: [5120, 5759, Canadian_Aboriginal]
- 29, // Range #80: [5760, 5788, Ogham]
- 32, // Range #81: [5792, 5866, Runic]
- 32, // Range #82: [5870, 5880, Runic]
- 42, // Range #83: [5888, 5908, Tagalog]
- 43, // Range #84: [5920, 5940, Hanunoo]
- 44, // Range #85: [5952, 5971, Buhid]
- 45, // Range #86: [5984, 6003, Tagbanwa]
- 23, // Range #87: [6016, 6121, Khmer]
- 23, // Range #88: [6128, 6137, Khmer]
- 27, // Range #89: [6144, 6145, Mongolian]
- 27, // Range #90: [6148, 6148, Mongolian]
- 27, // Range #91: [6150, 6169, Mongolian]
- 27, // Range #92: [6176, 6264, Mongolian]
- 27, // Range #93: [6272, 6314, Mongolian]
- 40, // Range #94: [6320, 6389, Canadian_Aboriginal]
- 48, // Range #95: [6400, 6479, Limbu]
- 52, // Range #96: [6480, 6516, Tai_Le]
- 59, // Range #97: [6528, 6601, New_Tai_Lue]
- 59, // Range #98: [6608, 6623, New_Tai_Lue]
- 23, // Range #99: [6624, 6655, Khmer]
- 55, // Range #100: [6656, 6687, Buginese]
- 106, // Range #101: [6688, 6793, Tai_Tham]
- 106, // Range #102: [6800, 6809, Tai_Tham]
- 106, // Range #103: [6816, 6829, Tai_Tham]
- 62, // Range #104: [6912, 7036, Balinese]
- 113, // Range #105: [7040, 7103, Sundanese]
- 63, // Range #106: [7104, 7155, Batak]
- 63, // Range #107: [7164, 7167, Batak]
- 82, // Range #108: [7168, 7247, Lepcha]
- 109, // Range #109: [7248, 7295, Ol_Chiki]
- 8, // Range #110: [7296, 7304, Cyrillic]
- 12, // Range #111: [7312, 7359, Georgian]
- 113, // Range #112: [7360, 7367, Sundanese]
- 25, // Range #113: [7424, 7461, Latin]
- 14, // Range #114: [7462, 7466, Greek]
- 8, // Range #115: [7467, 7467, Cyrillic]
- 25, // Range #116: [7468, 7516, Latin]
- 14, // Range #117: [7517, 7521, Greek]
- 25, // Range #118: [7522, 7525, Latin]
- 14, // Range #119: [7526, 7530, Greek]
- 25, // Range #120: [7531, 7543, Latin]
- 8, // Range #121: [7544, 7544, Cyrillic]
- 25, // Range #122: [7545, 7614, Latin]
- 14, // Range #123: [7615, 7615, Greek]
- 25, // Range #124: [7680, 7935, Latin]
- 14, // Range #125: [7936, 8190, Greek]
- 25, // Range #126: [8305, 8305, Latin]
- 25, // Range #127: [8319, 8319, Latin]
- 25, // Range #128: [8336, 8348, Latin]
- 14, // Range #129: [8486, 8486, Greek]
- 25, // Range #130: [8490, 8491, Latin]
- 25, // Range #131: [8498, 8498, Latin]
- 25, // Range #132: [8526, 8526, Latin]
- 25, // Range #133: [8544, 8584, Latin]
- 46, // Range #134: [10240, 10495, Braille]
- 56, // Range #135: [11264, 11358, Glagolitic]
- 25, // Range #136: [11360, 11391, Latin]
- 7, // Range #137: [11392, 11507, Coptic]
- 7, // Range #138: [11513, 11519, Coptic]
- 12, // Range #139: [11520, 11559, Georgian]
- 12, // Range #140: [11565, 11565, Georgian]
- 60, // Range #141: [11568, 11623, Tifinagh]
- 60, // Range #142: [11631, 11632, Tifinagh]
- 60, // Range #143: [11647, 11647, Tifinagh]
- 11, // Range #144: [11648, 11670, Ethiopic]
- 11, // Range #145: [11680, 11742, Ethiopic]
- 8, // Range #146: [11744, 11775, Cyrillic]
- 17, // Range #147: [11904, 12019, Han]
- 17, // Range #148: [12032, 12245, Han]
- 17, // Range #149: [12293, 12293, Han]
- 17, // Range #150: [12295, 12295, Han]
- 17, // Range #151: [12321, 12329, Han]
- 18, // Range #152: [12334, 12335, Hangul]
- 17, // Range #153: [12344, 12347, Han]
- 20, // Range #154: [12353, 12438, Hiragana]
- 20, // Range #155: [12445, 12447, Hiragana]
- 22, // Range #156: [12449, 12538, Katakana]
- 22, // Range #157: [12541, 12543, Katakana]
- 5, // Range #158: [12549, 12591, Bopomofo]
- 18, // Range #159: [12593, 12686, Hangul]
- 5, // Range #160: [12704, 12730, Bopomofo]
- 22, // Range #161: [12784, 12799, Katakana]
- 18, // Range #162: [12800, 12830, Hangul]
- 18, // Range #163: [12896, 12926, Hangul]
- 22, // Range #164: [13008, 13054, Katakana]
- 22, // Range #165: [13056, 13143, Katakana]
- 17, // Range #166: [13312, 19893, Han]
- 17, // Range #167: [19968, 40943, Han]
- 41, // Range #168: [40960, 42182, Yi]
- 131, // Range #169: [42192, 42239, Lisu]
- 99, // Range #170: [42240, 42539, Vai]
- 8, // Range #171: [42560, 42655, Cyrillic]
- 130, // Range #172: [42656, 42743, Bamum]
- 25, // Range #173: [42786, 42887, Latin]
- 25, // Range #174: [42891, 42950, Latin]
- 25, // Range #175: [42999, 43007, Latin]
- 58, // Range #176: [43008, 43051, Syloti_Nagri]
- 90, // Range #177: [43072, 43127, Phags_Pa]
- 111, // Range #178: [43136, 43205, Saurashtra]
- 111, // Range #179: [43214, 43225, Saurashtra]
- 10, // Range #180: [43232, 43263, Devanagari]
- 79, // Range #181: [43264, 43309, Kayah_Li]
- 79, // Range #182: [43311, 43311, Kayah_Li]
- 110, // Range #183: [43312, 43347, Rejang]
- 110, // Range #184: [43359, 43359, Rejang]
- 18, // Range #185: [43360, 43388, Hangul]
- 78, // Range #186: [43392, 43469, Javanese]
- 78, // Range #187: [43472, 43487, Javanese]
- 28, // Range #188: [43488, 43518, Myanmar]
- 66, // Range #189: [43520, 43574, Cham]
- 66, // Range #190: [43584, 43615, Cham]
- 28, // Range #191: [43616, 43647, Myanmar]
- 127, // Range #192: [43648, 43714, Tai_Viet]
- 127, // Range #193: [43739, 43743, Tai_Viet]
- 115, // Range #194: [43744, 43766, Meetei_Mayek]
- 11, // Range #195: [43777, 43798, Ethiopic]
- 11, // Range #196: [43808, 43822, Ethiopic]
- 25, // Range #197: [43824, 43866, Latin]
- 25, // Range #198: [43868, 43876, Latin]
- 14, // Range #199: [43877, 43877, Greek]
- 25, // Range #200: [43878, 43879, Latin]
- 6, // Range #201: [43888, 43967, Cherokee]
- 115, // Range #202: [43968, 44025, Meetei_Mayek]
- 18, // Range #203: [44032, 55203, Hangul]
- 18, // Range #204: [55216, 55291, Hangul]
- 17, // Range #205: [63744, 64217, Han]
- 25, // Range #206: [64256, 64262, Latin]
- 3, // Range #207: [64275, 64279, Armenian]
- 19, // Range #208: [64285, 64335, Hebrew]
- 2, // Range #209: [64336, 64449, Arabic]
- 2, // Range #210: [64467, 64829, Arabic]
- 2, // Range #211: [64848, 64967, Arabic]
- 2, // Range #212: [65008, 65021, Arabic]
- 8, // Range #213: [65070, 65071, Cyrillic]
- 2, // Range #214: [65136, 65276, Arabic]
- 25, // Range #215: [65313, 65338, Latin]
- 25, // Range #216: [65345, 65370, Latin]
- 22, // Range #217: [65382, 65391, Katakana]
- 22, // Range #218: [65393, 65437, Katakana]
- 18, // Range #219: [65440, 65500, Hangul]
- 49, // Range #220: [65536, 65629, Linear_B]
- 49, // Range #221: [65664, 65786, Linear_B]
- 14, // Range #222: [65856, 65934, Greek]
- 14, // Range #223: [65952, 65952, Greek]
- 107, // Range #224: [66176, 66204, Lycian]
- 104, // Range #225: [66208, 66256, Carian]
- 30, // Range #226: [66304, 66339, Old_Italic]
- 30, // Range #227: [66349, 66351, Old_Italic]
- 13, // Range #228: [66352, 66378, Gothic]
- 89, // Range #229: [66384, 66426, Old_Permic]
- 53, // Range #230: [66432, 66463, Ugaritic]
- 61, // Range #231: [66464, 66517, Old_Persian]
- 9, // Range #232: [66560, 66639, Deseret]
- 51, // Range #233: [66640, 66687, Shavian]
- 50, // Range #234: [66688, 66729, Osmanya]
- 171, // Range #235: [66736, 66811, Osage]
- 136, // Range #236: [66816, 66855, Elbasan]
- 159, // Range #237: [66864, 66915, Caucasian_Albanian]
- 159, // Range #238: [66927, 66927, Caucasian_Albanian]
- 83, // Range #239: [67072, 67382, Linear_A]
- 83, // Range #240: [67392, 67413, Linear_A]
- 83, // Range #241: [67424, 67431, Linear_A]
- 47, // Range #242: [67584, 67647, Cypriot]
- 116, // Range #243: [67648, 67679, Imperial_Aramaic]
- 144, // Range #244: [67680, 67711, Palmyrene]
- 143, // Range #245: [67712, 67742, Nabataean]
- 143, // Range #246: [67751, 67759, Nabataean]
- 162, // Range #247: [67808, 67829, Hatran]
- 162, // Range #248: [67835, 67839, Hatran]
- 91, // Range #249: [67840, 67871, Phoenician]
- 108, // Range #250: [67872, 67897, Lydian]
- 108, // Range #251: [67903, 67903, Lydian]
- 86, // Range #252: [67968, 67999, Meroitic_Hieroglyphs]
- 141, // Range #253: [68000, 68095, Meroitic_Cursive]
- 57, // Range #254: [68096, 68102, Kharoshthi]
- 57, // Range #255: [68108, 68168, Kharoshthi]
- 57, // Range #256: [68176, 68184, Kharoshthi]
- 133, // Range #257: [68192, 68223, Old_South_Arabian]
- 142, // Range #258: [68224, 68255, Old_North_Arabian]
- 121, // Range #259: [68288, 68342, Manichaean]
- 117, // Range #260: [68352, 68415, Avestan]
- 125, // Range #261: [68416, 68447, Inscriptional_Parthian]
- 122, // Range #262: [68448, 68466, Inscriptional_Pahlavi]
- 122, // Range #263: [68472, 68479, Inscriptional_Pahlavi]
- 123, // Range #264: [68480, 68497, Psalter_Pahlavi]
- 123, // Range #265: [68505, 68508, Psalter_Pahlavi]
- 123, // Range #266: [68521, 68527, Psalter_Pahlavi]
- 88, // Range #267: [68608, 68680, Old_Turkic]
- 76, // Range #268: [68736, 68786, Old_Hungarian]
- 76, // Range #269: [68800, 68850, Old_Hungarian]
- 76, // Range #270: [68858, 68863, Old_Hungarian]
- 182, // Range #271: [68864, 68903, Hanifi_Rohingya]
- 182, // Range #272: [68912, 68921, Hanifi_Rohingya]
- 2, // Range #273: [69216, 69246, Arabic]
- 184, // Range #274: [69376, 69415, Old_Sogdian]
- 183, // Range #275: [69424, 69465, Sogdian]
- 185, // Range #276: [69600, 69622, Elymaic]
- 65, // Range #277: [69632, 69743, Brahmi]
- 65, // Range #278: [69759, 69759, Brahmi]
- 120, // Range #279: [69760, 69825, Kaithi]
- 120, // Range #280: [69837, 69837, Kaithi]
- 152, // Range #281: [69840, 69864, Sora_Sompeng]
- 152, // Range #282: [69872, 69881, Sora_Sompeng]
- 118, // Range #283: [69888, 69958, Chakma]
- 160, // Range #284: [69968, 70006, Mahajani]
- 151, // Range #285: [70016, 70111, Sharada]
- 33, // Range #286: [70113, 70132, Sinhala]
- 157, // Range #287: [70144, 70206, Khojki]
- 164, // Range #288: [70272, 70313, Multani]
- 145, // Range #289: [70320, 70378, Khudawadi]
- 145, // Range #290: [70384, 70393, Khudawadi]
- 137, // Range #291: [70400, 70457, Grantha]
- 137, // Range #292: [70460, 70480, Grantha]
- 137, // Range #293: [70487, 70487, Grantha]
- 137, // Range #294: [70493, 70516, Grantha]
- 170, // Range #295: [70656, 70751, Newa]
- 158, // Range #296: [70784, 70855, Tirhuta]
- 158, // Range #297: [70864, 70873, Tirhuta]
- 166, // Range #298: [71040, 71133, Siddham]
- 163, // Range #299: [71168, 71236, Modi]
- 163, // Range #300: [71248, 71257, Modi]
- 27, // Range #301: [71264, 71276, Mongolian]
- 153, // Range #302: [71296, 71352, Takri]
- 153, // Range #303: [71360, 71369, Takri]
- 161, // Range #304: [71424, 71487, Ahom]
- 178, // Range #305: [71680, 71739, Dogra]
- 146, // Range #306: [71840, 71922, Warang_Citi]
- 146, // Range #307: [71935, 71935, Warang_Citi]
- 187, // Range #308: [72096, 72164, Nandinagari]
- 177, // Range #309: [72192, 72263, Zanabazar_Square]
- 176, // Range #310: [72272, 72354, Soyombo]
- 165, // Range #311: [72384, 72440, Pau_Cin_Hau]
- 168, // Range #312: [72704, 72773, Bhaiksuki]
- 168, // Range #313: [72784, 72812, Bhaiksuki]
- 169, // Range #314: [72816, 72886, Marchen]
- 175, // Range #315: [72960, 73031, Masaram_Gondi]
- 175, // Range #316: [73040, 73049, Masaram_Gondi]
- 179, // Range #317: [73056, 73112, Gunjala_Gondi]
- 179, // Range #318: [73120, 73129, Gunjala_Gondi]
- 180, // Range #319: [73440, 73464, Makasar]
- 35, // Range #320: [73664, 73713, Tamil]
- 35, // Range #321: [73727, 73727, Tamil]
- 101, // Range #322: [73728, 74649, Cuneiform]
- 101, // Range #323: [74752, 74868, Cuneiform]
- 101, // Range #324: [74880, 75075, Cuneiform]
- 71, // Range #325: [77824, 78904, Egyptian_Hieroglyphs]
- 156, // Range #326: [82944, 83526, Anatolian_Hieroglyphs]
- 130, // Range #327: [92160, 92728, Bamum]
- 149, // Range #328: [92736, 92783, Mro]
- 134, // Range #329: [92880, 92917, Bassa_Vah]
- 75, // Range #330: [92928, 92997, Pahawh_Hmong]
- 75, // Range #331: [93008, 93047, Pahawh_Hmong]
- 75, // Range #332: [93053, 93071, Pahawh_Hmong]
- 181, // Range #333: [93760, 93850, Medefaidrin]
- 92, // Range #334: [93952, 94087, Miao]
- 92, // Range #335: [94095, 94111, Miao]
- 154, // Range #336: [94176, 94176, Tangut]
- 150, // Range #337: [94177, 94177, Nushu]
- 154, // Range #338: [94208, 100343, Tangut]
- 154, // Range #339: [100352, 101106, Tangut]
- 22, // Range #340: [110592, 110592, Katakana]
- 20, // Range #341: [110593, 110878, Hiragana]
- 20, // Range #342: [110928, 110930, Hiragana]
- 22, // Range #343: [110948, 110951, Katakana]
- 150, // Range #344: [110960, 111355, Nushu]
- 135, // Range #345: [113664, 113770, Duployan]
- 135, // Range #346: [113776, 113800, Duployan]
- 135, // Range #347: [113808, 113823, Duployan]
- 14, // Range #348: [119296, 119365, Greek]
- 112, // Range #349: [120832, 121483, SignWriting]
- 112, // Range #350: [121499, 121519, SignWriting]
- 56, // Range #351: [122880, 122922, Glagolitic]
- 186, // Range #352: [123136, 123215, Nyiakeng_Puachue_Hmong]
- 188, // Range #353: [123584, 123641, Wancho]
- 188, // Range #354: [123647, 123647, Wancho]
- 140, // Range #355: [124928, 125142, Mende_Kikakui]
- 167, // Range #356: [125184, 125279, Adlam]
- 2, // Range #357: [126464, 126523, Arabic]
- 2, // Range #358: [126530, 126619, Arabic]
- 2, // Range #359: [126625, 126651, Arabic]
- 2, // Range #360: [126704, 126705, Arabic]
- 20, // Range #361: [127488, 127488, Hiragana]
- 17, // Range #362: [131072, 173782, Han]
- 17, // Range #363: [173824, 177972, Han]
- 17, // Range #364: [177984, 183969, Han]
- 17, // Range #365: [183984, 191456, Han]
- 17, // Range #366: [194560, 195101, Han]
-};
-
-const uint8 kMaxScript = 188;
-
-} // namespace approx_script_internal
-} // namespace mobile
-} // namespace nlp_saft
diff --git a/models/actions_suggestions.en.model b/models/actions_suggestions.en.model
deleted file mode 100644
index 6cec2b7..0000000
--- a/models/actions_suggestions.en.model
+++ /dev/null
Binary files differ
diff --git a/models/actions_suggestions.universal.model b/models/actions_suggestions.universal.model
deleted file mode 100644
index 60f10e6..0000000
--- a/models/actions_suggestions.universal.model
+++ /dev/null
Binary files differ
diff --git a/models/lang_id.model b/models/lang_id.model
deleted file mode 100644
index 49b4b07..0000000
--- a/models/lang_id.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.ar.model b/models/textclassifier.ar.model
deleted file mode 100644
index 9d8e2eb..0000000
--- a/models/textclassifier.ar.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
deleted file mode 100644
index 917db91..0000000
--- a/models/textclassifier.en.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.es.model b/models/textclassifier.es.model
deleted file mode 100644
index 94b7835..0000000
--- a/models/textclassifier.es.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.fr.model b/models/textclassifier.fr.model
deleted file mode 100644
index 19081e5..0000000
--- a/models/textclassifier.fr.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.it.model b/models/textclassifier.it.model
deleted file mode 100644
index 2f72c36..0000000
--- a/models/textclassifier.it.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.ja.model b/models/textclassifier.ja.model
deleted file mode 100644
index 92d7cef..0000000
--- a/models/textclassifier.ja.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.ko.model b/models/textclassifier.ko.model
deleted file mode 100644
index 7e88f54..0000000
--- a/models/textclassifier.ko.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.nl.model b/models/textclassifier.nl.model
deleted file mode 100644
index b2e3923..0000000
--- a/models/textclassifier.nl.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.pl.model b/models/textclassifier.pl.model
deleted file mode 100644
index 7231c49..0000000
--- a/models/textclassifier.pl.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.pt.model b/models/textclassifier.pt.model
deleted file mode 100644
index cae8692..0000000
--- a/models/textclassifier.pt.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.ru.model b/models/textclassifier.ru.model
deleted file mode 100644
index 5be2ecc..0000000
--- a/models/textclassifier.ru.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.th.model b/models/textclassifier.th.model
deleted file mode 100644
index 321edd7..0000000
--- a/models/textclassifier.th.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.tr.model b/models/textclassifier.tr.model
deleted file mode 100644
index 6d11cef..0000000
--- a/models/textclassifier.tr.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.universal.model b/models/textclassifier.universal.model
deleted file mode 100644
index af19e67..0000000
--- a/models/textclassifier.universal.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.zh-Hant.model b/models/textclassifier.zh-Hant.model
deleted file mode 100644
index 366c923..0000000
--- a/models/textclassifier.zh-Hant.model
+++ /dev/null
Binary files differ
diff --git a/models/textclassifier.zh.model b/models/textclassifier.zh.model
deleted file mode 100644
index 22f2777..0000000
--- a/models/textclassifier.zh.model
+++ /dev/null
Binary files differ
diff --git a/native/Android.bp b/native/Android.bp
new file mode 100644
index 0000000..ebbd423
--- /dev/null
+++ b/native/Android.bp
@@ -0,0 +1,394 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+cc_library_headers {
+ name: "libtextclassifier_hash_headers",
+ vendor_available: true,
+ export_include_dirs: ["."],
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.neuralnetworks",
+ "test_com.android.neuralnetworks",
+ ],
+}
+
+cc_defaults {
+ name: "libtextclassifier_hash_defaults",
+ srcs: [
+ "utils/hash/farmhash.cc",
+ "util/hash/hash.cc",
+ ],
+ cflags: [
+ "-DNAMESPACE_FOR_HASH_FUNCTIONS=farmhash",
+ "-Wall",
+ "-Werror",
+ "-Wno-unused-function",
+ ],
+}
+
+cc_library_shared {
+ name: "libtextclassifier_hash",
+ defaults: ["libtextclassifier_hash_defaults"],
+ vendor_available: true,
+ double_loadable: true,
+}
+
+cc_library_static {
+ name: "libtextclassifier_hash_static",
+ defaults: ["libtextclassifier_hash_defaults"],
+ vendor_available: true,
+ sdk_version: "current",
+ stl: "libc++_static",
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.neuralnetworks",
+ "test_com.android.neuralnetworks",
+ "com.android.extservices",
+ ],
+}
+
+cc_defaults {
+ name: "libtextclassifier_defaults",
+ stl: "libc++_static",
+ sdk_version: "current",
+ // For debug / treemap purposes.
+ //strip: {
+ // keep_symbols: true,
+ //},
+
+ cflags: [
+ "-Wall",
+ "-Werror",
+ "-Wno-deprecated-declarations",
+ "-Wno-ignored-qualifiers",
+ "-Wno-missing-field-initializers",
+ "-Wno-sign-compare",
+ "-Wno-tautological-constant-out-of-range-compare",
+ "-Wno-undefined-var-template",
+ "-Wno-unused-function",
+ "-Wno-unused-parameter",
+ "-Wno-extern-c-compat",
+
+ "-funsigned-char",
+ "-fvisibility=hidden",
+ "-DLIBTEXTCLASSIFIER_UNILIB_ICU",
+ "-DZLIB_CONST",
+ "-DSAFTM_COMPACT_LOGGING",
+ "-DTC3_WITH_ACTIONS_OPS",
+ "-DTC3_UNILIB_JAVAICU",
+ "-DTC3_CALENDAR_JAVAICU",
+ "-DTC3_AOSP"
+ ],
+
+ product_variables: {
+ debuggable: {
+ // Only enable debug logging in userdebug/eng builds.
+ cflags: ["-DTC3_DEBUG_LOGGING=1"],
+ },
+ },
+
+ generated_headers: [
+ "libtextclassifier_fbgen_flatbuffers",
+ "libtextclassifier_fbgen_tokenizer",
+ "libtextclassifier_fbgen_codepoint_range",
+ "libtextclassifier_fbgen_entity-data",
+ "libtextclassifier_fbgen_zlib_buffer",
+ "libtextclassifier_fbgen_resources_extra",
+ "libtextclassifier_fbgen_intent_config",
+ "libtextclassifier_fbgen_annotator_model",
+ "libtextclassifier_fbgen_annotator_experimental_model",
+ "libtextclassifier_fbgen_actions_model",
+ "libtextclassifier_fbgen_tflite_text_encoder_config",
+ "libtextclassifier_fbgen_lang_id_embedded_network",
+ "libtextclassifier_fbgen_lang_id_model",
+ "libtextclassifier_fbgen_actions-entity-data",
+ "libtextclassifier_fbgen_normalization",
+ "libtextclassifier_fbgen_language-tag",
+ "libtextclassifier_fbgen_person_name_model",
+ "libtextclassifier_fbgen_grammar_dates",
+ "libtextclassifier_fbgen_timezone_code",
+ "libtextclassifier_fbgen_grammar_rules"
+ ],
+
+ header_libs: [
+ "tensorflow_headers",
+ "flatbuffer_headers",
+ ],
+
+ shared_libs: [
+ "liblog",
+ "libz",
+ ],
+
+ static_libs: [
+ "liblua",
+ "libutf",
+ "libtflite_static",
+ ],
+}
+
+// -----------------
+// Generate headers with FlatBuffer schema compiler.
+// -----------------
+genrule_defaults {
+ name: "fbgen",
+ tools: ["flatc"],
+ // "depfile" is used here in conjunction with flatc's -M to gather the deps
+ cmd: "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier/native -M $(in) >$(depfile) && " +
+ "$(location flatc) --cpp --no-union-value-namespacing --gen-object-api --keep-prefix -I external/libtextclassifier/native -o $$(dirname $(out)) $(in)",
+ depfile: true,
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_flatbuffers",
+ srcs: ["utils/flatbuffers.fbs"],
+ out: ["utils/flatbuffers_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_tokenizer",
+ srcs: ["utils/tokenizer.fbs"],
+ out: ["utils/tokenizer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_codepoint_range",
+ srcs: ["utils/codepoint-range.fbs"],
+ out: ["utils/codepoint-range_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_resources_extra",
+ srcs: ["utils/resources.fbs"],
+ out: ["utils/resources_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_entity-data",
+ srcs: ["annotator/entity-data.fbs"],
+ out: ["annotator/entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_zlib_buffer",
+ srcs: ["utils/zlib/buffer.fbs"],
+ out: ["utils/zlib/buffer_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_intent_config",
+ srcs: ["utils/intents/intent-config.fbs"],
+ out: ["utils/intents/intent-config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_model",
+ srcs: ["annotator/model.fbs"],
+ out: ["annotator/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_annotator_experimental_model",
+ srcs: ["annotator/experimental/experimental.fbs"],
+ out: ["annotator/experimental/experimental_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions_model",
+ srcs: ["actions/actions_model.fbs"],
+ out: ["actions/actions_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_tflite_text_encoder_config",
+ srcs: ["utils/tflite/text_encoder_config.fbs"],
+ out: ["utils/tflite/text_encoder_config_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_embedded_network",
+ srcs: ["lang_id/common/flatbuffers/embedding-network.fbs"],
+ out: ["lang_id/common/flatbuffers/embedding-network_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_lang_id_model",
+ srcs: ["lang_id/common/flatbuffers/model.fbs"],
+ out: ["lang_id/common/flatbuffers/model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_actions-entity-data",
+ srcs: ["actions/actions-entity-data.fbs"],
+ out: ["actions/actions-entity-data_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_language-tag",
+ srcs: ["utils/i18n/language-tag.fbs"],
+ out: ["utils/i18n/language-tag_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_person_name_model",
+ srcs: ["annotator/person_name/person_name_model.fbs"],
+ out: ["annotator/person_name/person_name_model_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_grammar_dates",
+ srcs: ["annotator/grammar/dates/dates.fbs"],
+ out: ["annotator/grammar/dates/dates_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_timezone_code",
+ srcs: ["annotator/grammar/dates/timezone-code.fbs"],
+ out: ["annotator/grammar/dates/timezone-code_generated.h"],
+ defaults: ["fbgen"],
+}
+
+genrule {
+ name: "libtextclassifier_fbgen_grammar_rules",
+ srcs: ["utils/grammar/rules.fbs"],
+ out: ["utils/grammar/rules_generated.h"],
+ defaults: ["fbgen"],
+}
+
+// -----------------
+// libtextclassifier
+// -----------------
+cc_library_shared {
+ name: "libtextclassifier",
+ defaults: ["libtextclassifier_defaults"],
+
+ srcs: ["**/*.cc"],
+ exclude_srcs: [
+ "**/*_test.cc",
+ "**/*-test-lib.cc",
+ "**/testing/*.cc",
+ "**/*test-util.*",
+ "**/*test-utils.*",
+ "**/*_test-include.*",
+ "**/*unittest.cc",
+ ],
+
+ version_script: "jni.lds",
+
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.extservices",
+ ],
+}
+
+// -----------------------
+// libtextclassifier_tests
+// -----------------------
+cc_test {
+ name: "libtextclassifier_tests",
+ defaults: ["libtextclassifier_defaults"],
+
+ test_suites: ["device-tests"],
+
+ data: [
+ "annotator/test_data/**/*",
+ "actions/test_data/**/*",
+ ],
+
+ srcs: ["**/*.cc"],
+
+ header_libs: ["jni_headers"],
+
+ static_libs: [
+ "libgmock_ndk",
+ "libgtest_ndk_c++",
+ ],
+
+ multilib: {
+ lib32: {
+ cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest/libtextclassifier_tests/test_data/\""],
+ },
+ lib64: {
+ cppflags: ["-DTC3_TEST_DATA_DIR=\"/data/nativetest64/libtextclassifier_tests/test_data/\""],
+ },
+ },
+}
+
+// ----------------
+// Annotator models
+// ----------------
+
+prebuilt_etc {
+ name: "libtextclassifier_annotator_en_model",
+ filename: "textclassifier.en.model",
+ owner: "google",
+ src: "models/textclassifier.en.model",
+ sub_dir: "textclassifier",
+}
+
+prebuilt_etc {
+ name: "libtextclassifier_annotator_universal_model",
+ filename: "textclassifier.universal.model",
+ owner: "google",
+ src: "models/textclassifier.universal.model",
+ sub_dir: "textclassifier",
+}
+
+// ---------------------------
+// Actions Suggestions models
+// ---------------------------
+
+prebuilt_etc {
+ name: "libtextclassifier_actions_suggestions_universal_model",
+ filename: "actions_suggestions.universal.model",
+ owner: "google",
+ src: "models/actions_suggestions.universal.model",
+ sub_dir: "textclassifier",
+}
+
+// ------------
+// LangId model
+// ------------
+
+prebuilt_etc {
+ name: "libtextclassifier_lang_id_model",
+ filename: "lang_id.model",
+ owner: "google",
+ src: "models/lang_id.model",
+ sub_dir: "textclassifier",
+}
diff --git a/AndroidTest.xml b/native/AndroidTest.xml
similarity index 100%
rename from AndroidTest.xml
rename to native/AndroidTest.xml
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
new file mode 100755
index 0000000..21584b6
--- /dev/null
+++ b/native/actions/actions-entity-data.fbs
@@ -0,0 +1,24 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Extra information and data associated with actions.
+namespace libtextclassifier3;
+table ActionsEntityData {
+ // Extracted text.
+ text:string (shared);
+}
+
+root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
new file mode 100644
index 0000000..1fcd35c
--- /dev/null
+++ b/native/actions/actions-suggestions.cc
@@ -0,0 +1,1400 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/actions-suggestions.h"
+
+#include <memory>
+
+#include "actions/lua-actions.h"
+#include "actions/types.h"
+#include "actions/utils.h"
+#include "actions/zlib-utils.h"
+#include "annotator/collections.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "utils/lua-utils.h"
+#include "utils/normalization.h"
+#include "utils/optional.h"
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+const std::string& ActionsSuggestions::kViewCalendarType =
+ *[]() { return new std::string("view_calendar"); }();
+const std::string& ActionsSuggestions::kViewMapType =
+ *[]() { return new std::string("view_map"); }();
+const std::string& ActionsSuggestions::kTrackFlightType =
+ *[]() { return new std::string("track_flight"); }();
+const std::string& ActionsSuggestions::kOpenUrlType =
+ *[]() { return new std::string("open_url"); }();
+const std::string& ActionsSuggestions::kSendSmsType =
+ *[]() { return new std::string("send_sms"); }();
+const std::string& ActionsSuggestions::kCallPhoneType =
+ *[]() { return new std::string("call_phone"); }();
+const std::string& ActionsSuggestions::kSendEmailType =
+ *[]() { return new std::string("send_email"); }();
+const std::string& ActionsSuggestions::kShareLocation =
+ *[]() { return new std::string("share_location"); }();
+
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+ *[]() { return new std::string("time"); }();
+
+constexpr float kDefaultFloat = 0.0;
+constexpr bool kDefaultBool = false;
+constexpr int kDefaultInt = 1;
+
+namespace {
+
+const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
+ flatbuffers::Verifier verifier(addr, size);
+ if (VerifyActionsModelBuffer(verifier)) {
+ return GetActionsModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+template <typename T>
+T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
+ const T default_value) {
+ if (values == nullptr) {
+ return default_value;
+ }
+ return values->GetField<T>(field_offset, default_value);
+}
+
+// Returns number of (tail) messages of a conversation to consider.
+int NumMessagesToConsider(const Conversation& conversation,
+ const int max_conversation_history_length) {
+ return ((max_conversation_history_length < 0 ||
+ conversation.messages.size() < max_conversation_history_length)
+ ? conversation.messages.size()
+ : max_conversation_history_length);
+}
+
+} // namespace
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ const ActionsModel* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ actions->model_ = model;
+ actions->SetOrCreateUnilib(unilib);
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ if (!mmap->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+ const ActionsModel* model = LoadAndVerifyModel(
+ reinterpret_cast<const uint8_t*>(mmap->handle().start()),
+ mmap->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ actions->model_ = model;
+ actions->mmap_ = std::move(mmap);
+ actions->SetOrCreateUnilib(unilib);
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ if (!mmap->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+ const ActionsModel* model = LoadAndVerifyModel(
+ reinterpret_cast<const uint8_t*>(mmap->handle().start()),
+ mmap->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+ auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
+ actions->model_ = model;
+ actions->mmap_ = std::move(mmap);
+ actions->owned_unilib_ = std::move(unilib);
+ actions->unilib_ = actions->owned_unilib_.get();
+ actions->triggering_preconditions_overlay_buffer_ =
+ triggering_preconditions_overlay;
+ if (!actions->ValidateAndInitialize()) {
+ return nullptr;
+ }
+ return actions;
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const int offset, const int size, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
+ if (offset >= 0 && size >= 0) {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
+ } else {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd));
+ }
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
+ if (offset >= 0 && size >= 0) {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
+ } else {
+ mmap.reset(new libtextclassifier3::ScopedMmap(fd));
+ }
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
+ const int fd, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
+ const std::string& path, const UniLib* unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(path));
+ return FromScopedMmap(std::move(mmap), unilib,
+ triggering_preconditions_overlay);
+}
+
+std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay) {
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(path));
+ return FromScopedMmap(std::move(mmap), std::move(unilib),
+ triggering_preconditions_overlay);
+}
+
+void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
+ if (unilib != nullptr) {
+ unilib_ = unilib;
+ } else {
+ owned_unilib_.reset(new UniLib);
+ unilib_ = owned_unilib_.get();
+ }
+}
+
+bool ActionsSuggestions::ValidateAndInitialize() {
+ if (model_ == nullptr) {
+ TC3_LOG(ERROR) << "No model specified.";
+ return false;
+ }
+
+ if (model_->smart_reply_action_type() == nullptr) {
+ TC3_LOG(ERROR) << "No smart reply action type specified.";
+ return false;
+ }
+
+ if (!InitializeTriggeringPreconditions()) {
+ TC3_LOG(ERROR) << "Could not initialize preconditions.";
+ return false;
+ }
+
+ if (model_->locales() &&
+ !ParseLocales(model_->locales()->c_str(), &locales_)) {
+ TC3_LOG(ERROR) << "Could not parse model supported locales.";
+ return false;
+ }
+
+ if (model_->tflite_model_spec() != nullptr) {
+ model_executor_ = TfLiteModelExecutor::FromBuffer(
+ model_->tflite_model_spec()->tflite_model());
+ if (!model_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize model executor.";
+ return false;
+ }
+ }
+
+ // Gather annotation entities for the rules.
+ if (model_->annotation_actions_spec() != nullptr &&
+ model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
+ for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
+ *model_->annotation_actions_spec()->annotation_mapping()) {
+ annotation_entity_types_.insert(mapping->annotation_collection()->str());
+ }
+ }
+
+ if (model_->actions_entity_data_schema() != nullptr) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->actions_entity_data_schema()->Data(),
+ model_->actions_entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return false;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_schema_ = nullptr;
+ }
+
+ // Initialize regular expressions model.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ regex_actions_.reset(
+ new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
+ if (!regex_actions_->InitializeRules(
+ model_->rules(), model_->low_confidence_rules(),
+ triggering_preconditions_overlay_, decompressor.get())) {
+ TC3_LOG(ERROR) << "Could not initialize regex rules.";
+ return false;
+ }
+
+ // Setup grammar model.
+ if (model_->rules() != nullptr &&
+ model_->rules()->grammar_rules() != nullptr) {
+ grammar_actions_.reset(new GrammarActions(
+ unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
+ model_->smart_reply_action_type()->str()));
+
+ // Gather annotation entities for the grammars.
+ if (auto annotation_nt = model_->rules()
+ ->grammar_rules()
+ ->rules()
+ ->nonterminals()
+ ->annotation_nt()) {
+ for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
+ *annotation_nt) {
+ annotation_entity_types_.insert(entry->key()->str());
+ }
+ }
+ }
+
+ std::string actions_script;
+ if (GetUncompressedString(model_->lua_actions_script(),
+ model_->compressed_lua_actions_script(),
+ decompressor.get(), &actions_script) &&
+ !actions_script.empty()) {
+ if (!Compile(actions_script, &lua_bytecode_)) {
+ TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
+ return false;
+ }
+ }
+
+ if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ model_->ranking_options(), decompressor.get(),
+ model_->smart_reply_action_type()->str()))) {
+ TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
+ return false;
+ }
+
+ // Create feature processor if specified.
+ const ActionsTokenFeatureProcessorOptions* options =
+ model_->feature_processor_options();
+ if (options != nullptr) {
+ if (options->tokenizer_options() == nullptr) {
+ TC3_LOG(ERROR) << "No tokenizer options specified.";
+ return false;
+ }
+
+ feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
+ embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
+ options->embedding_model(), options->embedding_size(),
+ options->embedding_quantization_bits());
+
+ if (embedding_executor_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not initialize embedding executor.";
+ return false;
+ }
+
+ // Cache embedding of padding, start and end token.
+ if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
+ !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
+ !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
+ TC3_LOG(ERROR) << "Could not precompute token embeddings.";
+ return false;
+ }
+ token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
+ }
+
+ // Create low confidence model if specified.
+ if (model_->low_confidence_ngram_model() != nullptr) {
+ ngram_model_ = NGramModel::Create(
+ unilib_, model_->low_confidence_ngram_model(),
+ feature_processor_ == nullptr ? nullptr
+ : feature_processor_->tokenizer());
+ if (ngram_model_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::InitializeTriggeringPreconditions() {
+ triggering_preconditions_overlay_ =
+ LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
+ triggering_preconditions_overlay_buffer_);
+
+ if (triggering_preconditions_overlay_ == nullptr &&
+ !triggering_preconditions_overlay_buffer_.empty()) {
+ TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
+ return false;
+ }
+ const flatbuffers::Table* overlay =
+ reinterpret_cast<const flatbuffers::Table*>(
+ triggering_preconditions_overlay_);
+ const TriggeringPreconditions* defaults = model_->preconditions();
+ if (defaults == nullptr) {
+ TC3_LOG(ERROR) << "No triggering conditions specified.";
+ return false;
+ }
+
+ preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
+ defaults->min_smart_reply_triggering_score());
+ preconditions_.max_sensitive_topic_score = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
+ defaults->max_sensitive_topic_score());
+ preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
+ defaults->suppress_on_sensitive_topic());
+ preconditions_.min_input_length =
+ ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
+ defaults->min_input_length());
+ preconditions_.max_input_length =
+ ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
+ defaults->max_input_length());
+ preconditions_.min_locale_match_fraction = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
+ defaults->min_locale_match_fraction());
+ preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
+ defaults->handle_missing_locale_as_supported());
+ preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
+ defaults->handle_unknown_locale_as_supported());
+ preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
+ defaults->suppress_on_low_confidence_input());
+ preconditions_.min_reply_score_threshold = ValueOrDefault(
+ overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
+ defaults->min_reply_score_threshold());
+
+ return true;
+}
+
+bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
+ std::vector<float>* embedding) const {
+ return feature_processor_->AppendFeatures(
+ {token_id},
+ /*dense_features=*/{}, embedding_executor_.get(), embedding);
+}
+
+std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
+ const std::vector<std::string>& context) const {
+ std::vector<std::vector<Token>> tokens;
+ tokens.reserve(context.size());
+ for (const std::string& message : context) {
+ tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
+ }
+ return tokens;
+}
+
+bool ActionsSuggestions::EmbedTokensPerMessage(
+ const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
+ const int num_messages = tokens.size();
+ *max_num_tokens_per_message = 0;
+ for (int i = 0; i < num_messages; i++) {
+ const int num_message_tokens = tokens[i].size();
+ if (num_message_tokens > *max_num_tokens_per_message) {
+ *max_num_tokens_per_message = num_message_tokens;
+ }
+ }
+
+ if (model_->feature_processor_options()->min_num_tokens_per_message() >
+ *max_num_tokens_per_message) {
+ *max_num_tokens_per_message =
+ model_->feature_processor_options()->min_num_tokens_per_message();
+ }
+ if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
+ *max_num_tokens_per_message >
+ model_->feature_processor_options()->max_num_tokens_per_message()) {
+ *max_num_tokens_per_message =
+ model_->feature_processor_options()->max_num_tokens_per_message();
+ }
+
+ // Embed all tokens and add paddings to pad tokens of each message to the
+ // maximum number of tokens in a message of the conversation.
+ // If a number of tokens is specified in the model config, tokens at the
+ // beginning of a message are dropped if they don't fit in the limit.
+ for (int i = 0; i < num_messages; i++) {
+ const int start =
+ std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
+ for (int pos = start; pos < tokens[i].size(); pos++) {
+ if (!feature_processor_->AppendTokenFeatures(
+ tokens[i][pos], embedding_executor_.get(), embeddings)) {
+ TC3_LOG(ERROR) << "Could not run token feature extractor.";
+ return false;
+ }
+ }
+ // Add padding.
+ for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
+ embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
+ embedded_padding_token_.end());
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::EmbedAndFlattenTokens(
+ const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings, int* total_token_count) const {
+ const int num_messages = tokens.size();
+ int start_message = 0;
+ int message_token_offset = 0;
+
+ // If a maximum model input length is specified, we need to check how
+ // much we need to trim at the start.
+ const int max_num_total_tokens =
+ model_->feature_processor_options()->max_num_total_tokens();
+ if (max_num_total_tokens > 0) {
+ int total_tokens = 0;
+ start_message = num_messages - 1;
+ for (; start_message >= 0; start_message--) {
+ // Tokens of the message + start and end token.
+ const int num_message_tokens = tokens[start_message].size() + 2;
+ total_tokens += num_message_tokens;
+
+ // Check whether we exhausted the budget.
+ if (total_tokens >= max_num_total_tokens) {
+ message_token_offset = total_tokens - max_num_total_tokens;
+ break;
+ }
+ }
+ }
+
+ // Add embeddings.
+ *total_token_count = 0;
+ for (int i = start_message; i < num_messages; i++) {
+ if (message_token_offset == 0) {
+ ++(*total_token_count);
+ // Add `start message` token.
+ embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
+ embedded_start_token_.end());
+ }
+
+ for (int pos = std::max(0, message_token_offset - 1);
+ pos < tokens[i].size(); pos++) {
+ ++(*total_token_count);
+ if (!feature_processor_->AppendTokenFeatures(
+ tokens[i][pos], embedding_executor_.get(), embeddings)) {
+ TC3_LOG(ERROR) << "Could not run token feature extractor.";
+ return false;
+ }
+ }
+
+ // Add `end message` token.
+ ++(*total_token_count);
+ embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
+ embedded_end_token_.end());
+
+ // Reset for the subsequent messages.
+ message_token_offset = 0;
+ }
+
+ // Add optional padding.
+ const int min_num_total_tokens =
+ model_->feature_processor_options()->min_num_total_tokens();
+ for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
+ embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
+ embedded_padding_token_.end());
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::AllocateInput(const int conversation_length,
+ const int max_tokens,
+ const int total_token_count,
+ tflite::Interpreter* interpreter) const {
+ if (model_->tflite_model_spec()->resize_inputs()) {
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()->input_context()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
+ {1, conversation_length});
+ }
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
+ {conversation_length, 1});
+ }
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter
+ ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
+ {conversation_length, max_tokens, token_embedding_size_});
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ interpreter->ResizeInputTensor(
+ interpreter->inputs()[model_->tflite_model_spec()
+ ->input_flattened_token_embeddings()],
+ {1, total_token_count});
+ }
+ }
+
+ return interpreter->AllocateTensors() == kTfLiteOk;
+}
+
+bool ActionsSuggestions::SetupModelInput(
+ const std::vector<std::string>& context, const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs, const int num_suggestions,
+ const ActionSuggestionOptions& options,
+ tflite::Interpreter* interpreter) const {
+ // Compute token embeddings.
+ std::vector<std::vector<Token>> tokens;
+ std::vector<float> token_embeddings;
+ std::vector<float> flattened_token_embeddings;
+ int max_tokens = 0;
+ int total_token_count = 0;
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
+ model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
+ model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ if (feature_processor_ == nullptr) {
+ TC3_LOG(ERROR) << "No feature processor specified.";
+ return false;
+ }
+
+ // Tokenize the messages in the conversation.
+ tokens = Tokenize(context);
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
+ TC3_LOG(ERROR) << "Could not extract token features.";
+ return false;
+ }
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
+ &total_token_count)) {
+ TC3_LOG(ERROR) << "Could not extract token features.";
+ return false;
+ }
+ }
+ }
+
+ if (!AllocateInput(context.size(), max_tokens, total_token_count,
+ interpreter)) {
+ TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
+ return false;
+ }
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_context_length() >= 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_context_length(), context.size(),
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
+ user_ids, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_time_diffs(), time_diffs,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
+ std::vector<int> num_tokens_per_message(tokens.size());
+ for (int i = 0; i < tokens.size(); i++) {
+ num_tokens_per_message[i] = tokens[i].size();
+ }
+ model_executor_->SetInput<int>(
+ model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
+ interpreter);
+ }
+ if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_flattened_token_embeddings(),
+ flattened_token_embeddings, interpreter);
+ }
+ // Set up additional input parameters.
+ if (const auto* input_name_index =
+ model_->tflite_model_spec()->input_name_index()) {
+ const std::unordered_map<std::string, Variant>& model_parameters =
+ options.model_parameters;
+ for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
+ *input_name_index) {
+ const std::string param_name = entry->key()->str();
+ const int param_index = entry->value();
+ const TfLiteType param_type =
+ interpreter->tensor(interpreter->inputs()[param_index])->type;
+ const auto param_value_it = model_parameters.find(param_name);
+ const bool has_value = param_value_it != model_parameters.end();
+ switch (param_type) {
+ case kTfLiteFloat32:
+ model_executor_->SetInput<float>(
+ param_index,
+ has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
+ interpreter);
+ break;
+ case kTfLiteInt32:
+ model_executor_->SetInput<int32_t>(
+ param_index,
+ has_value ? param_value_it->second.Value<int>() : kDefaultInt,
+ interpreter);
+ break;
+ case kTfLiteInt64:
+ model_executor_->SetInput<int64_t>(
+ param_index,
+ has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
+ interpreter);
+ break;
+ case kTfLiteUInt8:
+ model_executor_->SetInput<uint8_t>(
+ param_index,
+ has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
+ interpreter);
+ break;
+ case kTfLiteInt8:
+ model_executor_->SetInput<int8_t>(
+ param_index,
+ has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
+ interpreter);
+ break;
+ case kTfLiteBool:
+ model_executor_->SetInput<bool>(
+ param_index,
+ has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
+ interpreter);
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
+ << param_name;
+ }
+ }
+ }
+ return true;
+}
+
+void ActionsSuggestions::PopulateTextReplies(
+ const tflite::Interpreter* interpreter, int suggestion_index,
+ int score_index, const std::string& type,
+ ActionsSuggestionsResponse* response) const {
+ const std::vector<tflite::StringRef> replies =
+ model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
+ const TensorView<float> scores =
+ model_executor_->OutputView<float>(score_index, interpreter);
+ for (int i = 0; i < replies.size(); i++) {
+ if (replies[i].len == 0) {
+ continue;
+ }
+ const float score = scores.data()[i];
+ if (score < preconditions_.min_reply_score_threshold) {
+ continue;
+ }
+ response->actions.push_back(
+ {std::string(replies[i].str, replies[i].len), type, score});
+ }
+}
+
+void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
+ const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+ FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
+}
+
+void ActionsSuggestions::PopulateIntentTriggering(
+ const tflite::Interpreter* interpreter, int suggestion_index,
+ int score_index, const ActionSuggestionSpec* task_spec,
+ ActionsSuggestionsResponse* response) const {
+ if (!task_spec || task_spec->type()->size() == 0) {
+ TC3_LOG(ERROR)
+ << "Task type for intent (action) triggering cannot be empty!";
+ return;
+ }
+ const TensorView<bool> intent_prediction =
+ model_executor_->OutputView<bool>(suggestion_index, interpreter);
+ const TensorView<float> intent_scores =
+ model_executor_->OutputView<float>(score_index, interpreter);
+ // Two result corresponding to binary triggering case.
+ TC3_CHECK_EQ(intent_prediction.size(), 2);
+ TC3_CHECK_EQ(intent_scores.size(), 2);
+ // We rely on in-graph thresholding logic so at this point the results
+ // have been ranked properly according to threshold.
+ const bool triggering = intent_prediction.data()[0];
+ const float trigger_score = intent_scores.data()[0];
+
+ if (triggering) {
+ ActionSuggestion suggestion;
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+ FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
+ suggestion.score = trigger_score;
+ response->actions.push_back(std::move(suggestion));
+ }
+}
+
+bool ActionsSuggestions::ReadModelOutput(
+ tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const {
+ // Read sensitivity and triggering score predictions.
+ if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
+ const TensorView<float> triggering_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_triggering_score(),
+ interpreter);
+ if (!triggering_score.is_valid() || triggering_score.size() == 0) {
+ TC3_LOG(ERROR) << "Could not compute triggering score.";
+ return false;
+ }
+ response->triggering_score = triggering_score.data()[0];
+ response->output_filtered_min_triggering_score =
+ (response->triggering_score <
+ preconditions_.min_smart_reply_triggering_score);
+ }
+ if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
+ const TensorView<float> sensitive_topic_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_sensitive_topic_score(),
+ interpreter);
+ if (!sensitive_topic_score.is_valid() ||
+ sensitive_topic_score.dim(0) != 1) {
+ TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
+ return false;
+ }
+ response->sensitivity_score = sensitive_topic_score.data()[0];
+ response->output_filtered_sensitivity =
+ (response->sensitivity_score >
+ preconditions_.max_sensitive_topic_score);
+ }
+
+ // Suppress model outputs.
+ if (response->output_filtered_sensitivity) {
+ return true;
+ }
+
+ // Read smart reply predictions.
+ if (!response->output_filtered_min_triggering_score &&
+ model_->tflite_model_spec()->output_replies() >= 0) {
+ PopulateTextReplies(interpreter,
+ model_->tflite_model_spec()->output_replies(),
+ model_->tflite_model_spec()->output_replies_scores(),
+ model_->smart_reply_action_type()->str(), response);
+ }
+
+ // Read actions suggestions.
+ if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
+ const TensorView<float> actions_scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_actions_scores(), interpreter);
+ for (int i = 0; i < model_->action_type()->size(); i++) {
+ const ActionTypeOptions* action_type = model_->action_type()->Get(i);
+ // Skip disabled action classes, such as the default other category.
+ if (!action_type->enabled()) {
+ continue;
+ }
+ const float score = actions_scores.data()[i];
+ if (score < action_type->min_triggering_score()) {
+ continue;
+ }
+
+ // Create action from model output.
+ ActionSuggestion suggestion;
+ suggestion.type = action_type->name()->str();
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+ FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
+ suggestion.score = score;
+ response->actions.push_back(std::move(suggestion));
+ }
+ }
+
+ // Read multi-task predictions and construct the result properly.
+ if (const auto* prediction_metadata =
+ model_->tflite_model_spec()->prediction_metadata()) {
+ for (const PredictionMetadata* metadata : *prediction_metadata) {
+ const ActionSuggestionSpec* task_spec = metadata->task_spec();
+ const int suggestions_index = metadata->output_suggestions();
+ const int suggestions_scores_index =
+ metadata->output_suggestions_scores();
+ switch (metadata->prediction_type()) {
+ case PredictionType_NEXT_MESSAGE_PREDICTION:
+ if (!task_spec || task_spec->type()->size() == 0) {
+ TC3_LOG(WARNING) << "Task type not provided, use default "
+ "smart_reply_action_type!";
+ }
+ PopulateTextReplies(
+ interpreter, suggestions_index, suggestions_scores_index,
+ task_spec ? task_spec->type()->str()
+ : model_->smart_reply_action_type()->str(),
+ response);
+ break;
+ case PredictionType_INTENT_TRIGGERING:
+ PopulateIntentTriggering(interpreter, suggestions_index,
+ suggestions_scores_index, task_spec,
+ response);
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported prediction type!";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool ActionsSuggestions::SuggestActionsFromModel(
+ const Conversation& conversation, const int num_messages,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response,
+ std::unique_ptr<tflite::Interpreter>* interpreter) const {
+ TC3_CHECK_LE(num_messages, conversation.messages.size());
+
+ if (!model_executor_) {
+ return true;
+ }
+ *interpreter = model_executor_->CreateInterpreter();
+
+ if (!*interpreter) {
+ TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
+ "actions suggestions model.";
+ return false;
+ }
+
+ std::vector<std::string> context;
+ std::vector<int> user_ids;
+ std::vector<float> time_diffs;
+ context.reserve(num_messages);
+ user_ids.reserve(num_messages);
+ time_diffs.reserve(num_messages);
+
+ // Gather last `num_messages` messages from the conversation.
+ int64 last_message_reference_time_ms_utc = 0;
+ const float second_in_ms = 1000;
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ const ConversationMessage& message = conversation.messages[i];
+ context.push_back(message.text);
+ user_ids.push_back(message.user_id);
+
+ float time_diff_secs = 0;
+ if (message.reference_time_ms_utc != 0 &&
+ last_message_reference_time_ms_utc != 0) {
+ time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
+ last_message_reference_time_ms_utc) /
+ second_in_ms);
+ }
+ if (message.reference_time_ms_utc != 0) {
+ last_message_reference_time_ms_utc = message.reference_time_ms_utc;
+ }
+ time_diffs.push_back(time_diff_secs);
+ }
+
+ if (!SetupModelInput(context, user_ids, time_diffs,
+ /*num_suggestions=*/model_->num_smart_replies(), options,
+ interpreter->get())) {
+ TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
+ return false;
+ }
+
+ if ((*interpreter)->Invoke() != kTfLiteOk) {
+ TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
+ return false;
+ }
+
+ return ReadModelOutput(interpreter->get(), options, response);
+}
+
+AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
+ const ConversationMessage& message) const {
+ AnnotationOptions options;
+ options.detected_text_language_tags = message.detected_text_language_tags;
+ options.reference_time_ms_utc = message.reference_time_ms_utc;
+ options.reference_timezone = message.reference_timezone;
+ options.annotation_usecase =
+ model_->annotation_actions_spec()->annotation_usecase();
+ options.is_serialized_entity_data_enabled =
+ model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
+ options.entity_types = annotation_entity_types_;
+ return options;
+}
+
+// Run annotator on the messages of a conversation.
+Conversation ActionsSuggestions::AnnotateConversation(
+ const Conversation& conversation, const Annotator* annotator) const {
+ if (annotator == nullptr) {
+ return conversation;
+ }
+ const int num_messages_grammar =
+ ((model_->rules() && model_->rules()->grammar_rules() &&
+ model_->rules()
+ ->grammar_rules()
+ ->rules()
+ ->nonterminals()
+ ->annotation_nt())
+ ? 1
+ : 0);
+ const int num_messages_mapping =
+ (model_->annotation_actions_spec()
+ ? std::max(model_->annotation_actions_spec()
+ ->max_history_from_any_person(),
+ model_->annotation_actions_spec()
+ ->max_history_from_last_person())
+ : 0);
+ const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
+ if (num_messages == 0) {
+ // No annotations are used.
+ return conversation;
+ }
+ Conversation annotated_conversation = conversation;
+ for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
+ i < num_messages && message_index >= 0; i++, message_index--) {
+ ConversationMessage* message =
+ &annotated_conversation.messages[message_index];
+ if (message->annotations.empty()) {
+ message->annotations = annotator->Annotate(
+ message->text, AnnotationOptionsForMessage(*message));
+ for (int i = 0; i < message->annotations.size(); i++) {
+ ClassificationResult* classification =
+ &message->annotations[i].classification.front();
+
+ // Specialize datetime annotation to time annotation if no date
+ // component is present.
+ if (classification->collection == Collections::DateTime() &&
+ classification->datetime_parse_result.IsSet()) {
+ bool has_only_time = true;
+ for (const DatetimeComponent& component :
+ classification->datetime_parse_result.datetime_components) {
+ if (component.component_type !=
+ DatetimeComponent::ComponentType::UNSPECIFIED &&
+ component.component_type <
+ DatetimeComponent::ComponentType::HOUR) {
+ has_only_time = false;
+ break;
+ }
+ }
+ if (has_only_time) {
+ classification->collection = kTimeAnnotation;
+ }
+ }
+ }
+ }
+ }
+ return annotated_conversation;
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotations(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const {
+ if (model_->annotation_actions_spec() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
+ model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
+ return;
+ }
+
+ // Create actions based on the annotations.
+ const int max_from_any_person =
+ model_->annotation_actions_spec()->max_history_from_any_person();
+ const int max_from_last_person =
+ model_->annotation_actions_spec()->max_history_from_last_person();
+ const int last_person = conversation.messages.back().user_id;
+
+ int num_messages_last_person = 0;
+ int num_messages_any_person = 0;
+ bool all_from_last_person = true;
+ for (int message_index = conversation.messages.size() - 1; message_index >= 0;
+ message_index--) {
+ const ConversationMessage& message = conversation.messages[message_index];
+ std::vector<AnnotatedSpan> annotations = message.annotations;
+
+ // Update how many messages we have processed from the last person in the
+ // conversation and from any person in the conversation.
+ num_messages_any_person++;
+ if (all_from_last_person && message.user_id == last_person) {
+ num_messages_last_person++;
+ } else {
+ all_from_last_person = false;
+ }
+
+ if (num_messages_any_person > max_from_any_person &&
+ (!all_from_last_person ||
+ num_messages_last_person > max_from_last_person)) {
+ break;
+ }
+
+ if (message.user_id == kLocalUserId) {
+ if (model_->annotation_actions_spec()->only_until_last_sent()) {
+ break;
+ }
+ if (!model_->annotation_actions_spec()->include_local_user_messages()) {
+ continue;
+ }
+ }
+
+ std::vector<ActionSuggestionAnnotation> action_annotations;
+ action_annotations.reserve(annotations.size());
+ for (const AnnotatedSpan& annotation : annotations) {
+ if (annotation.classification.empty()) {
+ continue;
+ }
+
+ const ClassificationResult& classification_result =
+ annotation.classification[0];
+
+ ActionSuggestionAnnotation action_annotation;
+ action_annotation.span = {
+ message_index, annotation.span,
+ UTF8ToUnicodeText(message.text, /*do_copy=*/false)
+ .UTF8Substring(annotation.span.first, annotation.span.second)};
+ action_annotation.entity = classification_result;
+ action_annotation.name = classification_result.collection;
+ action_annotations.push_back(std::move(action_annotation));
+ }
+
+ if (model_->annotation_actions_spec()->deduplicate_annotations()) {
+ // Create actions only for deduplicated annotations.
+ for (const int annotation_id :
+ DeduplicateAnnotations(action_annotations)) {
+ SuggestActionsFromAnnotation(
+ message_index, action_annotations[annotation_id], actions);
+ }
+ } else {
+ // Create actions for all annotations.
+ for (const ActionSuggestionAnnotation& annotation : action_annotations) {
+ SuggestActionsFromAnnotation(message_index, annotation, actions);
+ }
+ }
+ }
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotation(
+ const int message_index, const ActionSuggestionAnnotation& annotation,
+ std::vector<ActionSuggestion>* actions) const {
+ for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
+ *model_->annotation_actions_spec()->annotation_mapping()) {
+ if (annotation.entity.collection ==
+ mapping->annotation_collection()->str()) {
+ if (annotation.entity.score < mapping->min_annotation_score()) {
+ continue;
+ }
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
+ : nullptr;
+
+ // Set annotation text as (additional) entity data field.
+ if (mapping->entity_field() != nullptr) {
+ TC3_CHECK_NE(entity_data, nullptr);
+
+ UnicodeText normalized_annotation_text =
+ UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (mapping->normalization_options() != nullptr) {
+ normalized_annotation_text =
+ NormalizeText(*unilib_, mapping->normalization_options(),
+ normalized_annotation_text);
+ }
+
+ entity_data->ParseAndSet(mapping->entity_field(),
+ normalized_annotation_text.ToUTF8String());
+ }
+
+ ActionSuggestion suggestion;
+ FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
+ if (mapping->use_annotation_score()) {
+ suggestion.score = annotation.entity.score;
+ }
+ suggestion.annotations = {annotation};
+ actions->push_back(std::move(suggestion));
+ }
+ }
+}
+
+std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
+ const std::vector<ActionSuggestionAnnotation>& annotations) const {
+ std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
+
+ for (int i = 0; i < annotations.size(); i++) {
+ const std::pair<std::string, std::string> key = {annotations[i].name,
+ annotations[i].span.text};
+ auto entry = deduplicated_annotations.find(key);
+ if (entry != deduplicated_annotations.end()) {
+ // Kepp the annotation with the higher score.
+ if (annotations[entry->second].entity.score <
+ annotations[i].entity.score) {
+ entry->second = i;
+ }
+ continue;
+ }
+ deduplicated_annotations.insert(entry, {key, i});
+ }
+
+ std::vector<int> result;
+ result.reserve(deduplicated_annotations.size());
+ for (const auto& key_and_annotation : deduplicated_annotations) {
+ result.push_back(key_and_annotation.second);
+ }
+ return result;
+}
+
+bool ActionsSuggestions::SuggestActionsFromLua(
+ const Conversation& conversation, const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const {
+ if (lua_bytecode_.empty()) {
+ return true;
+ }
+
+ auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
+ interpreter, entity_data_schema_, annotation_entity_data_schema);
+ if (lua_actions == nullptr) {
+ TC3_LOG(ERROR) << "Could not create lua actions.";
+ return false;
+ }
+ return lua_actions->SuggestActions(actions);
+}
+
+bool ActionsSuggestions::GatherActionsSuggestions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const {
+ if (conversation.messages.empty()) {
+ return true;
+ }
+
+ // Run annotator against messages.
+ const Conversation annotated_conversation =
+ AnnotateConversation(conversation, annotator);
+
+ const int num_messages = NumMessagesToConsider(
+ annotated_conversation, model_->max_conversation_history_length());
+
+ if (num_messages <= 0) {
+ TC3_LOG(INFO) << "No messages provided for actions suggestions.";
+ return false;
+ }
+
+ SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
+
+ if (grammar_actions_ != nullptr &&
+ !grammar_actions_->SuggestActions(annotated_conversation,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
+ return false;
+ }
+
+ int input_text_length = 0;
+ int num_matching_locales = 0;
+ for (int i = annotated_conversation.messages.size() - num_messages;
+ i < annotated_conversation.messages.size(); i++) {
+ input_text_length += annotated_conversation.messages[i].text.length();
+ std::vector<Locale> message_languages;
+ if (!ParseLocales(
+ annotated_conversation.messages[i].detected_text_language_tags,
+ &message_languages)) {
+ continue;
+ }
+ if (Locale::IsAnyLocaleSupported(
+ message_languages, locales_,
+ preconditions_.handle_unknown_locale_as_supported)) {
+ ++num_matching_locales;
+ }
+ }
+
+ // Bail out if we are provided with too few or too much input.
+ if (input_text_length < preconditions_.min_input_length ||
+ (preconditions_.max_input_length >= 0 &&
+ input_text_length > preconditions_.max_input_length)) {
+ TC3_LOG(INFO) << "Too much or not enough input for inference.";
+ return response;
+ }
+
+ // Bail out if the text does not look like it can be handled by the model.
+ const float matching_fraction =
+ static_cast<float>(num_matching_locales) / num_messages;
+ if (matching_fraction < preconditions_.min_locale_match_fraction) {
+ TC3_LOG(INFO) << "Not enough locale matches.";
+ response->output_filtered_locale_mismatch = true;
+ return true;
+ }
+
+ std::vector<const UniLib::RegexPattern*> post_check_rules;
+ if (preconditions_.suppress_on_low_confidence_input) {
+ if ((ngram_model_ != nullptr &&
+ ngram_model_->EvalConversation(annotated_conversation,
+ num_messages)) ||
+ regex_actions_->IsLowConfidenceInput(annotated_conversation,
+ num_messages, &post_check_rules)) {
+ response->output_filtered_low_confidence = true;
+ return true;
+ }
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
+ response, &interpreter)) {
+ TC3_LOG(ERROR) << "Could not run model.";
+ return false;
+ }
+
+ // Suppress all predictions if the conversation was deemed sensitive.
+ if (preconditions_.suppress_on_sensitive_topic &&
+ response->output_filtered_sensitivity) {
+ return true;
+ }
+
+ if (!SuggestActionsFromLua(
+ annotated_conversation, model_executor_.get(), interpreter.get(),
+ annotator != nullptr ? annotator->entity_data_schema() : nullptr,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from script.";
+ return false;
+ }
+
+ if (!regex_actions_->SuggestActions(annotated_conversation,
+ entity_data_builder_.get(),
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
+ return false;
+ }
+
+ if (preconditions_.suppress_on_low_confidence_input &&
+ !regex_actions_->FilterConfidenceOutput(post_check_rules,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not post-check actions.";
+ return false;
+ }
+
+ return true;
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options) const {
+ ActionsSuggestionsResponse response;
+
+ // Assert that messages are sorted correctly.
+ for (int i = 1; i < conversation.messages.size(); i++) {
+ if (conversation.messages[i].reference_time_ms_utc <
+ conversation.messages[i - 1].reference_time_ms_utc) {
+ TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
+ }
+ }
+
+ if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
+ TC3_LOG(ERROR) << "Could not gather actions suggestions.";
+ response.actions.clear();
+ } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
+ annotator != nullptr
+ ? annotator->entity_data_schema()
+ : nullptr)) {
+ TC3_LOG(ERROR) << "Could not rank actions.";
+ response.actions.clear();
+ }
+ return response;
+}
+
+ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options) const {
+ return SuggestActions(conversation, /*annotator=*/nullptr, options);
+}
+
+const ActionsModel* ActionsSuggestions::model() const { return model_; }
+const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
+ return entity_data_schema_;
+}
+
+const ActionsModel* ViewActionsModel(const void* buffer, int size) {
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+ return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
new file mode 100644
index 0000000..2a321f0
--- /dev/null
+++ b/native/actions/actions-suggestions.h
@@ -0,0 +1,302 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "actions/feature-processor.h"
+#include "actions/grammar-actions.h"
+#include "actions/ngram-model.h"
+#include "actions/ranker.h"
+#include "actions/regex-actions.h"
+#include "actions/types.h"
+#include "annotator/annotator.h"
+#include "annotator/model-executor.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/i18n/locale.h"
+#include "utils/memory/mmap.h"
+#include "utils/tflite-model-executor.h"
+#include "utils/utf8/unilib.h"
+#include "utils/variant.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Options for suggesting actions.
+struct ActionSuggestionOptions {
+ static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
+ std::unordered_map<std::string, Variant> model_parameters;
+};
+
+// Class for predicting actions following a conversation.
+class ActionsSuggestions {
+ public:
+ // Creates ActionsSuggestions from given data buffer with model.
+ static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
+ const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+
+ // Creates ActionsSuggestions from model in the ScopedMmap object and takes
+ // ownership of it.
+ static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ // Creates ActionsSuggestions from model given as a file descriptor, offset
+ // and size in it. If offset and size are less than 0, will ignore them and
+ // will just use the fd.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const int offset, const int size,
+ std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay = "");
+
+ // Creates ActionsSuggestions from model given as a file descriptor.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of the unilib.
+ static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
+ const int fd, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ // Creates ActionsSuggestions from model given as a POSIX path.
+ static std::unique_ptr<ActionsSuggestions> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr,
+ const std::string& triggering_preconditions_overlay = "");
+ // Same as above, but also takes ownership of unilib.
+ static std::unique_ptr<ActionsSuggestions> FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ const std::string& triggering_preconditions_overlay);
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+
+ ActionsSuggestionsResponse SuggestActions(
+ const Conversation& conversation, const Annotator* annotator,
+ const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+
+ const ActionsModel* model() const;
+ const reflection::Schema* entity_data_schema() const;
+
+ static constexpr int kLocalUserId = 0;
+
+ // Should be in sync with those defined in Android.
+ // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
+ static const std::string& kViewCalendarType;
+ static const std::string& kViewMapType;
+ static const std::string& kTrackFlightType;
+ static const std::string& kOpenUrlType;
+ static const std::string& kSendSmsType;
+ static const std::string& kCallPhoneType;
+ static const std::string& kSendEmailType;
+ static const std::string& kShareLocation;
+
+ protected:
+ // Exposed for testing.
+ bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
+
+ // Embeds the tokens per message separately. Each message is padded to the
+ // maximum length with the padding token.
+ bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings,
+ int* max_num_tokens_per_message) const;
+
+ // Concatenates the embedded message tokens - separated by start and end
+ // token between messages.
+ // If the total token count is greater than the maximum length, tokens at the
+ // start are dropped to fit into the limit.
+ // If the total token count is smaller than the minimum length, padding tokens
+ // are added to the end.
+ // Messages are assumed to be ordered by recency - most recent is last.
+ bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
+ std::vector<float>* embeddings,
+ int* total_token_count) const;
+
+ const ActionsModel* model_;
+
+ // Feature extractor and options.
+ std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
+ std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
+ std::vector<float> embedded_padding_token_;
+ std::vector<float> embedded_start_token_;
+ std::vector<float> embedded_end_token_;
+ int token_embedding_size_;
+
+ private:
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ bool ValidateAndInitialize();
+
+ void SetOrCreateUnilib(const UniLib* unilib);
+
+ // Prepare preconditions.
+ // Takes values from flag provided data, but falls back to model provided
+ // values for parameters that are not explicitly provided.
+ bool InitializeTriggeringPreconditions();
+
+ // Tokenizes a conversation and produces the tokens per message.
+ std::vector<std::vector<Token>> Tokenize(
+ const std::vector<std::string>& context) const;
+
+ bool AllocateInput(const int conversation_length, const int max_tokens,
+ const int total_token_count,
+ tflite::Interpreter* interpreter) const;
+
+ bool SetupModelInput(const std::vector<std::string>& context,
+ const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs,
+ const int num_suggestions,
+ const ActionSuggestionOptions& options,
+ tflite::Interpreter* interpreter) const;
+
+ void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
+ ActionSuggestion* suggestion) const;
+
+ void PopulateTextReplies(const tflite::Interpreter* interpreter,
+ int suggestion_index, int score_index,
+ const std::string& type,
+ ActionsSuggestionsResponse* response) const;
+
+ void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
+ int suggestion_index, int score_index,
+ const ActionSuggestionSpec* task_spec,
+ ActionsSuggestionsResponse* response) const;
+
+ bool ReadModelOutput(tflite::Interpreter* interpreter,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const;
+
+ bool SuggestActionsFromModel(
+ const Conversation& conversation, const int num_messages,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response,
+ std::unique_ptr<tflite::Interpreter>* interpreter) const;
+
+ // Creates options for annotation of a message.
+ AnnotationOptions AnnotationOptionsForMessage(
+ const ConversationMessage& message) const;
+
+ void SuggestActionsFromAnnotations(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* actions) const;
+
+ void SuggestActionsFromAnnotation(
+ const int message_index, const ActionSuggestionAnnotation& annotation,
+ std::vector<ActionSuggestion>* actions) const;
+
+ // Run annotator on the messages of a conversation.
+ Conversation AnnotateConversation(const Conversation& conversation,
+ const Annotator* annotator) const;
+
+ // Deduplicates equivalent annotations - annotations that have the same type
+ // and same span text.
+ // Returns the indices of the deduplicated annotations.
+ std::vector<int> DeduplicateAnnotations(
+ const std::vector<ActionSuggestionAnnotation>& annotations) const;
+
+ bool SuggestActionsFromLua(
+ const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* annotation_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const;
+
+ bool GatherActionsSuggestions(const Conversation& conversation,
+ const Annotator* annotator,
+ const ActionSuggestionOptions& options,
+ ActionsSuggestionsResponse* response) const;
+
+ std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
+
+ // Tensorflow Lite models.
+ std::unique_ptr<const TfLiteModelExecutor> model_executor_;
+
+ // Regex rules model.
+ std::unique_ptr<RegexActions> regex_actions_;
+
+ // The grammar rules model.
+ std::unique_ptr<GrammarActions> grammar_actions_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
+
+ // Locales supported by the model.
+ std::vector<Locale> locales_;
+
+ // Annotation entities used by the model.
+ std::unordered_set<std::string> annotation_entity_types_;
+
+ // Builder for creating extra data.
+ const reflection::Schema* entity_data_schema_;
+ std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+ std::unique_ptr<ActionsSuggestionsRanker> ranker_;
+
+ std::string lua_bytecode_;
+
+ // Triggering preconditions. These parameters can be backed by the model and
+ // (partially) be provided by flags.
+ TriggeringPreconditionsT preconditions_;
+ std::string triggering_preconditions_overlay_buffer_;
+ const TriggeringPreconditions* triggering_preconditions_overlay_;
+
+ // Low confidence input ngram classifier.
+ std::unique_ptr<const NGramModel> ngram_model_;
+};
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const ActionsModel* ViewActionsModel(const void* buffer, int size);
+
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitActionsModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const ActionsModel* model =
+ ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
new file mode 100644
index 0000000..7dd0169
--- /dev/null
+++ b/native/actions/actions_jni.cc
@@ -0,0 +1,517 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for actions.
+
+#include "actions/actions_jni.h"
+
+#include <jni.h>
+
+#include <map>
+#include <type_traits>
+#include <vector>
+
+#include "actions/actions-suggestions.h"
+#include "annotator/annotator.h"
+#include "annotator/annotator_jni_common.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/jni.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/jni-helper.h"
+#include "utils/java/string_utils.h"
+#include "utils/memory/mmap.h"
+
+using libtextclassifier3::ActionsSuggestions;
+using libtextclassifier3::ActionsSuggestionsResponse;
+using libtextclassifier3::ActionSuggestion;
+using libtextclassifier3::ActionSuggestionOptions;
+using libtextclassifier3::Annotator;
+using libtextclassifier3::Conversation;
+using libtextclassifier3::IntentGenerator;
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
+using libtextclassifier3::ToStlString;
+
+// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
+// pointer from JNI. When using a standard ICU the pointer is not needed and the
+// objects are instantiated implicitly.
+#ifdef TC3_UNILIB_JAVAICU
+using libtextclassifier3::UniLib;
+#endif
+
+namespace libtextclassifier3 {
+
+namespace {
+
+// Cached state for model inference.
+// Keeps a jni cache, intent generator and model instance so that they don't
+// have to be recreated for each call.
+class ActionsSuggestionsJniContext {
+ public:
+ static ActionsSuggestionsJniContext* Create(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<ActionsSuggestions> model) {
+ if (jni_cache == nullptr || model == nullptr) {
+ return nullptr;
+ }
+ std::unique_ptr<IntentGenerator> intent_generator =
+ IntentGenerator::Create(model->model()->android_intent_options(),
+ model->model()->resources(), jni_cache);
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
+
+ if (intent_generator == nullptr || template_handler == nullptr) {
+ return nullptr;
+ }
+
+ return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
+ std::move(intent_generator),
+ std::move(template_handler));
+ }
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
+ return jni_cache_;
+ }
+
+ ActionsSuggestions* model() const { return model_.get(); }
+
+ IntentGenerator* intent_generator() const { return intent_generator_.get(); }
+
+ RemoteActionTemplatesHandler* template_handler() const {
+ return template_handler_.get();
+ }
+
+ private:
+ ActionsSuggestionsJniContext(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<ActionsSuggestions> model,
+ std::unique_ptr<IntentGenerator> intent_generator,
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
+ : jni_cache_(jni_cache),
+ model_(std::move(model)),
+ intent_generator_(std::move(intent_generator)),
+ template_handler_(std::move(template_handler)) {}
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
+ std::unique_ptr<ActionsSuggestions> model_;
+ std::unique_ptr<IntentGenerator> intent_generator_;
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
+};
+
+ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
+ jobject joptions) {
+ ActionSuggestionOptions options = ActionSuggestionOptions::Default();
+ return options;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>> ActionSuggestionsToJObjectArray(
+ JNIEnv* env, const ActionsSuggestionsJniContext* context,
+ jobject app_context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const std::vector<ActionSuggestion>& action_result,
+ const Conversation& conversation, const jstring device_locales,
+ const bool generate_intents) {
+ auto status_or_result_class = JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestion");
+ if (!status_or_result_class.ok()) {
+ TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
+ return status_or_result_class.status();
+ }
+ ScopedLocalRef<jclass> result_class =
+ std::move(status_or_result_class.ValueOrDie());
+
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR
+ ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
+ ";)V"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, action_result.size(),
+ result_class.get(), nullptr));
+ for (int i = 0; i < action_result.size(); i++) {
+ ScopedLocalRef<jobjectArray> extras;
+ const reflection::Schema* actions_entity_data_schema =
+ context->model()->entity_data_schema();
+ if (actions_entity_data_schema != nullptr &&
+ !action_result[i].serialized_entity_data.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ extras, context->template_handler()->EntityDataAsNamedVariantArray(
+ actions_entity_data_schema,
+ action_result[i].serialized_entity_data));
+ }
+
+ ScopedLocalRef<jbyteArray> serialized_entity_data;
+ if (!action_result[i].serialized_entity_data.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ serialized_entity_data,
+ JniHelper::NewByteArray(
+ env, action_result[i].serialized_entity_data.size()));
+ env->SetByteArrayRegion(
+ serialized_entity_data.get(), 0,
+ action_result[i].serialized_entity_data.size(),
+ reinterpret_cast<const jbyte*>(
+ action_result[i].serialized_entity_data.data()));
+ }
+
+ ScopedLocalRef<jobjectArray> remote_action_templates_result;
+ if (generate_intents) {
+ std::vector<RemoteActionTemplate> remote_action_templates;
+ if (context->intent_generator()->GenerateIntents(
+ device_locales, action_result[i], conversation, app_context,
+ /*annotations_entity_data_schema=*/annotations_entity_data_schema,
+ /*actions_entity_data_schema=*/actions_entity_data_schema,
+ &remote_action_templates)) {
+ TC3_ASSIGN_OR_RETURN(
+ remote_action_templates_result,
+ context->template_handler()->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates));
+ }
+ }
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reply,
+ context->jni_cache()->ConvertToJavaString(
+ action_result[i].response_text));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> action_type,
+ JniHelper::NewStringUTF(env, action_result[i].type.c_str()));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ reply.get(), action_type.get(),
+ static_cast<jfloat>(action_result[i].score),
+ extras.get(), serialized_entity_data.get(),
+ remote_action_templates_result.get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
+ }
+ return results;
+}
+
+StatusOr<ConversationMessage> FromJavaConversationMessage(JNIEnv* env,
+ jobject jmessage) {
+ if (!jmessage) {
+ return {};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> message_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage"));
+ // .getText()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_text_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getText",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> text,
+ JniHelper::CallObjectMethod<jstring>(env, jmessage, get_text_method));
+
+ // .getUserId()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_user_id_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getUserId", "()I"));
+ TC3_ASSIGN_OR_RETURN(int32 user_id, JniHelper::CallIntMethod(
+ env, jmessage, get_user_id_method));
+
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, message_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, jmessage, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getReferenceTimezone",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, jmessage, get_reference_timezone_method));
+
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, message_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, jmessage, get_detected_text_language_tags_method));
+
+ ConversationMessage message;
+ TC3_ASSIGN_OR_RETURN(message.text, ToStlString(env, text.get()));
+ message.user_id = user_id;
+ message.reference_time_ms_utc = reference_time;
+ TC3_ASSIGN_OR_RETURN(message.reference_timezone,
+ ToStlString(env, reference_timezone.get()));
+ TC3_ASSIGN_OR_RETURN(message.detected_text_language_tags,
+ ToStlString(env, detected_text_language_tags.get()));
+ return message;
+}
+
+StatusOr<Conversation> FromJavaConversation(JNIEnv* env,
+ jobject jconversation) {
+ if (!jconversation) {
+ return {Status::UNKNOWN};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> conversation_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Conversation"));
+
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_conversation_messages_method,
+ JniHelper::GetMethodID(env, conversation_class.get(),
+ "getConversationMessages",
+ "()[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> jmessages,
+ JniHelper::CallObjectMethod<jobjectArray>(
+ env, jconversation, get_conversation_messages_method));
+
+ std::vector<ConversationMessage> messages;
+ const int size = env->GetArrayLength(jmessages.get());
+ for (int i = 0; i < size; i++) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> jmessage,
+ JniHelper::GetObjectArrayElement<jobject>(env, jmessages.get(), i));
+ TC3_ASSIGN_OR_RETURN(ConversationMessage message,
+ FromJavaConversationMessage(env, jmessage.get()));
+ messages.push_back(message);
+ }
+ Conversation conversation;
+ conversation.messages = messages;
+ return conversation;
+}
+
+StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ return JniHelper::NewStringUTF(env, model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ return JniHelper::NewStringUTF(env, model->name()->c_str());
+}
+} // namespace
+} // namespace libtextclassifier3
+
+using libtextclassifier3::ActionsSuggestionsJniContext;
+using libtextclassifier3::ActionSuggestionsToJObjectArray;
+using libtextclassifier3::FromJavaActionSuggestionOptions;
+using libtextclassifier3::FromJavaConversation;
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
+ preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromPath(
+ path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
+ preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+ jbyteArray serialized_preconditions) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
+ libtextclassifier3::JniCache::Create(env);
+ std::string preconditions;
+ if (serialized_preconditions != nullptr &&
+ !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
+ &preconditions)) {
+ TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
+ return 0;
+ }
+#ifdef TC3_UNILIB_JAVAICU
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache,
+ ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ preconditions)));
+#else
+ return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
+ jni_cache, ActionsSuggestions::FromFileDescriptor(
+ fd, offset, size, /*unilib=*/nullptr, preconditions)));
+#endif // TC3_UNILIB_JAVAICU
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr, jobject app_context, jstring device_locales,
+ jboolean generate_intents) {
+ if (!ptr) {
+ return nullptr;
+ }
+ TC3_ASSIGN_OR_RETURN_NULL(const Conversation conversation,
+ FromJavaConversation(env, jconversation));
+ const ActionSuggestionOptions options =
+ FromJavaActionSuggestionOptions(env, joptions);
+ const ActionsSuggestionsJniContext* context =
+ reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
+ const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
+
+ const ActionsSuggestionsResponse response =
+ context->model()->SuggestActions(conversation, annotator, options);
+
+ const reflection::Schema* anntotations_entity_data_schema =
+ annotator ? annotator->entity_data_schema() : nullptr;
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> result,
+ ActionSuggestionsToJObjectArray(
+ env, context, app_context, anntotations_entity_data_schema,
+ response.actions, conversation, device_locales, generate_intents));
+ return result.release();
+}
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject clazz, jlong model_ptr) {
+ const ActionsSuggestionsJniContext* context =
+ reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
+ delete context;
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetLocalesFromMmap(env, mmap.get()));
+ return result.release();
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetLocalesFromMmap(env, mmap.get()));
+ return result.release();
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetNameFromMmap(env, mmap.get()));
+ return result.release();
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetNameFromMmap(env, mmap.get()));
+ return result.release();
+}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
new file mode 100644
index 0000000..276e361
--- /dev/null
+++ b/native/actions/actions_jni.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_ACTIONS_CLASS_NAME
+#define TC3_ACTIONS_CLASS_NAME ActionsSuggestionsModel
+#endif
+
+#define TC3_ACTIONS_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ACTIONS_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
+(JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions);
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
+(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions);
+
+TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size,
+ jbyteArray serialized_preconditions);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
+ jlong annotatorPtr, jobject app_context, jstring device_locales,
+ jboolean generate_intents);
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_JNI_H_
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
new file mode 100755
index 0000000..251610e
--- /dev/null
+++ b/native/actions/actions_model.fbs
@@ -0,0 +1,556 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "actions/actions-entity-data.fbs";
+include "annotator/model.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers.fbs";
+include "utils/grammar/rules.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
+include "utils/resources.fbs";
+include "utils/tokenizer.fbs";
+include "utils/zlib/buffer.fbs";
+
+file_identifier "TC3A";
+
+// Prediction type for a multi-task model.
+namespace libtextclassifier3;
+enum PredictionType : int {
+ UNSUPPORTED = 0,
+ NEXT_MESSAGE_PREDICTION = 1,
+ INTENT_TRIGGERING = 2,
+ ENTITY_ANNOTATION = 3,
+}
+
+// Prediction metadata for an arbitrary task.
+namespace libtextclassifier3;
+table PredictionMetadata {
+ prediction_type:PredictionType;
+ task_spec:ActionSuggestionSpec;
+ output_suggestions:int;
+ output_suggestions_scores:int;
+ output_suggestions_spans:int;
+}
+
+namespace libtextclassifier3.TensorflowLiteModelSpec_;
+table InputNameIndexEntry {
+ key:string (key, shared);
+ value:int;
+}
+
+// TensorFlow Lite model for suggesting actions.
+namespace libtextclassifier3;
+table TensorflowLiteModelSpec {
+ // TensorFlow Lite model for suggesting actions.
+ tflite_model:[ubyte] (force_align: 16);
+
+ // Input specification.
+ // (num messages,) int32 tensor, the user id per message.
+ input_user_id:int = 0;
+
+ // (num messages,) string tensor, each message of the conversation.
+ input_context:int = 1;
+
+ // int, the number of messages in the conversation.
+ input_context_length:int = 2;
+
+ // (num messages,) float tensor, the time difference in seconds of the
+ // messages in the conversation.
+ input_time_diffs:int = 3;
+
+ // int, the number of smart replies to produce.
+ input_num_suggestions:int = 4;
+
+ // float, the output diversification distance parameter.
+ reserved_7:int (deprecated);
+
+ // float, the empirical probability factor parameter.
+ reserved_8:int (deprecated);
+
+ // float, the confidence threshold.
+ reserved_9:int (deprecated);
+
+ // Input port for hashed and embedded tokens, a (num messages, max tokens,
+ // embedding size) float tensor specifying the embeddings of each token of
+ // each message in the conversation.
+ input_token_embeddings:int = -1;
+
+ // Input port for the number of tokens per message.
+ // (num messages) int32 tensor specifying the number of tokens in each message
+ // in the conversation.
+ input_num_tokens:int = -1;
+
+ // Output specification.
+ output_replies:int = 0;
+
+ output_replies_scores:int = 1;
+ output_sensitive_topic_score:int = 3;
+ output_triggering_score:int = 4;
+ output_actions_scores:int = 5;
+
+ // Model setup.
+ // When true, the inputs are resized to the concrete input sizes before
+ // inference otherwise, it's assumed that the model has the correct input
+ // shapes set.
+ resize_inputs:bool = false;
+
+ // Input port for the hashed, embedded and flattened/concatenated tokens.
+ // A (max tokens, embedding_size) float tensor specifying the embeddings of
+ // each token.
+ input_flattened_token_embeddings:int = -1;
+
+ // Generalized output specification that handles arbitrary number of
+ // prediction tasks.
+ prediction_metadata:[PredictionMetadata];
+
+ // Map of additional input tensor name to its index.
+ input_name_index:[TensorflowLiteModelSpec_.InputNameIndexEntry];
+}
+
+// Configuration for the tokenizer.
+namespace libtextclassifier3;
+table ActionsTokenizerOptions {
+ type:TokenizationType = INTERNAL_TOKENIZER;
+
+ // If true, white space tokens will be kept when using the icu tokenizer.
+ icu_preserve_whitespace_tokens:bool = false;
+
+ // Codepoint ranges that determine what role the different codepoints play
+ // during tokenized. The ranges must not overlap.
+ tokenization_codepoint_config:[TokenizationCodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[CodepointRange];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = false;
+}
+
+// Configuration for the feature processor.
+namespace libtextclassifier3;
+table ActionsTokenFeatureProcessorOptions {
+ // Tokenizer options.
+ tokenizer_options:ActionsTokenizerOptions;
+
+ // Serialized TensorFlow Lite model with weights for the token embeddings.
+ embedding_model:[ubyte] (force_align: 16);
+
+ // Size of the embedding.
+ embedding_size:int = -1;
+
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
+
+ // Number of buckets used for hashing charactergrams.
+ num_buckets:int = -1;
+
+ // Orders of charactergrams to extract, e.g. 2 means character bigrams, 3
+ // character trigrams etc.
+ chargram_orders:[int];
+
+ // Whether to extract the token case feature.
+ extract_case_feature:bool;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ unicode_aware_features:bool;
+
+ // Regexp features to extract.
+ regexp_features:[string];
+
+ // Whether to remap digits to a single number.
+ remap_digits:bool;
+
+ // Whether to lowercase all tokens.
+ lowercase_tokens:bool;
+
+ // Maximum length of a word.
+ max_token_length:int = 20;
+
+ // The `max_num_tokens_per_message` and `min_num_tokens_per_message` are
+ // applied when tokens are embedded per message.
+ // If set and the number of tokens of a message is bigger than this limit,
+ // tokens at the beginning of the message are dropped to fit the limit.
+ max_num_tokens_per_message:int = -1;
+
+ // If set, the tokens of each message will be padded to this fixed number of
+ // tokens.
+ min_num_tokens_per_message:int = -1;
+
+ // If set and the total number of concatenated tokens is bigger than this
+ // limit, tokens at the start of the conversation are dropped.
+ max_num_total_tokens:int = -1;
+
+ // If set and the total number of concatenaed tokens is smaller than this
+ // limit, the conversation is padded with padding tokens.
+ min_num_total_tokens:int = -1;
+
+ // Id that is used as encoding of the padding token.
+ padding_token_id:int = 0;
+
+ // Id that is used as encoding of the start of message token.
+ start_token_id:int = 1;
+
+ // Id that is used as encoding of the end of message token.
+ end_token_id:int = 2;
+}
+
+// N-Gram based linear regression model.
+namespace libtextclassifier3;
+table NGramLinearRegressionModel {
+ // A flat list of all the hashed n-grams concatenated back to back. Elements
+ // should only ever be accessed via the offset table below.
+ hashed_ngram_tokens:[uint];
+
+ // Offsets to the start of the n-grams in hashed_ngram_tokens. The last
+ // element in this array is the length of hashed_ngrams to make it easier to
+ // compute n-gram lengths.
+ ngram_start_offsets:[ushort];
+
+ // Weights of the n-grams.
+ ngram_weights:[float];
+
+ // The default weight assigned to n-grams that weren't matched.
+ default_token_weight:float;
+
+ // Maximum n-gram length to consider when calculating the denominatior.
+ // This should usually be the same as max_ngram_length but can diverge
+ // if additional (longer) n-grams are added to a model as part of a minor
+ // update.
+ max_denom_ngram_length:int;
+
+ // If non-zero, the order of the skip-gram to match.
+ max_skips:int;
+
+ // The threshold above which the model output is considered positive.
+ threshold:float;
+
+ // Model specific tokenizer options.
+ // If not specified, will reuse the feature processor tokenizer.
+ tokenizer_options:ActionsTokenizerOptions;
+}
+
+namespace libtextclassifier3;
+table TriggeringPreconditions {
+ // Lower bound thresholds for the smart reply model prediction output.
+ min_smart_reply_triggering_score:float;
+
+ // Maximum sensitive score for which actions and smart replies are shown.
+ max_sensitive_topic_score:float = 1;
+
+ // Whether to suppress all model output when a conversation is classified as
+ // sensitive.
+ suppress_on_sensitive_topic:bool = true;
+
+ // Thresholds on the model prediction input.
+ // The minimal length of input to consider for prediction.
+ min_input_length:int = 0;
+
+ // The maximal length of input to consider for prediciton, -1 if unbounded.
+ max_input_length:int = -1;
+
+ // Minimal fraction of messages in the input conversation that need to match
+ // a locale that the model can handle.
+ min_locale_match_fraction:float = 0.75;
+
+ handle_missing_locale_as_supported:bool = false;
+ handle_unknown_locale_as_supported:bool = false;
+
+ // Filter input with low-confidence triggers.
+ suppress_on_low_confidence_input:bool = true;
+
+ // Same as low_confidence_rules in ActionsModel.
+ // NOTE: Only fill this when the TriggeringPreconditions are pushed separately
+ // as a flag value (i.e. as overlay).
+ low_confidence_rules:RulesModel;
+
+ reserved_11:float (deprecated);
+ reserved_12:float (deprecated);
+ reserved_13:float (deprecated);
+
+ // Smart reply thresholds.
+ min_reply_score_threshold:float = 0;
+}
+
+namespace libtextclassifier3;
+table ActionSuggestionSpec {
+ // Type of the action suggestion.
+ type:string (shared);
+
+ // Text of a smart reply action.
+ response_text:string (shared);
+
+ // Score.
+ score:float;
+
+ // Additional entity information.
+ serialized_entity_data:string (shared);
+
+ // Priority score used for internal conflict resolution.
+ priority_score:float = 0;
+
+ entity_data:ActionsEntityData;
+}
+
+// Options to specify triggering behaviour per action class.
+namespace libtextclassifier3;
+table ActionTypeOptions {
+ // The name of the predicted action.
+ name:string (shared);
+
+ // Triggering behaviour.
+ // Whether the action class is considered in the model output or not.
+ enabled:bool = true;
+
+ // Minimal output score threshold.
+ min_triggering_score:float = 0;
+
+ // The action to trigger.
+ action:ActionSuggestionSpec;
+}
+
+namespace libtextclassifier3.AnnotationActionsSpec_;
+table AnnotationMapping {
+ // The annotation collection.
+ annotation_collection:string (shared);
+
+ // The action name to use.
+ action:ActionSuggestionSpec;
+
+ // Whether to use the score of the annotation as the action score.
+ use_annotation_score:bool = true;
+
+ // Minimum threshold for the annotation score for filtering.
+ min_annotation_score:float;
+
+ // If set, the text of the annotation will be used to set a field in the
+ // action entity data.
+ entity_field:FlatbufferFieldPath;
+
+ // If set, normalization to apply to the annotation text.
+ normalization_options:NormalizationOptions;
+}
+
+// Configuration for actions based on annotatations.
+namespace libtextclassifier3;
+table AnnotationActionsSpec {
+ annotation_mapping:[AnnotationActionsSpec_.AnnotationMapping];
+
+ // Whether to deduplicate annotations by type and text prior to generating
+ // actions.
+ deduplicate_annotations:bool = true;
+
+ // Annotation usecase to specify for text annotation.
+ annotation_usecase:AnnotationUsecase = ANNOTATION_USECASE_SMART;
+
+ // Maximum number of recent messages to consider from any person.
+ // We consider at most `max_history_from_any_person` many recent messages if
+ // they were received from different users or at most the maximum of this and
+ // `max_history_from_last_person` if they are all from the same user.
+ max_history_from_any_person:int = 1;
+
+ // Maximum number of recent messages to consider from the last person.
+ max_history_from_last_person:int = 1;
+
+ // Whether to include messages from the local user.
+ include_local_user_messages:bool = false;
+
+ // Whether to only consider messages up to the last one sent by the local
+ // user.
+ only_until_last_sent:bool = true;
+
+ // If true, annotator would populate serialized_entity_data in the results.
+ is_serialized_entity_data_enabled:bool = true;
+}
+
+// Ranking options.
+namespace libtextclassifier3;
+table RankingOptions {
+ // When true, actions suggestions are deduplicated by `type`, `response_text`
+ // and associated annotations, keeping the higher scoring actions.
+ deduplicate_suggestions:bool = true;
+
+ // When true, actions are deduplicated by the span they are referring to.
+ deduplicate_suggestions_by_span:bool = true;
+
+ // Optional script to run for ranking and filtering the action suggestions.
+ // The following global variables are available to the script:
+ // * input: (optionally deduplicated) action suggestions, via the `actions`
+ // global
+ // * output: indices of the actions to keep in the provided order.
+ lua_ranking_script:string (shared);
+
+ compressed_lua_ranking_script:CompressedBuffer;
+
+ // If true, suppresses smart replies if other smart actions are suggested.
+ suppress_smart_replies_with_actions:bool = false;
+
+ // If true, keep actions from the same entities together for ranking.
+ group_by_annotations:bool = true;
+}
+
+// Entity data to set from capturing groups.
+namespace libtextclassifier3.RulesModel_.RuleActionSpec_;
+table RuleCapturingGroup {
+ // The id of group.
+ group_id:int;
+
+ // If set, the text of the capturing group will be used to set a field
+ // in the action entity data.
+ entity_field:FlatbufferFieldPath;
+
+ // If set, the capturing group will be used to create a text annotation
+ // with the given name and type.
+ annotation_type:string (shared);
+
+ annotation_name:string (shared);
+
+ // If set, the capturing group text will be used to create a text
+ // reply.
+ text_reply:ActionSuggestionSpec;
+
+ // If set, normalization to apply to the capturing group text.
+ normalization_options:NormalizationOptions;
+
+ // If set to true, an existing annotator annotation will be used to
+ // create the actions suggestions text annotation.
+ use_annotation_match:bool;
+
+ // If set, merge in fixed entity data for a match.
+ entity_data:ActionsEntityData;
+}
+
+// The actions to produce upon triggering.
+namespace libtextclassifier3.RulesModel_;
+table RuleActionSpec {
+ // The action.
+ action:ActionSuggestionSpec;
+
+ capturing_group:[RuleActionSpec_.RuleCapturingGroup];
+}
+
+// List of regular expression matchers.
+namespace libtextclassifier3.RulesModel_;
+table RegexRule {
+ // The regular expression pattern.
+ pattern:string (shared);
+
+ compressed_pattern:CompressedBuffer;
+ actions:[RuleActionSpec];
+
+ // Patterns for post-checking the outputs.
+ output_pattern:string (shared);
+
+ compressed_output_pattern:CompressedBuffer;
+}
+
+// Action configuration.
+// Specifies an action rules match.
+namespace libtextclassifier3.RulesModel_.GrammarRules_;
+table RuleMatch {
+ // The actions to produce as part of this match.
+ // These are indices into the `actions` array below.
+ action_id:[uint];
+}
+
+// Configuration for actions based on context-free grammars.
+namespace libtextclassifier3.RulesModel_;
+table GrammarRules {
+ // The tokenizer config.
+ tokenizer_options:ActionsTokenizerOptions;
+
+ // The grammar.
+ rules:grammar.RulesSet;
+
+ rule_match:[GrammarRules_.RuleMatch];
+
+ // The action specifications used by the rule matches.
+ actions:[RuleActionSpec];
+}
+
+// Rule based actions.
+namespace libtextclassifier3;
+table RulesModel {
+ regex_rule:[RulesModel_.RegexRule];
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
+
+ grammar_rules:RulesModel_.GrammarRules;
+}
+
+namespace libtextclassifier3;
+table ActionsModel {
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string (shared);
+
+ // Version of the actions model.
+ version:int;
+
+ // A name for the model that can be used e.g. for logging.
+ name:string (shared);
+
+ tflite_model_spec:TensorflowLiteModelSpec;
+
+ // Output classes.
+ smart_reply_action_type:string (shared);
+
+ action_type:[ActionTypeOptions];
+
+ // Triggering conditions of the model.
+ preconditions:TriggeringPreconditions;
+
+ // Default number of smart reply predictions.
+ num_smart_replies:int = 3;
+
+ // Length of message history to consider, -1 if unbounded.
+ max_conversation_history_length:int = 1;
+
+ // Configuration for mapping annotations to action suggestions.
+ annotation_actions_spec:AnnotationActionsSpec;
+
+ // Configuration for rules.
+ rules:RulesModel;
+
+ // Configuration for intent generation on Android.
+ android_intent_options:IntentFactoryModel;
+
+ // Model resources.
+ resources:ResourcePool;
+
+ // Schema data for handling entity data.
+ actions_entity_data_schema:[ubyte];
+
+ // Action ranking options.
+ ranking_options:RankingOptions;
+
+ // Lua based actions.
+ lua_actions_script:string (shared);
+
+ compressed_lua_actions_script:CompressedBuffer;
+
+ // Low confidence classifiers.
+ low_confidence_rules:RulesModel;
+
+ low_confidence_ngram_model:NGramLinearRegressionModel;
+
+ // Feature processor options.
+ feature_processor_options:ActionsTokenFeatureProcessorOptions;
+}
+
+root_type libtextclassifier3.ActionsModel;
diff --git a/native/actions/feature-processor.cc b/native/actions/feature-processor.cc
new file mode 100644
index 0000000..249a132
--- /dev/null
+++ b/native/actions/feature-processor.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/feature-processor.h"
+
+namespace libtextclassifier3 {
+namespace {
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const ActionsTokenFeatureProcessorOptions* const options) {
+ TokenFeatureExtractorOptions extractor_options;
+ extractor_options.num_buckets = options->num_buckets();
+ if (options->chargram_orders() != nullptr) {
+ for (int order : *options->chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
+ }
+ extractor_options.max_word_length = options->max_token_length();
+ extractor_options.extract_case_feature = options->extract_case_feature();
+ extractor_options.unicode_aware_features = options->unicode_aware_features();
+ extractor_options.extract_selection_mask_feature = false;
+ if (options->regexp_features() != nullptr) {
+ for (const auto regexp_feature : *options->regexp_features()) {
+ extractor_options.regexp_features.push_back(regexp_feature->str());
+ }
+ }
+ extractor_options.remap_digits = options->remap_digits();
+ extractor_options.lowercase_tokens = options->lowercase_tokens();
+ return extractor_options;
+}
+} // namespace
+
+std::unique_ptr<Tokenizer> CreateTokenizer(
+ const ActionsTokenizerOptions* options, const UniLib* unilib) {
+ std::vector<const TokenizationCodepointRange*> codepoint_config;
+ if (options->tokenization_codepoint_config() != nullptr) {
+ codepoint_config.insert(codepoint_config.end(),
+ options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end());
+ }
+ std::vector<const CodepointRange*> internal_codepoint_config;
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ internal_codepoint_config.insert(
+ internal_codepoint_config.end(),
+ options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end());
+ }
+ const bool tokenize_on_script_change =
+ options->tokenization_codepoint_config() != nullptr &&
+ options->tokenize_on_script_change();
+ return std::unique_ptr<Tokenizer>(new Tokenizer(
+ options->type(), unilib, codepoint_config, internal_codepoint_config,
+ tokenize_on_script_change, options->icu_preserve_whitespace_tokens()));
+}
+
+ActionsFeatureProcessor::ActionsFeatureProcessor(
+ const ActionsTokenFeatureProcessorOptions* options, const UniLib* unilib)
+ : options_(options),
+ tokenizer_(CreateTokenizer(options->tokenizer_options(), unilib)),
+ token_feature_extractor_(BuildTokenFeatureExtractorOptions(options),
+ unilib) {}
+
+int ActionsFeatureProcessor::GetTokenEmbeddingSize() const {
+ return options_->embedding_size() +
+ token_feature_extractor_.DenseFeaturesCount();
+}
+
+bool ActionsFeatureProcessor::AppendFeatures(
+ const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ // Embed the sparse features, appending them directly to the output.
+ const int embedding_size = options_->embedding_size();
+ output_features->resize(output_features->size() + embedding_size);
+ float* output_features_end =
+ output_features->data() + output_features->size();
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ /*dest=*/output_features_end - embedding_size,
+ /*dest_size=*/embedding_size)) {
+ TC3_LOG(ERROR) << "Could not embed token's sparse features.";
+ return false;
+ }
+
+ // Append the dense features to the output.
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+}
+
+bool ActionsFeatureProcessor::AppendTokenFeatures(
+ const Token& token, const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ // Extract the sparse and dense features.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ if (!token_feature_extractor_.Extract(token, /*(unused) is_in_span=*/false,
+ &sparse_features, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's features.";
+ return false;
+ }
+ return AppendFeatures(sparse_features, dense_features, embedding_executor,
+ output_features);
+}
+
+bool ActionsFeatureProcessor::AppendTokenFeatures(
+ const std::vector<Token>& tokens,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const {
+ for (const Token& token : tokens) {
+ if (!AppendTokenFeatures(token, embedding_executor, output_features)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/feature-processor.h b/native/actions/feature-processor.h
new file mode 100644
index 0000000..5e4085a
--- /dev/null
+++ b/native/actions/feature-processor.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "annotator/model-executor.h"
+#include "annotator/types.h"
+#include "utils/token-feature-extractor.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Create tokenizer from options.
+std::unique_ptr<Tokenizer> CreateTokenizer(
+ const ActionsTokenizerOptions* options, const UniLib* unilib);
+
+// Feature processor for the actions suggestions model.
+class ActionsFeatureProcessor {
+ public:
+ explicit ActionsFeatureProcessor(
+ const ActionsTokenFeatureProcessorOptions* options, const UniLib* unilib);
+
+ // Embeds and appends features to the output vector.
+ bool AppendFeatures(const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ // Extracts the features of a token and appends them to the output vector.
+ bool AppendTokenFeatures(const Token& token,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ // Extracts the features of a vector of tokens and appends each to the output
+ // vector.
+ bool AppendTokenFeatures(const std::vector<Token>& tokens,
+ const EmbeddingExecutor* embedding_executor,
+ std::vector<float>* output_features) const;
+
+ int GetTokenEmbeddingSize() const;
+
+ const Tokenizer* tokenizer() const { return tokenizer_.get(); }
+
+ private:
+ const ActionsTokenFeatureProcessorOptions* options_;
+ const std::unique_ptr<Tokenizer> tokenizer_;
+ const TokenFeatureExtractor token_feature_extractor_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_
diff --git a/native/actions/feature-processor_test.cc b/native/actions/feature-processor_test.cc
new file mode 100644
index 0000000..969bbf7
--- /dev/null
+++ b/native/actions/feature-processor_test.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/feature-processor.h"
+
+#include "actions/actions_model_generated.h"
+#include "annotator/model-executor.h"
+#include "utils/tensor-view.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::testing::FloatEq;
+using ::testing::SizeIs;
+
+// EmbeddingExecutor that always returns features based on
+// the id of the sparse features.
+class FakeEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ const int dest_size) const override {
+ TC3_CHECK_GE(dest_size, 4);
+ EXPECT_THAT(sparse_features, SizeIs(1));
+ dest[0] = sparse_features.data()[0];
+ dest[1] = sparse_features.data()[0];
+ dest[2] = -sparse_features.data()[0];
+ dest[3] = -sparse_features.data()[0];
+ return true;
+ }
+
+ private:
+ std::vector<float> storage_;
+};
+
+class FeatureProcessorTest : public ::testing::Test {
+ protected:
+ FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+
+ flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
+ ActionsTokenFeatureProcessorOptionsT* options) const {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
+ return builder.Release();
+ }
+
+ FakeEmbeddingExecutor embedding_executor_;
+ UniLib unilib_;
+};
+
+TEST_F(FeatureProcessorTest, TokenEmbeddings) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ Token token("aaa", 0, 3);
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
+ &token_features));
+ EXPECT_THAT(token_features, SizeIs(4));
+}
+
+TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.extract_case_feature = true;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ Token token("Aaa", 0, 3);
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
+ &token_features));
+ EXPECT_THAT(token_features, SizeIs(5));
+ EXPECT_THAT(token_features[4], FloatEq(1.0));
+}
+
+TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
+ ActionsTokenFeatureProcessorOptionsT options;
+ options.embedding_size = 4;
+ options.extract_case_feature = true;
+ options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
+
+ flatbuffers::DetachedBuffer options_fb =
+ PackFeatureProcessorOptions(&options);
+ ActionsFeatureProcessor feature_processor(
+ flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
+ options_fb.data()),
+ &unilib_);
+
+ const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
+ Token("Cccc", 8, 12)};
+ std::vector<float> token_features;
+ EXPECT_TRUE(feature_processor.AppendTokenFeatures(
+ tokens, &embedding_executor_, &token_features));
+ EXPECT_THAT(token_features, SizeIs(15));
+ EXPECT_THAT(token_features[4], FloatEq(1.0));
+ EXPECT_THAT(token_features[9], FloatEq(-1.0));
+ EXPECT_THAT(token_features[14], FloatEq(1.0));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions.cc b/native/actions/grammar-actions.cc
new file mode 100644
index 0000000..7f3e71f
--- /dev/null
+++ b/native/actions/grammar-actions.cc
@@ -0,0 +1,254 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/grammar-actions.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "actions/feature-processor.h"
+#include "actions/utils.h"
+#include "annotator/types.h"
+#include "utils/grammar/callback-delegate.h"
+#include "utils/grammar/match.h"
+#include "utils/grammar/matcher.h"
+#include "utils/grammar/rules-utils.h"
+#include "utils/i18n/language-tag_generated.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class GrammarActionsCallbackDelegate : public grammar::CallbackDelegate {
+ public:
+ GrammarActionsCallbackDelegate(const UniLib* unilib,
+ const RulesModel_::GrammarRules* grammar_rules)
+ : unilib_(*unilib), grammar_rules_(grammar_rules) {}
+
+ // Handle a grammar rule match in the actions grammar.
+ void MatchFound(const grammar::Match* match, grammar::CallbackId type,
+ int64 value, grammar::Matcher* matcher) override {
+ switch (static_cast<GrammarActions::Callback>(type)) {
+ case GrammarActions::Callback::kActionRuleMatch: {
+ HandleRuleMatch(match, /*rule_id=*/value);
+ return;
+ }
+ default:
+ grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
+ }
+ }
+
+ // Deduplicate, verify and populate actions from grammar matches.
+ bool GetActions(const Conversation& conversation,
+ const std::string& smart_reply_action_type,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ std::vector<ActionSuggestion>* action_suggestions) const {
+ std::vector<UnicodeText::const_iterator> codepoint_offsets;
+ const UnicodeText message_unicode =
+ UTF8ToUnicodeText(conversation.messages.back().text,
+ /*do_copy=*/false);
+ for (auto it = message_unicode.begin(); it != message_unicode.end(); it++) {
+ codepoint_offsets.push_back(it);
+ }
+ codepoint_offsets.push_back(message_unicode.end());
+ for (const grammar::Derivation& candidate :
+ grammar::DeduplicateDerivations(candidates_)) {
+ // Check that assertions are fulfilled.
+ if (!VerifyAssertions(candidate.match)) {
+ continue;
+ }
+ if (!InstantiateActionsFromMatch(
+ codepoint_offsets,
+ /*message_index=*/conversation.messages.size() - 1,
+ smart_reply_action_type, candidate, entity_data_builder,
+ action_suggestions)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private:
+ // Handles action rule matches.
+ void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
+ candidates_.push_back(grammar::Derivation{match, rule_id});
+ }
+
+ // Instantiates action suggestions from verified and deduplicated rule matches
+ // and appends them to the result.
+ // Expects the message as codepoints for text extraction from capturing
+ // matches as well as the index of the message, for correct span production.
+ bool InstantiateActionsFromMatch(
+ const std::vector<UnicodeText::const_iterator>& message_codepoint_offsets,
+ int message_index, const std::string& smart_reply_action_type,
+ const grammar::Derivation& candidate,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ std::vector<ActionSuggestion>* result) const {
+ const RulesModel_::GrammarRules_::RuleMatch* rule_match =
+ grammar_rules_->rule_match()->Get(candidate.rule_id);
+ if (rule_match == nullptr || rule_match->action_id() == nullptr) {
+ TC3_LOG(ERROR) << "No rule action defined.";
+ return false;
+ }
+
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::Match*> capturing_matches;
+ for (const grammar::MappingMatch* match :
+ grammar::SelectAllOfType<grammar::MappingMatch>(
+ candidate.match, grammar::Match::kMappingMatch)) {
+ capturing_matches[match->id] = match;
+ }
+
+ // Instantiate actions from the rule match.
+ for (const uint16 action_id : *rule_match->action_id()) {
+ const RulesModel_::RuleActionSpec* action_spec =
+ grammar_rules_->actions()->Get(action_id);
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder != nullptr ? entity_data_builder->NewRoot()
+ : nullptr;
+
+ // Set information from capturing matches.
+ if (action_spec->capturing_group() != nullptr) {
+ for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
+ *action_spec->capturing_group()) {
+ auto it = capturing_matches.find(group->group_id());
+ if (it == capturing_matches.end()) {
+ // Capturing match is not active, skip.
+ continue;
+ }
+
+ const grammar::Match* capturing_match = it->second;
+ StringPiece match_text = StringPiece(
+ message_codepoint_offsets[capturing_match->codepoint_span.first]
+ .utf8_data(),
+ message_codepoint_offsets[capturing_match->codepoint_span.second]
+ .utf8_data() -
+ message_codepoint_offsets[capturing_match->codepoint_span
+ .first]
+ .utf8_data());
+ UnicodeText normalized_match_text =
+ NormalizeMatchText(unilib_, group, match_text);
+
+ if (!MergeEntityDataFromCapturingMatch(
+ group, normalized_match_text.ToUTF8String(),
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not merge entity data from a capturing match.";
+ return false;
+ }
+
+ // Add smart reply suggestions.
+ SuggestTextRepliesFromCapturingMatch(entity_data_builder, group,
+ normalized_match_text,
+ smart_reply_action_type, result);
+
+ // Add annotation.
+ ActionSuggestionAnnotation annotation;
+ if (FillAnnotationFromCapturingMatch(
+ /*span=*/capturing_match->codepoint_span, group,
+ /*message_index=*/message_index, match_text, &annotation)) {
+ if (group->use_annotation_match()) {
+ const grammar::AnnotationMatch* annotation_match =
+ grammar::SelectFirstOfType<grammar::AnnotationMatch>(
+ capturing_match, grammar::Match::kAnnotationMatch);
+ if (!annotation_match) {
+ TC3_LOG(ERROR) << "Could not get annotation for match.";
+ return false;
+ }
+ annotation.entity = *annotation_match->annotation;
+ }
+ annotations.push_back(std::move(annotation));
+ }
+ }
+ }
+
+ if (action_spec->action() != nullptr) {
+ ActionSuggestion suggestion;
+ suggestion.annotations = annotations;
+ FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
+ &suggestion);
+ result->push_back(std::move(suggestion));
+ }
+ }
+ return true;
+ }
+
+ const UniLib& unilib_;
+ const RulesModel_::GrammarRules* grammar_rules_;
+
+ // All action rule match candidates.
+ // Grammar rule matches are recorded, deduplicated, verified and then
+ // instantiated.
+ std::vector<grammar::Derivation> candidates_;
+};
+} // namespace
+
+GrammarActions::GrammarActions(
+ const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const std::string& smart_reply_action_type)
+ : unilib_(*unilib),
+ grammar_rules_(grammar_rules),
+ tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
+ lexer_(unilib, grammar_rules->rules()),
+ entity_data_builder_(entity_data_builder),
+ smart_reply_action_type_(smart_reply_action_type),
+ rules_locales_(ParseRulesLocales(grammar_rules->rules())) {}
+
+bool GrammarActions::SuggestActions(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* result) const {
+ if (grammar_rules_->rules()->rules() == nullptr) {
+ // Nothing to do.
+ return true;
+ }
+
+ std::vector<Locale> locales;
+ if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
+ &locales)) {
+ TC3_LOG(ERROR) << "Could not parse locales of input text.";
+ return false;
+ }
+
+ // Select locale matching rules.
+ std::vector<const grammar::RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(grammar_rules_->rules(), rules_locales_,
+ locales);
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ GrammarActionsCallbackDelegate callback_handler(&unilib_, grammar_rules_);
+ grammar::Matcher matcher(&unilib_, grammar_rules_->rules(), locale_rules,
+ &callback_handler);
+
+ const UnicodeText text =
+ UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false);
+
+ // Run grammar on last message.
+ lexer_.Process(text, tokenizer_->Tokenize(text),
+ /*annotations=*/&conversation.messages.back().annotations,
+ &matcher);
+
+ // Populate results.
+ return callback_handler.GetActions(conversation, smart_reply_action_type_,
+ entity_data_builder_, result);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/grammar-actions.h b/native/actions/grammar-actions.h
new file mode 100644
index 0000000..fc3270d
--- /dev/null
+++ b/native/actions/grammar-actions.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_GRAMMAR_ACTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_GRAMMAR_ACTIONS_H_
+
+#include <memory>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/grammar/lexer.h"
+#include "utils/grammar/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Grammar backed actions suggestions.
+class GrammarActions {
+ public:
+ enum class Callback : grammar::CallbackId { kActionRuleMatch = 1 };
+
+ explicit GrammarActions(
+ const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const std::string& smart_reply_action_type);
+
+ // Suggests actions for a conversation from a message stream.
+ bool SuggestActions(const Conversation& conversation,
+ std::vector<ActionSuggestion>* result) const;
+
+ private:
+ const UniLib& unilib_;
+ const RulesModel_::GrammarRules* grammar_rules_;
+ const std::unique_ptr<Tokenizer> tokenizer_;
+ const grammar::Lexer lexer_;
+ const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const std::string smart_reply_action_type_;
+
+ // Pre-parsed locales of the rules.
+ const std::vector<std::vector<Locale>> rules_locales_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_GRAMMAR_ACTIONS_H_
diff --git a/native/actions/lua-actions.cc b/native/actions/lua-actions.cc
new file mode 100644
index 0000000..7cf871a
--- /dev/null
+++ b/native/actions/lua-actions.cc
@@ -0,0 +1,168 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-actions.h"
+
+#include "utils/base/logging.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+TensorView<float> GetTensorViewForOutput(
+ const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter, int output) {
+ if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
+ return TensorView<float>::Invalid();
+ }
+ return model_executor->OutputView<float>(output, interpreter);
+}
+
+std::vector<std::string> GetStringTensorForOutput(
+ const TfLiteModelExecutor* model_executor,
+ const tflite::Interpreter* interpreter, int output) {
+ if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
+ return {};
+ }
+ return model_executor->Output<std::string>(output, interpreter);
+}
+
+} // namespace
+
+std::unique_ptr<LuaActionsSuggestions>
+LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) {
+ auto lua_actions =
+ std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
+ snippet, conversation, model_executor, model_spec, interpreter,
+ actions_entity_data_schema, annotations_entity_data_schema));
+ if (!lua_actions->Initialize()) {
+ TC3_LOG(ERROR)
+ << "Could not initialize lua environment for actions suggestions.";
+ return nullptr;
+ }
+ return lua_actions;
+}
+
+LuaActionsSuggestions::LuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema)
+ : snippet_(snippet),
+ conversation_(conversation),
+ actions_scores_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_actions_scores())),
+ smart_reply_scores_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_replies_scores())),
+ sensitivity_score_(model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(
+ model_executor, interpreter,
+ model_spec->output_sensitive_topic_score())),
+ triggering_score_(
+ model_spec == nullptr
+ ? TensorView<float>::Invalid()
+ : GetTensorViewForOutput(model_executor, interpreter,
+ model_spec->output_triggering_score())),
+ smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
+ : GetStringTensorForOutput(
+ model_executor, interpreter,
+ model_spec->output_replies())),
+ actions_entity_data_schema_(actions_entity_data_schema),
+ annotations_entity_data_schema_(annotations_entity_data_schema) {}
+
+bool LuaActionsSuggestions::Initialize() {
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose conversation message stream.
+ PushConversation(&conversation_.messages,
+ annotations_entity_data_schema_);
+ lua_setglobal(state_, "messages");
+
+ // Expose ML model output.
+ lua_newtable(state_);
+
+ PushTensor(&actions_scores_);
+ lua_setfield(state_, /*idx=*/-2, "actions_scores");
+
+ PushTensor(&smart_reply_scores_);
+ lua_setfield(state_, /*idx=*/-2, "reply_scores");
+
+ PushTensor(&sensitivity_score_);
+ lua_setfield(state_, /*idx=*/-2, "sensitivity");
+
+ PushTensor(&triggering_score_);
+ lua_setfield(state_, /*idx=*/-2, "triggering_score");
+
+ PushVectorIterator(&smart_replies_);
+ lua_setfield(state_, /*idx=*/-2, "reply");
+
+ lua_setglobal(state_, "model");
+
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+bool LuaActionsSuggestions::SuggestActions(
+ std::vector<ActionSuggestion>* actions) {
+ if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
+ return false;
+ }
+
+ if (RunProtected(
+ [this, actions] {
+ return ReadActions(actions_entity_data_schema_,
+ annotations_entity_data_schema_, actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/lua-actions.h b/native/actions/lua-actions.h
new file mode 100644
index 0000000..b0c68b6
--- /dev/null
+++ b/native/actions/lua-actions.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "utils/lua-utils.h"
+#include "utils/tensor-view.h"
+#include "utils/tflite-model-executor.h"
+
+namespace libtextclassifier3 {
+
+// Lua backed actions suggestions.
+class LuaActionsSuggestions : public LuaEnvironment {
+ public:
+ static std::unique_ptr<LuaActionsSuggestions> CreateLuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema);
+
+ bool SuggestActions(std::vector<ActionSuggestion>* actions);
+
+ private:
+ LuaActionsSuggestions(
+ const std::string& snippet, const Conversation& conversation,
+ const TfLiteModelExecutor* model_executor,
+ const TensorflowLiteModelSpec* model_spec,
+ const tflite::Interpreter* interpreter,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema);
+
+ bool Initialize();
+
+ template <typename T>
+ void PushTensor(const TensorView<T>* tensor) const {
+ PushIterator(tensor ? tensor->size() : 0,
+ [this, tensor](const int64 index) {
+ Push(tensor->data()[index]);
+ return 1; // Num. values pushed.
+ });
+ }
+
+ const std::string& snippet_;
+ const Conversation& conversation_;
+ TensorView<float> actions_scores_;
+ TensorView<float> smart_reply_scores_;
+ TensorView<float> sensitivity_score_;
+ TensorView<float> triggering_score_;
+ const std::vector<std::string> smart_replies_;
+ const reflection::Schema* actions_entity_data_schema_;
+ const reflection::Schema* annotations_entity_data_schema_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_ACTIONS_H_
diff --git a/native/actions/lua-actions_test.cc b/native/actions/lua-actions_test.cc
new file mode 100644
index 0000000..72cae2c
--- /dev/null
+++ b/native/actions/lua-actions_test.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-actions.h"
+
+#include <map>
+#include <string>
+
+#include "actions/test-utils.h"
+#include "actions/types.h"
+#include "utils/tflite-model-executor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(LuaActions, SimpleAction) {
+ Conversation conversation;
+ const std::string test_snippet = R"(
+ return {{ type = "test_action" }}
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
+}
+
+TEST(LuaActions, ConversationActions) {
+ Conversation conversation;
+ conversation.messages.push_back({/*user_id=*/0, "hello there!"});
+ conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
+ const std::string test_snippet = R"(
+ local actions = {}
+ for i, message in pairs(messages) do
+ if i < #messages then
+ if message.text == "hello there!" and
+ messages[i+1].text == "general kenobi!" then
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "you are a bold one!"
+ })
+ end
+ if message.text == "i am the senate!" and
+ messages[i+1].text == "not yet!" then
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "it's treason then"
+ })
+ end
+ end
+ end
+ return actions;
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
+}
+
+TEST(LuaActions, SimpleModelAction) {
+ Conversation conversation;
+ const std::string test_snippet = R"(
+ if #model.actions_scores == 0 then
+ return {{ type = "test_action" }}
+ end
+ return {}
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
+}
+
+TEST(LuaActions, SimpleModelRepliesAction) {
+ Conversation conversation;
+ const std::string test_snippet = R"(
+ if #model.reply == 0 then
+ return {{ type = "test_action" }}
+ end
+ return {}
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
+}
+
+TEST(LuaActions, AnnotationActions) {
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ Conversation conversation = {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}};
+ const std::string test_snippet = R"(
+ local actions = {}
+ local last_message = messages[#messages]
+ for i, annotation in pairs(last_message.annotation) do
+ if #annotation.classification > 0 then
+ if annotation.classification[1].collection == "address" then
+ local text = string.sub(last_message.text,
+ annotation.span["begin"] + 1,
+ annotation.span["end"])
+ table.insert(actions, {
+ type = "text_reply",
+ response_text = "i am at " .. text,
+ annotation = {{
+ name = "location",
+ span = {
+ text = text
+ },
+ entity = annotation.classification[1]
+ }},
+ })
+ end
+ end
+ end
+ return actions;
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/nullptr,
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
+ EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
+}
+
+TEST(LuaActions, EntityData) {
+ std::string test_schema = TestEntityDataSchema();
+ Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
+ const std::string test_snippet = R"(
+ return {{
+ type = "test",
+ entity = {
+ greeting = "hello",
+ location = "there",
+ person = "Kenobi",
+ },
+ }};
+ )";
+ std::vector<ActionSuggestion> actions;
+ EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
+ test_snippet, conversation,
+ /*model_executor=*/nullptr,
+ /*model_spec=*/nullptr,
+ /*interpreter=*/nullptr,
+ /*actions_entity_data_schema=*/
+ flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
+ /*annotations_entity_data_schema=*/nullptr)
+ ->SuggestActions(&actions));
+ EXPECT_THAT(actions, testing::SizeIs(1));
+ EXPECT_EQ("test", actions.front().type);
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "hello");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "there");
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
+ "Kenobi");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/actions/lua-ranker.cc b/native/actions/lua-ranker.cc
new file mode 100644
index 0000000..73032d0
--- /dev/null
+++ b/native/actions/lua-ranker.cc
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/lua-ranker.h"
+
+#include "utils/base/logging.h"
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<ActionsSuggestionsLuaRanker>
+ActionsSuggestionsLuaRanker::Create(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response) {
+ auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
+ new ActionsSuggestionsLuaRanker(
+ conversation, ranker_code, entity_data_schema,
+ annotations_entity_data_schema, response));
+ if (!ranker->Initialize()) {
+ TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
+ return nullptr;
+ }
+ return ranker;
+}
+
+bool ActionsSuggestionsLuaRanker::Initialize() {
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose generated actions.
+ PushActions(&response_->actions, actions_entity_data_schema_,
+ annotations_entity_data_schema_);
+ lua_setglobal(state_, "actions");
+
+ // Expose conversation message stream.
+ PushConversation(&conversation_.messages,
+ annotations_entity_data_schema_);
+ lua_setglobal(state_, "messages");
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected actions table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ std::vector<ActionSuggestion> ranked_actions;
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const int action_id = Read<int>(/*index=*/-1) - 1;
+ lua_pop(state_, 1);
+ if (action_id < 0 || action_id >= response_->actions.size()) {
+ TC3_LOG(ERROR) << "Invalid action index: " << action_id;
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ ranked_actions.push_back(response_->actions[action_id]);
+ }
+ lua_pop(state_, 1);
+ response_->actions = ranked_actions;
+ return LUA_OK;
+}
+
+bool ActionsSuggestionsLuaRanker::RankActions() {
+ if (response_->actions.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run ranking snippet.";
+ return false;
+ }
+
+ if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
+ LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/lua-ranker.h b/native/actions/lua-ranker.h
new file mode 100644
index 0000000..040d303
--- /dev/null
+++ b/native/actions/lua-ranker.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
+
+#include <memory>
+#include <string>
+
+#include "actions/types.h"
+#include "utils/lua-utils.h"
+
+namespace libtextclassifier3 {
+
+// Lua backed action suggestion ranking.
+class ActionsSuggestionsLuaRanker : public LuaEnvironment {
+ public:
+ static std::unique_ptr<ActionsSuggestionsLuaRanker> Create(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response);
+
+ bool RankActions();
+
+ private:
+ explicit ActionsSuggestionsLuaRanker(
+ const Conversation& conversation, const std::string& ranker_code,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ ActionsSuggestionsResponse* response)
+ : conversation_(conversation),
+ ranker_code_(ranker_code),
+ actions_entity_data_schema_(actions_entity_data_schema),
+ annotations_entity_data_schema_(annotations_entity_data_schema),
+ response_(response) {}
+
+ bool Initialize();
+
+ // Reads ranking results from the lua stack.
+ int ReadActionsRanking();
+
+ const Conversation& conversation_;
+ const std::string& ranker_code_;
+ const reflection::Schema* actions_entity_data_schema_;
+ const reflection::Schema* annotations_entity_data_schema_;
+ ActionsSuggestionsResponse* response_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_LUA_RANKER_H_
diff --git a/actions/lua-ranker_test.cc b/native/actions/lua-ranker_test.cc
similarity index 100%
rename from actions/lua-ranker_test.cc
rename to native/actions/lua-ranker_test.cc
diff --git a/native/actions/ngram-model.cc b/native/actions/ngram-model.cc
new file mode 100644
index 0000000..fb3992c
--- /dev/null
+++ b/native/actions/ngram-model.cc
@@ -0,0 +1,225 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/ngram-model.h"
+
+#include <algorithm>
+
+#include "actions/feature-processor.h"
+#include "utils/hash/farmhash.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// An iterator to iterate over the initial tokens of the n-grams of a model.
+class FirstTokenIterator
+ : public std::iterator<std::random_access_iterator_tag,
+ /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
+ /*pointer=*/const uint32*,
+ /*reference=*/uint32&> {
+ public:
+ explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
+ int index)
+ : model_(model), index_(index) {}
+
+ FirstTokenIterator& operator++() {
+ index_++;
+ return *this;
+ }
+ FirstTokenIterator& operator+=(ptrdiff_t dist) {
+ index_ += dist;
+ return *this;
+ }
+ ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
+ return index_ - other_it.index_;
+ }
+ uint32 operator*() const {
+ const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
+ return (*model_->hashed_ngram_tokens())[token_offset];
+ }
+ int index() const { return index_; }
+
+ private:
+ const NGramLinearRegressionModel* model_;
+ int index_;
+};
+
+} // anonymous namespace
+
+std::unique_ptr<NGramModel> NGramModel::Create(
+ const UniLib* unilib, const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer) {
+ if (model == nullptr) {
+ return nullptr;
+ }
+ if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
+ TC3_LOG(ERROR) << "No tokenizer options specified.";
+ return nullptr;
+ }
+ return std::unique_ptr<NGramModel>(new NGramModel(unilib, model, tokenizer));
+}
+
+NGramModel::NGramModel(const UniLib* unilib,
+ const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer)
+ : model_(model) {
+ // Create new tokenizer if options are specified, reuse feature processor
+ // tokenizer otherwise.
+ if (model->tokenizer_options() != nullptr) {
+ owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
+ tokenizer_ = owned_tokenizer_.get();
+ } else {
+ tokenizer_ = tokenizer;
+ }
+}
+
+// Returns whether a given n-gram matches the token stream.
+bool NGramModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
+ const uint32* ngram_tokens,
+ size_t num_ngram_tokens, int max_skips) const {
+ int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
+ for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
+ if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
+ // Token matches. Advance both and reset the skip budget.
+ ++token_idx;
+ ++ngram_token_idx;
+ skip_remain = max_skips;
+ } else if (skip_remain > 0) {
+ // No match, but we have skips left, so just advance over the token.
+ ++token_idx;
+ skip_remain--;
+ } else {
+ // No match and we're out of skips. Reject.
+ return false;
+ }
+ }
+ return ngram_token_idx == num_ngram_tokens;
+}
+
+// Calculates the total number of skip-grams that can be created for a stream
+// with the given number of tokens.
+uint64 NGramModel::GetNumSkipGrams(int num_tokens, int max_ngram_length,
+ int max_skips) {
+ // Start with unigrams.
+ uint64 total = num_tokens;
+ for (int ngram_len = 2;
+ ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
+ // We can easily compute the expected length of the n-gram (with skips),
+ // but it doesn't account for the fact that they may be longer than the
+ // input and should be pruned.
+ // Instead, we iterate over the distribution of effective n-gram lengths
+ // and add each length individually.
+ const int num_gaps = ngram_len - 1;
+ const int len_min = ngram_len;
+ const int len_max = ngram_len + num_gaps * max_skips;
+ const int len_mid = (len_max + len_min) / 2;
+ for (int len_i = len_min; len_i <= len_max; ++len_i) {
+ if (len_i > num_tokens) continue;
+ const int num_configs_of_len_i =
+ len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
+ const int num_start_offsets = num_tokens - len_i + 1;
+ total += num_configs_of_len_i * num_start_offsets;
+ }
+ }
+ return total;
+}
+
+std::pair<int, int> NGramModel::GetFirstTokenMatches(uint32 token_hash) const {
+ const int num_ngrams = model_->ngram_weights()->size();
+ const auto start_it = FirstTokenIterator(model_, 0);
+ const auto end_it = FirstTokenIterator(model_, num_ngrams);
+ const int start = std::lower_bound(start_it, end_it, token_hash).index();
+ const int end = std::upper_bound(start_it, end_it, token_hash).index();
+ return std::make_pair(start, end);
+}
+
+bool NGramModel::Eval(const UnicodeText& text, float* score) const {
+ const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
+
+ // If we have no tokens, then just bail early.
+ if (raw_tokens.empty()) {
+ if (score != nullptr) {
+ *score = model_->default_token_weight();
+ }
+ return false;
+ }
+
+ // Hash the tokens.
+ std::vector<uint32> tokens;
+ tokens.reserve(raw_tokens.size());
+ for (const Token& raw_token : raw_tokens) {
+ tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
+ raw_token.value.length()));
+ }
+
+ // Calculate the total number of skip-grams that can be generated for the
+ // input text.
+ const uint64 num_candidates = GetNumSkipGrams(
+ tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
+
+ // For each token, see whether it denotes the start of an n-gram in the model.
+ int num_matches = 0;
+ float weight_matches = 0.f;
+ for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
+ const std::pair<int, int> ngram_range =
+ GetFirstTokenMatches(tokens[start_i]);
+ for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
+ ++ngram_idx) {
+ const uint16 ngram_tokens_begin =
+ (*model_->ngram_start_offsets())[ngram_idx];
+ const uint16 ngram_tokens_end =
+ (*model_->ngram_start_offsets())[ngram_idx + 1];
+ if (IsNGramMatch(
+ /*tokens=*/tokens.data() + start_i,
+ /*num_tokens=*/tokens.size() - start_i,
+ /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
+ ngram_tokens_begin,
+ /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
+ /*max_skips=*/model_->max_skips())) {
+ ++num_matches;
+ weight_matches += (*model_->ngram_weights())[ngram_idx];
+ }
+ }
+ }
+
+ // Calculate the score.
+ const int num_misses = num_candidates - num_matches;
+ const float internal_score =
+ (weight_matches + (model_->default_token_weight() * num_misses)) /
+ num_candidates;
+ if (score != nullptr) {
+ *score = internal_score;
+ }
+ return internal_score > model_->threshold();
+}
+
+bool NGramModel::EvalConversation(const Conversation& conversation,
+ const int num_messages) const {
+ for (int i = 1; i <= num_messages; i++) {
+ const std::string& message =
+ conversation.messages[conversation.messages.size() - i].text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+ // Run ngram linear regression model.
+ if (Eval(message_unicode)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/ngram-model.h b/native/actions/ngram-model.h
new file mode 100644
index 0000000..a9072cd
--- /dev/null
+++ b/native/actions/ngram-model.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+class NGramModel {
+ public:
+ static std::unique_ptr<NGramModel> Create(
+ const UniLib* unilib, const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer);
+
+ // Evaluates an n-gram linear regression model, and tests against the
+ // threshold. Returns true in case of a positive classification. The caller
+ // may also optionally query the score.
+ bool Eval(const UnicodeText& text, float* score = nullptr) const;
+
+ // Evaluates an n-gram linear regression model against all messages in a
+ // conversation and returns true in case of any positive classification.
+ bool EvalConversation(const Conversation& conversation,
+ const int num_messages) const;
+
+ // Exposed for testing only.
+ static uint64 GetNumSkipGrams(int num_tokens, int max_ngram_length,
+ int max_skips);
+
+ private:
+ NGramModel(const UniLib* unilib, const NGramLinearRegressionModel* model,
+ const Tokenizer* tokenizer);
+
+ // Returns the (begin,end] range of n-grams where the first hashed token
+ // matches the given value.
+ std::pair<int, int> GetFirstTokenMatches(uint32 token_hash) const;
+
+ // Returns whether a given n-gram matches the token stream.
+ bool IsNGramMatch(const uint32* tokens, size_t num_tokens,
+ const uint32* ngram_tokens, size_t num_ngram_tokens,
+ int max_skips) const;
+
+ const NGramLinearRegressionModel* model_;
+ const Tokenizer* tokenizer_;
+ std::unique_ptr<Tokenizer> owned_tokenizer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_NGRAM_MODEL_H_
diff --git a/actions/ranker.cc b/native/actions/ranker.cc
similarity index 100%
rename from actions/ranker.cc
rename to native/actions/ranker.cc
diff --git a/actions/ranker.h b/native/actions/ranker.h
similarity index 100%
rename from actions/ranker.h
rename to native/actions/ranker.h
diff --git a/actions/ranker_test.cc b/native/actions/ranker_test.cc
similarity index 100%
rename from actions/ranker_test.cc
rename to native/actions/ranker_test.cc
diff --git a/native/actions/regex-actions.cc b/native/actions/regex-actions.cc
new file mode 100644
index 0000000..7d5a4b2
--- /dev/null
+++ b/native/actions/regex-actions.cc
@@ -0,0 +1,262 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/regex-actions.h"
+
+#include "actions/utils.h"
+#include "utils/base/logging.h"
+#include "utils/regex-match.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Creates an annotation from a regex capturing group.
+bool FillAnnotationFromMatchGroup(
+ const UniLib::RegexMatcher* matcher,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const std::string& group_match_text, const int message_index,
+ ActionSuggestionAnnotation* annotation) {
+ if (group->annotation_name() != nullptr ||
+ group->annotation_type() != nullptr) {
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan span = {matcher->Start(group->group_id(), &status),
+ matcher->End(group->group_id(), &status)};
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
+ return false;
+ }
+ return FillAnnotationFromCapturingMatch(span, group, message_index,
+ group_match_text, annotation);
+ }
+ return true;
+}
+
+} // namespace
+
+bool RegexActions::InitializeRules(
+ const RulesModel* rules, const RulesModel* low_confidence_rules,
+ const TriggeringPreconditions* triggering_preconditions_overlay,
+ ZlibDecompressor* decompressor) {
+ if (rules != nullptr) {
+ if (!InitializeRulesModel(rules, decompressor, &rules_)) {
+ TC3_LOG(ERROR) << "Could not initialize action rules.";
+ return false;
+ }
+ }
+
+ if (low_confidence_rules != nullptr) {
+ if (!InitializeRulesModel(low_confidence_rules, decompressor,
+ &low_confidence_rules_)) {
+ TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
+ return false;
+ }
+ }
+
+ // Extend by rules provided by the overwrite.
+ // NOTE: The rules from the original models are *not* cleared.
+ if (triggering_preconditions_overlay != nullptr &&
+ triggering_preconditions_overlay->low_confidence_rules() != nullptr) {
+ // These rules are optionally compressed, but separately.
+ std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
+ ZlibDecompressor::Instance();
+ if (overwrite_decompressor == nullptr) {
+ TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
+ return false;
+ }
+ if (!InitializeRulesModel(
+ triggering_preconditions_overlay->low_confidence_rules(),
+ overwrite_decompressor.get(), &low_confidence_rules_)) {
+ TC3_LOG(ERROR)
+ << "Could not initialize low confidence rules from overwrite.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RegexActions::InitializeRulesModel(
+ const RulesModel* rules, ZlibDecompressor* decompressor,
+ std::vector<CompiledRule>* compiled_rules) const {
+ for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(
+ unilib_, rule->pattern(), rule->compressed_pattern(),
+ rules->lazy_regex_compilation(), decompressor);
+ if (compiled_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Failed to load rule pattern.";
+ return false;
+ }
+
+ // Check whether there is a check on the output.
+ std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
+ if (rule->output_pattern() != nullptr ||
+ rule->compressed_output_pattern() != nullptr) {
+ compiled_output_pattern = UncompressMakeRegexPattern(
+ unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
+ rules->lazy_regex_compilation(), decompressor);
+ if (compiled_output_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Failed to load rule output pattern.";
+ return false;
+ }
+ }
+
+ compiled_rules->emplace_back(rule, std::move(compiled_pattern),
+ std::move(compiled_output_pattern));
+ }
+
+ return true;
+}
+
+bool RegexActions::IsLowConfidenceInput(
+ const Conversation& conversation, const int num_messages,
+ std::vector<const UniLib::RegexPattern*>* post_check_rules) const {
+ for (int i = 1; i <= num_messages; i++) {
+ const std::string& message =
+ conversation.messages[conversation.messages.size() - i].text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+ for (int low_confidence_rule = 0;
+ low_confidence_rule < low_confidence_rules_.size();
+ low_confidence_rule++) {
+ const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.pattern->Matcher(message_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ // Rule only applies to input-output pairs, so defer the check.
+ if (rule.output_pattern != nullptr) {
+ post_check_rules->push_back(rule.output_pattern.get());
+ continue;
+ }
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool RegexActions::FilterConfidenceOutput(
+ const std::vector<const UniLib::RegexPattern*>& post_check_rules,
+ std::vector<ActionSuggestion>* actions) const {
+ if (post_check_rules.empty() || actions->empty()) {
+ return true;
+ }
+ std::vector<ActionSuggestion> filtered_text_replies;
+ for (const ActionSuggestion& action : *actions) {
+ if (action.response_text.empty()) {
+ filtered_text_replies.push_back(action);
+ continue;
+ }
+ bool passes_post_check = true;
+ const UnicodeText text_reply_unicode(
+ UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
+ for (const UniLib::RegexPattern* post_check_rule : post_check_rules) {
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ post_check_rule->Matcher(text_reply_unicode);
+ if (matcher == nullptr) {
+ TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
+ return false;
+ }
+ int status = UniLib::RegexMatcher::kNoError;
+ if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
+ passes_post_check = false;
+ break;
+ }
+ }
+ if (passes_post_check) {
+ filtered_text_replies.push_back(action);
+ }
+ }
+ *actions = std::move(filtered_text_replies);
+ return true;
+}
+
+bool RegexActions::SuggestActions(
+ const Conversation& conversation,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ std::vector<ActionSuggestion>* actions) const {
+ // Create actions based on rules checking the last message.
+ const int message_index = conversation.messages.size() - 1;
+ const std::string& message = conversation.messages.back().text;
+ const UnicodeText message_unicode(
+ UTF8ToUnicodeText(message, /*do_copy=*/false));
+ for (const CompiledRule& rule : rules_) {
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.pattern->Matcher(message_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ for (const RulesModel_::RuleActionSpec* rule_action :
+ *rule.rule->actions()) {
+ const ActionSuggestionSpec* action = rule_action->action();
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder != nullptr ? entity_data_builder->NewRoot()
+ : nullptr;
+
+ // Add entity data from rule capturing groups.
+ if (rule_action->capturing_group() != nullptr) {
+ for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
+ *rule_action->capturing_group()) {
+ Optional<std::string> group_match_text =
+ GetCapturingGroupText(matcher.get(), group->group_id());
+ if (!group_match_text.has_value()) {
+ // The group was not part of the match, ignore and continue.
+ continue;
+ }
+
+ UnicodeText normalized_group_match_text =
+ NormalizeMatchText(unilib_, group, group_match_text.value());
+
+ if (!MergeEntityDataFromCapturingMatch(
+ group, normalized_group_match_text.ToUTF8String(),
+ entity_data.get())) {
+ TC3_LOG(ERROR)
+ << "Could not merge entity data from a capturing match.";
+ return false;
+ }
+
+ // Create a text annotation for the group span.
+ ActionSuggestionAnnotation annotation;
+ if (FillAnnotationFromMatchGroup(matcher.get(), group,
+ group_match_text.value(),
+ message_index, &annotation)) {
+ annotations.push_back(annotation);
+ }
+
+ // Create text reply.
+ SuggestTextRepliesFromCapturingMatch(
+ entity_data_builder, group, normalized_group_match_text,
+ smart_reply_action_type_, actions);
+ }
+ }
+
+ if (action != nullptr) {
+ ActionSuggestion suggestion;
+ suggestion.annotations = annotations;
+ FillSuggestionFromSpec(action, entity_data.get(), &suggestion);
+ actions->push_back(suggestion);
+ }
+ }
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/regex-actions.h b/native/actions/regex-actions.h
new file mode 100644
index 0000000..871f08b
--- /dev/null
+++ b/native/actions/regex-actions.h
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_REGEX_ACTIONS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_REGEX_ACTIONS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Regular expression backed actions suggestions.
+class RegexActions {
+ public:
+ explicit RegexActions(const UniLib* unilib,
+ const std::string& smart_reply_action_type)
+ : unilib_(*unilib), smart_reply_action_type_(smart_reply_action_type) {}
+
+ // Decompresses and initializes all rules in a model.
+ bool InitializeRules(
+ const RulesModel* rules, const RulesModel* low_confidence_rules,
+ const TriggeringPreconditions* triggering_preconditions_overlay,
+ ZlibDecompressor* decompressor);
+
+ // Checks whether the input triggers the low confidence rules.
+ bool IsLowConfidenceInput(
+ const Conversation& conversation, const int num_messages,
+ std::vector<const UniLib::RegexPattern*>* post_check_rules) const;
+
+ // Checks and filters suggestions triggering the low confidence post checks.
+ bool FilterConfidenceOutput(
+ const std::vector<const UniLib::RegexPattern*>& post_check_rules,
+ std::vector<ActionSuggestion>* actions) const;
+
+ // Suggests actions for a conversation from a message stream using the regex
+ // rules.
+ bool SuggestActions(const Conversation& conversation,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ std::vector<ActionSuggestion>* actions) const;
+
+ private:
+ struct CompiledRule {
+ const RulesModel_::RegexRule* rule;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ std::unique_ptr<UniLib::RegexPattern> output_pattern;
+ CompiledRule(const RulesModel_::RegexRule* rule,
+ std::unique_ptr<UniLib::RegexPattern> pattern,
+ std::unique_ptr<UniLib::RegexPattern> output_pattern)
+ : rule(rule),
+ pattern(std::move(pattern)),
+ output_pattern(std::move(output_pattern)) {}
+ };
+
+ // Decompresses and initializes a set of regular expression rules.
+ bool InitializeRulesModel(const RulesModel* rules,
+ ZlibDecompressor* decompressor,
+ std::vector<CompiledRule>* compiled_rules) const;
+
+ const UniLib& unilib_;
+ const std::string smart_reply_action_type_;
+ std::vector<CompiledRule> rules_, low_confidence_rules_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_REGEX_ACTIONS_H_
diff --git a/native/actions/test-utils.cc b/native/actions/test-utils.cc
new file mode 100644
index 0000000..9b003dd
--- /dev/null
+++ b/native/actions/test-utils.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/test-utils.h"
+
+namespace libtextclassifier3 {
+
+std::string TestEntityDataSchema() {
+ // Create fake entity data schema meta data.
+ // Cannot use object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("greeting"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("location"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("person"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/2,
+ /*offset=*/8)};
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("EntityData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+
+ return std::string(
+ reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
+ schema_builder.GetSize());
+}
+
+void SetTestEntityDataSchema(ActionsModelT* test_model) {
+ const std::string serialized_schema = TestEntityDataSchema();
+
+ test_model->actions_entity_data_schema.assign(
+ serialized_schema.data(),
+ serialized_schema.data() + serialized_schema.size());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/test-utils.h b/native/actions/test-utils.h
new file mode 100644
index 0000000..c05d6a9
--- /dev/null
+++ b/native/actions/test-utils.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
+
+#include <string>
+
+#include "actions/actions_model_generated.h"
+#include "utils/flatbuffers.h"
+#include "gmock/gmock.h"
+
+namespace libtextclassifier3 {
+
+using testing::ExplainMatchResult;
+using testing::Value;
+
+// Create test entity data schema.
+std::string TestEntityDataSchema();
+void SetTestEntityDataSchema(ActionsModelT* test_model);
+
+MATCHER_P(IsActionOfType, type, "") { return Value(arg.type, type); }
+MATCHER_P(IsSmartReply, response_text, "") {
+ return ExplainMatchResult(IsActionOfType("text_reply"), arg,
+ result_listener) &&
+ Value(arg.response_text, response_text);
+}
+MATCHER_P(IsSpan, span, "") {
+ return Value(arg.first, span.first) && Value(arg.second, span.second);
+}
+MATCHER_P3(IsActionSuggestionAnnotation, name, text, span, "") {
+ return Value(arg.name, name) && Value(arg.span.text, text) &&
+ ExplainMatchResult(IsSpan(span), arg.span.span, result_listener);
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_TEST_UTILS_H_
diff --git a/native/actions/types.h b/native/actions/types.h
new file mode 100644
index 0000000..e7d384f
--- /dev/null
+++ b/native/actions/types.h
@@ -0,0 +1,136 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "actions/actions-entity-data_generated.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// A text span in the conversation.
+struct MessageTextSpan {
+ explicit MessageTextSpan() = default;
+ MessageTextSpan(const int message_index, const CodepointSpan span,
+ const std::string& text)
+ : message_index(message_index), span(span), text(text) {}
+
+ // The referenced message.
+ // -1 if not referencing a particular message in the provided input.
+ int message_index = kInvalidIndex;
+
+ // The span within the reference message.
+ // (-1, -1) if not referencing a particular location.
+ CodepointSpan span = CodepointSpan{kInvalidIndex, kInvalidIndex};
+
+ // The span text.
+ std::string text;
+};
+
+// An entity associated with an action.
+struct ActionSuggestionAnnotation {
+ MessageTextSpan span;
+ ClassificationResult entity;
+
+ // Optional annotation name.
+ std::string name;
+};
+
+// Action suggestion that contains a response text and the type of the response.
+struct ActionSuggestion {
+ // Text of the action suggestion.
+ std::string response_text;
+
+ // Type of the action suggestion.
+ std::string type;
+
+ // Score.
+ float score = 0.f;
+
+ // Priority score for internal conflict resolution.
+ float priority_score = 0.f;
+
+ // The associated annotations.
+ std::vector<ActionSuggestionAnnotation> annotations;
+
+ // Extras information.
+ std::string serialized_entity_data;
+
+ const ActionsEntityData* entity_data() {
+ return LoadAndVerifyFlatbuffer<ActionsEntityData>(
+ serialized_entity_data.data(), serialized_entity_data.size());
+ }
+};
+
+// Actions suggestions result containing meta - information and the suggested
+// actions.
+struct ActionsSuggestionsResponse {
+ // The sensitivity assessment.
+ float sensitivity_score = -1.f;
+ float triggering_score = -1.f;
+
+ // Whether the output was suppressed by the sensitivity threshold.
+ bool output_filtered_sensitivity = false;
+
+ // Whether the output was suppressed by the triggering score threshold.
+ bool output_filtered_min_triggering_score = false;
+
+ // Whether the output was suppressed by the low confidence patterns.
+ bool output_filtered_low_confidence = false;
+
+ // Whether the output was suppressed due to locale mismatch.
+ bool output_filtered_locale_mismatch = false;
+
+ // The suggested actions.
+ std::vector<ActionSuggestion> actions;
+};
+
+// Represents a single message in the conversation.
+struct ConversationMessage {
+ // User ID distinguishing the user from other users in the conversation.
+ int user_id = 0;
+
+ // Text of the message.
+ std::string text;
+
+ // Reference time of this message.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ // Annotations on the text.
+ std::vector<AnnotatedSpan> annotations;
+
+ // Comma-separated list of BCP 47 language tags of the message.
+ std::string detected_text_language_tags;
+};
+
+// Conversation between multiple users.
+struct Conversation {
+ // Sequence of messages that were exchanged in the conversation.
+ std::vector<ConversationMessage> messages;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_TYPES_H_
diff --git a/native/actions/utils.cc b/native/actions/utils.cc
new file mode 100644
index 0000000..96f6f1f
--- /dev/null
+++ b/native/actions/utils.cc
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/utils.h"
+
+#include "utils/base/logging.h"
+#include "utils/normalization.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
+ ReflectiveFlatbuffer* entity_data,
+ ActionSuggestion* suggestion) {
+ if (action != nullptr) {
+ suggestion->score = action->score();
+ suggestion->priority_score = action->priority_score();
+ if (action->type() != nullptr) {
+ suggestion->type = action->type()->str();
+ }
+ if (action->response_text() != nullptr) {
+ suggestion->response_text = action->response_text()->str();
+ }
+ if (action->serialized_entity_data() != nullptr) {
+ TC3_CHECK_NE(entity_data, nullptr);
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(action->serialized_entity_data()->data(),
+ action->serialized_entity_data()->size()));
+ }
+ if (action->entity_data() != nullptr) {
+ TC3_CHECK_NE(entity_data, nullptr);
+ entity_data->MergeFrom(
+ reinterpret_cast<const flatbuffers::Table*>(action->entity_data()));
+ }
+ }
+ if (entity_data != nullptr && entity_data->HasExplicitlySetFields()) {
+ suggestion->serialized_entity_data = entity_data->Serialize();
+ }
+}
+
+void SuggestTextRepliesFromCapturingMatch(
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText& match_text, const std::string& smart_reply_action_type,
+ std::vector<ActionSuggestion>* actions) {
+ if (group->text_reply() != nullptr) {
+ ActionSuggestion suggestion;
+ suggestion.response_text = match_text.ToUTF8String();
+ suggestion.type = smart_reply_action_type;
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder != nullptr ? entity_data_builder->NewRoot()
+ : nullptr;
+ FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
+ actions->push_back(suggestion);
+ }
+}
+
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ StringPiece match_text) {
+ UnicodeText normalized_match_text =
+ UTF8ToUnicodeText(match_text, /*do_copy=*/false);
+ if (group->normalization_options() != nullptr) {
+ normalized_match_text = NormalizeText(
+ unilib, group->normalization_options(), normalized_match_text);
+ }
+ return normalized_match_text;
+}
+
+bool FillAnnotationFromCapturingMatch(
+ const CodepointSpan& span,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const int message_index, StringPiece match_text,
+ ActionSuggestionAnnotation* annotation) {
+ if (group->annotation_name() == nullptr &&
+ group->annotation_type() == nullptr) {
+ return false;
+ }
+ annotation->span.span = span;
+ annotation->span.message_index = message_index;
+ annotation->span.text = match_text.ToString();
+ if (group->annotation_name() != nullptr) {
+ annotation->name = group->annotation_name()->str();
+ }
+ if (group->annotation_type() != nullptr) {
+ annotation->entity.collection = group->annotation_type()->str();
+ }
+ return true;
+}
+
+bool MergeEntityDataFromCapturingMatch(
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ StringPiece match_text, ReflectiveFlatbuffer* buffer) {
+ if (group->entity_field() != nullptr) {
+ if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
+ TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+ if (group->entity_data() != nullptr) {
+ if (!buffer->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
+ group->entity_data()))) {
+ TC3_LOG(ERROR) << "Could not set entity data for capturing match.";
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/utils.h b/native/actions/utils.h
new file mode 100644
index 0000000..820c79d
--- /dev/null
+++ b/native/actions/utils.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utils for creating action suggestions.
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "actions/actions_model_generated.h"
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Fills an action suggestion from a template.
+void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
+ ReflectiveFlatbuffer* entity_data,
+ ActionSuggestion* suggestion);
+
+// Creates text replies from capturing matches.
+void SuggestTextRepliesFromCapturingMatch(
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const UnicodeText& match_text, const std::string& smart_reply_action_type,
+ std::vector<ActionSuggestion>* actions);
+
+// Applies normalization to a capturing match.
+UnicodeText NormalizeMatchText(
+ const UniLib& unilib,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ StringPiece match_text);
+
+// Fills the fields in an annotation from a capturing match.
+bool FillAnnotationFromCapturingMatch(
+ const CodepointSpan& span,
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ const int message_index, StringPiece match_text,
+ ActionSuggestionAnnotation* annotation);
+
+// Merges entity data from a capturing match.
+// Parses and sets values from the text and merges fixed data.
+bool MergeEntityDataFromCapturingMatch(
+ const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
+ StringPiece match_text, ReflectiveFlatbuffer* buffer);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_UTILS_H_
diff --git a/native/actions/zlib-utils.cc b/native/actions/zlib-utils.cc
new file mode 100644
index 0000000..c8ad4e7
--- /dev/null
+++ b/native/actions/zlib-utils.cc
@@ -0,0 +1,175 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/intents/zlib-utils.h"
+#include "utils/resources.h"
+
+namespace libtextclassifier3 {
+
+// Compress rule fields in the model.
+bool CompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC3_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->regex_rule.size(); i++) {
+ RulesModel_::RegexRuleT* rule = model->rules->regex_rule[i].get();
+ rule->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
+ rule->pattern.clear();
+ }
+ }
+
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->regex_rule.size(); i++) {
+ RulesModel_::RegexRuleT* rule =
+ model->low_confidence_rules->regex_rule[i].get();
+ if (!rule->pattern.empty()) {
+ rule->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->pattern,
+ rule->compressed_pattern.get());
+ rule->pattern.clear();
+ }
+ if (!rule->output_pattern.empty()) {
+ rule->compressed_output_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(rule->output_pattern,
+ rule->compressed_output_pattern.get());
+ rule->output_pattern.clear();
+ }
+ }
+ }
+
+ if (!model->lua_actions_script.empty()) {
+ model->compressed_lua_actions_script.reset(new CompressedBufferT);
+ zlib_compressor->Compress(model->lua_actions_script,
+ model->compressed_lua_actions_script.get());
+ }
+
+ if (model->ranking_options != nullptr &&
+ !model->ranking_options->lua_ranking_script.empty()) {
+ model->ranking_options->compressed_lua_ranking_script.reset(
+ new CompressedBufferT);
+ zlib_compressor->Compress(
+ model->ranking_options->lua_ranking_script,
+ model->ranking_options->compressed_lua_ranking_script.get());
+ }
+
+ // Compress resources.
+ if (model->resources != nullptr) {
+ CompressResources(model->resources.get());
+ }
+
+ // Compress intent generator.
+ if (model->android_intent_options != nullptr) {
+ CompressIntentModel(model->android_intent_options.get());
+ }
+
+ return true;
+}
+
+bool DecompressActionsModel(ActionsModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->rules != nullptr) {
+ for (int i = 0; i < model->rules->regex_rule.size(); i++) {
+ RulesModel_::RegexRuleT* rule = model->rules->regex_rule[i].get();
+ if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
+ &rule->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ rule->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ // Decompress low confidence rules.
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->regex_rule.size(); i++) {
+ RulesModel_::RegexRuleT* rule =
+ model->low_confidence_rules->regex_rule[i].get();
+ if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
+ &rule->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ if (!zlib_decompressor->MaybeDecompress(
+ rule->compressed_output_pattern.get(), &rule->output_pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ rule->compressed_pattern.reset(nullptr);
+ rule->compressed_output_pattern.reset(nullptr);
+ }
+ }
+
+ if (!zlib_decompressor->MaybeDecompress(
+ model->compressed_lua_actions_script.get(),
+ &model->lua_actions_script)) {
+ TC3_LOG(ERROR) << "Cannot decompress actions script.";
+ return false;
+ }
+
+ if (model->ranking_options != nullptr &&
+ !zlib_decompressor->MaybeDecompress(
+ model->ranking_options->compressed_lua_ranking_script.get(),
+ &model->ranking_options->lua_ranking_script)) {
+ TC3_LOG(ERROR) << "Cannot decompress actions script.";
+ return false;
+ }
+
+ return true;
+}
+
+std::string CompressSerializedActionsModel(const std::string& model) {
+ std::unique_ptr<ActionsModelT> unpacked_model =
+ UnPackActionsModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(CompressActionsModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
+ const CompressedBuffer* compressed_buffer,
+ ZlibDecompressor* decompressor, std::string* out) {
+ if (uncompressed_buffer == nullptr && compressed_buffer == nullptr) {
+ out->clear();
+ return true;
+ }
+
+ return decompressor->MaybeDecompressOptionallyCompressedBuffer(
+ uncompressed_buffer, compressed_buffer, out);
+}
+
+} // namespace libtextclassifier3
diff --git a/actions/zlib-utils.h b/native/actions/zlib-utils.h
similarity index 100%
rename from actions/zlib-utils.h
rename to native/actions/zlib-utils.h
diff --git a/native/actions/zlib-utils_test.cc b/native/actions/zlib-utils_test.cc
new file mode 100644
index 0000000..75e4c78
--- /dev/null
+++ b/native/actions/zlib-utils_test.cc
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/zlib-utils.h"
+
+#include <memory>
+
+#include "actions/actions_model_generated.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+using testing::Field;
+using testing::Pointee;
+
+TEST(ActionsZlibUtilsTest, CompressModel) {
+ ActionsModelT model;
+ constexpr char kTestPattern1[] = "this is a test pattern";
+ constexpr char kTestPattern2[] = "this is a second test pattern";
+ constexpr char kTestOutputPattern[] = "this is an output pattern";
+ model.rules.reset(new RulesModelT);
+ model.rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ model.rules->regex_rule.back()->pattern = kTestPattern1;
+ model.rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
+ model.rules->regex_rule.back()->pattern = kTestPattern2;
+ model.rules->regex_rule.back()->output_pattern = kTestOutputPattern;
+
+ // Compress the model.
+ EXPECT_TRUE(CompressActionsModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ const auto is_empty_pattern =
+ Pointee(Field(&libtextclassifier3::RulesModel_::RegexRuleT::pattern,
+ testing::IsEmpty()));
+ EXPECT_THAT(model.rules->regex_rule,
+ ElementsAre(is_empty_pattern, is_empty_pattern));
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model));
+ const ActionsModel* compressed_model = GetActionsModel(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->regex_rule()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern1);
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->rules()->regex_rule()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, kTestPattern2);
+ EXPECT_TRUE(DecompressActionsModel(&model));
+ EXPECT_EQ(model.rules->regex_rule[0]->pattern, kTestPattern1);
+ EXPECT_EQ(model.rules->regex_rule[1]->pattern, kTestPattern2);
+ EXPECT_EQ(model.rules->regex_rule[1]->output_pattern, kTestOutputPattern);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
new file mode 100644
index 0000000..6ee983f
--- /dev/null
+++ b/native/annotator/annotator.cc
@@ -0,0 +1,2946 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <iterator>
+#include <numeric>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/checksum.h"
+#include "utils/i18n/locale.h"
+#include "utils/math/softmax.h"
+#include "utils/normalization.h"
+#include "utils/optional.h"
+#include "utils/regex-match.h"
+#include "utils/strings/numbers.h"
+#include "utils/strings/split.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3 {
+
+using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
+
+const std::string& Annotator::kPhoneCollection =
+ *[]() { return new std::string("phone"); }();
+const std::string& Annotator::kAddressCollection =
+ *[]() { return new std::string("address"); }();
+const std::string& Annotator::kDateCollection =
+ *[]() { return new std::string("date"); }();
+const std::string& Annotator::kUrlCollection =
+ *[]() { return new std::string("url"); }();
+const std::string& Annotator::kEmailCollection =
+ *[]() { return new std::string("email"); }();
+
+namespace {
+const Model* LoadAndVerifyModel(const void* addr, int size) {
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
+ if (VerifyModelBuffer(verifier)) {
+ return GetModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
+ int size) {
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
+ if (VerifyPersonNameModelBuffer(verifier)) {
+ return GetPersonNameModel(addr);
+ } else {
+ return nullptr;
+ }
+}
+
+// If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
+// create a new instance, assign ownership to owned_lib, and return it.
+const UniLib* MaybeCreateUnilib(const UniLib* lib,
+ std::unique_ptr<UniLib>* owned_lib) {
+ if (lib) {
+ return lib;
+ } else {
+ owned_lib->reset(new UniLib);
+ return owned_lib->get();
+ }
+}
+
+// As above, but for CalendarLib.
+const CalendarLib* MaybeCreateCalendarlib(
+ const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
+ if (lib) {
+ return lib;
+ } else {
+ owned_lib->reset(new CalendarLib);
+ return owned_lib->get();
+ }
+}
+
+// Returns whether the provided input is valid:
+// * Valid utf8 text.
+// * Sane span indices.
+bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan span) {
+ if (!context.is_valid()) {
+ return false;
+ }
+ return (span.first >= 0 && span.first < span.second &&
+ span.second <= context.size_codepoints());
+}
+
+std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
+ const flatbuffers::Vector<int32_t>* ints) {
+ if (ints == nullptr) {
+ return {};
+ }
+ std::unordered_set<char32> ints_set;
+ for (auto value : *ints) {
+ ints_set.insert(static_cast<char32>(value));
+ }
+ return ints_set;
+}
+
+DateAnnotationOptions ToDateAnnotationOptions(
+ const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options,
+ const std::string& reference_timezone, const int64 reference_time_ms_utc) {
+ DateAnnotationOptions result_annotation_options;
+ result_annotation_options.base_timestamp_millis = reference_time_ms_utc;
+ result_annotation_options.reference_timezone = reference_timezone;
+ if (fb_annotation_options != nullptr) {
+ result_annotation_options.enable_special_day_offset =
+ fb_annotation_options->enable_special_day_offset();
+ result_annotation_options.merge_adjacent_components =
+ fb_annotation_options->merge_adjacent_components();
+ result_annotation_options.enable_date_range =
+ fb_annotation_options->enable_date_range();
+ result_annotation_options.include_preposition =
+ fb_annotation_options->include_preposition();
+ if (fb_annotation_options->extra_requested_dates() != nullptr) {
+ for (const auto& extra_requested_date :
+ *fb_annotation_options->extra_requested_dates()) {
+ result_annotation_options.extra_requested_dates.push_back(
+ extra_requested_date->str());
+ }
+ }
+ if (fb_annotation_options->ignored_spans() != nullptr) {
+ for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) {
+ result_annotation_options.ignored_spans.push_back(ignored_span->str());
+ }
+ }
+ }
+ return result_annotation_options;
+}
+
+} // namespace
+
+tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
+ if (!selection_interpreter_) {
+ TC3_CHECK(selection_executor_);
+ selection_interpreter_ = selection_executor_->CreateInterpreter();
+ if (!selection_interpreter_) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return selection_interpreter_.get();
+}
+
+tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
+ if (!classification_interpreter_) {
+ TC3_CHECK(classification_executor_);
+ classification_interpreter_ = classification_executor_->CreateInterpreter();
+ if (!classification_interpreter_) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
+ }
+ }
+ return classification_interpreter_.get();
+}
+
+std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ const Model* model = LoadAndVerifyModel(buffer, size);
+ if (model == nullptr) {
+ return nullptr;
+ }
+
+ auto classifier =
+ std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ if (!(*mmap)->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+
+ const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
+ (*mmap)->handle().num_bytes());
+ if (!model) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+
+ auto classifier = std::unique_ptr<Annotator>(
+ new Annotator(mmap, model, unilib, calendarlib));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ if (!(*mmap)->handle().ok()) {
+ TC3_VLOG(1) << "Mmap failed.";
+ return nullptr;
+ }
+
+ const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
+ (*mmap)->handle().num_bytes());
+ if (model == nullptr) {
+ TC3_LOG(ERROR) << "Model verification failed.";
+ return nullptr;
+ }
+
+ auto classifier = std::unique_ptr<Annotator>(
+ new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
+ if (!classifier->IsInitialized()) {
+ return nullptr;
+ }
+
+ return classifier;
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
+ int fd, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
+std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
+ const UniLib* unilib,
+ const CalendarLib* calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return FromScopedMmap(&mmap, unilib, calendarlib);
+}
+
+std::unique_ptr<Annotator> Annotator::FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
+}
+
+Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ const UniLib* unilib, const CalendarLib* calendarlib)
+ : model_(model),
+ mmap_(std::move(*mmap)),
+ owned_unilib_(nullptr),
+ unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
+ owned_calendarlib_(nullptr),
+ calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
+ ValidateAndInitialize();
+}
+
+Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib)
+ : model_(model),
+ mmap_(std::move(*mmap)),
+ owned_unilib_(std::move(unilib)),
+ unilib_(owned_unilib_.get()),
+ owned_calendarlib_(std::move(calendarlib)),
+ calendarlib_(owned_calendarlib_.get()) {
+ ValidateAndInitialize();
+}
+
+Annotator::Annotator(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib)
+ : model_(model),
+ owned_unilib_(nullptr),
+ unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
+ owned_calendarlib_(nullptr),
+ calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
+ ValidateAndInitialize();
+}
+
+void Annotator::ValidateAndInitialize() {
+ initialized_ = false;
+
+ if (model_ == nullptr) {
+ TC3_LOG(ERROR) << "No model specified.";
+ return;
+ }
+
+ const bool model_enabled_for_annotation =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
+ const bool model_enabled_for_classification =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION));
+ const bool model_enabled_for_selection =
+ (model_->triggering_options() != nullptr &&
+ (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
+
+ // Annotation requires the selection model.
+ if (model_enabled_for_annotation || model_enabled_for_selection) {
+ if (!model_->selection_options()) {
+ TC3_LOG(ERROR) << "No selection options.";
+ return;
+ }
+ if (!model_->selection_feature_options()) {
+ TC3_LOG(ERROR) << "No selection feature options.";
+ return;
+ }
+ if (!model_->selection_feature_options()->bounds_sensitive_features()) {
+ TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->selection_model()) {
+ TC3_LOG(ERROR) << "No selection model.";
+ return;
+ }
+ selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
+ if (!selection_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize selection executor.";
+ return;
+ }
+ selection_feature_processor_.reset(
+ new FeatureProcessor(model_->selection_feature_options(), unilib_));
+ }
+
+ // Annotation requires the classification model for conflict resolution and
+ // scoring.
+ // Selection requires the classification model for conflict resolution.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->classification_options()) {
+ TC3_LOG(ERROR) << "No classification options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()) {
+ TC3_LOG(ERROR) << "No classification feature options.";
+ return;
+ }
+
+ if (!model_->classification_feature_options()
+ ->bounds_sensitive_features()) {
+ TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
+ return;
+ }
+ if (!model_->classification_model()) {
+ TC3_LOG(ERROR) << "No clf model.";
+ return;
+ }
+
+ classification_executor_ =
+ ModelExecutor::FromBuffer(model_->classification_model());
+ if (!classification_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize classification executor.";
+ return;
+ }
+
+ classification_feature_processor_.reset(new FeatureProcessor(
+ model_->classification_feature_options(), unilib_));
+ }
+
+ // The embeddings need to be specified if the model is to be used for
+ // classification or selection.
+ if (model_enabled_for_annotation || model_enabled_for_classification ||
+ model_enabled_for_selection) {
+ if (!model_->embedding_model()) {
+ TC3_LOG(ERROR) << "No embedding model.";
+ return;
+ }
+
+ // Check that the embedding size of the selection and classification model
+ // matches, as they are using the same embeddings.
+ if (model_enabled_for_selection &&
+ (model_->selection_feature_options()->embedding_size() !=
+ model_->classification_feature_options()->embedding_size() ||
+ model_->selection_feature_options()->embedding_quantization_bits() !=
+ model_->classification_feature_options()
+ ->embedding_quantization_bits())) {
+ TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
+ return;
+ }
+
+ embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
+ model_->embedding_model(),
+ model_->classification_feature_options()->embedding_size(),
+ model_->classification_feature_options()->embedding_quantization_bits(),
+ model_->embedding_pruning_mask());
+ if (!embedding_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize embedding executor.";
+ return;
+ }
+ }
+
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ if (model_->regex_model()) {
+ if (!InitializeRegexModel(decompressor.get())) {
+ TC3_LOG(ERROR) << "Could not initialize regex model.";
+ return;
+ }
+ }
+ if (model_->grammar_datetime_model() &&
+ model_->grammar_datetime_model()->datetime_rules()) {
+ cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator(
+ unilib_,
+ /*tokenizer_options=*/
+ model_->grammar_datetime_model()->grammar_tokenizer_options(),
+ calendarlib_,
+ /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(),
+ model_->grammar_datetime_model()->target_classification_score(),
+ model_->grammar_datetime_model()->priority_score()));
+ if (!cfg_datetime_parser_) {
+ TC3_LOG(ERROR) << "Could not initialize context free grammar based "
+ "datetime parser.";
+ return;
+ }
+ }
+
+ if (model_->datetime_model()) {
+ datetime_parser_ = DatetimeParser::Instance(
+ model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
+ if (!datetime_parser_) {
+ TC3_LOG(ERROR) << "Could not initialize datetime parser.";
+ return;
+ }
+ }
+
+ if (model_->output_options()) {
+ if (model_->output_options()->filtered_collections_annotation()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_annotation()) {
+ filtered_collections_annotation_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_classification()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_classification()) {
+ filtered_collections_classification_.insert(collection->str());
+ }
+ }
+ if (model_->output_options()->filtered_collections_selection()) {
+ for (const auto collection :
+ *model_->output_options()->filtered_collections_selection()) {
+ filtered_collections_selection_.insert(collection->str());
+ }
+ }
+ }
+
+ if (model_->number_annotator_options() &&
+ model_->number_annotator_options()->enabled()) {
+ number_annotator_.reset(
+ new NumberAnnotator(model_->number_annotator_options(), unilib_));
+ }
+
+ if (model_->money_parsing_options()) {
+ money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
+ model_->money_parsing_options()->separators());
+ }
+
+ if (model_->duration_annotator_options() &&
+ model_->duration_annotator_options()->enabled()) {
+ duration_annotator_.reset(
+ new DurationAnnotator(model_->duration_annotator_options(),
+ selection_feature_processor_.get(), unilib_));
+ }
+
+ if (model_->entity_data_schema()) {
+ entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model_->entity_data_schema()->Data(),
+ model_->entity_data_schema()->size());
+ if (entity_data_schema_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not load entity data schema data.";
+ return;
+ }
+
+ entity_data_builder_.reset(
+ new ReflectiveFlatbufferBuilder(entity_data_schema_));
+ } else {
+ entity_data_schema_ = nullptr;
+ entity_data_builder_ = nullptr;
+ }
+
+ if (model_->grammar_model()) {
+ grammar_annotator_.reset(new GrammarAnnotator(
+ unilib_, model_->grammar_model(), entity_data_builder_.get()));
+ }
+
+ if (model_->triggering_locales() &&
+ !ParseLocales(model_->triggering_locales()->c_str(),
+ &model_triggering_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse model supported locales.";
+ return;
+ }
+
+ if (model_->triggering_options() != nullptr &&
+ model_->triggering_options()->locales() != nullptr &&
+ !ParseLocales(model_->triggering_options()->locales()->c_str(),
+ &ml_model_triggering_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
+ return;
+ }
+
+ if (model_->triggering_options() != nullptr &&
+ model_->triggering_options()->dictionary_locales() != nullptr &&
+ !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
+ &dictionary_locales_)) {
+ TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
+ return;
+ }
+
+ if (model_->conflict_resolution_options() != nullptr) {
+ prioritize_longest_annotation_ =
+ model_->conflict_resolution_options()->prioritize_longest_annotation();
+ do_conflict_resolution_in_raw_mode_ =
+ model_->conflict_resolution_options()
+ ->do_conflict_resolution_in_raw_mode();
+ }
+
+#ifdef TC3_EXPERIMENTAL
+ TC3_LOG(WARNING) << "Enabling experimental annotators.";
+ InitializeExperimentalAnnotators();
+#endif
+
+ initialized_ = true;
+}
+
+bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
+ if (!model_->regex_model()->patterns()) {
+ return true;
+ }
+
+ // Initialize pattern recognizers.
+ int regex_pattern_id = 0;
+ for (const auto regex_pattern : *model_->regex_model()->patterns()) {
+ std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
+ UncompressMakeRegexPattern(
+ *unilib_, regex_pattern->pattern(),
+ regex_pattern->compressed_pattern(),
+ model_->regex_model()->lazy_regex_compilation(), decompressor);
+ if (!compiled_pattern) {
+ TC3_LOG(INFO) << "Failed to load regex pattern";
+ return false;
+ }
+
+ if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
+ annotation_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
+ classification_regex_patterns_.push_back(regex_pattern_id);
+ }
+ if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
+ selection_regex_patterns_.push_back(regex_pattern_id);
+ }
+ regex_patterns_.push_back({
+ regex_pattern,
+ std::move(compiled_pattern),
+ });
+ ++regex_pattern_id;
+ }
+
+ return true;
+}
+
+bool Annotator::InitializeKnowledgeEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
+ if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
+ TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
+ return false;
+ }
+ if (model_->triggering_options() != nullptr) {
+ knowledge_engine->SetPriorityScore(
+ model_->triggering_options()->knowledge_priority_score());
+ }
+ knowledge_engine_ = std::move(knowledge_engine);
+ return true;
+}
+
+bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
+ std::unique_ptr<ContactEngine> contact_engine(
+ new ContactEngine(selection_feature_processor_.get(), unilib_,
+ model_->contact_annotator_options()));
+ if (!contact_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
+ return false;
+ }
+ contact_engine_ = std::move(contact_engine);
+ return true;
+}
+
+bool Annotator::InitializeInstalledAppEngine(
+ const std::string& serialized_config) {
+ std::unique_ptr<InstalledAppEngine> installed_app_engine(
+ new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
+ if (!installed_app_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
+ return false;
+ }
+ installed_app_engine_ = std::move(installed_app_engine);
+ return true;
+}
+
+void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
+ lang_id_ = lang_id;
+ if (lang_id_ != nullptr && model_->translate_annotator_options() &&
+ model_->translate_annotator_options()->enabled()) {
+ translate_annotator_.reset(new TranslateAnnotator(
+ model_->translate_annotator_options(), lang_id_, unilib_));
+ } else {
+ translate_annotator_.reset(nullptr);
+ }
+}
+
+bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
+ int size) {
+ const PersonNameModel* person_name_model =
+ LoadAndVerifyPersonNameModel(buffer, size);
+
+ if (person_name_model == nullptr) {
+ TC3_LOG(ERROR) << "Person name model verification failed.";
+ return false;
+ }
+
+ if (!person_name_model->enabled()) {
+ return true;
+ }
+
+ std::unique_ptr<PersonNameEngine> person_name_engine(
+ new PersonNameEngine(selection_feature_processor_.get(), unilib_));
+ if (!person_name_engine->Initialize(person_name_model)) {
+ TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
+ return false;
+ }
+ person_name_engine_ = std::move(person_name_engine);
+ return true;
+}
+
+bool Annotator::InitializePersonNameEngineFromScopedMmap(
+ const ScopedMmap& mmap) {
+ if (!mmap.handle().ok()) {
+ TC3_LOG(ERROR) << "Mmap for person name model failed.";
+ return false;
+ }
+
+ return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
+ mmap.handle().num_bytes());
+}
+
+bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
+ return InitializePersonNameEngineFromScopedMmap(*mmap);
+}
+
+bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
+ int size) {
+ std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
+ return InitializePersonNameEngineFromScopedMmap(*mmap);
+}
+
+bool Annotator::InitializeExperimentalAnnotators() {
+ if (ExperimentalAnnotator::IsEnabled()) {
+ experimental_annotator_.reset(new ExperimentalAnnotator(
+ model_->experimental_model(), *selection_feature_processor_, *unilib_));
+ return true;
+ }
+ return false;
+}
+
+namespace {
+
+int CountDigits(const std::string& str, CodepointSpan selection_indices) {
+ int count = 0;
+ int i = 0;
+ const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
+ for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
+ if (i >= selection_indices.first && i < selection_indices.second &&
+ IsDigit(*it)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+} // namespace
+
+namespace internal {
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on a left or right side
+// of this space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib) {
+ TC3_CHECK(ValidNonEmptySpan(span));
+
+ UnicodeText::const_iterator it;
+
+ // Check that the current selection is all whitespaces.
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ for (int i = 0; i < (span.second - span.first); ++i, ++it) {
+ if (!unilib.IsWhitespace(*it)) {
+ return span;
+ }
+ }
+
+ CodepointSpan result;
+
+ // Try moving left.
+ result = span;
+ it = context_unicode.begin();
+ std::advance(it, span.first);
+ while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
+ --result.first;
+ --it;
+ }
+ result.second = result.first + 1;
+ if (!unilib.IsWhitespace(*it)) {
+ return result;
+ }
+
+ // If moving left didn't find a non-whitespace character, just return the
+ // original span.
+ return span;
+}
+} // namespace internal
+
+bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_annotation_.find(
+ span.classification[0].collection) !=
+ filtered_collections_annotation_.end();
+}
+
+bool Annotator::FilteredForClassification(
+ const ClassificationResult& classification) const {
+ return filtered_collections_classification_.find(classification.collection) !=
+ filtered_collections_classification_.end();
+}
+
+bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
+ return !span.classification.empty() &&
+ filtered_collections_selection_.find(
+ span.classification[0].collection) !=
+ filtered_collections_selection_.end();
+}
+
+namespace {
+inline bool ClassifiedAsOther(
+ const std::vector<ClassificationResult>& classification) {
+ return !classification.empty() &&
+ classification[0].collection == Collections::Other();
+}
+
+} // namespace
+
+float Annotator::GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) const {
+ if (!classification.empty() && !ClassifiedAsOther(classification)) {
+ return classification[0].priority_score;
+ } else {
+ if (model_->triggering_options() != nullptr) {
+ return model_->triggering_options()->other_collection_priority_score();
+ } else {
+ return -1000.0;
+ }
+ }
+}
+
+bool Annotator::VerifyRegexMatchCandidate(
+ const std::string& context, const VerificationOptions* verification_options,
+ const std::string& match, const UniLib::RegexMatcher* matcher) const {
+ if (verification_options == nullptr) {
+ return true;
+ }
+ if (verification_options->verify_luhn_checksum() &&
+ !VerifyLuhnChecksum(match)) {
+ return false;
+ }
+ const int lua_verifier = verification_options->lua_verifier();
+ if (lua_verifier >= 0) {
+ if (model_->regex_model()->lua_verifier() == nullptr ||
+ lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
+ TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
+ return false;
+ }
+ return VerifyMatch(
+ context, matcher,
+ model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
+ }
+ return true;
+}
+
+CodepointSpan Annotator::SuggestSelection(
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options) const {
+ CodepointSpan original_click_indices = click_indices;
+ if (!initialized_) {
+ TC3_LOG(ERROR) << "Not initialized";
+ return original_click_indices;
+ }
+ if (options.annotation_usecase !=
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ TC3_LOG(WARNING)
+ << "Invoking SuggestSelection, which is not supported in RAW mode.";
+ return original_click_indices;
+ }
+ if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
+ return original_click_indices;
+ }
+
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
+ return original_click_indices;
+ }
+
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+
+ if (!IsValidSpanInput(context_unicode, click_indices)) {
+ TC3_VLOG(1)
+ << "Trying to run SuggestSelection with invalid input, indices: "
+ << click_indices.first << " " << click_indices.second;
+ return original_click_indices;
+ }
+
+ if (model_->snap_whitespace_selections()) {
+ // We want to expand a purely white-space selection to a multi-selection it
+ // would've been part of. But with this feature disabled we would do a no-
+ // op, because no token is found. Therefore, we need to modify the
+ // 'click_indices' a bit to include a part of the token, so that the click-
+ // finding logic finds the clicked token correctly. This modification is
+ // done by the following function. Note, that it's enough to check the left
+ // side of the current selection, because if the white-space is a part of a
+ // multi-selection, necessarily both tokens - on the left and the right
+ // sides need to be selected. Thus snapping only to the left is sufficient
+ // (there's a check at the bottom that makes sure that if we snap to the
+ // left token but the result does not contain the initial white-space,
+ // returns the original indices).
+ click_indices = internal::SnapLeftIfWhitespaceSelection(
+ click_indices, context_unicode, *unilib_);
+ }
+
+ std::vector<AnnotatedSpan> candidates;
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ std::vector<Token> tokens;
+ if (!ModelSuggestSelection(context_unicode, click_indices,
+ detected_text_language_tags, &interpreter_manager,
+ &tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Model suggest selection failed.";
+ return original_click_indices;
+ }
+ if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
+ /*is_serialized_entity_data_enabled=*/false)) {
+ TC3_LOG(ERROR) << "Regex suggest selection failed.";
+ return original_click_indices;
+ }
+ if (!DatetimeChunk(
+ UTF8ToUnicodeText(context, /*do_copy=*/false),
+ /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
+ options.locales, ModeFlag_SELECTION, options.annotation_usecase,
+ /*is_serialized_entity_data_enabled=*/false, &candidates)) {
+ TC3_LOG(ERROR) << "Datetime suggest selection failed.";
+ return original_click_indices;
+ }
+ if (knowledge_engine_ != nullptr &&
+ !knowledge_engine_->Chunk(context, options.annotation_usecase,
+ options.location_context, Permissions(),
+ &candidates)) {
+ TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
+ return original_click_indices;
+ }
+ if (contact_engine_ != nullptr &&
+ !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Contact suggest selection failed.";
+ return original_click_indices;
+ }
+ if (installed_app_engine_ != nullptr &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Installed app suggest selection failed.";
+ return original_click_indices;
+ }
+ if (number_annotator_ != nullptr &&
+ !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ &candidates)) {
+ TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
+ return original_click_indices;
+ }
+ if (duration_annotator_ != nullptr &&
+ !duration_annotator_->FindAll(context_unicode, tokens,
+ options.annotation_usecase, &candidates)) {
+ TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
+ return original_click_indices;
+ }
+ if (person_name_engine_ != nullptr &&
+ !person_name_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Person name suggest selection failed.";
+ return original_click_indices;
+ }
+
+ AnnotatedSpan grammar_suggested_span;
+ if (grammar_annotator_ != nullptr &&
+ grammar_annotator_->SuggestSelection(detected_text_language_tags,
+ context_unicode, click_indices,
+ &grammar_suggested_span)) {
+ candidates.push_back(grammar_suggested_span);
+ }
+
+ if (experimental_annotator_ != nullptr) {
+ candidates.push_back(experimental_annotator_->SuggestSelection(
+ context_unicode, click_indices));
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ std::sort(candidates.begin(), candidates.end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
+ TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return original_click_indices;
+ }
+
+ std::sort(candidate_indices.begin(), candidate_indices.end(),
+ [this, &candidates](int a, int b) {
+ return GetPriorityScore(candidates[a].classification) >
+ GetPriorityScore(candidates[b].classification);
+ });
+
+ for (const int i : candidate_indices) {
+ if (SpansOverlap(candidates[i].span, click_indices) &&
+ SpansOverlap(candidates[i].span, original_click_indices)) {
+ // Run model classification if not present but requested and there's a
+ // classification collection filter specified.
+ if (candidates[i].classification.empty() &&
+ model_->selection_options()->always_classify_suggested_selection() &&
+ !filtered_collections_selection_.empty()) {
+ if (!ModelClassifyText(context, detected_text_language_tags,
+ candidates[i].span, &interpreter_manager,
+ /*embedding_cache=*/nullptr,
+ &candidates[i].classification)) {
+ return original_click_indices;
+ }
+ }
+
+ // Ignore if span classification is filtered.
+ if (FilteredForSelection(candidates[i])) {
+ return original_click_indices;
+ }
+
+ return candidates[i].span;
+ }
+ }
+
+ return original_click_indices;
+}
+
+namespace {
+// Helper function that returns the index of the first candidate that
+// transitively does not overlap with the candidate on 'start_index'. If the end
+// of 'candidates' is reached, it returns the index that points right behind the
+// array.
+int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
+ int start_index) {
+ int first_non_overlapping = start_index + 1;
+ CodepointSpan conflicting_span = candidates[start_index].span;
+ while (
+ first_non_overlapping < candidates.size() &&
+ SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
+ // Grow the span to include the current one.
+ conflicting_span.second = std::max(
+ conflicting_span.second, candidates[first_non_overlapping].span.second);
+
+ ++first_non_overlapping;
+ }
+ return first_non_overlapping;
+}
+} // namespace
+
+bool Annotator::ResolveConflicts(
+ const std::vector<AnnotatedSpan>& candidates, const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager, std::vector<int>* result) const {
+ result->clear();
+ result->reserve(candidates.size());
+ for (int i = 0; i < candidates.size();) {
+ int first_non_overlapping =
+ FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
+
+ const bool conflict_found = first_non_overlapping != (i + 1);
+ if (conflict_found) {
+ std::vector<int> candidate_indices;
+ if (!ResolveConflict(context, cached_tokens, candidates,
+ detected_text_language_tags, i,
+ first_non_overlapping, annotation_usecase,
+ interpreter_manager, &candidate_indices)) {
+ return false;
+ }
+ result->insert(result->end(), candidate_indices.begin(),
+ candidate_indices.end());
+ } else {
+ result->push_back(i);
+ }
+
+ // Skip over the whole conflicting group/go to next candidate.
+ i = first_non_overlapping;
+ }
+ return true;
+}
+
+namespace {
+// Returns true, if the given two sources do conflict in given annotation
+// usecase.
+// - In SMART usecase, all sources do conflict, because there's only 1 possible
+// annotation for a given span.
+// - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
+// and duration), while others not (e.g. duration and number).
+bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
+ const AnnotatedSpan::Source source1,
+ const AnnotatedSpan::Source source2) {
+ uint32 source_mask =
+ (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
+
+ switch (annotation_usecase) {
+ case AnnotationUsecase_ANNOTATION_USECASE_SMART:
+ // In the SMART mode, all annotations conflict.
+ return true;
+
+ case AnnotationUsecase_ANNOTATION_USECASE_RAW:
+ // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
+ // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
+ // hours" (duration).
+ if ((source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
+ (source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
+ return false;
+ }
+
+ // A KNOWLEDGE entity does not conflict with anything.
+ if ((source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
+ return false;
+ }
+
+ // A PERSONNAME entity does not conflict with anything.
+ if ((source_mask &
+ (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
+ return false;
+ }
+
+ // Entities from other sources can conflict.
+ return true;
+ }
+}
+} // namespace
+
+bool Annotator::ResolveConflict(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates,
+ const std::vector<Locale>& detected_text_language_tags, int start_index,
+ int end_index, AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const {
+ std::vector<int> conflicting_indices;
+ std::unordered_map<int, std::pair<float, int>> scores_lengths;
+ for (int i = start_index; i < end_index; ++i) {
+ conflicting_indices.push_back(i);
+ if (!candidates[i].classification.empty()) {
+ scores_lengths[i] = {
+ GetPriorityScore(candidates[i].classification),
+ candidates[i].span.second - candidates[i].span.first};
+ continue;
+ }
+
+ // OPTIMIZATION: So that we don't have to classify all the ML model
+ // spans apriori, we wait until we get here, when they conflict with
+ // something and we need the actual classification scores. So if the
+ // candidate conflicts and comes from the model, we need to run a
+ // classification to determine its priority:
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
+ candidates[i].span, interpreter_manager,
+ /*embedding_cache=*/nullptr, &classification)) {
+ return false;
+ }
+
+ if (!classification.empty()) {
+ scores_lengths[i] = {
+ GetPriorityScore(classification),
+ candidates[i].span.second - candidates[i].span.first};
+ }
+ }
+
+ std::sort(
+ conflicting_indices.begin(), conflicting_indices.end(),
+ [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
+ if (scores_lengths[i].first == scores_lengths[j].first &&
+ prioritize_longest_annotation_) {
+ return scores_lengths[i].second > scores_lengths[j].second;
+ }
+ return scores_lengths[i].first > scores_lengths[j].first;
+ });
+
+ // Here we keep a set of indices that were chosen, per-source, to enable
+ // effective computation.
+ std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
+ chosen_indices_for_source_map;
+
+ // Greedily place the candidates if they don't conflict with the already
+ // placed ones.
+ for (int i = 0; i < conflicting_indices.size(); ++i) {
+ const int considered_candidate = conflicting_indices[i];
+
+ // See if there is a conflict between the candidate and all already placed
+ // candidates.
+ bool conflict = false;
+ SortedIntSet* chosen_indices_for_source_ptr = nullptr;
+ for (auto& source_set_pair : chosen_indices_for_source_map) {
+ if (source_set_pair.first == candidates[considered_candidate].source) {
+ chosen_indices_for_source_ptr = &source_set_pair.second;
+ }
+
+ const bool needs_conflict_resolution =
+ annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
+ (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
+ do_conflict_resolution_in_raw_mode_);
+ if (needs_conflict_resolution &&
+ DoSourcesConflict(annotation_usecase, source_set_pair.first,
+ candidates[considered_candidate].source) &&
+ DoesCandidateConflict(considered_candidate, candidates,
+ source_set_pair.second)) {
+ conflict = true;
+ break;
+ }
+ }
+
+ // Skip the candidate if a conflict was found.
+ if (conflict) {
+ continue;
+ }
+
+ // If the set of indices for the current source doesn't exist yet,
+ // initialize it.
+ if (chosen_indices_for_source_ptr == nullptr) {
+ SortedIntSet new_set([&candidates](int a, int b) {
+ return candidates[a].span.first < candidates[b].span.first;
+ });
+ chosen_indices_for_source_map[candidates[considered_candidate].source] =
+ std::move(new_set);
+ chosen_indices_for_source_ptr =
+ &chosen_indices_for_source_map[candidates[considered_candidate]
+ .source];
+ }
+
+ // Place the candidate to the output and to the per-source conflict set.
+ chosen_indices->push_back(considered_candidate);
+ chosen_indices_for_source_ptr->insert(considered_candidate);
+ }
+
+ std::sort(chosen_indices->begin(), chosen_indices->end());
+
+ return true;
+}
+
+bool Annotator::ModelSuggestSelection(
+ const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
+ return true;
+ }
+
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
+ int click_pos;
+ *tokens = selection_feature_processor_->Tokenize(context_unicode);
+ selection_feature_processor_->RetokenizeAndFindClick(
+ context_unicode, click_indices,
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ tokens, &click_pos);
+ if (click_pos == kInvalidIndex) {
+ TC3_VLOG(1) << "Could not calculate the click position.";
+ return false;
+ }
+
+ const int symmetry_context_size =
+ model_->selection_options()->symmetry_context_size();
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features = selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+
+ // The symmetry context span is the clicked token with symmetry_context_size
+ // tokens on either side.
+ const TokenSpan symmetry_context_span = IntersectTokenSpans(
+ ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/symmetry_context_size,
+ /*num_tokens_right=*/symmetry_context_size),
+ {0, tokens->size()});
+
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the symmetry context span expanded to include
+ // max_selection_span tokens on either side, which is how far a selection
+ // can stretch from the click, plus a relevant number of tokens outside of
+ // the bounds of the selection.
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ extraction_span =
+ ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/max_selection_span +
+ bounds_sensitive_features->num_tokens_after());
+ } else {
+ // The extraction span is the symmetry context span expanded to include
+ // context_size tokens on either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(symmetry_context_span,
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, extraction_span)) {
+ return true;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!selection_feature_processor_->ExtractFeatures(
+ *tokens, extraction_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ // Produce selection model candidates.
+ std::vector<TokenSpan> chunks;
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
+ interpreter_manager->SelectionInterpreter(), *cached_features,
+ &chunks)) {
+ TC3_LOG(ERROR) << "Could not chunk.";
+ return false;
+ }
+
+ for (const TokenSpan& chunk : chunks) {
+ AnnotatedSpan candidate;
+ candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
+ context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
+ if (model_->selection_options()->strip_unpaired_brackets()) {
+ candidate.span =
+ StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
+ }
+
+ // Only output non-empty spans.
+ if (candidate.span.first != candidate.span.second) {
+ result->push_back(candidate);
+ }
+ }
+ return true;
+}
+
+bool Annotator::ModelClassifyText(
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const {
+ return ModelClassifyText(context, {}, detected_text_language_tags,
+ selection_indices, interpreter_manager,
+ embedding_cache, classification_results);
+}
+
+namespace internal {
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy) {
+ const auto first_selection_token = std::upper_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
+ [](int selection_start, const Token& token) {
+ return selection_start < token.end;
+ });
+ const auto last_selection_token = std::lower_bound(
+ cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
+ [](const Token& token, int selection_end) {
+ return token.start < selection_end;
+ });
+
+ const int64 first_token = std::max(
+ static_cast<int64>(0),
+ static_cast<int64>((first_selection_token - cached_tokens.begin()) -
+ tokens_around_selection_to_copy.first));
+ const int64 last_token = std::min(
+ static_cast<int64>(cached_tokens.size()),
+ static_cast<int64>((last_selection_token - cached_tokens.begin()) +
+ tokens_around_selection_to_copy.second));
+
+ std::vector<Token> tokens;
+ tokens.reserve(last_token - first_token);
+ for (int i = first_token; i < last_token; ++i) {
+ tokens.push_back(cached_tokens[i]);
+ }
+ return tokens;
+}
+} // namespace internal
+
+TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ return {bounds_sensitive_features->num_tokens_before(),
+ bounds_sensitive_features->num_tokens_after()};
+ } else {
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ selection_feature_processor_->GetOptions()->context_size();
+ return {context_size, context_size};
+ }
+}
+
+namespace {
+// Sorts the classification results from high score to low score.
+void SortClassificationResults(
+ std::vector<ClassificationResult>* classification_results) {
+ std::sort(classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+}
+} // namespace
+
+bool Annotator::ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const {
+ std::vector<Token> tokens;
+ return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
+ selection_indices, interpreter_manager,
+ embedding_cache, classification_results, &tokens);
+}
+
+bool Annotator::ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
+ return true;
+ }
+
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
+ if (cached_tokens.empty()) {
+ *tokens = classification_feature_processor_->Tokenize(context);
+ } else {
+ *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
+ ClassifyTextUpperBoundNeededTokens());
+ }
+
+ int click_pos;
+ classification_feature_processor_->RetokenizeAndFindClick(
+ context, selection_indices,
+ classification_feature_processor_->GetOptions()
+ ->only_use_line_with_click(),
+ tokens, &click_pos);
+ const TokenSpan selection_token_span =
+ CodepointSpanToTokenSpan(*tokens, selection_indices);
+ const int selection_num_tokens = TokenSpanSize(selection_token_span);
+ if (model_->classification_options()->max_num_tokens() > 0 &&
+ model_->classification_options()->max_num_tokens() <
+ selection_num_tokens) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+
+ const FeatureProcessorOptions_::BoundsSensitiveFeatures*
+ bounds_sensitive_features =
+ classification_feature_processor_->GetOptions()
+ ->bounds_sensitive_features();
+ if (selection_token_span.first == kInvalidIndex ||
+ selection_token_span.second == kInvalidIndex) {
+ TC3_LOG(ERROR) << "Could not determine span.";
+ return false;
+ }
+
+ // Compute the extraction span based on the model type.
+ TokenSpan extraction_span;
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ // The extraction span is the selection span expanded to include a relevant
+ // number of tokens outside of the bounds of the selection.
+ extraction_span = ExpandTokenSpan(
+ selection_token_span,
+ /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
+ /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
+ } else {
+ if (click_pos == kInvalidIndex) {
+ TC3_LOG(ERROR) << "Couldn't choose a click position.";
+ return false;
+ }
+ // The extraction span is the clicked token with context_size tokens on
+ // either side.
+ const int context_size =
+ classification_feature_processor_->GetOptions()->context_size();
+ extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
+ /*num_tokens_left=*/context_size,
+ /*num_tokens_right=*/context_size);
+ }
+ extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
+
+ if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, extraction_span)) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!classification_feature_processor_->ExtractFeatures(
+ *tokens, extraction_span, selection_indices,
+ embedding_executor_.get(), embedding_cache,
+ classification_feature_processor_->EmbeddingSize() +
+ classification_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ std::vector<float> features;
+ features.reserve(cached_features->OutputFeaturesSize());
+ if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
+ cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
+ &features);
+ } else {
+ cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
+ }
+
+ TensorView<float> logits = classification_executor_->ComputeLogits(
+ TensorView<float>(features.data(),
+ {1, static_cast<int>(features.size())}),
+ interpreter_manager->ClassificationInterpreter());
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+
+ if (logits.dims() != 2 || logits.dim(0) != 1 ||
+ logits.dim(1) != classification_feature_processor_->NumCollections()) {
+ TC3_LOG(ERROR) << "Mismatching output";
+ return false;
+ }
+
+ const std::vector<float> scores =
+ ComputeSoftmax(logits.data(), logits.dim(1));
+
+ if (scores.empty()) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+
+ const int best_score_index =
+ std::max_element(scores.begin(), scores.end()) - scores.begin();
+ const std::string top_collection =
+ classification_feature_processor_->LabelToCollection(best_score_index);
+
+ // Sanity checks.
+ if (top_collection == Collections::Phone()) {
+ const int digit_count = CountDigits(context, selection_indices);
+ if (digit_count <
+ model_->classification_options()->phone_min_num_digits() ||
+ digit_count >
+ model_->classification_options()->phone_max_num_digits()) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+ } else if (top_collection == Collections::Address()) {
+ if (selection_num_tokens <
+ model_->classification_options()->address_min_num_tokens()) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+ } else if (top_collection == Collections::Dictionary()) {
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ dictionary_locales_,
+ /*default_value=*/false)) {
+ *classification_results = {{Collections::Other(), 1.0}};
+ return true;
+ }
+ }
+
+ *classification_results = {{top_collection, /*arg_score=*/1.0,
+ /*arg_priority_score=*/scores[best_score_index]}};
+
+ // For some entities, we might want to clamp the priority score, for better
+ // conflict resolution between entities.
+ if (model_->triggering_options() != nullptr &&
+ model_->triggering_options()->collection_to_priority() != nullptr) {
+ if (auto entry =
+ model_->triggering_options()->collection_to_priority()->LookupByKey(
+ top_collection.c_str())) {
+ (*classification_results)[0].priority_score *= entry->value();
+ }
+ }
+ return true;
+}
+
+bool Annotator::RegexClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ std::vector<ClassificationResult>* classification_result) const {
+ const std::string selection_text =
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
+ const UnicodeText selection_text_unicode(
+ UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
+
+ // Check whether any of the regular expressions match.
+ for (const int pattern_id : classification_regex_patterns_) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern.pattern->Matcher(selection_text_unicode);
+ int status = UniLib::RegexMatcher::kNoError;
+ bool matches;
+ if (regex_pattern.config->use_approximate_matching()) {
+ matches = matcher->ApproximatelyMatches(&status);
+ } else {
+ matches = matcher->Matches(&status);
+ }
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (matches && VerifyRegexMatchCandidate(
+ context, regex_pattern.config->verification_options(),
+ selection_text, matcher.get())) {
+ classification_result->push_back(
+ {regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()});
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(),
+ &classification_result->back().serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+namespace {
+std::string PickCollectionForDatetime(
+ const DatetimeParseResult& datetime_parse_result) {
+ switch (datetime_parse_result.granularity) {
+ case GRANULARITY_HOUR:
+ case GRANULARITY_MINUTE:
+ case GRANULARITY_SECOND:
+ return Collections::DateTime();
+ default:
+ return Collections::Date();
+ }
+}
+
+std::string CreateDatetimeSerializedEntityData(
+ const DatetimeParseResult& parse_result) {
+ EntityDataT entity_data;
+ entity_data.datetime.reset(new EntityData_::DatetimeT());
+ entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
+ entity_data.datetime->granularity =
+ static_cast<EntityData_::Datetime_::Granularity>(
+ parse_result.granularity);
+
+ for (const auto& c : parse_result.datetime_components) {
+ EntityData_::Datetime_::DatetimeComponentT datetime_component;
+ datetime_component.absolute_value = c.value;
+ datetime_component.relative_count = c.relative_count;
+ datetime_component.component_type =
+ static_cast<EntityData_::Datetime_::DatetimeComponent_::ComponentType>(
+ c.component_type);
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE;
+ if (c.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ datetime_component.relation_type =
+ EntityData_::Datetime_::DatetimeComponent_::RelationType_RELATIVE;
+ }
+ entity_data.datetime->datetime_component.emplace_back(
+ new EntityData_::Datetime_::DatetimeComponentT(datetime_component));
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace
+
+bool Annotator::DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ std::vector<ClassificationResult>* classification_results) const {
+ if (!datetime_parser_ && !cfg_datetime_parser_) {
+ return true;
+ }
+
+ const std::string selection_text =
+ UTF8ToUnicodeText(context, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
+
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+
+ if (cfg_datetime_parser_) {
+ if (!(model_->grammar_datetime_model()->enabled_modes() &
+ ModeFlag_CLASSIFICATION)) {
+ return true;
+ }
+ std::vector<Locale> parsed_locales;
+ ParseLocales(options.locales, &parsed_locales);
+ cfg_datetime_parser_->Parse(
+ selection_text,
+ ToDateAnnotationOptions(
+ model_->grammar_datetime_model()->annotation_options(),
+ options.reference_timezone, options.reference_time_ms_utc),
+ parsed_locales, &datetime_spans);
+ }
+
+ if (datetime_parser_) {
+ if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, options.locales,
+ ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
+ /*anchor_start_end=*/true, &datetime_spans)) {
+ TC3_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
+ }
+ }
+
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ // Only consider the result valid if the selection and extracted datetime
+ // spans exactly match.
+ if (std::make_pair(datetime_span.span.first + selection_indices.first,
+ datetime_span.span.second + selection_indices.first) ==
+ selection_indices) {
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ classification_results->emplace_back(
+ PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score);
+ classification_results->back().datetime_parse_result = parse_result;
+ classification_results->back().serialized_entity_data =
+ CreateDatetimeSerializedEntityData(parse_result);
+ classification_results->back().priority_score =
+ datetime_span.priority_score;
+ }
+ return true;
+ }
+ }
+ return true;
+}
+
+std::vector<ClassificationResult> Annotator::ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options) const {
+ if (!initialized_) {
+ TC3_LOG(ERROR) << "Not initialized";
+ return {};
+ }
+ if (options.annotation_usecase !=
+ AnnotationUsecase_ANNOTATION_USECASE_SMART) {
+ TC3_LOG(WARNING)
+ << "Invoking ClassifyText, which is not supported in RAW mode.";
+ return {};
+ }
+ if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
+ return {};
+ }
+
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
+ return {};
+ }
+
+ if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ selection_indices)) {
+ TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
+ return {};
+ }
+
+ // We'll accumulate a list of candidates, and pick the best candidate in the
+ // end.
+ std::vector<AnnotatedSpan> candidates;
+
+ // Try the knowledge engine.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult knowledge_result;
+ if (knowledge_engine_ &&
+ knowledge_engine_->ClassifyText(
+ context, selection_indices, options.annotation_usecase,
+ options.location_context, Permissions(), &knowledge_result)) {
+ candidates.push_back({selection_indices, {knowledge_result}});
+ candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
+ }
+
+ AddContactMetadataToKnowledgeClassificationResults(&candidates);
+
+ // Try the contact engine.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult contact_result;
+ if (contact_engine_ && contact_engine_->ClassifyText(
+ context, selection_indices, &contact_result)) {
+ candidates.push_back({selection_indices, {contact_result}});
+ }
+
+ // Try the person name engine.
+ ClassificationResult person_name_result;
+ if (person_name_engine_ &&
+ person_name_engine_->ClassifyText(context, selection_indices,
+ &person_name_result)) {
+ candidates.push_back({selection_indices, {person_name_result}});
+ candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
+ }
+
+ // Try the installed app engine.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult installed_app_result;
+ if (installed_app_engine_ &&
+ installed_app_engine_->ClassifyText(context, selection_indices,
+ &installed_app_result)) {
+ candidates.push_back({selection_indices, {installed_app_result}});
+ }
+
+ // Try the regular expression models.
+ std::vector<ClassificationResult> regex_results;
+ if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
+ return {};
+ }
+ for (const ClassificationResult& result : regex_results) {
+ candidates.push_back({selection_indices, {result}});
+ }
+
+ // Try the date model.
+ //
+ // DatetimeClassifyText only returns the first result, which can however have
+ // more interpretations. They are inserted in the candidates as a single
+ // AnnotatedSpan, so that they get treated together by the conflict resolution
+ // algorithm.
+ std::vector<ClassificationResult> datetime_results;
+ if (!DatetimeClassifyText(context, selection_indices, options,
+ &datetime_results)) {
+ return {};
+ }
+ if (!datetime_results.empty()) {
+ candidates.push_back({selection_indices, std::move(datetime_results)});
+ candidates.back().source = AnnotatedSpan::Source::DATETIME;
+ }
+
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ // Try the number annotator.
+ // TODO(b/126579108): Propagate error status.
+ ClassificationResult number_annotator_result;
+ if (number_annotator_ &&
+ number_annotator_->ClassifyText(context_unicode, selection_indices,
+ options.annotation_usecase,
+ &number_annotator_result)) {
+ candidates.push_back({selection_indices, {number_annotator_result}});
+ }
+
+ // Try the duration annotator.
+ ClassificationResult duration_annotator_result;
+ if (duration_annotator_ &&
+ duration_annotator_->ClassifyText(context_unicode, selection_indices,
+ options.annotation_usecase,
+ &duration_annotator_result)) {
+ candidates.push_back({selection_indices, {duration_annotator_result}});
+ candidates.back().source = AnnotatedSpan::Source::DURATION;
+ }
+
+ // Try the translate annotator.
+ ClassificationResult translate_annotator_result;
+ if (translate_annotator_ &&
+ translate_annotator_->ClassifyText(context_unicode, selection_indices,
+ options.user_familiar_language_tags,
+ &translate_annotator_result)) {
+ candidates.push_back({selection_indices, {translate_annotator_result}});
+ }
+
+ // Try the grammar model.
+ ClassificationResult grammar_annotator_result;
+ if (grammar_annotator_ && grammar_annotator_->ClassifyText(
+ detected_text_language_tags, context_unicode,
+ selection_indices, &grammar_annotator_result)) {
+ candidates.push_back({selection_indices, {grammar_annotator_result}});
+ }
+
+ ClassificationResult experimental_annotator_result;
+ if (experimental_annotator_ &&
+ experimental_annotator_->ClassifyText(context_unicode, selection_indices,
+ &experimental_annotator_result)) {
+ candidates.push_back({selection_indices, {experimental_annotator_result}});
+ }
+
+ // Try the ML model.
+ //
+ // The output of the model is considered as an exclusive 1-of-N choice. That's
+ // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
+ // span for each candidate, like e.g. the regex model.
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+ std::vector<ClassificationResult> model_results;
+ std::vector<Token> tokens;
+ if (!ModelClassifyText(
+ context, /*cached_tokens=*/{}, detected_text_language_tags,
+ selection_indices, &interpreter_manager,
+ /*embedding_cache=*/nullptr, &model_results, &tokens)) {
+ return {};
+ }
+ if (!model_results.empty()) {
+ candidates.push_back({selection_indices, std::move(model_results)});
+ }
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
+ TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
+ return {};
+ }
+
+ std::vector<ClassificationResult> results;
+ for (const int i : candidate_indices) {
+ for (const ClassificationResult& result : candidates[i].classification) {
+ if (!FilteredForClassification(result)) {
+ results.push_back(result);
+ }
+ }
+ }
+
+ // Sort results according to score.
+ std::sort(results.begin(), results.end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
+
+ if (results.empty()) {
+ results = {{Collections::Other(), 1.0}};
+ }
+ return results;
+}
+
+bool Annotator::ModelAnnotate(
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ if (model_->triggering_options() == nullptr ||
+ !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return true;
+ }
+
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ ml_model_triggering_locales_,
+ /*default_value=*/true)) {
+ return true;
+ }
+
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ std::vector<UnicodeTextRange> lines;
+ if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
+ lines.push_back({context_unicode.begin(), context_unicode.end()});
+ } else {
+ lines = selection_feature_processor_->SplitContext(
+ context_unicode, selection_feature_processor_->GetOptions()
+ ->use_pipe_character_for_newline());
+ }
+
+ const float min_annotate_confidence =
+ (model_->triggering_options() != nullptr
+ ? model_->triggering_options()->min_annotate_confidence()
+ : 0.f);
+
+ for (const UnicodeTextRange& line : lines) {
+ FeatureProcessor::EmbeddingCache embedding_cache;
+ const std::string line_str =
+ UnicodeText::UTF8Substring(line.first, line.second);
+
+ *tokens = selection_feature_processor_->Tokenize(line_str);
+ selection_feature_processor_->RetokenizeAndFindClick(
+ line_str, {0, std::distance(line.first, line.second)},
+ selection_feature_processor_->GetOptions()->only_use_line_with_click(),
+ tokens,
+ /*click_pos=*/nullptr);
+ const TokenSpan full_line_span = {0, tokens->size()};
+
+ // TODO(zilka): Add support for greater granularity of this check.
+ if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
+ *tokens, full_line_span)) {
+ continue;
+ }
+
+ std::unique_ptr<CachedFeatures> cached_features;
+ if (!selection_feature_processor_->ExtractFeatures(
+ *tokens, full_line_span,
+ /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
+ embedding_executor_.get(),
+ /*embedding_cache=*/nullptr,
+ selection_feature_processor_->EmbeddingSize() +
+ selection_feature_processor_->DenseFeaturesCount(),
+ &cached_features)) {
+ TC3_LOG(ERROR) << "Could not extract features.";
+ return false;
+ }
+
+ std::vector<TokenSpan> local_chunks;
+ if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
+ interpreter_manager->SelectionInterpreter(),
+ *cached_features, &local_chunks)) {
+ TC3_LOG(ERROR) << "Could not chunk.";
+ return false;
+ }
+
+ const int offset = std::distance(context_unicode.begin(), line.first);
+ for (const TokenSpan& chunk : local_chunks) {
+ const CodepointSpan codepoint_span =
+ selection_feature_processor_->StripBoundaryCodepoints(
+ line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+
+ // Skip empty spans.
+ if (codepoint_span.first != codepoint_span.second) {
+ std::vector<ClassificationResult> classification;
+ if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
+ codepoint_span, interpreter_manager,
+ &embedding_cache, &classification)) {
+ TC3_LOG(ERROR) << "Could not classify text: "
+ << (codepoint_span.first + offset) << " "
+ << (codepoint_span.second + offset);
+ return false;
+ }
+
+ // Do not include the span if it's classified as "other".
+ if (!classification.empty() && !ClassifiedAsOther(classification) &&
+ classification[0].score >= min_annotate_confidence) {
+ AnnotatedSpan result_span;
+ result_span.span = {codepoint_span.first + offset,
+ codepoint_span.second + offset};
+ result_span.classification = std::move(classification);
+ result->push_back(std::move(result_span));
+ }
+ }
+ }
+ }
+ return true;
+}
+
+const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
+ return selection_feature_processor_.get();
+}
+
+const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
+ const {
+ return classification_feature_processor_.get();
+}
+
+const DatetimeParser* Annotator::DatetimeParserForTests() const {
+ return datetime_parser_.get();
+}
+
+void Annotator::RemoveNotEnabledEntityTypes(
+ const EnabledEntityTypes& is_entity_type_enabled,
+ std::vector<AnnotatedSpan>* annotated_spans) const {
+ for (AnnotatedSpan& annotated_span : *annotated_spans) {
+ std::vector<ClassificationResult>& classifications =
+ annotated_span.classification;
+ classifications.erase(
+ std::remove_if(classifications.begin(), classifications.end(),
+ [&is_entity_type_enabled](
+ const ClassificationResult& classification_result) {
+ return !is_entity_type_enabled(
+ classification_result.collection);
+ }),
+ classifications.end());
+ }
+ annotated_spans->erase(
+ std::remove_if(annotated_spans->begin(), annotated_spans->end(),
+ [](const AnnotatedSpan& annotated_span) {
+ return annotated_span.classification.empty();
+ }),
+ annotated_spans->end());
+}
+
+void Annotator::AddContactMetadataToKnowledgeClassificationResults(
+ std::vector<AnnotatedSpan>* candidates) const {
+ if (candidates == nullptr || contact_engine_ == nullptr) {
+ return;
+ }
+ for (auto& candidate : *candidates) {
+ for (auto& classification_result : candidate.classification) {
+ contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
+ &classification_result);
+ }
+ }
+}
+
+Status Annotator::AnnotateSingleInput(
+ const std::string& context, const AnnotationOptions& options,
+ std::vector<AnnotatedSpan>* candidates) const {
+ if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
+ return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
+ }
+
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!context_unicode.is_valid()) {
+ return Status(StatusCode::INVALID_ARGUMENT,
+ "Context string isn't valid UTF8.");
+ }
+
+ std::vector<Locale> detected_text_language_tags;
+ if (!ParseLocales(options.detected_text_language_tags,
+ &detected_text_language_tags)) {
+ TC3_LOG(WARNING)
+ << "Failed to parse the detected_text_language_tags in options: "
+ << options.detected_text_language_tags;
+ }
+ if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ model_triggering_locales_,
+ /*default_value=*/true)) {
+ return Status(
+ StatusCode::UNAVAILABLE,
+ "The detected language tags are not in the supported locales.");
+ }
+
+ InterpreterManager interpreter_manager(selection_executor_.get(),
+ classification_executor_.get());
+
+ // Annotate with the selection model.
+ std::vector<Token> tokens;
+ if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
+ &tokens, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
+ }
+
+ // Annotate with the regular expression models.
+ if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ annotation_regex_patterns_, candidates,
+ options.is_serialized_entity_data_enabled)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
+ }
+
+ // Annotate with the datetime model.
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ if ((is_entity_type_enabled(Collections::Date()) ||
+ is_entity_type_enabled(Collections::DateTime())) &&
+ !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ options.reference_time_ms_utc, options.reference_timezone,
+ options.locales, ModeFlag_ANNOTATION,
+ options.annotation_usecase,
+ options.is_serialized_entity_data_enabled, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
+ }
+
+ // Annotate with the contact engine.
+ if (contact_engine_ &&
+ !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
+ }
+
+ // Annotate with the installed app engine.
+ if (installed_app_engine_ &&
+ !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
+ return Status(StatusCode::INTERNAL,
+ "Couldn't run installed app engine Chunk.");
+ }
+
+ // Annotate with the number annotator.
+ if (number_annotator_ != nullptr &&
+ !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
+ candidates)) {
+ return Status(StatusCode::INTERNAL,
+ "Couldn't run number annotator FindAll.");
+ }
+
+ // Annotate with the duration annotator.
+ if (is_entity_type_enabled(Collections::Duration()) &&
+ duration_annotator_ != nullptr &&
+ !duration_annotator_->FindAll(context_unicode, tokens,
+ options.annotation_usecase, candidates)) {
+ return Status(StatusCode::INTERNAL,
+ "Couldn't run duration annotator FindAll.");
+ }
+
+ // Annotate with the person name engine.
+ if (is_entity_type_enabled(Collections::PersonName()) &&
+ person_name_engine_ &&
+ !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
+ return Status(StatusCode::INTERNAL,
+ "Couldn't run person name engine Chunk.");
+ }
+
+ // Annotate with the grammar annotators.
+ if (grammar_annotator_ != nullptr &&
+ !grammar_annotator_->Annotate(detected_text_language_tags,
+ context_unicode, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
+ }
+
+ if (experimental_annotator_ != nullptr &&
+ !experimental_annotator_->Annotate(context_unicode, candidates)) {
+ return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
+ }
+
+ // Sort candidates according to their position in the input, so that the next
+ // code can assume that any connected component of overlapping spans forms a
+ // contiguous block.
+ // Also sort them according to the end position and collection, so that the
+ // deduplication code below can assume that same spans and classifications
+ // form contiguous blocks.
+ std::sort(candidates->begin(), candidates->end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ if (a.span.first != b.span.first) {
+ return a.span.first < b.span.first;
+ }
+
+ if (a.span.second != b.span.second) {
+ return a.span.second < b.span.second;
+ }
+
+ return a.classification[0].collection <
+ b.classification[0].collection;
+ });
+
+ std::vector<int> candidate_indices;
+ if (!ResolveConflicts(*candidates, context, tokens,
+ detected_text_language_tags, options.annotation_usecase,
+ &interpreter_manager, &candidate_indices)) {
+ return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
+ }
+
+ // Remove candidates that overlap exactly and have the same collection.
+ // This can e.g. happen for phone coming from both ML model and regex.
+ candidate_indices.erase(
+ std::unique(candidate_indices.begin(), candidate_indices.end(),
+ [&candidates](const int a_index, const int b_index) {
+ const AnnotatedSpan& a = (*candidates)[a_index];
+ const AnnotatedSpan& b = (*candidates)[b_index];
+ return a.span == b.span &&
+ a.classification[0].collection ==
+ b.classification[0].collection;
+ }),
+ candidate_indices.end());
+
+ std::vector<AnnotatedSpan> result;
+ result.reserve(candidate_indices.size());
+ for (const int i : candidate_indices) {
+ if ((*candidates)[i].classification.empty() ||
+ ClassifiedAsOther((*candidates)[i].classification) ||
+ FilteredForAnnotation((*candidates)[i])) {
+ continue;
+ }
+ result.push_back(std::move((*candidates)[i]));
+ }
+
+ // We generate all candidates and remove them later (with the exception of
+ // date/time/duration entities) because there are complex interdependencies
+ // between the entity types. E.g., the TLD of an email can be interpreted as a
+ // URL, but most likely a user of the API does not want such annotations if
+ // "url" is enabled and "email" is not.
+ RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
+
+ for (AnnotatedSpan& annotated_span : result) {
+ SortClassificationResults(&annotated_span.classification);
+ }
+ *candidates = result;
+ return Status::OK;
+}
+
+StatusOr<std::vector<std::vector<AnnotatedSpan>>>
+Annotator::AnnotateStructuredInput(
+ const std::vector<InputFragment>& string_fragments,
+ const AnnotationOptions& options) const {
+ std::vector<std::vector<AnnotatedSpan>> annotation_candidates(
+ string_fragments.size());
+
+ std::vector<std::string> text_to_annotate;
+ text_to_annotate.reserve(string_fragments.size());
+ for (const auto& string_fragment : string_fragments) {
+ text_to_annotate.push_back(string_fragment.text);
+ }
+
+ // KnowledgeEngine is special, because it supports annotation of multiple
+ // fragments at once.
+ if (knowledge_engine_ &&
+ !knowledge_engine_
+ ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
+ options.location_context, options.permissions,
+ &annotation_candidates)
+ .ok()) {
+ return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
+ }
+ // The annotator engines shouldn't change the number of annotation vectors.
+ if (annotation_candidates.size() != text_to_annotate.size()) {
+ TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
+ << " texts to annotate but generated a different number of "
+ "lists of annotations:"
+ << annotation_candidates.size();
+ return Status(StatusCode::INTERNAL,
+ "Number of annotation candidates differs from "
+ "number of texts to annotate.");
+ }
+
+ // Other annotators run on each fragment independently.
+ for (int i = 0; i < text_to_annotate.size(); ++i) {
+ AnnotationOptions annotation_options = options;
+ if (string_fragments[i].datetime_options.has_value()) {
+ DatetimeOptions reference_datetime =
+ string_fragments[i].datetime_options.value();
+ annotation_options.reference_time_ms_utc =
+ reference_datetime.reference_time_ms_utc;
+ annotation_options.reference_timezone =
+ reference_datetime.reference_timezone;
+ }
+
+ AddContactMetadataToKnowledgeClassificationResults(
+ &annotation_candidates[i]);
+
+ Status annotation_status = AnnotateSingleInput(
+ text_to_annotate[i], annotation_options, &annotation_candidates[i]);
+ if (!annotation_status.ok()) {
+ return annotation_status;
+ }
+ }
+ return annotation_candidates;
+}
+
+std::vector<AnnotatedSpan> Annotator::Annotate(
+ const std::string& context, const AnnotationOptions& options) const {
+ std::vector<InputFragment> string_fragments;
+ string_fragments.push_back({.text = context});
+ StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations =
+ AnnotateStructuredInput(string_fragments, options);
+ if (!annotations.ok()) {
+ TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
+ << annotations.status().error_message();
+ return {};
+ }
+ return annotations.ValueOrDie()[0];
+}
+
+CodepointSpan Annotator::ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const {
+ if (config->capturing_group() == nullptr) {
+ // Use first capturing group to specify the selection.
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan result = {match->Start(1, &status),
+ match->End(1, &status)};
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ return result;
+ }
+
+ CodepointSpan result = {kInvalidIndex, kInvalidIndex};
+ const int num_groups = config->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ if (!config->capturing_group()->Get(i)->extend_selection()) {
+ continue;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ // Check match and adjust bounds.
+ const int group_start = match->Start(i, &status);
+ const int group_end = match->End(i, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
+ continue;
+ }
+ if (result.first == kInvalidIndex) {
+ result = {group_start, group_end};
+ } else {
+ result.first = std::min(result.first, group_start);
+ result.second = std::max(result.second, group_end);
+ }
+ }
+ return result;
+}
+
+bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
+ if (pattern->serialized_entity_data() != nullptr ||
+ pattern->entity_data() != nullptr) {
+ return true;
+ }
+ if (pattern->capturing_group() != nullptr) {
+ for (const CapturingGroup* group : *pattern->capturing_group()) {
+ if (group->entity_field_path() != nullptr) {
+ return true;
+ }
+ if (group->serialized_entity_data() != nullptr ||
+ group->entity_data() != nullptr) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool Annotator::SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const {
+ if (!HasEntityData(pattern)) {
+ serialized_entity_data->clear();
+ return true;
+ }
+ TC3_CHECK(entity_data_builder_ != nullptr);
+
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+
+ TC3_CHECK(entity_data != nullptr);
+
+ // Set fixed entity data.
+ if (pattern->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(pattern->serialized_entity_data()->c_str(),
+ pattern->serialized_entity_data()->size()));
+ }
+ if (pattern->entity_data() != nullptr) {
+ entity_data->MergeFrom(
+ reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
+ }
+
+ // Add entity data from rule capturing groups.
+ if (pattern->capturing_group() != nullptr) {
+ const int num_groups = pattern->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ const CapturingGroup* group = pattern->capturing_group()->Get(i);
+
+ // Check whether the group matched.
+ Optional<std::string> group_match_text =
+ GetCapturingGroupText(matcher, /*group_id=*/i);
+ if (!group_match_text.has_value()) {
+ continue;
+ }
+
+ // Set fixed entity data from capturing group match.
+ if (group->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(group->serialized_entity_data()->c_str(),
+ group->serialized_entity_data()->size()));
+ }
+ if (group->entity_data() != nullptr) {
+ entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
+ pattern->entity_data()));
+ }
+
+ // Set entity field from capturing group text.
+ if (group->entity_field_path() != nullptr) {
+ UnicodeText normalized_group_match_text =
+ UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (group->normalization_options() != nullptr) {
+ normalized_group_match_text =
+ NormalizeText(*unilib_, group->normalization_options(),
+ normalized_group_match_text);
+ }
+
+ if (!entity_data->ParseAndSet(
+ group->entity_field_path(),
+ normalized_group_match_text.ToUTF8String())) {
+ TC3_LOG(ERROR)
+ << "Could not set entity data from rule capturing group.";
+ return false;
+ }
+ }
+ }
+ }
+
+ *serialized_entity_data = entity_data->Serialize();
+ return true;
+}
+
+UnicodeText RemoveMoneySeparators(
+ const std::unordered_set<char32>& decimal_separators,
+ const UnicodeText& amount,
+ UnicodeText::const_iterator it_decimal_separator) {
+ UnicodeText whole_amount;
+ for (auto it = amount.begin();
+ it != amount.end() && it != it_decimal_separator; ++it) {
+ if (std::find(decimal_separators.begin(), decimal_separators.end(),
+ static_cast<char32>(*it)) == decimal_separators.end()) {
+ whole_amount.push_back(*it);
+ }
+ }
+ return whole_amount;
+}
+
+bool Annotator::ParseAndFillInMoneyAmount(
+ std::string* serialized_entity_data) const {
+ std::unique_ptr<EntityDataT> data =
+ LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
+ *serialized_entity_data);
+ if (data == nullptr) {
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models newer than
+ // v706, consequently logging errors only for them (b/156634162).
+ TC3_LOG(ERROR)
+ << "Data field is null when trying to parse Money Entity Data";
+ }
+ return false;
+ }
+ if (data->money->unnormalized_amount.empty()) {
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models newer than
+ // v706, consequently logging errors only for them (b/156634162).
+ TC3_LOG(ERROR)
+ << "Data unnormalized_amount is empty when trying to parse "
+ "Money Entity Data";
+ }
+ return false;
+ }
+
+ UnicodeText amount =
+ UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
+ int separator_back_index = 0;
+ auto it_decimal_separator = --amount.end();
+ for (; it_decimal_separator != amount.begin();
+ --it_decimal_separator, ++separator_back_index) {
+ if (std::find(money_separators_.begin(), money_separators_.end(),
+ static_cast<char32>(*it_decimal_separator)) !=
+ money_separators_.end()) {
+ break;
+ }
+ }
+
+ // If there are 3 digits after the last separator, we consider that a
+ // thousands separator => the number is an int (e.g. 1.234 is considered int).
+ // If there is no separator in number, also that number is an int.
+ if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
+ it_decimal_separator = amount.end();
+ }
+
+ if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
+ it_decimal_separator),
+ &data->money->amount_whole_part)) {
+ TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
+ "amount: "
+ << data->money->unnormalized_amount;
+ return false;
+ }
+ if (it_decimal_separator == amount.end()) {
+ data->money->amount_decimal_part = 0;
+ } else {
+ const int amount_codepoints_size = amount.size_codepoints();
+ if (!unilib_->ParseInt32(
+ UnicodeText::Substring(
+ amount, amount_codepoints_size - separator_back_index,
+ amount_codepoints_size, /*do_copy=*/false),
+ &data->money->amount_decimal_part)) {
+ TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
+ "the amount: "
+ << data->money->unnormalized_amount;
+ return false;
+ }
+ }
+
+ *serialized_entity_data =
+ PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
+ return true;
+}
+
+bool Annotator::RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result,
+ bool is_serialized_entity_data_enabled) const {
+ for (int pattern_id : rules) {
+ const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
+ const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
+ if (!matcher) {
+ TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
+ << pattern_id;
+ return false;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (regex_pattern.config->verification_options()) {
+ if (!VerifyRegexMatchCandidate(
+ context_unicode.ToUTF8String(),
+ regex_pattern.config->verification_options(),
+ matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
+ continue;
+ }
+ }
+
+ std::string serialized_entity_data;
+ if (is_serialized_entity_data_enabled) {
+ if (!SerializedEntityDataFromRegexMatch(
+ regex_pattern.config, matcher.get(), &serialized_entity_data)) {
+ TC3_LOG(ERROR) << "Could not get entity data.";
+ return false;
+ }
+
+ // Further parsing unnormalized_amount for money into amount_whole_part
+ // and amount_decimal_part. Can't do this with regexes because we cannot
+ // have empty groups (amount_decimal_part might be an empty group).
+ if (regex_pattern.config->collection_name()->str() ==
+ Collections::Money()) {
+ if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models
+ // newer than v706 => logging errors only for them (b/156634162).
+ TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
+ }
+ }
+ }
+ }
+
+ result->emplace_back();
+
+ // Selection/annotation regular expressions need to specify a capturing
+ // group specifying the selection.
+ result->back().span =
+ ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
+
+ result->back().classification = {
+ {regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()}};
+
+ result->back().classification[0].serialized_entity_data =
+ serialized_entity_data;
+ }
+ }
+ return true;
+}
+
+bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ // The inference span is the span of interest expanded to include
+ // max_selection_span tokens on either side, which is how far a selection can
+ // stretch from the click.
+ const TokenSpan inference_span = IntersectTokenSpans(
+ ExpandTokenSpan(span_of_interest,
+ /*num_tokens_left=*/max_selection_span,
+ /*num_tokens_right=*/max_selection_span),
+ {0, num_tokens});
+
+ std::vector<ScoredChunk> scored_chunks;
+ if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->enabled()) {
+ if (!ModelBoundsSensitiveScoreChunks(
+ num_tokens, span_of_interest, inference_span, cached_features,
+ selection_interpreter, &scored_chunks)) {
+ return false;
+ }
+ } else {
+ if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
+ cached_features, selection_interpreter,
+ &scored_chunks)) {
+ return false;
+ }
+ }
+ std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
+
+ // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
+ // them greedily as long as they do not overlap with any previously picked
+ // chunks.
+ std::vector<bool> token_used(TokenSpanSize(inference_span));
+ chunks->clear();
+ for (const ScoredChunk& scored_chunk : scored_chunks) {
+ bool feasible = true;
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ if (token_used[i - inference_span.first]) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (!feasible) {
+ continue;
+ }
+
+ for (int i = scored_chunk.token_span.first;
+ i < scored_chunk.token_span.second; ++i) {
+ token_used[i - inference_span.first] = true;
+ }
+
+ chunks->push_back(scored_chunk.token_span);
+ }
+
+ std::sort(chunks->begin(), chunks->end());
+
+ return true;
+}
+
+namespace {
+// Updates the value at the given key in the map to maximum of the current value
+// and the given value, or simply inserts the value if the key is not yet there.
+template <typename Map>
+void UpdateMax(Map* map, typename Map::key_type key,
+ typename Map::mapped_type value) {
+ const auto it = map->find(key);
+ if (it != map->end()) {
+ it->second = std::max(it->second, value);
+ } else {
+ (*map)[key] = value;
+ }
+}
+} // namespace
+
+bool Annotator::ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ std::map<TokenSpan, float> chunk_scores;
+ for (int batch_start = span_of_interest.first;
+ batch_start < span_of_interest.second; batch_start += max_batch_size) {
+ const int batch_end =
+ std::min(batch_start + max_batch_size, span_of_interest.second);
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) !=
+ selection_feature_processor_->GetSelectionLabelCount()) {
+ TC3_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
+ const std::vector<float> scores = ComputeSoftmax(
+ logits.data() + logits.dim(1) * (click_pos - batch_start),
+ logits.dim(1));
+ for (int j = 0;
+ j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
+ TokenSpan relative_token_span;
+ if (!selection_feature_processor_->LabelToTokenSpan(
+ j, &relative_token_span)) {
+ TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
+ return false;
+ }
+ const TokenSpan candidate_span = ExpandTokenSpan(
+ SingleTokenSpan(click_pos), relative_token_span.first,
+ relative_token_span.second);
+ if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
+ UpdateMax(&chunk_scores, candidate_span, scores[j]);
+ }
+ }
+ }
+ }
+
+ scored_chunks->clear();
+ scored_chunks->reserve(chunk_scores.size());
+ for (const auto& entry : chunk_scores) {
+ scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
+ }
+
+ return true;
+}
+
+bool Annotator::ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const {
+ const int max_selection_span =
+ selection_feature_processor_->GetOptions()->max_selection_span();
+ const int max_chunk_length = selection_feature_processor_->GetOptions()
+ ->selection_reduced_output_space()
+ ? max_selection_span + 1
+ : 2 * max_selection_span + 1;
+ const bool score_single_token_spans_as_zero =
+ selection_feature_processor_->GetOptions()
+ ->bounds_sensitive_features()
+ ->score_single_token_spans_as_zero();
+
+ scored_chunks->clear();
+ if (score_single_token_spans_as_zero) {
+ scored_chunks->reserve(TokenSpanSize(span_of_interest));
+ }
+
+ // Prepare all chunk candidates into one batch:
+ // - Are contained in the inference span
+ // - Have a non-empty intersection with the span of interest
+ // - Are at least one token long
+ // - Are not longer than the maximum chunk length
+ std::vector<TokenSpan> candidate_spans;
+ for (int start = inference_span.first; start < span_of_interest.second;
+ ++start) {
+ const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
+ for (int end = leftmost_end_index;
+ end <= inference_span.second && end - start <= max_chunk_length;
+ ++end) {
+ const TokenSpan candidate_span = {start, end};
+ if (score_single_token_spans_as_zero &&
+ TokenSpanSize(candidate_span) == 1) {
+ // Do not include the single token span in the batch, add a zero score
+ // for it directly to the output.
+ scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
+ } else {
+ candidate_spans.push_back(candidate_span);
+ }
+ }
+ }
+
+ const int max_batch_size = model_->selection_options()->batch_size();
+
+ std::vector<float> all_features;
+ scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
+ for (int batch_start = 0; batch_start < candidate_spans.size();
+ batch_start += max_batch_size) {
+ const int batch_end = std::min(batch_start + max_batch_size,
+ static_cast<int>(candidate_spans.size()));
+
+ // Prepare features for the whole batch.
+ all_features.clear();
+ all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
+ for (int i = batch_start; i < batch_end; ++i) {
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
+ &all_features);
+ }
+
+ // Run batched inference.
+ const int batch_size = batch_end - batch_start;
+ const int features_size = cached_features.OutputFeaturesSize();
+ TensorView<float> logits = selection_executor_->ComputeLogits(
+ TensorView<float>(all_features.data(), {batch_size, features_size}),
+ selection_interpreter);
+ if (!logits.is_valid()) {
+ TC3_LOG(ERROR) << "Couldn't compute logits.";
+ return false;
+ }
+ if (logits.dims() != 2 || logits.dim(0) != batch_size ||
+ logits.dim(1) != 1) {
+ TC3_LOG(ERROR) << "Mismatching output.";
+ return false;
+ }
+
+ // Save results.
+ for (int i = batch_start; i < batch_end; ++i) {
+ scored_chunks->push_back(
+ ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
+ }
+ }
+
+ return true;
+}
+
+bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
+ bool is_serialized_entity_data_enabled,
+ std::vector<AnnotatedSpan>* result) const {
+ std::vector<DatetimeParseResultSpan> datetime_spans;
+ if (cfg_datetime_parser_) {
+ if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) {
+ return true;
+ }
+ std::vector<Locale> parsed_locales;
+ ParseLocales(locales, &parsed_locales);
+ cfg_datetime_parser_->Parse(
+ context_unicode.ToUTF8String(),
+ ToDateAnnotationOptions(
+ model_->grammar_datetime_model()->annotation_options(),
+ reference_timezone, reference_time_ms_utc),
+ parsed_locales, &datetime_spans);
+ }
+
+ if (datetime_parser_) {
+ if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locales, mode,
+ annotation_usecase,
+ /*anchor_start_end=*/false, &datetime_spans)) {
+ return false;
+ }
+ }
+
+ for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ AnnotatedSpan annotated_span;
+ annotated_span.span = datetime_span.span;
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ annotated_span.classification.emplace_back(
+ PickCollectionForDatetime(parse_result),
+ datetime_span.target_classification_score,
+ datetime_span.priority_score);
+ annotated_span.classification.back().datetime_parse_result = parse_result;
+ if (is_serialized_entity_data_enabled) {
+ annotated_span.classification.back().serialized_entity_data =
+ CreateDatetimeSerializedEntityData(parse_result);
+ }
+ }
+ annotated_span.source = AnnotatedSpan::Source::DATETIME;
+ result->push_back(std::move(annotated_span));
+ }
+ return true;
+}
+
+const Model* Annotator::model() const { return model_; }
+const reflection::Schema* Annotator::entity_data_schema() const {
+ return entity_data_schema_;
+}
+
+const Model* ViewModel(const void* buffer, int size) {
+ if (!buffer) {
+ return nullptr;
+ }
+
+ return LoadAndVerifyModel(buffer, size);
+}
+
+bool Annotator::LookUpKnowledgeEntity(
+ const std::string& id, std::string* serialized_knowledge_result) const {
+ return knowledge_engine_ &&
+ knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
new file mode 100644
index 0000000..ebd762c
--- /dev/null
+++ b/native/annotator/annotator.h
@@ -0,0 +1,563 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Inference code for the text classification model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/contact/contact-engine.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/duration/duration.h"
+#include "annotator/experimental/experimental.h"
+#include "annotator/feature-processor.h"
+#include "annotator/grammar/dates/cfg-datetime-annotator.h"
+#include "annotator/grammar/grammar-annotator.h"
+#include "annotator/installed_app/installed-app-engine.h"
+#include "annotator/knowledge/knowledge-engine.h"
+#include "annotator/model-executor.h"
+#include "annotator/model_generated.h"
+#include "annotator/number/number.h"
+#include "annotator/person_name/person-name-engine.h"
+#include "annotator/strip-unpaired-brackets.h"
+#include "annotator/translate/translate.h"
+#include "annotator/types.h"
+#include "annotator/zlib-utils.h"
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers.h"
+#include "utils/i18n/locale.h"
+#include "utils/memory/mmap.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+// Holds TFLite interpreters for selection and classification models.
+// NOTE: This class is not thread-safe, thus should NOT be re-used across
+// threads.
+class InterpreterManager {
+ public:
+ // The constructor can be called with nullptr for any of the executors, and is
+ // a defined behavior, as long as the corresponding *Interpreter() method is
+ // not called when the executor is null.
+ InterpreterManager(const ModelExecutor* selection_executor,
+ const ModelExecutor* classification_executor)
+ : selection_executor_(selection_executor),
+ classification_executor_(classification_executor) {}
+
+ // Gets or creates and caches an interpreter for the selection model.
+ tflite::Interpreter* SelectionInterpreter();
+
+ // Gets or creates and caches an interpreter for the classification model.
+ tflite::Interpreter* ClassificationInterpreter();
+
+ private:
+ const ModelExecutor* selection_executor_;
+ const ModelExecutor* classification_executor_;
+
+ std::unique_ptr<tflite::Interpreter> selection_interpreter_;
+ std::unique_ptr<tflite::Interpreter> classification_interpreter_;
+};
+
+// Stores entity types enabled for annotation, and provides operator() for
+// checking whether a given entity type is enabled.
+class EnabledEntityTypes {
+ public:
+ explicit EnabledEntityTypes(
+ const std::unordered_set<std::string>& entity_types)
+ : entity_types_(entity_types) {}
+
+ bool operator()(const std::string& entity_type) const {
+ return entity_types_.empty() ||
+ entity_types_.find(entity_type) != entity_types_.cend();
+ }
+
+ private:
+ const std::unordered_set<std::string>& entity_types_;
+};
+
+// A text processing model that provides text classification, annotation,
+// selection suggestion for various types.
+// NOTE: This class is not thread-safe.
+class Annotator {
+ public:
+ static std::unique_ptr<Annotator> FromUnownedBuffer(
+ const char* buffer, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ // Takes ownership of the mmap.
+ static std::unique_ptr<Annotator> FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromScopedMmap(
+ std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, int offset, int size, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromFileDescriptor(
+ int fd, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+ static std::unique_ptr<Annotator> FromPath(
+ const std::string& path, const UniLib* unilib = nullptr,
+ const CalendarLib* calendarlib = nullptr);
+ static std::unique_ptr<Annotator> FromPath(
+ const std::string& path, std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+
+ // Returns true if the model is ready for use.
+ bool IsInitialized() { return initialized_; }
+
+ // Initializes the knowledge engine with the given config.
+ bool InitializeKnowledgeEngine(const std::string& serialized_config);
+
+ // Initializes the contact engine with the given config.
+ bool InitializeContactEngine(const std::string& serialized_config);
+
+ // Initializes the installed app engine with the given config.
+ bool InitializeInstalledAppEngine(const std::string& serialized_config);
+
+ // Initializes the person name engine with the given person name model in the
+ // provided buffer. The buffer needs to outlive the annotator.
+ bool InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
+ int size);
+
+ // Initializes the person name engine with the given person name model from
+ // the provided mmap.
+ bool InitializePersonNameEngineFromScopedMmap(const ScopedMmap& mmap);
+
+ // Initializes the person name engine with the given person name model in the
+ // provided file path.
+ bool InitializePersonNameEngineFromPath(const std::string& path);
+
+ // Initializes the person name engine with the given person name model in the
+ // provided file descriptor.
+ bool InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
+ int size);
+
+ // Initializes the experimental annotators if available.
+ // Returns true if there is an implementation of experimental annotators
+ // linked in.
+ bool InitializeExperimentalAnnotators();
+
+ // Sets up the lang-id instance that should be used.
+ void SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id);
+
+ // Runs inference for given a context and current selection (i.e. index
+ // of the first and one past last selected characters (utf8 codepoint
+ // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
+ // beginning character and one past selection end character.
+ // Returns the original click_indices if an error occurs.
+ // NOTE: The selection indices are passed in and returned in terms of
+ // UTF8 codepoints (not bytes).
+ // Requires that the model is a smart selection model.
+ CodepointSpan SuggestSelection(
+ const std::string& context, CodepointSpan click_indices,
+ const SelectionOptions& options = SelectionOptions()) const;
+
+ // Classifies the selected text given the context string.
+ // Returns an empty result if an error occurs.
+ std::vector<ClassificationResult> ClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options = ClassificationOptions()) const;
+
+ // Annotates the given structed input request. Models which handle the full
+ // context request will receive all the metadata they require. While models
+ // that don't use the extra context are called using only a string.
+ // For each fragment the annotations are sorted by their position in
+ // the fragment and exclude spans classified as 'other'.
+ //
+ // The number of vectors of annotated spans will match the number
+ // of input fragments. The order of annotation span vectors will match the
+ // order of input fragments. If annotation is not possible for any of the
+ // annotators, no annotation is returned.
+ StatusOr<std::vector<std::vector<AnnotatedSpan>>> AnnotateStructuredInput(
+ const std::vector<InputFragment>& string_fragments,
+ const AnnotationOptions& options = AnnotationOptions()) const;
+
+ // Annotates given input text. The annotations are sorted by their position
+ // in the context string and exclude spans classified as 'other'.
+ std::vector<AnnotatedSpan> Annotate(
+ const std::string& context,
+ const AnnotationOptions& options = AnnotationOptions()) const;
+
+ // Looks up a knowledge entity by its id. If successful, populates the
+ // serialized knowledge result and returns true.
+ bool LookUpKnowledgeEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const;
+
+ const Model* model() const;
+ const reflection::Schema* entity_data_schema() const;
+
+ // Exposes the feature processor for tests and evaluations.
+ const FeatureProcessor* SelectionFeatureProcessorForTests() const;
+ const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const;
+
+ static const std::string& kPhoneCollection;
+ static const std::string& kAddressCollection;
+ static const std::string& kDateCollection;
+ static const std::string& kUrlCollection;
+ static const std::string& kEmailCollection;
+
+ protected:
+ struct ScoredChunk {
+ TokenSpan token_span;
+ float score;
+ };
+
+ // Constructs and initializes text classifier from given model.
+ // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
+ Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ const UniLib* unilib, const CalendarLib* calendarlib);
+ Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
+ std::unique_ptr<UniLib> unilib,
+ std::unique_ptr<CalendarLib> calendarlib);
+
+ // Constructs, validates and initializes text classifier from given model.
+ // Does not own the buffer that backs 'model'.
+ Annotator(const Model* model, const UniLib* unilib,
+ const CalendarLib* calendarlib);
+
+ // Checks that model contains all required fields, and initializes internal
+ // datastructures.
+ void ValidateAndInitialize();
+
+ // Initializes regular expressions for the regex model.
+ bool InitializeRegexModel(ZlibDecompressor* decompressor);
+
+ // Resolves conflicts in the list of candidates by removing some overlapping
+ // ones. Returns indices of the surviving ones.
+ // NOTE: Assumes that the candidates are sorted according to their position in
+ // the span.
+ bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
+ const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const;
+
+ // Resolves one conflict between candidates on indices 'start_index'
+ // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
+ // indices to 'chosen_indices'. Returns false if a problem arises.
+ bool ResolveConflict(const std::string& context,
+ const std::vector<Token>& cached_tokens,
+ const std::vector<AnnotatedSpan>& candidates,
+ const std::vector<Locale>& detected_text_language_tags,
+ int start_index, int end_index,
+ AnnotationUsecase annotation_usecase,
+ InterpreterManager* interpreter_manager,
+ std::vector<int>* chosen_indices) const;
+
+ // Gets selection candidates from the ML model.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
+ bool ModelSuggestSelection(
+ const UnicodeText& context_unicode, CodepointSpan click_indices,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Classifies the selected text given the context string with the
+ // classification model.
+ // Returns true if no error occurred.
+ bool ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& locales, CodepointSpan selection_indices,
+ InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results,
+ std::vector<Token>* tokens) const;
+
+ // Same as above but doesn't output tokens.
+ bool ModelClassifyText(
+ const std::string& context, const std::vector<Token>& cached_tokens,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Same as above but doesn't take cached tokens and doesn't output tokens.
+ bool ModelClassifyText(
+ const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
+ FeatureProcessor::EmbeddingCache* embedding_cache,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Returns a relative token span that represents how many tokens on the left
+ // from the selection and right from the selection are needed for the
+ // classifier input.
+ TokenSpan ClassifyTextUpperBoundNeededTokens() const;
+
+ // Classifies the selected text with the regular expressions models.
+ // Returns true if no error happened, false otherwise.
+ bool RegexClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ std::vector<ClassificationResult>* classification_result) const;
+
+ // Classifies the selected text with the date time model.
+ // Returns true if no error happened, false otherwise.
+ bool DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ std::vector<ClassificationResult>* classification_results) const;
+
+ // Chunks given input text with the selection model and classifies the spans
+ // with the classification model.
+ // The annotations are sorted by their position in the context string and
+ // exclude spans classified as 'other'.
+ // Provides the tokens produced during tokenization of the context string for
+ // reuse.
+ bool ModelAnnotate(const std::string& context,
+ const std::vector<Locale>& detected_text_language_tags,
+ InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Groups the tokens into chunks. A chunk is a token span that should be the
+ // suggested selection when any of its contained tokens is clicked. The chunks
+ // are non-overlapping and are sorted by their position in the context string.
+ // "num_tokens" is the total number of tokens available (as this method does
+ // not need the actual vector of tokens).
+ // "span_of_interest" is a span of all the tokens that could be clicked.
+ // The resulting chunks all have to overlap with it and they cover this span
+ // completely. The first and last chunk might extend beyond it.
+ // The chunks vector is cleared before filling.
+ bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
+ tflite::Interpreter* selection_interpreter,
+ const CachedFeatures& cached_features,
+ std::vector<TokenSpan>* chunks) const;
+
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a click context model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelClickContextScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const;
+
+ // A helper method for ModelChunk(). It generates scored chunk candidates for
+ // a bounds-sensitive model.
+ // NOTE: The returned chunks can (and most likely do) overlap.
+ bool ModelBoundsSensitiveScoreChunks(
+ int num_tokens, const TokenSpan& span_of_interest,
+ const TokenSpan& inference_span, const CachedFeatures& cached_features,
+ tflite::Interpreter* selection_interpreter,
+ std::vector<ScoredChunk>* scored_chunks) const;
+
+ // Produces chunks isolated by a set of regular expressions.
+ bool RegexChunk(const UnicodeText& context_unicode,
+ const std::vector<int>& rules,
+ std::vector<AnnotatedSpan>* result,
+ bool is_serialized_entity_data_enabled) const;
+
+ // Produces chunks from the datetime parser.
+ bool DatetimeChunk(const UnicodeText& context_unicode,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& locales, ModeFlag mode,
+ AnnotationUsecase annotation_usecase,
+ bool is_serialized_entity_data_enabled,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Returns whether a classification should be filtered.
+ bool FilteredForAnnotation(const AnnotatedSpan& span) const;
+ bool FilteredForClassification(
+ const ClassificationResult& classification) const;
+ bool FilteredForSelection(const AnnotatedSpan& span) const;
+
+ // Computes the selection boundaries from a regular expression match.
+ CodepointSpan ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const;
+
+ // Returns whether a regex pattern provides entity data from a match.
+ bool HasEntityData(const RegexModel_::Pattern* pattern) const;
+
+ // Constructs and serializes entity data from regex matches.
+ bool SerializedEntityDataFromRegexMatch(
+ const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
+ std::string* serialized_entity_data) const;
+
+ // For knowledge candidates which have a ContactPointer, fill in the
+ // appropriate contact metadata, if possible.
+ void AddContactMetadataToKnowledgeClassificationResults(
+ std::vector<AnnotatedSpan>* candidates) const;
+
+ // Gets priority score from the list of classification results.
+ float GetPriorityScore(
+ const std::vector<ClassificationResult>& classification) const;
+
+ // Verifies a regex match and returns true if verification was successful.
+ bool VerifyRegexMatchCandidate(
+ const std::string& context,
+ const VerificationOptions* verification_options, const std::string& match,
+ const UniLib::RegexMatcher* matcher) const;
+
+ const Model* model_;
+
+ std::unique_ptr<const ModelExecutor> selection_executor_;
+ std::unique_ptr<const ModelExecutor> classification_executor_;
+ std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
+
+ std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
+ std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
+
+ std::unique_ptr<const DatetimeParser> datetime_parser_;
+ std::unique_ptr<const dates::CfgDatetimeAnnotator> cfg_datetime_parser_;
+
+ std::unique_ptr<const GrammarAnnotator> grammar_annotator_;
+
+ private:
+ struct CompiledRegexPattern {
+ const RegexModel_::Pattern* config;
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ };
+
+ // Removes annotations the entity type of which is not in the set of enabled
+ // entity types.
+ void RemoveNotEnabledEntityTypes(
+ const EnabledEntityTypes& is_entity_type_enabled,
+ std::vector<AnnotatedSpan>* annotated_spans) const;
+
+ // Runs only annotators that do not support structured input. Does conflict
+ // resolution, removal of disallowed entities and sorting on both new
+ // generated candidates and passed in entities.
+ // Returns Status::Error if the annotation failed, in which case the vector of
+ // candidates should be ignored.
+ Status AnnotateSingleInput(const std::string& context,
+ const AnnotationOptions& options,
+ std::vector<AnnotatedSpan>* candidates) const;
+
+ // Parses the money amount into whole and decimal part and fills in the
+ // entity data information.
+ bool ParseAndFillInMoneyAmount(std::string* serialized_entity_data) const;
+
+ std::unique_ptr<ScopedMmap> mmap_;
+ bool initialized_ = false;
+ bool enabled_for_annotation_ = false;
+ bool enabled_for_classification_ = false;
+ bool enabled_for_selection_ = false;
+ std::unordered_set<std::string> filtered_collections_annotation_;
+ std::unordered_set<std::string> filtered_collections_classification_;
+ std::unordered_set<std::string> filtered_collections_selection_;
+
+ std::vector<CompiledRegexPattern> regex_patterns_;
+
+ // Indices into regex_patterns_ for the different modes.
+ std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
+ selection_regex_patterns_;
+
+ std::unique_ptr<UniLib> owned_unilib_;
+ const UniLib* unilib_;
+ std::unique_ptr<CalendarLib> owned_calendarlib_;
+ const CalendarLib* calendarlib_;
+
+ std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
+ std::unique_ptr<const ContactEngine> contact_engine_;
+ std::unique_ptr<const InstalledAppEngine> installed_app_engine_;
+ std::unique_ptr<const NumberAnnotator> number_annotator_;
+ std::unique_ptr<const DurationAnnotator> duration_annotator_;
+ std::unique_ptr<const PersonNameEngine> person_name_engine_;
+ std::unique_ptr<const TranslateAnnotator> translate_annotator_;
+ std::unique_ptr<const ExperimentalAnnotator> experimental_annotator_;
+
+ // Builder for creating extra data.
+ const reflection::Schema* entity_data_schema_;
+ std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
+
+ // Locales for which the entire model triggers.
+ std::vector<Locale> model_triggering_locales_;
+
+ // Locales for which the ML model triggers.
+ std::vector<Locale> ml_model_triggering_locales_;
+
+ // Locales that the dictionary classification support.
+ std::vector<Locale> dictionary_locales_;
+
+ // Decimal and thousands number separators.
+ std::unordered_set<char32> money_separators_;
+
+ // Model for language identification.
+ const libtextclassifier3::mobile::lang_id::LangId* lang_id_ = nullptr;
+
+ // If true, will prioritize the longest annotation during conflict resolution.
+ bool prioritize_longest_annotation_ = false;
+
+ // If true, the annotator will perform conflict resolution between the
+ // different sub-annotators also in the RAW mode. If false, no conflict
+ // resolution will be performed in RAW mode.
+ bool do_conflict_resolution_in_raw_mode_ = true;
+};
+
+namespace internal {
+
+// Helper function, which if the initial 'span' contains only white-spaces,
+// moves the selection to a single-codepoint selection on the left side
+// of this block of white-space.
+CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
+ const UnicodeText& context_unicode,
+ const UniLib& unilib);
+
+// Copies tokens from 'cached_tokens' that are
+// 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
+// from the tokens that correspond to 'selection_indices'.
+std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
+ CodepointSpan selection_indices,
+ TokenSpan tokens_around_selection_to_copy);
+} // namespace internal
+
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const Model* ViewModel(const void* buffer, int size);
+
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const Model* model =
+ ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
new file mode 100644
index 0000000..3e04f7f
--- /dev/null
+++ b/native/annotator/annotator_jni.cc
@@ -0,0 +1,905 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for the Annotator.
+
+#include "annotator/annotator_jni.h"
+
+#include <jni.h>
+
+#include <type_traits>
+#include <vector>
+
+#include "annotator/annotator.h"
+#include "annotator/annotator_jni_common.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/status_macros.h"
+#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/jni.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/jni-helper.h"
+#include "utils/java/string_utils.h"
+#include "utils/memory/mmap.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+#ifdef TC3_UNILIB_JAVAICU
+#ifndef TC3_CALENDAR_JAVAICU
+#error Inconsistent usage of Java ICU components
+#else
+#define TC3_USE_JAVAICU
+#endif
+#endif
+
+using libtextclassifier3::AnnotatedSpan;
+using libtextclassifier3::Annotator;
+using libtextclassifier3::ClassificationResult;
+using libtextclassifier3::CodepointSpan;
+using libtextclassifier3::JniHelper;
+using libtextclassifier3::Model;
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
+// When using the Java's ICU, CalendarLib and UniLib need to be instantiated
+// with a JavaVM pointer from JNI. When using a standard ICU the pointer is
+// not needed and the objects are instantiated implicitly.
+#ifdef TC3_USE_JAVAICU
+using libtextclassifier3::CalendarLib;
+using libtextclassifier3::UniLib;
+#endif
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::CodepointSpan;
+
+namespace {
+class AnnotatorJniContext {
+ public:
+ static AnnotatorJniContext* Create(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<Annotator> model) {
+ if (jni_cache == nullptr || model == nullptr) {
+ return nullptr;
+ }
+ // Intent generator will be null if the options are not specified.
+ std::unique_ptr<IntentGenerator> intent_generator =
+ IntentGenerator::Create(model->model()->intent_options(),
+ model->model()->resources(), jni_cache);
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
+ libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
+ if (template_handler == nullptr) {
+ return nullptr;
+ }
+
+ return new AnnotatorJniContext(jni_cache, std::move(model),
+ std::move(intent_generator),
+ std::move(template_handler));
+ }
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
+ return jni_cache_;
+ }
+
+ Annotator* model() const { return model_.get(); }
+
+ // NOTE: Intent generator will be null if the options are not specified in
+ // the model.
+ IntentGenerator* intent_generator() const { return intent_generator_.get(); }
+
+ RemoteActionTemplatesHandler* template_handler() const {
+ return template_handler_.get();
+ }
+
+ private:
+ AnnotatorJniContext(
+ const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
+ std::unique_ptr<Annotator> model,
+ std::unique_ptr<IntentGenerator> intent_generator,
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
+ : jni_cache_(jni_cache),
+ model_(std::move(model)),
+ intent_generator_(std::move(intent_generator)),
+ template_handler_(std::move(template_handler)) {}
+
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
+ std::unique_ptr<Annotator> model_;
+ std::unique_ptr<IntentGenerator> intent_generator_;
+ std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
+};
+
+StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject(
+ JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
+ jclass result_class, jmethodID result_class_constructor,
+ jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
+ const jstring device_locales, const ClassificationOptions* options,
+ const std::string& context, const CodepointSpan& selection_indices,
+ const ClassificationResult& classification_result, bool generate_intents) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> row_string,
+ JniHelper::NewStringUTF(env, classification_result.collection.c_str()));
+
+ ScopedLocalRef<jobject> row_datetime_parse;
+ if (classification_result.datetime_parse_result.IsSet()) {
+ TC3_ASSIGN_OR_RETURN(
+ row_datetime_parse,
+ JniHelper::NewObject(
+ env, datetime_parse_class, datetime_parse_class_constructor,
+ classification_result.datetime_parse_result.time_ms_utc,
+ classification_result.datetime_parse_result.granularity));
+ }
+
+ ScopedLocalRef<jbyteArray> serialized_knowledge_result;
+ const std::string& serialized_knowledge_result_string =
+ classification_result.serialized_knowledge_result;
+ if (!serialized_knowledge_result_string.empty()) {
+ TC3_ASSIGN_OR_RETURN(serialized_knowledge_result,
+ JniHelper::NewByteArray(
+ env, serialized_knowledge_result_string.size()));
+ env->SetByteArrayRegion(serialized_knowledge_result.get(), 0,
+ serialized_knowledge_result_string.size(),
+ reinterpret_cast<const jbyte*>(
+ serialized_knowledge_result_string.data()));
+ }
+
+ ScopedLocalRef<jstring> contact_name;
+ if (!classification_result.contact_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(contact_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_name.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_given_name;
+ if (!classification_result.contact_given_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_given_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_given_name.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_family_name;
+ if (!classification_result.contact_family_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_family_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_family_name.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_nickname;
+ if (!classification_result.contact_nickname.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_nickname,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_nickname.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_email_address;
+ if (!classification_result.contact_email_address.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_email_address,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_email_address.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_phone_number;
+ if (!classification_result.contact_phone_number.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_phone_number,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_phone_number.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_id;
+ if (!classification_result.contact_id.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_id,
+ JniHelper::NewStringUTF(env, classification_result.contact_id.c_str()));
+ }
+
+ ScopedLocalRef<jstring> app_name;
+ if (!classification_result.app_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ app_name,
+ JniHelper::NewStringUTF(env, classification_result.app_name.c_str()));
+ }
+
+ ScopedLocalRef<jstring> app_package_name;
+ if (!classification_result.app_package_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ app_package_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.app_package_name.c_str()));
+ }
+
+ ScopedLocalRef<jobjectArray> extras;
+ if (model_context->model()->entity_data_schema() != nullptr &&
+ !classification_result.serialized_entity_data.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ extras,
+ model_context->template_handler()->EntityDataAsNamedVariantArray(
+ model_context->model()->entity_data_schema(),
+ classification_result.serialized_entity_data));
+ }
+
+ ScopedLocalRef<jbyteArray> serialized_entity_data;
+ if (!classification_result.serialized_entity_data.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ serialized_entity_data,
+ JniHelper::NewByteArray(
+ env, classification_result.serialized_entity_data.size()));
+ env->SetByteArrayRegion(
+ serialized_entity_data.get(), 0,
+ classification_result.serialized_entity_data.size(),
+ reinterpret_cast<const jbyte*>(
+ classification_result.serialized_entity_data.data()));
+ }
+
+ ScopedLocalRef<jobjectArray> remote_action_templates_result;
+ // Only generate RemoteActionTemplate for the top classification result
+ // as classifyText does not need RemoteAction from other results anyway.
+ if (generate_intents && model_context->intent_generator() != nullptr) {
+ std::vector<RemoteActionTemplate> remote_action_templates;
+ if (!model_context->intent_generator()->GenerateIntents(
+ device_locales, classification_result,
+ options->reference_time_ms_utc, context, selection_indices,
+ app_context, model_context->model()->entity_data_schema(),
+ &remote_action_templates)) {
+ return {Status::UNKNOWN};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ remote_action_templates_result,
+ model_context->template_handler()->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates));
+ }
+
+ return JniHelper::NewObject(
+ env, result_class, result_class_constructor, row_string.get(),
+ static_cast<jfloat>(classification_result.score),
+ row_datetime_parse.get(), serialized_knowledge_result.get(),
+ contact_name.get(), contact_given_name.get(), contact_family_name.get(),
+ contact_nickname.get(), contact_email_address.get(),
+ contact_phone_number.get(), contact_id.get(), app_name.get(),
+ app_package_name.get(), extras.get(), serialized_entity_data.get(),
+ remote_action_templates_result.get(), classification_result.duration_ms,
+ classification_result.numeric_value,
+ classification_result.numeric_double_value);
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>>
+ClassificationResultsWithIntentsToJObjectArray(
+ JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
+ const jstring device_locales, const ClassificationOptions* options,
+ const std::string& context, const CodepointSpan& selection_indices,
+ const std::vector<ClassificationResult>& classification_result,
+ bool generate_intents) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> datetime_parse_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult"));
+
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
+ "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
+ JniHelper::GetMethodID(env, datetime_parse_class.get(),
+ "<init>", "(JI)V"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, classification_result.size(),
+ result_class.get()));
+
+ for (int i = 0; i < classification_result.size(); i++) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ ClassificationResultWithIntentsToJObject(
+ env, model_context, app_context, result_class.get(),
+ result_class_constructor, datetime_parse_class.get(),
+ datetime_parse_class_constructor, device_locales, options, context,
+ selection_indices, classification_result[i],
+ generate_intents && (i == 0)));
+ TC3_RETURN_IF_ERROR(
+ JniHelper::SetObjectArrayElement(env, results.get(), i, result.get()));
+ }
+ return results;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>> ClassificationResultsToJObjectArray(
+ JNIEnv* env, const AnnotatorJniContext* model_context,
+ const std::vector<ClassificationResult>& classification_result) {
+ return ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context,
+ /*(unused) app_context=*/nullptr,
+ /*(unused) devide_locale=*/nullptr,
+ /*(unusued) options=*/nullptr,
+ /*(unused) selection_text=*/"",
+ /*(unused) selection_indices=*/{kInvalidIndex, kInvalidIndex},
+ classification_result,
+ /*generate_intents=*/false);
+}
+
+CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
+ CodepointSpan orig_indices,
+ bool from_utf8) {
+ const libtextclassifier3::UnicodeText unicode_str =
+ libtextclassifier3::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
+
+ int unicode_index = 0;
+ int bmp_index = 0;
+
+ const int* source_index;
+ const int* target_index;
+ if (from_utf8) {
+ source_index = &unicode_index;
+ target_index = &bmp_index;
+ } else {
+ source_index = &bmp_index;
+ target_index = &unicode_index;
+ }
+
+ CodepointSpan result{-1, -1};
+ std::function<void()> assign_indices_fn = [&result, &orig_indices,
+ &source_index, &target_index]() {
+ if (orig_indices.first == *source_index) {
+ result.first = *target_index;
+ }
+
+ if (orig_indices.second == *source_index) {
+ result.second = *target_index;
+ }
+ };
+
+ for (auto it = unicode_str.begin(); it != unicode_str.end();
+ ++it, ++unicode_index, ++bmp_index) {
+ assign_indices_fn();
+
+ // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
+ if (*it > 0xFFFF) {
+ ++bmp_index;
+ }
+ }
+ assign_indices_fn();
+
+ return result;
+}
+
+} // namespace
+
+CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
+ CodepointSpan bmp_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
+}
+
+CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
+ CodepointSpan utf8_indices) {
+ return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
+}
+
+StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+
+ return JniHelper::NewStringUTF(env, model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ const Model* model = libtextclassifier3::ViewModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return JniHelper::NewStringUTF(env, "");
+ }
+ return JniHelper::NewStringUTF(env, model->name()->c_str());
+}
+
+} // namespace libtextclassifier3
+
+using libtextclassifier3::AnnotatorJniContext;
+using libtextclassifier3::ClassificationResultsToJObjectArray;
+using libtextclassifier3::ClassificationResultsWithIntentsToJObjectArray;
+using libtextclassifier3::ConvertIndicesBMPToUTF8;
+using libtextclassifier3::ConvertIndicesUTF8ToBMP;
+using libtextclassifier3::FromJavaAnnotationOptions;
+using libtextclassifier3::FromJavaClassificationOptions;
+using libtextclassifier3::FromJavaInputFragment;
+using libtextclassifier3::FromJavaSelectionOptions;
+using libtextclassifier3::InputFragment;
+using libtextclassifier3::ToStlString;
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
+(JNIEnv* env, jobject thiz, jint fd) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromFileDescriptor(
+ fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
+#else
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache, Annotator::FromFileDescriptor(fd)));
+#endif
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromPath(
+ path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
+#else
+ return reinterpret_cast<jlong>(
+ AnnotatorJniContext::Create(jni_cache, Annotator::FromPath(path_str)));
+#endif
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+#ifdef TC3_USE_JAVAICU
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache,
+ Annotator::FromFileDescriptor(
+ fd, offset, size, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
+ std::unique_ptr<CalendarLib>(new CalendarLib(jni_cache)))));
+#else
+ return reinterpret_cast<jlong>(AnnotatorJniContext::Create(
+ jni_cache, Annotator::FromFileDescriptor(fd, offset, size)));
+#endif
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeKnowledgeEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ std::string serialized_config_string;
+ TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
+ JniHelper::GetArrayLength(env, serialized_config));
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeKnowledgeEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ std::string serialized_config_string;
+ TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
+ JniHelper::GetArrayLength(env, serialized_config));
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeContactEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ std::string serialized_config_string;
+ TC3_ASSIGN_OR_RETURN_FALSE(jsize length,
+ JniHelper::GetArrayLength(env, serialized_config));
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeInstalledAppEngine(serialized_config_string);
+}
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializePersonNameEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jint fd, jlong offset, jlong size) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+
+ return model->InitializePersonNameEngineFromFileDescriptor(fd, offset, size);
+}
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeSetLangId)
+(JNIEnv* env, jobject thiz, jlong annotator_ptr, jlong lang_id_ptr) {
+ if (!annotator_ptr) {
+ return;
+ }
+ Annotator* model =
+ reinterpret_cast<AnnotatorJniContext*>(annotator_ptr)->model();
+ libtextclassifier3::mobile::lang_id::LangId* lang_id_model =
+ reinterpret_cast<libtextclassifier3::mobile::lang_id::LangId*>(lang_id_ptr);
+ model->SetLangId(lang_id_model);
+}
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ if (!ptr) {
+ return 0L;
+ }
+ return reinterpret_cast<jlong>(
+ reinterpret_cast<AnnotatorJniContext*>(ptr)->model());
+}
+
+TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
+ CodepointSpan input_indices =
+ ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::SelectionOptions selection_options,
+ FromJavaSelectionOptions(env, options));
+ CodepointSpan selection =
+ model->SuggestSelection(context_utf8, input_indices, selection_options);
+ selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
+
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result,
+ JniHelper::NewIntArray(env, 2));
+ env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection)));
+ env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection)));
+ return result.release();
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options, jobject app_context,
+ jstring device_locales) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const AnnotatorJniContext* model_context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
+ const CodepointSpan input_indices =
+ ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ TC3_ASSIGN_OR_RETURN_NULL(
+ const libtextclassifier3::ClassificationOptions classification_options,
+ FromJavaClassificationOptions(env, options));
+ const std::vector<ClassificationResult> classification_result =
+ model_context->model()->ClassifyText(context_utf8, input_indices,
+ classification_options);
+
+ ScopedLocalRef<jobjectArray> result;
+ if (app_context != nullptr) {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context, app_context, device_locales,
+ &classification_options, context_utf8, input_indices,
+ classification_result,
+ /*generate_intents=*/true));
+
+ } else {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsToJObjectArray(env, model_context,
+ classification_result));
+ }
+
+ return result.release();
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const AnnotatorJniContext* model_context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::AnnotationOptions annotation_options,
+ FromJavaAnnotationOptions(env, options));
+ const std::vector<AnnotatedSpan> annotations =
+ model_context->model()->Annotate(context_utf8, annotation_options);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult;)V"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, annotations.size(), result_class.get()));
+
+ for (int i = 0; i < annotations.size(); ++i) {
+ CodepointSpan span_bmp =
+ ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> classification_results,
+ ClassificationResultsToJObjectArray(env, model_context,
+ annotations[i].classification));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ static_cast<jint>(span_bmp.first),
+ static_cast<jint>(span_bmp.second),
+ classification_results.get()));
+ if (!JniHelper::SetObjectArrayElement(env, results.get(), i, result.get())
+ .ok()) {
+ return nullptr;
+ }
+ }
+ return results.release();
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeAnnotateStructuredInput)
+(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
+ jobject options) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const AnnotatorJniContext* model_context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+
+ std::vector<InputFragment> string_fragments;
+ TC3_ASSIGN_OR_RETURN_NULL(jsize input_size,
+ JniHelper::GetArrayLength(env, jinput_fragments));
+ for (int i = 0; i < input_size; ++i) {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> jfragment,
+ JniHelper::GetObjectArrayElement<jobject>(env, jinput_fragments, i));
+ TC3_ASSIGN_OR_RETURN_NULL(InputFragment fragment,
+ FromJavaInputFragment(env, jfragment.get()));
+ string_fragments.push_back(std::move(fragment));
+ }
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::AnnotationOptions annotation_options,
+ FromJavaAnnotationOptions(env, options));
+ const StatusOr<std::vector<std::vector<AnnotatedSpan>>> annotations_or =
+ model_context->model()->AnnotateStructuredInput(string_fragments,
+ annotation_options);
+ if (!annotations_or.ok()) {
+ TC3_LOG(ERROR) << "Annotation of structured input failed with error: "
+ << annotations_or.status().error_message();
+ return nullptr;
+ }
+
+ std::vector<std::vector<AnnotatedSpan>> annotations =
+ std::move(annotations_or.ValueOrDie());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> span_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ jmethodID span_class_constructor,
+ JniHelper::GetMethodID(
+ env, span_class.get(), "<init>",
+ "(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult;)V"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> span_class_array,
+ JniHelper::FindClass(env,
+ "[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotatedSpan;"));
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, input_size, span_class_array.get()));
+
+ for (int fragment_index = 0; fragment_index < annotations.size();
+ ++fragment_index) {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> jfragmentAnnotations,
+ JniHelper::NewObjectArray(env, annotations[fragment_index].size(),
+ span_class.get()));
+ for (int annotation_index = 0;
+ annotation_index < annotations[fragment_index].size();
+ ++annotation_index) {
+ CodepointSpan span_bmp = ConvertIndicesUTF8ToBMP(
+ string_fragments[fragment_index].text,
+ annotations[fragment_index][annotation_index].span);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> classification_results,
+ ClassificationResultsToJObjectArray(
+ env, model_context,
+ annotations[fragment_index][annotation_index].classification));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> single_annotation,
+ JniHelper::NewObject(env, span_class.get(), span_class_constructor,
+ static_cast<jint>(span_bmp.first),
+ static_cast<jint>(span_bmp.second),
+ classification_results.get()));
+
+ if (!JniHelper::SetObjectArrayElement(env, jfragmentAnnotations.get(),
+ annotation_index,
+ single_annotation.get())
+ .ok()) {
+ return nullptr;
+ }
+ }
+
+ if (!JniHelper::SetObjectArrayElement(env, results.get(), fragment_index,
+ jfragmentAnnotations.get())
+ .ok()) {
+ return nullptr;
+ }
+ }
+
+ return results.release();
+}
+
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id) {
+ if (!ptr) {
+ return nullptr;
+ }
+ const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id));
+ std::string serialized_knowledge_result;
+ if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
+ return nullptr;
+ }
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jbyteArray> result,
+ JniHelper::NewByteArray(env, serialized_knowledge_result.size()));
+ env->SetByteArrayRegion(
+ result.get(), 0, serialized_knowledge_result.size(),
+ reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
+ return result.release();
+}
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
+(JNIEnv* env, jobject thiz, jlong ptr) {
+ const AnnotatorJniContext* context =
+ reinterpret_cast<AnnotatorJniContext*>(ptr);
+ delete context;
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd) {
+ TC3_LOG(WARNING) << "Using deprecated getLanguage().";
+ return TC3_JNI_METHOD_NAME(TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)(
+ env, clazz, fd);
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
+}
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return GetVersionFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ return GetVersionFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
+}
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd, offset, size));
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
+}
diff --git a/native/annotator/annotator_jni.h b/native/annotator/annotator_jni.h
new file mode 100644
index 0000000..39a9d9a
--- /dev/null
+++ b/native/annotator/annotator_jni.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "annotator/annotator_jni_common.h"
+#include "annotator/types.h"
+#include "utils/java/jni-base.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// SmartSelection.
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotator)
+(JNIEnv* env, jobject thiz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
+(JNIEnv* env, jobject thiz, jstring path);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeKnowledgeEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeInstalledAppEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializePersonNameEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeSetLangId)
+(JNIEnv* env, jobject thiz, jlong annotator_ptr, jlong lang_id_ptr);
+
+TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeGetNativeModelPtr)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
+ jint selection_end, jobject options, jobject app_context,
+ jstring device_locales);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeAnnotateStructuredInput)
+(JNIEnv* env, jobject thiz, jlong ptr, jobjectArray jinput_fragments,
+ jobject options);
+
+TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
+
+TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
+ nativeLookUpKnowledgeEntity)
+(JNIEnv* env, jobject thiz, jlong ptr, jstring id);
+
+TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
+(JNIEnv* env, jobject thiz, jlong ptr);
+
+// DEPRECATED. Use nativeGetLocales instead.
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLanguage)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
+(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size);
+
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+// Given a utf8 string and a span expressed in Java BMP (basic multilingual
+// plane) codepoints, converts it to a span expressed in utf8 codepoints.
+libtextclassifier3::CodepointSpan ConvertIndicesBMPToUTF8(
+ const std::string& utf8_str, libtextclassifier3::CodepointSpan bmp_indices);
+
+// Given a utf8 string and a span expressed in utf8 codepoints, converts it to a
+// span expressed in Java BMP (basic multilingual plane) codepoints.
+libtextclassifier3::CodepointSpan ConvertIndicesUTF8ToBMP(
+ const std::string& utf8_str,
+ libtextclassifier3::CodepointSpan utf8_indices);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_H_
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
new file mode 100644
index 0000000..de58b70
--- /dev/null
+++ b/native/annotator/annotator_jni_common.cc
@@ -0,0 +1,335 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/annotator_jni_common.h"
+
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+StatusOr<std::unordered_set<std::string>> EntityTypesFromJObject(
+ JNIEnv* env, const jobject& jobject) {
+ std::unordered_set<std::string> entity_types;
+ jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
+ const int size = env->GetArrayLength(jentity_types);
+ for (int i = 0; i < size; ++i) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> jentity_type,
+ JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i));
+ TC3_ASSIGN_OR_RETURN(std::string entity_type,
+ ToStlString(env, jentity_type.get()));
+ entity_types.insert(entity_type);
+ }
+ return entity_types;
+}
+
+template <typename T>
+StatusOr<T> FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
+ const std::string& class_name) {
+ if (!joptions) {
+ return {Status::UNKNOWN};
+ }
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, class_name.c_str()));
+
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locale,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocale",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locale));
+
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, joptions, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, options_class.get(), "getReferenceTimezone",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_reference_timezone_method));
+
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_detected_text_language_tags_method));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
+
+ // .getUserLocationLat()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lat,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLat", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lat,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lat));
+
+ // .getUserLocationLng()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_location_lng,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationLng", "()D"));
+ TC3_ASSIGN_OR_RETURN(
+ double user_location_lng,
+ JniHelper::CallDoubleMethod(env, joptions, get_user_location_lng));
+
+ // .getUserLocationAccuracyMeters()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_user_location_accuracy_meters,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserLocationAccuracyMeters", "()F"));
+ TC3_ASSIGN_OR_RETURN(float user_location_accuracy_meters,
+ JniHelper::CallFloatMethod(
+ env, joptions, get_user_location_accuracy_meters));
+
+ T options;
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.reference_timezone,
+ ToStlString(env, reference_timezone.get()));
+ options.reference_time_ms_utc = reference_time;
+ TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags,
+ ToStlString(env, detected_text_language_tags.get()));
+ options.annotation_usecase =
+ static_cast<AnnotationUsecase>(annotation_usecase);
+ options.location_context = {user_location_lat, user_location_lng,
+ user_location_accuracy_meters};
+ return options;
+}
+} // namespace
+
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
+ jobject joptions) {
+ if (!joptions) {
+ // Falling back to default options in case joptions is null
+ SelectionOptions default_selection_options;
+ return default_selection_options;
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$SelectionOptions"));
+
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locales,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocales",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
+
+ SelectionOptions options;
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ options.annotation_usecase =
+ static_cast<AnnotationUsecase>(annotation_usecase);
+
+ return options;
+}
+
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(
+ JNIEnv* env, jobject joptions) {
+ if (!joptions) {
+ // Falling back to default options in case joptions is null
+ ClassificationOptions default_classification_options;
+ return default_classification_options;
+ }
+
+ TC3_ASSIGN_OR_RETURN(ClassificationOptions classifier_options,
+ FromJavaOptionsInternal<ClassificationOptions>(
+ env, joptions,
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationOptions"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationOptions"));
+ // .getUserFamiliarLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_user_familiar_language_tags,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getUserFamiliarLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> user_familiar_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_user_familiar_language_tags));
+
+ TC3_ASSIGN_OR_RETURN(classifier_options.user_familiar_language_tags,
+ ToStlString(env, user_familiar_language_tags.get()));
+
+ return classifier_options;
+}
+
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions) {
+ if (!joptions) {
+ // Falling back to default options in case joptions is null
+ AnnotationOptions default_annotation_options;
+ return default_annotation_options;
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotationOptions"));
+
+ // .getEntityTypes()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_entity_types,
+ JniHelper::GetMethodID(env, options_class.get(), "getEntityTypes",
+ "()[Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> entity_types,
+ JniHelper::CallObjectMethod<jobject>(env, joptions, get_entity_types));
+
+ // .isSerializedEntityDataEnabled()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID is_serialized_entity_data_enabled_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "isSerializedEntityDataEnabled", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool is_serialized_entity_data_enabled,
+ JniHelper::CallBooleanMethod(env, joptions,
+ is_serialized_entity_data_enabled_method));
+
+ // .hasLocationPermission()
+ TC3_ASSIGN_OR_RETURN(jmethodID has_location_permission_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "hasLocationPermission", "()Z"));
+ TC3_ASSIGN_OR_RETURN(bool has_location_permission,
+ JniHelper::CallBooleanMethod(
+ env, joptions, has_location_permission_method));
+
+ // .hasPersonalizationPermission()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID has_personalization_permission_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "hasPersonalizationPermission", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool has_personalization_permission,
+ JniHelper::CallBooleanMethod(env, joptions,
+ has_personalization_permission_method));
+
+ TC3_ASSIGN_OR_RETURN(
+ AnnotationOptions annotation_options,
+ FromJavaOptionsInternal<AnnotationOptions>(
+ env, joptions,
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"));
+ TC3_ASSIGN_OR_RETURN(annotation_options.entity_types,
+ EntityTypesFromJObject(env, entity_types.get()));
+ annotation_options.is_serialized_entity_data_enabled =
+ is_serialized_entity_data_enabled;
+ annotation_options.permissions.has_location_permission =
+ has_location_permission;
+ annotation_options.permissions.has_personalization_permission =
+ has_personalization_permission;
+ return annotation_options;
+}
+
+StatusOr<InputFragment> FromJavaInputFragment(JNIEnv* env, jobject jfragment) {
+ if (!jfragment) {
+ return Status(StatusCode::INTERNAL, "Called with null input fragment.");
+ }
+ InputFragment fragment;
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> fragment_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$InputFragment"));
+
+ // .getText()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_text,
+ JniHelper::GetMethodID(env, fragment_class.get(), "getText",
+ "()Ljava/lang/String;"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> text,
+ JniHelper::CallObjectMethod<jstring>(env, jfragment, get_text));
+
+ TC3_ASSIGN_OR_RETURN(fragment.text, ToStlString(env, text.get()));
+
+ // .hasDatetimeOptions()
+ TC3_ASSIGN_OR_RETURN(jmethodID has_date_time_options_method,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "hasDatetimeOptions", "()Z"));
+
+ TC3_ASSIGN_OR_RETURN(bool has_date_time_options,
+ JniHelper::CallBooleanMethod(
+ env, jfragment, has_date_time_options_method));
+
+ if (has_date_time_options) {
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, jfragment, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getReferenceTimezone", "()Ljava/lang/String;"));
+
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> jreference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, jfragment, get_reference_timezone_method));
+
+ TC3_ASSIGN_OR_RETURN(std::string reference_timezone,
+ ToStlString(env, jreference_timezone.get()));
+
+ fragment.datetime_options =
+ DatetimeOptions{.reference_time_ms_utc = reference_time,
+ .reference_timezone = reference_timezone};
+ }
+
+ return fragment;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_jni_common.h b/native/annotator/annotator_jni_common.h
new file mode 100644
index 0000000..cadd2fd
--- /dev/null
+++ b/native/annotator/annotator_jni_common.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
+
+#include <jni.h>
+
+#include "annotator/annotator.h"
+#include "annotator/types.h"
+#include "utils/base/statusor.h"
+
+#ifndef TC3_ANNOTATOR_CLASS_NAME
+#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel
+#endif
+
+#define TC3_ANNOTATOR_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_ANNOTATOR_CLASS_NAME)
+
+namespace libtextclassifier3 {
+
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
+ jobject joptions);
+
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions);
+
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions);
+
+StatusOr<InputFragment> FromJavaInputFragment(JNIEnv* env, jobject jfragment);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_JNI_COMMON_H_
diff --git a/annotator/annotator_jni_test.cc b/native/annotator/annotator_jni_test.cc
similarity index 100%
rename from annotator/annotator_jni_test.cc
rename to native/annotator/annotator_jni_test.cc
diff --git a/annotator/cached-features.cc b/native/annotator/cached-features.cc
similarity index 100%
rename from annotator/cached-features.cc
rename to native/annotator/cached-features.cc
diff --git a/annotator/cached-features.h b/native/annotator/cached-features.h
similarity index 100%
rename from annotator/cached-features.h
rename to native/annotator/cached-features.h
diff --git a/annotator/cached-features_test.cc b/native/annotator/cached-features_test.cc
similarity index 100%
rename from annotator/cached-features_test.cc
rename to native/annotator/cached-features_test.cc
diff --git a/native/annotator/collections.h b/native/annotator/collections.h
new file mode 100644
index 0000000..417b447
--- /dev/null
+++ b/native/annotator/collections.h
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+// String collection names for various classes.
+class Collections {
+ public:
+ static const std::string& Address() {
+ static const std::string& value =
+ *[]() { return new std::string("address"); }();
+ return value;
+ }
+ static const std::string& App() {
+ static const std::string& value =
+ *[]() { return new std::string("app"); }();
+ return value;
+ }
+ static const std::string& Contact() {
+ static const std::string& value =
+ *[]() { return new std::string("contact"); }();
+ return value;
+ }
+ static const std::string& Date() {
+ static const std::string& value =
+ *[]() { return new std::string("date"); }();
+ return value;
+ }
+ static const std::string& DateTime() {
+ static const std::string& value =
+ *[]() { return new std::string("datetime"); }();
+ return value;
+ }
+ static const std::string& Dictionary() {
+ static const std::string& value =
+ *[]() { return new std::string("dictionary"); }();
+ return value;
+ }
+ static const std::string& Duration() {
+ static const std::string& value =
+ *[]() { return new std::string("duration"); }();
+ return value;
+ }
+ static const std::string& Email() {
+ static const std::string& value =
+ *[]() { return new std::string("email"); }();
+ return value;
+ }
+ static const std::string& Entity() {
+ static const std::string& value =
+ *[]() { return new std::string("entity"); }();
+ return value;
+ }
+ static const std::string& Flight() {
+ static const std::string& value =
+ *[]() { return new std::string("flight"); }();
+ return value;
+ }
+ static const std::string& Iban() {
+ static const std::string& value =
+ *[]() { return new std::string("iban"); }();
+ return value;
+ }
+ static const std::string& Isbn() {
+ static const std::string& value =
+ *[]() { return new std::string("isbn"); }();
+ return value;
+ }
+ static const std::string& Money() {
+ static const std::string& value =
+ *[]() { return new std::string("money"); }();
+ return value;
+ }
+ static const std::string& Unit() {
+ static const std::string& value =
+ *[]() { return new std::string("unit"); }();
+ return value;
+ }
+ static const std::string& Number() {
+ static const std::string& value =
+ *[]() { return new std::string("number"); }();
+ return value;
+ }
+ static const std::string& Other() {
+ static const std::string& value =
+ *[]() { return new std::string("other"); }();
+ return value;
+ }
+ static const std::string& PaymentCard() {
+ static const std::string& value =
+ *[]() { return new std::string("payment_card"); }();
+ return value;
+ }
+ static const std::string& Percentage() {
+ static const std::string& value =
+ *[]() { return new std::string("percentage"); }();
+ return value;
+ }
+ static const std::string& PersonName() {
+ static const std::string& value =
+ *[]() { return new std::string("person_name"); }();
+ return value;
+ }
+ static const std::string& Phone() {
+ static const std::string& value =
+ *[]() { return new std::string("phone"); }();
+ return value;
+ }
+ static const std::string& TrackingNumber() {
+ static const std::string& value =
+ *[]() { return new std::string("tracking_number"); }();
+ return value;
+ }
+ static const std::string& Translate() {
+ static const std::string& value =
+ *[]() { return new std::string("translate"); }();
+ return value;
+ }
+ static const std::string& Url() {
+ static const std::string& value =
+ *[]() { return new std::string("url"); }();
+ return value;
+ }
+ static const std::string& OtpCode() {
+ static const std::string& value =
+ *[]() { return new std::string("otp_code"); }();
+ return value;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
diff --git a/native/annotator/contact/contact-engine-dummy.h b/native/annotator/contact/contact-engine-dummy.h
new file mode 100644
index 0000000..fe60203
--- /dev/null
+++ b/native/annotator/contact/contact-engine-dummy.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the contact engine.
+class ContactEngine {
+ public:
+ explicit ContactEngine(const FeatureProcessor* feature_processor,
+ const UniLib* unilib,
+ const ContactAnnotatorOptions* options) {}
+
+ bool Initialize(const std::string& serialized_config) {
+ TC3_LOG(ERROR) << "No contact engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+
+ void AddContactMetadataToKnowledgeClassificationResult(
+ ClassificationResult* classification_result) const {}
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
diff --git a/annotator/contact/contact-engine.h b/native/annotator/contact/contact-engine.h
similarity index 100%
rename from annotator/contact/contact-engine.h
rename to native/annotator/contact/contact-engine.h
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
new file mode 100644
index 0000000..b8e1b7a
--- /dev/null
+++ b/native/annotator/datetime/extractor.cc
@@ -0,0 +1,537 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/extractor.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool DatetimeExtractor::Extract(DatetimeParsedData* result,
+ CodepointSpan* result_span) const {
+ *result_span = {kInvalidIndex, kInvalidIndex};
+
+ if (rule_.regex->groups() == nullptr) {
+ return false;
+ }
+
+ // In the current implementation of extractor, the assumption is that there
+ // can only be one relative field.
+ DatetimeComponent::ComponentType component_type;
+ DatetimeComponent::RelativeQualifier relative_qualifier =
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+ int relative_count = 0;
+
+ for (int group_id = 0; group_id < rule_.regex->groups()->size(); group_id++) {
+ UnicodeText group_text;
+ const int group_type = rule_.regex->groups()->Get(group_id);
+ if (group_type == DatetimeGroupType_GROUP_UNUSED) {
+ continue;
+ }
+ if (!GroupTextFromMatch(group_id, &group_text)) {
+ TC3_LOG(ERROR) << "Couldn't retrieve group.";
+ return false;
+ }
+ // The pattern can have a group defined in a part that was not matched,
+ // e.g. an optional part. In this case we'll get an empty content here.
+ if (group_text.empty()) {
+ continue;
+ }
+
+ switch (group_type) {
+ case DatetimeGroupType_GROUP_YEAR: {
+ int year;
+ if (!ParseYear(group_text, &(year))) {
+ TC3_LOG(ERROR) << "Couldn't extract YEAR.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::YEAR, year);
+ break;
+ }
+ case DatetimeGroupType_GROUP_MONTH: {
+ int month;
+ if (!ParseMonth(group_text, &(month))) {
+ TC3_LOG(ERROR) << "Couldn't extract MONTH.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH,
+ month);
+ break;
+ }
+ case DatetimeGroupType_GROUP_DAY: {
+ int day_of_month;
+ if (!ParseDigits(group_text, &(day_of_month))) {
+ TC3_LOG(ERROR) << "Couldn't extract DAY.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ day_of_month);
+ break;
+ }
+ case DatetimeGroupType_GROUP_HOUR: {
+ int hour;
+ if (!ParseDigits(group_text, &(hour))) {
+ TC3_LOG(ERROR) << "Couldn't extract HOUR.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, hour);
+ break;
+ }
+ case DatetimeGroupType_GROUP_MINUTE: {
+ int minute;
+ if (!ParseDigits(group_text, &(minute)) &&
+ !ParseWrittenNumber(group_text, &(minute))) {
+ TC3_LOG(ERROR) << "Couldn't extract MINUTE.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE,
+ minute);
+ break;
+ }
+ case DatetimeGroupType_GROUP_SECOND: {
+ int second;
+ if (!ParseDigits(group_text, &(second))) {
+ TC3_LOG(ERROR) << "Couldn't extract SECOND.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND,
+ second);
+ break;
+ }
+ case DatetimeGroupType_GROUP_AMPM: {
+ int meridiem;
+ if (!ParseMeridiem(group_text, &(meridiem))) {
+ TC3_LOG(ERROR) << "Couldn't extract AMPM.";
+ return false;
+ }
+ result->SetAbsoluteValue(DatetimeComponent::ComponentType::MERIDIEM,
+ meridiem);
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONDISTANCE: {
+ relative_count = 0;
+ if (!ParseRelationDistance(group_text, &(relative_count))) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_DISTANCE_FIELD.";
+ return false;
+ }
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATION: {
+ if (!ParseRelativeValue(group_text, &relative_qualifier)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_FIELD.";
+ return false;
+ }
+ ParseRelationAndConvertToRelativeCount(group_text, &relative_count);
+ if (relative_qualifier ==
+ DatetimeComponent::RelativeQualifier::TOMORROW ||
+ relative_qualifier == DatetimeComponent::RelativeQualifier::NOW ||
+ relative_qualifier ==
+ DatetimeComponent::RelativeQualifier::YESTERDAY) {
+ if (!ParseFieldType(group_text, &component_type)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ }
+ break;
+ }
+ case DatetimeGroupType_GROUP_RELATIONTYPE: {
+ if (!ParseFieldType(group_text, &component_type)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ if (component_type == DatetimeComponent::ComponentType::DAY_OF_WEEK) {
+ int day_of_week;
+ if (!ParseDayOfWeek(group_text, &day_of_week)) {
+ TC3_LOG(ERROR) << "Couldn't extract RELATION_TYPE_FIELD.";
+ return false;
+ }
+ result->SetAbsoluteValue(component_type, day_of_week);
+ }
+ break;
+ }
+ case DatetimeGroupType_GROUP_DUMMY1:
+ case DatetimeGroupType_GROUP_DUMMY2:
+ break;
+ default:
+ TC3_LOG(INFO) << "Unknown group type.";
+ continue;
+ }
+ if (!UpdateMatchSpan(group_id, result_span)) {
+ TC3_LOG(ERROR) << "Couldn't update span.";
+ return false;
+ }
+ }
+
+ if (relative_qualifier != DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ result->SetRelativeValue(component_type, relative_qualifier);
+ result->SetRelativeCount(component_type, relative_count);
+ }
+
+ if (result_span->first == kInvalidIndex ||
+ result_span->second == kInvalidIndex) {
+ *result_span = {kInvalidIndex, kInvalidIndex};
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::RuleIdForType(DatetimeExtractorType type,
+ int* rule_id) const {
+ auto type_it = type_and_locale_to_rule_.find(type);
+ if (type_it == type_and_locale_to_rule_.end()) {
+ return false;
+ }
+
+ auto locale_it = type_it->second.find(locale_id_);
+ if (locale_it == type_it->second.end()) {
+ return false;
+ }
+ *rule_id = locale_it->second;
+ return true;
+}
+
+bool DatetimeExtractor::ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result) const {
+ int rule_id;
+ if (!RuleIdForType(extractor_type, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+
+ int status;
+ if (!matcher->Find(&status)) {
+ return false;
+ }
+
+ if (match_result != nullptr) {
+ *match_result = matcher->Group(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool DatetimeExtractor::GroupTextFromMatch(int group_id,
+ UnicodeText* result) const {
+ int status;
+ *result = matcher_.Group(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeExtractor::UpdateMatchSpan(int group_id,
+ CodepointSpan* span) const {
+ int status;
+ const int match_start = matcher_.Start(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ const int match_end = matcher_.End(group_id, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ if (span->first == kInvalidIndex || span->first > match_start) {
+ span->first = match_start;
+ }
+ if (span->second == kInvalidIndex || span->second < match_end) {
+ span->second = match_end;
+ }
+
+ return true;
+}
+
+template <typename T>
+bool DatetimeExtractor::MapInput(
+ const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const {
+ for (const auto& type_value_pair : mapping) {
+ if (ExtractType(input, type_value_pair.first)) {
+ *result = type_value_pair.second;
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
+ int* parsed_number) const {
+ std::vector<std::pair<int, int>> found_numbers;
+ for (const auto& type_value_pair :
+ std::vector<std::pair<DatetimeExtractorType, int>>{
+ {DatetimeExtractorType_ZERO, 0},
+ {DatetimeExtractorType_ONE, 1},
+ {DatetimeExtractorType_TWO, 2},
+ {DatetimeExtractorType_THREE, 3},
+ {DatetimeExtractorType_FOUR, 4},
+ {DatetimeExtractorType_FIVE, 5},
+ {DatetimeExtractorType_SIX, 6},
+ {DatetimeExtractorType_SEVEN, 7},
+ {DatetimeExtractorType_EIGHT, 8},
+ {DatetimeExtractorType_NINE, 9},
+ {DatetimeExtractorType_TEN, 10},
+ {DatetimeExtractorType_ELEVEN, 11},
+ {DatetimeExtractorType_TWELVE, 12},
+ {DatetimeExtractorType_THIRTEEN, 13},
+ {DatetimeExtractorType_FOURTEEN, 14},
+ {DatetimeExtractorType_FIFTEEN, 15},
+ {DatetimeExtractorType_SIXTEEN, 16},
+ {DatetimeExtractorType_SEVENTEEN, 17},
+ {DatetimeExtractorType_EIGHTEEN, 18},
+ {DatetimeExtractorType_NINETEEN, 19},
+ {DatetimeExtractorType_TWENTY, 20},
+ {DatetimeExtractorType_THIRTY, 30},
+ {DatetimeExtractorType_FORTY, 40},
+ {DatetimeExtractorType_FIFTY, 50},
+ {DatetimeExtractorType_SIXTY, 60},
+ {DatetimeExtractorType_SEVENTY, 70},
+ {DatetimeExtractorType_EIGHTY, 80},
+ {DatetimeExtractorType_NINETY, 90},
+ {DatetimeExtractorType_HUNDRED, 100},
+ {DatetimeExtractorType_THOUSAND, 1000},
+ }) {
+ int rule_id;
+ if (!RuleIdForType(type_value_pair.first, &rule_id)) {
+ return false;
+ }
+
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rules_[rule_id]->Matcher(input);
+ if (!matcher) {
+ return false;
+ }
+ int status;
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ int span_start = matcher->Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+ found_numbers.push_back({span_start, type_value_pair.second});
+ }
+ }
+
+ std::sort(found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
+
+ int sum = 0;
+ int running_value = -1;
+ // Simple math to make sure we handle written numerical modifiers correctly
+ // so that :="fifty one thousand and one" maps to 51001 and not 50 1 1000 1.
+ for (const std::pair<int, int>& position_number_pair : found_numbers) {
+ if (running_value >= 0) {
+ if (running_value > position_number_pair.second) {
+ sum += running_value;
+ running_value = position_number_pair.second;
+ } else {
+ running_value *= position_number_pair.second;
+ }
+ } else {
+ running_value = position_number_pair.second;
+ }
+ }
+ sum += running_value;
+ *parsed_number = sum;
+ return true;
+}
+
+bool DatetimeExtractor::ParseDigits(const UnicodeText& input,
+ int* parsed_digits) const {
+ UnicodeText digit;
+ if (!ExtractType(input, DatetimeExtractorType_DIGITS, &digit)) {
+ return false;
+ }
+
+ if (!unilib_.ParseInt32(digit, parsed_digits)) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeExtractor::ParseYear(const UnicodeText& input,
+ int* parsed_year) const {
+ if (!ParseDigits(input, parsed_year)) {
+ return false;
+ }
+
+ // Logic to decide if XX will be 20XX or 19XX
+ if (*parsed_year < 100) {
+ if (*parsed_year < 50) {
+ *parsed_year += 2000;
+ } else {
+ *parsed_year += 1900;
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeExtractor::ParseMonth(const UnicodeText& input,
+ int* parsed_month) const {
+ if (ParseDigits(input, parsed_month)) {
+ return true;
+ }
+
+ if (MapInput(input,
+ {
+ {DatetimeExtractorType_JANUARY, 1},
+ {DatetimeExtractorType_FEBRUARY, 2},
+ {DatetimeExtractorType_MARCH, 3},
+ {DatetimeExtractorType_APRIL, 4},
+ {DatetimeExtractorType_MAY, 5},
+ {DatetimeExtractorType_JUNE, 6},
+ {DatetimeExtractorType_JULY, 7},
+ {DatetimeExtractorType_AUGUST, 8},
+ {DatetimeExtractorType_SEPTEMBER, 9},
+ {DatetimeExtractorType_OCTOBER, 10},
+ {DatetimeExtractorType_NOVEMBER, 11},
+ {DatetimeExtractorType_DECEMBER, 12},
+ },
+ parsed_month)) {
+ return true;
+ }
+
+ return false;
+}
+
+bool DatetimeExtractor::ParseMeridiem(const UnicodeText& input,
+ int* parsed_meridiem) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_AM, 0 /* AM */},
+ {DatetimeExtractorType_PM, 1 /* PM */},
+ },
+ parsed_meridiem);
+}
+
+bool DatetimeExtractor::ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const {
+ if (ParseDigits(input, parsed_distance)) {
+ return true;
+ }
+ if (ParseWrittenNumber(input, parsed_distance)) {
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeExtractor::ParseRelativeValue(
+ const UnicodeText& input,
+ DatetimeComponent::RelativeQualifier* parsed_relative_value) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_NOW,
+ DatetimeComponent::RelativeQualifier::NOW},
+ {DatetimeExtractorType_YESTERDAY,
+ DatetimeComponent::RelativeQualifier::YESTERDAY},
+ {DatetimeExtractorType_TOMORROW,
+ DatetimeComponent::RelativeQualifier::TOMORROW},
+ {DatetimeExtractorType_NEXT,
+ DatetimeComponent::RelativeQualifier::NEXT},
+ {DatetimeExtractorType_NEXT_OR_SAME,
+ DatetimeComponent::RelativeQualifier::THIS},
+ {DatetimeExtractorType_LAST,
+ DatetimeComponent::RelativeQualifier::LAST},
+ {DatetimeExtractorType_PAST,
+ DatetimeComponent::RelativeQualifier::PAST},
+ {DatetimeExtractorType_FUTURE,
+ DatetimeComponent::RelativeQualifier::FUTURE},
+ },
+ parsed_relative_value);
+}
+
+bool DatetimeExtractor::ParseRelationAndConvertToRelativeCount(
+ const UnicodeText& input, int* relative_count) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_NOW, 0},
+ {DatetimeExtractorType_YESTERDAY, -1},
+ {DatetimeExtractorType_TOMORROW, 1},
+ {DatetimeExtractorType_NEXT, 1},
+ {DatetimeExtractorType_NEXT_OR_SAME, 1},
+ {DatetimeExtractorType_LAST, -1},
+ {DatetimeExtractorType_PAST, -1},
+ },
+ relative_count);
+}
+
+bool DatetimeExtractor::ParseDayOfWeek(const UnicodeText& input,
+ int* parsed_day_of_week) const {
+ return MapInput(input,
+ {
+ {DatetimeExtractorType_SUNDAY, kSunday},
+ {DatetimeExtractorType_MONDAY, kMonday},
+ {DatetimeExtractorType_TUESDAY, kTuesday},
+ {DatetimeExtractorType_WEDNESDAY, kWednesday},
+ {DatetimeExtractorType_THURSDAY, kThursday},
+ {DatetimeExtractorType_FRIDAY, kFriday},
+ {DatetimeExtractorType_SATURDAY, kSaturday},
+ },
+ parsed_day_of_week);
+}
+
+bool DatetimeExtractor::ParseFieldType(
+ const UnicodeText& input,
+ DatetimeComponent::ComponentType* parsed_field_type) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_TUESDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_WEDNESDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_THURSDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_FRIDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_SATURDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_SUNDAY,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK},
+ {DatetimeExtractorType_SECONDS,
+ DatetimeComponent::ComponentType::SECOND},
+ {DatetimeExtractorType_MINUTES,
+ DatetimeComponent::ComponentType::MINUTE},
+ {DatetimeExtractorType_NOW,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_HOURS, DatetimeComponent::ComponentType::HOUR},
+ {DatetimeExtractorType_DAY,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_TOMORROW,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_YESTERDAY,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH},
+ {DatetimeExtractorType_WEEK, DatetimeComponent::ComponentType::WEEK},
+ {DatetimeExtractorType_MONTH,
+ DatetimeComponent::ComponentType::MONTH},
+ {DatetimeExtractorType_YEAR, DatetimeComponent::ComponentType::YEAR},
+ },
+ parsed_field_type);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/extractor.h b/native/annotator/datetime/extractor.h
new file mode 100644
index 0000000..0f92b2a
--- /dev/null
+++ b/native/annotator/datetime/extractor.h
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+struct CompiledRule {
+ // The compiled regular expression.
+ std::unique_ptr<const UniLib::RegexPattern> compiled_regex;
+
+ // The uncompiled pattern and information about the pattern groups.
+ const DatetimeModelPattern_::Regex* regex;
+
+ // DatetimeModelPattern which 'regex' is part of and comes from.
+ const DatetimeModelPattern* pattern;
+};
+
+// A helper class for DatetimeParser that extracts structured data
+// (DateParseDate) from the current match of the passed RegexMatcher.
+class DatetimeExtractor {
+ public:
+ explicit DatetimeExtractor(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int locale_id, const UniLib* unilib,
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>&
+ extractor_rules,
+ const std::unordered_map<DatetimeExtractorType,
+ std::unordered_map<int, int>>&
+ type_and_locale_to_extractor_rule)
+ : rule_(rule),
+ matcher_(matcher),
+ locale_id_(locale_id),
+ unilib_(*unilib),
+ rules_(extractor_rules),
+ type_and_locale_to_rule_(type_and_locale_to_extractor_rule) {}
+ bool Extract(DatetimeParsedData* result, CodepointSpan* result_span) const;
+
+ private:
+ bool RuleIdForType(DatetimeExtractorType type, int* rule_id) const;
+
+ // Returns true if the rule for given extractor matched. If it matched,
+ // match_result will contain the first group of the rule (if match_result not
+ // nullptr).
+ bool ExtractType(const UnicodeText& input,
+ DatetimeExtractorType extractor_type,
+ UnicodeText* match_result = nullptr) const;
+
+ bool GroupTextFromMatch(int group_id, UnicodeText* result) const;
+
+ // Updates the span to include the current match for the given group.
+ bool UpdateMatchSpan(int group_id, CodepointSpan* span) const;
+
+ // Returns true if any of the extractors from 'mapping' matched. If it did,
+ // will fill 'result' with the associated value from 'mapping'.
+ template <typename T>
+ bool MapInput(const UnicodeText& input,
+ const std::vector<std::pair<DatetimeExtractorType, T>>& mapping,
+ T* result) const;
+
+ bool ParseDigits(const UnicodeText& input, int* parsed_digits) const;
+ bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
+ bool ParseYear(const UnicodeText& input, int* parsed_year) const;
+ bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
+ bool ParseMeridiem(const UnicodeText& input, int* parsed_meridiem) const;
+ bool ParseRelativeValue(
+ const UnicodeText& input,
+ DatetimeComponent::RelativeQualifier* parsed_relative_value) const;
+ bool ParseRelationDistance(const UnicodeText& input,
+ int* parsed_distance) const;
+ bool ParseFieldType(
+ const UnicodeText& input,
+ DatetimeComponent::ComponentType* parsed_field_type) const;
+ bool ParseDayOfWeek(const UnicodeText& input, int* parsed_day_of_week) const;
+ bool ParseRelationAndConvertToRelativeCount(const UnicodeText& input,
+ int* relative_count) const;
+
+ const CompiledRule& rule_;
+ const UniLib::RegexMatcher& matcher_;
+ int locale_id_;
+ const UniLib& unilib_;
+ const std::vector<std::unique_ptr<const UniLib::RegexPattern>>& rules_;
+ const std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>&
+ type_and_locale_to_rule_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_EXTRACTOR_H_
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/parser.cc
new file mode 100644
index 0000000..72fd3ab
--- /dev/null
+++ b/native/annotator/datetime/parser.cc
@@ -0,0 +1,404 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/parser.h"
+
+#include <set>
+#include <unordered_set>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/datetime/utils.h"
+#include "utils/calendar/calendar.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/split.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3 {
+std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+ const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
+ std::unique_ptr<DatetimeParser> result(
+ new DatetimeParser(model, unilib, calendarlib, decompressor));
+ if (!result->initialized_) {
+ result.reset();
+ }
+ return result;
+}
+
+DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor)
+ : unilib_(*unilib), calendarlib_(*calendarlib) {
+ initialized_ = false;
+
+ if (model == nullptr) {
+ return;
+ }
+
+ if (model->patterns() != nullptr) {
+ for (const DatetimeModelPattern* pattern : *model->patterns()) {
+ if (pattern->regexes()) {
+ for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(
+ unilib_, regex->pattern(), regex->compressed_pattern(),
+ model->lazy_regex_compilation(), decompressor);
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Couldn't create rule pattern.";
+ return;
+ }
+ rules_.push_back({std::move(regex_pattern), regex, pattern});
+ if (pattern->locales()) {
+ for (int locale : *pattern->locales()) {
+ locale_to_rules_[locale].push_back(rules_.size() - 1);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (model->extractors() != nullptr) {
+ for (const DatetimeModelExtractor* extractor : *model->extractors()) {
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ UncompressMakeRegexPattern(
+ unilib_, extractor->pattern(), extractor->compressed_pattern(),
+ model->lazy_regex_compilation(), decompressor);
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Couldn't create extractor pattern";
+ return;
+ }
+ extractor_rules_.push_back(std::move(regex_pattern));
+
+ if (extractor->locales()) {
+ for (int locale : *extractor->locales()) {
+ type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
+ extractor_rules_.size() - 1;
+ }
+ }
+ }
+ }
+
+ if (model->locales() != nullptr) {
+ for (int i = 0; i < model->locales()->size(); ++i) {
+ locale_string_to_id_[model->locales()->Get(i)->str()] = i;
+ }
+ }
+
+ if (model->default_locales() != nullptr) {
+ for (const int locale : *model->default_locales()) {
+ default_locale_ids_.push_back(locale);
+ }
+ }
+
+ use_extractors_for_locating_ = model->use_extractors_for_locating();
+ generate_alternative_interpretations_when_ambiguous_ =
+ model->generate_alternative_interpretations_when_ambiguous();
+ prefer_future_for_unspecified_date_ =
+ model->prefer_future_for_unspecified_date();
+
+ initialized_ = true;
+}
+
+bool DatetimeParser::Parse(
+ const std::string& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
+ reference_time_ms_utc, reference_timezone, locales, mode,
+ annotation_usecase, anchor_start_end, results);
+}
+
+bool DatetimeParser::FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
+ const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const {
+ for (const int locale_id : locale_ids) {
+ auto rules_it = locale_to_rules_.find(locale_id);
+ if (rules_it == locale_to_rules_.end()) {
+ continue;
+ }
+
+ for (const int rule_id : rules_it->second) {
+ // Skip rules that were already executed in previous locales.
+ if (executed_rules->find(rule_id) != executed_rules->end()) {
+ continue;
+ }
+
+ if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
+ (1 << annotation_usecase)) == 0) {
+ continue;
+ }
+
+ if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
+ continue;
+ }
+
+ executed_rules->insert(rule_id);
+
+ if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end, found_spans)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool DatetimeParser::Parse(
+ const UnicodeText& input, const int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
+ std::unordered_set<int> executed_rules;
+ std::string reference_locale;
+ const std::vector<int> requested_locales =
+ ParseAndExpandLocales(locales, &reference_locale);
+ if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, reference_locale,
+ &executed_rules, &found_spans)) {
+ return false;
+ }
+
+ std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
+ indexed_found_spans.reserve(found_spans.size());
+ for (int i = 0; i < found_spans.size(); i++) {
+ indexed_found_spans.push_back({found_spans[i], i});
+ }
+
+ // Resolve conflicts by always picking the longer span and breaking ties by
+ // selecting the earlier entry in the list for a given locale.
+ std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
+
+ found_spans.clear();
+ for (auto& span_index_pair : indexed_found_spans) {
+ found_spans.push_back(span_index_pair.first);
+ }
+
+ std::set<int, std::function<bool(int, int)>> chosen_indices_set(
+ [&found_spans](int a, int b) {
+ return found_spans[a].span.first < found_spans[b].span.first;
+ });
+ for (int i = 0; i < found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ chosen_indices_set.insert(i);
+ results->push_back(found_spans[i]);
+ }
+ }
+
+ return true;
+}
+
+bool DatetimeParser::HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const {
+ int status = UniLib::RegexMatcher::kNoError;
+ const int start = matcher.Start(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ const int end = matcher.End(&status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return false;
+ }
+
+ DatetimeParseResultSpan parse_result;
+ std::vector<DatetimeParseResult> alternatives;
+ if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
+ reference_locale, locale_id, &alternatives,
+ &parse_result.span)) {
+ return false;
+ }
+
+ if (!use_extractors_for_locating_) {
+ parse_result.span = {start, end};
+ }
+
+ if (parse_result.span.first != kInvalidIndex &&
+ parse_result.span.second != kInvalidIndex) {
+ parse_result.target_classification_score =
+ rule.pattern->target_classification_score();
+ parse_result.priority_score = rule.pattern->priority_score();
+
+ for (DatetimeParseResult& alternative : alternatives) {
+ parse_result.data.push_back(alternative);
+ }
+ }
+ result->push_back(parse_result);
+ return true;
+}
+
+bool DatetimeParser::ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
+ std::unique_ptr<UniLib::RegexMatcher> matcher =
+ rule.compiled_regex->Matcher(input);
+ int status = UniLib::RegexMatcher::kNoError;
+ if (anchor_start_end) {
+ if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ result)) {
+ return false;
+ }
+ }
+ } else {
+ while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
+ if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ result)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+std::vector<int> DatetimeParser::ParseAndExpandLocales(
+ const std::string& locales, std::string* reference_locale) const {
+ std::vector<StringPiece> split_locales = strings::Split(locales, ',');
+ if (!split_locales.empty()) {
+ *reference_locale = split_locales[0].ToString();
+ } else {
+ *reference_locale = "";
+ }
+
+ std::vector<int> result;
+ for (const StringPiece& locale_str : split_locales) {
+ auto locale_it = locale_string_to_id_.find(locale_str.ToString());
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ continue;
+ }
+
+ const std::string language = locale.Language();
+ const std::string script = locale.Script();
+ const std::string region = locale.Region();
+
+ // First, try adding *-region locale.
+ if (!region.empty()) {
+ locale_it = locale_string_to_id_.find("*-" + region);
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Second, try adding language-script-* locale.
+ if (!script.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ // Third, try adding language-* locale.
+ if (!language.empty()) {
+ locale_it = locale_string_to_id_.find(language + "-*");
+ if (locale_it != locale_string_to_id_.end()) {
+ result.push_back(locale_it->second);
+ }
+ }
+ }
+
+ // Add the default locales if they haven't been added already.
+ const std::unordered_set<int> result_set(result.begin(), result.end());
+ for (const int default_locale_id : default_locale_ids_) {
+ if (result_set.find(default_locale_id) == result_set.end()) {
+ result.push_back(default_locale_id);
+ }
+ }
+
+ return result;
+}
+
+bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const {
+ DatetimeParsedData parse;
+ DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
+ extractor_rules_,
+ type_and_locale_to_extractor_rule_);
+ if (!extractor.Extract(&parse, result_span)) {
+ return false;
+ }
+ std::vector<DatetimeParsedData> interpretations;
+ if (generate_alternative_interpretations_when_ambiguous_) {
+ FillInterpretations(parse, calendarlib_.GetGranularity(parse),
+ &interpretations);
+ } else {
+ interpretations.push_back(parse);
+ }
+
+ results->reserve(results->size() + interpretations.size());
+ for (const DatetimeParsedData& interpretation : interpretations) {
+ std::vector<DatetimeComponent> date_components;
+ interpretation.GetDatetimeComponents(&date_components);
+ DatetimeParseResult result;
+ // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM”
+ // which is encoded in the pair of DatetimeParseResult; both
+ // corresponding to the same date, but one corresponding to
+ // “AM” and the other one corresponding to “PM”.
+ // Remove multiple DatetimeParseResult per datetime span,
+ // once the ambiguities/DatetimeComponents are added in the
+ // response. For Details see b/130355975
+ if (!calendarlib_.InterpretParseData(
+ interpretation, reference_time_ms_utc, reference_timezone,
+ reference_locale, prefer_future_for_unspecified_date_,
+ &(result.time_ms_utc), &(result.granularity))) {
+ return false;
+ }
+
+ // Sort the date time units by component type.
+ std::sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
+ result.datetime_components.swap(date_components);
+ results->push_back(result);
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
new file mode 100644
index 0000000..8b58388
--- /dev/null
+++ b/native/annotator/datetime/parser.h
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/calendar/calendar.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ bool Parse(const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ // Same as above but takes UnicodeText.
+ bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& locales,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ protected:
+ explicit DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor);
+
+ // Returns a list of locale ids for given locale spec string (comma-separated
+ // locale names). Assigns the first parsed locale to reference_locale.
+ std::vector<int> ParseAndExpandLocales(const std::string& locales,
+ std::string* reference_locale) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ bool FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules,
+ std::vector<DatetimeParseResultSpan>* found_spans) const;
+
+ bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const;
+
+ // Parse and extract information from current match in 'matcher'.
+ bool HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResultSpan>* result) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ const CalendarLib& calendarlib_;
+ std::vector<CompiledRule> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
+ bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/native/annotator/datetime/utils.cc b/native/annotator/datetime/utils.cc
new file mode 100644
index 0000000..30a99a1
--- /dev/null
+++ b/native/annotator/datetime/utils.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/utils.h"
+
+namespace libtextclassifier3 {
+
+void FillInterpretations(const DatetimeParsedData& parse,
+ const DatetimeGranularity& granularity,
+ std::vector<DatetimeParsedData>* interpretations) {
+ DatetimeParsedData modified_parse(parse);
+ // If the relation field is not set, but relation_type field *is*, assume
+ // the relation field is NEXT_OR_SAME. This is necessary to handle e.g.
+ // "monday 3pm" (otherwise only "this monday 3pm" would work).
+ if (parse.HasFieldType(DatetimeComponent::ComponentType::DAY_OF_WEEK)) {
+ DatetimeComponent::RelativeQualifier relative_value;
+ if (parse.GetRelativeValue(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ &relative_value)) {
+ if (relative_value == DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ modified_parse.SetRelativeValue(
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::RelativeQualifier::THIS);
+ }
+ }
+ }
+
+ // Multiple interpretations of ambiguous datetime expressions are generated
+ // here.
+ if (granularity > DatetimeGranularity::GRANULARITY_DAY &&
+ modified_parse.HasFieldType(DatetimeComponent::ComponentType::HOUR) &&
+ !modified_parse.HasRelativeValue(
+ DatetimeComponent::ComponentType::HOUR) &&
+ !modified_parse.HasFieldType(
+ DatetimeComponent::ComponentType::MERIDIEM)) {
+ int hour_value;
+ modified_parse.GetFieldValue(DatetimeComponent::ComponentType::HOUR,
+ &hour_value);
+ if (hour_value <= 12) {
+ modified_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MERIDIEM, 0);
+ interpretations->push_back(modified_parse);
+ modified_parse.SetAbsoluteValue(
+ DatetimeComponent::ComponentType::MERIDIEM, 1);
+ interpretations->push_back(modified_parse);
+ } else {
+ interpretations->push_back(modified_parse);
+ }
+ } else {
+ // Otherwise just generate 1 variant.
+ interpretations->push_back(modified_parse);
+ }
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/utils.h b/native/annotator/datetime/utils.h
new file mode 100644
index 0000000..cdf1c8b
--- /dev/null
+++ b/native/annotator/datetime/utils.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_UTILS_H_
+
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+// Generate alternative interpretations when datetime is ambiguous e.g. '9:45'
+// has hour:9 and minute:45 will be resolve to 9:45 AM & 9:45 PM.
+void FillInterpretations(const DatetimeParsedData& parse,
+ const DatetimeGranularity& granularity,
+ std::vector<DatetimeParsedData>* interpretations);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_UTILS_H_
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
new file mode 100644
index 0000000..07b9885
--- /dev/null
+++ b/native/annotator/duration/duration.cc
@@ -0,0 +1,357 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/duration/duration.h"
+
+#include <climits>
+#include <cstdlib>
+
+#include "annotator/collections.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/numbers.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+using DurationUnit = internal::DurationUnit;
+
+namespace internal {
+
+namespace {
+std::string ToLowerString(const std::string& str, const UniLib* unilib) {
+ return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
+ .ToUTF8String();
+}
+
+void FillDurationUnitMap(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ expressions,
+ DurationUnit duration_unit,
+ std::unordered_map<std::string, DurationUnit>* target_map,
+ const UniLib* unilib) {
+ if (expressions == nullptr) {
+ return;
+ }
+
+ for (const flatbuffers::String* expression_string : *expressions) {
+ (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
+ duration_unit;
+ }
+}
+} // namespace
+
+std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
+ const DurationAnnotatorOptions* options, const UniLib* unilib) {
+ std::unordered_map<std::string, DurationUnit> mapping;
+ FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
+ unilib);
+ FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
+ unilib);
+ FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
+ unilib);
+ FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
+ &mapping, unilib);
+ FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
+ &mapping, unilib);
+ return mapping;
+}
+
+std::unordered_set<std::string> BuildStringSet(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ strings,
+ const UniLib* unilib) {
+ std::unordered_set<std::string> result;
+ if (strings == nullptr) {
+ return result;
+ }
+
+ for (const flatbuffers::String* string_value : *strings) {
+ result.insert(ToLowerString(string_value->c_str(), unilib));
+ }
+
+ return result;
+}
+
+std::unordered_set<int32> BuildInt32Set(
+ const flatbuffers::Vector<int32>* ints) {
+ std::unordered_set<int32> result;
+ if (ints == nullptr) {
+ return result;
+ }
+
+ for (const int32 int_value : *ints) {
+ result.insert(int_value);
+ }
+
+ return result;
+}
+
+} // namespace internal
+
+bool DurationAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const {
+ if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
+ (1 << annotation_usecase))) == 0) {
+ return false;
+ }
+
+ const UnicodeText selection =
+ UnicodeText::Substring(context, selection_indices.first,
+ selection_indices.second, /*do_copy=*/false);
+ const std::vector<Token> tokens = feature_processor_->Tokenize(selection);
+
+ AnnotatedSpan annotated_span;
+ if (tokens.empty() ||
+ FindDurationStartingAt(context, tokens, 0, &annotated_span) !=
+ tokens.size()) {
+ return false;
+ }
+
+ TC3_DCHECK(!annotated_span.classification.empty());
+
+ *classification_result = annotated_span.classification[0];
+ return true;
+}
+
+bool DurationAnnotator::FindAll(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* results) const {
+ if (!options_->enabled() || ((options_->enabled_annotation_usecases() &
+ (1 << annotation_usecase))) == 0) {
+ return true;
+ }
+
+ for (int i = 0; i < tokens.size();) {
+ AnnotatedSpan span;
+ const int next_i = FindDurationStartingAt(context, tokens, i, &span);
+ if (next_i != i) {
+ results->push_back(span);
+ i = next_i;
+ } else {
+ i++;
+ }
+ }
+ return true;
+}
+
+int DurationAnnotator::FindDurationStartingAt(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ int start_token_index,
+ AnnotatedSpan* result) const {
+ CodepointIndex start_index = kInvalidIndex;
+ CodepointIndex end_index = kInvalidIndex;
+
+ bool has_quantity = false;
+ ParsedDurationAtom parsed_duration;
+
+ std::vector<ParsedDurationAtom> parsed_duration_atoms;
+
+ // This is the core algorithm for finding the duration expressions. It
+ // basically iterates over tokens and changes the state variables above as it
+ // goes.
+ int token_index;
+ int quantity_end_index;
+ for (token_index = start_token_index; token_index < tokens.size();
+ token_index++) {
+ const Token& token = tokens[token_index];
+
+ if (ParseQuantityToken(token, &parsed_duration)) {
+ has_quantity = true;
+ if (start_index == kInvalidIndex) {
+ start_index = token.start;
+ }
+ quantity_end_index = token.end;
+ } else if (((!options_->require_quantity() || has_quantity) &&
+ ParseDurationUnitToken(token, &parsed_duration.unit)) ||
+ ParseQuantityDurationUnitToken(token, &parsed_duration)) {
+ if (start_index == kInvalidIndex) {
+ start_index = token.start;
+ }
+ end_index = token.end;
+ parsed_duration_atoms.push_back(parsed_duration);
+ has_quantity = false;
+ parsed_duration = ParsedDurationAtom();
+ } else if (ParseFillerToken(token)) {
+ } else {
+ break;
+ }
+ }
+
+ if (parsed_duration_atoms.empty()) {
+ return start_token_index;
+ }
+
+ const bool parse_ended_without_unit_for_last_mentioned_quantity =
+ has_quantity;
+
+ ClassificationResult classification{Collections::Duration(),
+ options_->score()};
+ classification.priority_score = options_->priority_score();
+ classification.duration_ms =
+ ParsedDurationAtomsToMillis(parsed_duration_atoms);
+
+ // Process suffix expressions like "and half" that don't have the
+ // duration_unit explicitly mentioned.
+ if (parse_ended_without_unit_for_last_mentioned_quantity) {
+ if (parsed_duration.plus_half) {
+ end_index = quantity_end_index;
+ ParsedDurationAtom atom = ParsedDurationAtom::Half();
+ atom.unit = parsed_duration_atoms.rbegin()->unit;
+ classification.duration_ms += ParsedDurationAtomsToMillis({atom});
+ } else if (options_->enable_dangling_quantity_interpretation()) {
+ end_index = quantity_end_index;
+ // TODO(b/144752747) Add dangling quantity to duration_ms.
+ }
+ }
+
+ result->span = feature_processor_->StripBoundaryCodepoints(
+ context, {start_index, end_index});
+ result->classification.push_back(classification);
+ result->source = AnnotatedSpan::Source::DURATION;
+
+ return token_index;
+}
+
+int64 DurationAnnotator::ParsedDurationAtomsToMillis(
+ const std::vector<ParsedDurationAtom>& atoms) const {
+ int64 result = 0;
+ for (auto atom : atoms) {
+ int multiplier;
+ switch (atom.unit) {
+ case DurationUnit::WEEK:
+ multiplier = 7 * 24 * 60 * 60 * 1000;
+ break;
+ case DurationUnit::DAY:
+ multiplier = 24 * 60 * 60 * 1000;
+ break;
+ case DurationUnit::HOUR:
+ multiplier = 60 * 60 * 1000;
+ break;
+ case DurationUnit::MINUTE:
+ multiplier = 60 * 1000;
+ break;
+ case DurationUnit::SECOND:
+ multiplier = 1000;
+ break;
+ case DurationUnit::UNKNOWN:
+ TC3_LOG(ERROR) << "Requesting parse of UNKNOWN duration duration_unit.";
+ return -1;
+ break;
+ }
+
+ int64 value = atom.value;
+ // This condition handles expressions like "an hour", where the quantity is
+ // not specified. In this case we assume quantity 1. Except for cases like
+ // "half hour".
+ if (value == 0 && !atom.plus_half) {
+ value = 1;
+ }
+ result += value * multiplier;
+ result += atom.plus_half * multiplier / 2;
+ }
+ return result;
+}
+
+bool DurationAnnotator::ParseQuantityToken(const Token& token,
+ ParsedDurationAtom* value) const {
+ if (token.value.empty()) {
+ return false;
+ }
+
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
+
+ if (half_expressions_.find(lowercase_token_value) !=
+ half_expressions_.end()) {
+ value->plus_half = true;
+ return true;
+ }
+
+ int32 parsed_value;
+ if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
+ value->value = parsed_value;
+ return true;
+ }
+
+ return false;
+}
+
+bool DurationAnnotator::ParseDurationUnitToken(
+ const Token& token, DurationUnit* duration_unit) const {
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
+
+ const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
+ if (it == token_value_to_duration_unit_.end()) {
+ return false;
+ }
+
+ *duration_unit = it->second;
+ return true;
+}
+
+bool DurationAnnotator::ParseQuantityDurationUnitToken(
+ const Token& token, ParsedDurationAtom* value) const {
+ if (token.value.empty()) {
+ return false;
+ }
+
+ Token sub_token;
+ bool has_quantity = false;
+ for (const char c : token.value) {
+ if (sub_token_separator_codepoints_.find(c) !=
+ sub_token_separator_codepoints_.end()) {
+ if (has_quantity || !ParseQuantityToken(sub_token, value)) {
+ return false;
+ }
+ has_quantity = true;
+
+ sub_token = Token();
+ } else {
+ sub_token.value += c;
+ }
+ }
+
+ return (!options_->require_quantity() || has_quantity) &&
+ ParseDurationUnitToken(sub_token, &(value->unit));
+}
+
+bool DurationAnnotator::ParseFillerToken(const Token& token) const {
+ std::string token_value_buffer;
+ const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
+ token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
+
+ if (filler_expressions_.find(lowercase_token_value) ==
+ filler_expressions_.end()) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
new file mode 100644
index 0000000..db4bdae
--- /dev/null
+++ b/native/annotator/duration/duration.h
@@ -0,0 +1,143 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+namespace internal {
+enum class DurationUnit {
+ UNKNOWN = -1,
+ WEEK = 0,
+ DAY = 1,
+ HOUR = 2,
+ MINUTE = 3,
+ SECOND = 4
+
+ // NOTE: If we want to add MONTH and YEAR we'll have to think of different
+ // parsing format, because MONTH and YEAR don't have a fixed number of
+ // milliseconds, unlike week/day/hour/minute/second. We ignore the daylight
+ // savings time and assume the day is always 24 hours.
+};
+
+// Prepares the mapping between token values and duration unit types.
+std::unordered_map<std::string, internal::DurationUnit>
+BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options,
+ const UniLib* unilib);
+
+// Creates a set of strings from a flatbuffer string vector.
+std::unordered_set<std::string> BuildStringSet(
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
+ strings,
+ const UniLib* unilib);
+
+// Creates a set of ints from a flatbuffer int vector.
+std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints);
+
+} // namespace internal
+
+// Annotator of duration expressions like "3 minutes 30 seconds".
+class DurationAnnotator {
+ public:
+ explicit DurationAnnotator(const DurationAnnotatorOptions* options,
+ const FeatureProcessor* feature_processor,
+ const UniLib* unilib)
+ : options_(options),
+ feature_processor_(feature_processor),
+ unilib_(unilib),
+ token_value_to_duration_unit_(
+ internal::BuildTokenToDurationUnitMapping(options, unilib)),
+ filler_expressions_(
+ internal::BuildStringSet(options->filler_expressions(), unilib)),
+ half_expressions_(
+ internal::BuildStringSet(options->half_expressions(), unilib)),
+ sub_token_separator_codepoints_(internal::BuildInt32Set(
+ options->sub_token_separator_codepoints())) {}
+
+ // Classifies given text, and if it is a duration, it passes the result in
+ // 'classification_result' and returns true, otherwise returns false.
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const;
+
+ // Finds all duration instances in the input text.
+ bool FindAll(const UnicodeText& context, const std::vector<Token>& tokens,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* results) const;
+
+ private:
+ // Represents a component of duration parsed from text (e.g. "3 hours" from
+ // the expression "3 hours and 20 minutes").
+ struct ParsedDurationAtom {
+ // Unit of the duration.
+ internal::DurationUnit unit = internal::DurationUnit::UNKNOWN;
+
+ // Quantity of the duration unit.
+ int value = 0;
+
+ // True, if half an unit was specified (either in addition, or exclusively).
+ // E.g. "hour and a half".
+ // NOTE: Quarter, three-quarters etc. is not supported.
+ bool plus_half = false;
+
+ static ParsedDurationAtom Half() {
+ ParsedDurationAtom result;
+ result.plus_half = true;
+ return result;
+ }
+ };
+
+ // Starts consuming tokens and returns the index past the last consumed token.
+ int FindDurationStartingAt(const UnicodeText& context,
+ const std::vector<Token>& tokens,
+ int start_token_index,
+ AnnotatedSpan* result) const;
+
+ bool ParseQuantityToken(const Token& token, ParsedDurationAtom* value) const;
+ bool ParseDurationUnitToken(const Token& token,
+ internal::DurationUnit* duration_unit) const;
+ bool ParseQuantityDurationUnitToken(const Token& token,
+ ParsedDurationAtom* value) const;
+ bool ParseFillerToken(const Token& token) const;
+
+ int64 ParsedDurationAtomsToMillis(
+ const std::vector<ParsedDurationAtom>& atoms) const;
+
+ const DurationAnnotatorOptions* options_;
+ const FeatureProcessor* feature_processor_;
+ const UniLib* unilib_;
+ const std::unordered_map<std::string, internal::DurationUnit>
+ token_value_to_duration_unit_;
+ const std::unordered_set<std::string> filler_expressions_;
+ const std::unordered_set<std::string> half_expressions_;
+ const std::unordered_set<int32> sub_token_separator_codepoints_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DURATION_DURATION_H_
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
new file mode 100644
index 0000000..a0985a2
--- /dev/null
+++ b/native/annotator/duration/duration_test.cc
@@ -0,0 +1,567 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/duration/duration.h"
+
+#include <string>
+#include <vector>
+
+#include "annotator/collections.h"
+#include "annotator/model_generated.h"
+#include "annotator/types-test-util.h"
+#include "annotator/types.h"
+#include "utils/test-utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::AllOf;
+using testing::ElementsAre;
+using testing::Field;
+using testing::IsEmpty;
+
+const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+
+ options.week_expressions.push_back("week");
+ options.week_expressions.push_back("weeks");
+
+ options.day_expressions.push_back("day");
+ options.day_expressions.push_back("days");
+
+ options.hour_expressions.push_back("hour");
+ options.hour_expressions.push_back("hours");
+
+ options.minute_expressions.push_back("minute");
+ options.minute_expressions.push_back("minutes");
+
+ options.second_expressions.push_back("second");
+ options.second_expressions.push_back("seconds");
+
+ options.filler_expressions.push_back("and");
+ options.filler_expressions.push_back("a");
+ options.filler_expressions.push_back("an");
+ options.filler_expressions.push_back("one");
+
+ options.half_expressions.push_back("half");
+
+ options.sub_token_separator_codepoints.push_back('-');
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+}
+
+std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ FeatureProcessorOptionsT options;
+ options.context_size = 1;
+ options.max_selection_span = 1;
+ options.snap_label_span_boundaries_to_containing_tokens = false;
+ options.ignored_span_boundary_codepoints.push_back(',');
+
+ options.tokenization_codepoint_config.emplace_back(
+ new TokenizationCodepointRangeT());
+ auto& config = options.tokenization_codepoint_config.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ const FeatureProcessorOptions* feature_processor_options =
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
+
+ return std::unique_ptr<FeatureProcessor>(
+ new FeatureProcessor(feature_processor_options, unilib));
+}
+
+class DurationAnnotatorTest : public ::testing::Test {
+ protected:
+ DurationAnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ feature_processor_(BuildFeatureProcessor(&unilib_)),
+ duration_annotator_(TestingDurationAnnotatorOptions(),
+ feature_processor_.get(), &unilib_) {}
+
+ std::vector<Token> Tokenize(const UnicodeText& text) {
+ return feature_processor_->Tokenize(text);
+ }
+
+ UniLib unilib_;
+ std::unique_ptr<FeatureProcessor> feature_processor_;
+ DurationAnnotator duration_annotator_;
+};
+
+TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
+}
+
+TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
+}
+
+TEST_F(DurationAnnotatorTest, DoNotClassifyWhenInputIsInvalid) {
+ ClassificationResult classification;
+ EXPECT_FALSE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Weird space"), {5, 6},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+}
+
+TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 15 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3 * 60 * 60 * 1000 + 5 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
+ const UnicodeText text = UTF8ToUnicodeText(
+ "See you in a week and a day and an hour and a minute and a second");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
+ 60 * 60 * 1000 + 60 * 1000 + 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
+ const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 0.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 1 hour and a half");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for an hour and a half");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1.5 * 60 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 1 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
+ const UnicodeText text = UTF8ToUnicodeText(
+ "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 2 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
+ const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ ASSERT_EQ(result.size(), 0);
+}
+
+TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000 + 2 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
+ const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
+ ClassificationResult classification;
+ EXPECT_TRUE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("1400 hours"), {0, 10},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+
+ EXPECT_THAT(classification,
+ AllOf(Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 1400LL * 60LL * 60LL * 1000LL)));
+}
+
+TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 15 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, CorrectlyAnnotatesSpanWithDanglingQuantity) {
+ const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ // TODO(b/144752747) Include test for duration_ms.
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(Field(&ClassificationResult::collection,
+ "duration")))))));
+}
+
+const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
+ static const flatbuffers::DetachedBuffer* options_data = []() {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+
+ options.week_expressions.push_back("週間");
+
+ options.day_expressions.push_back("日間");
+
+ options.hour_expressions.push_back("時間");
+
+ options.minute_expressions.push_back("分");
+ options.minute_expressions.push_back("分間");
+
+ options.second_expressions.push_back("秒");
+ options.second_expressions.push_back("秒間");
+
+ options.half_expressions.push_back("半");
+
+ options.require_quantity = true;
+ options.enable_dangling_quantity_interpretation = false;
+
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+ }();
+
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+}
+
+class JapaneseDurationAnnotatorTest : public ::testing::Test {
+ protected:
+ JapaneseDurationAnnotatorTest()
+ : INIT_UNILIB_FOR_TESTING(unilib_),
+ feature_processor_(BuildFeatureProcessor(&unilib_)),
+ duration_annotator_(TestingJapaneseDurationAnnotatorOptions(),
+ feature_processor_.get(), &unilib_) {}
+
+ std::vector<Token> Tokenize(const UnicodeText& text) {
+ return feature_processor_->Tokenize(text);
+ }
+
+ UniLib unilib_;
+ std::unique_ptr<FeatureProcessor> feature_processor_;
+ DurationAnnotator duration_annotator_;
+};
+
+TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
+ const UnicodeText text = UTF8ToUnicodeText("10 分 の アラーム");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 4)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 10 * 60 * 1000)))))));
+}
+
+TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
+ const UnicodeText text = UTF8ToUnicodeText("2 分 半 の アラーム");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 5)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 2.5 * 60 * 1000)))))));
+}
+
+TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
+ const UnicodeText text = UTF8ToUnicodeText("分 の アラーム");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
+TEST_F(JapaneseDurationAnnotatorTest, IgnoresDanglingQuantity) {
+ const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 3)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 2 * 60 * 1000)))))));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
new file mode 100755
index 0000000..4c02f6d
--- /dev/null
+++ b/native/annotator/entity-data.fbs
@@ -0,0 +1,225 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.EntityData_.Datetime_;
+enum Granularity : int {
+ GRANULARITY_UNKNOWN = -1,
+ GRANULARITY_YEAR = 0,
+ GRANULARITY_MONTH = 1,
+ GRANULARITY_WEEK = 2,
+ GRANULARITY_DAY = 3,
+ GRANULARITY_HOUR = 4,
+ GRANULARITY_MINUTE = 5,
+ GRANULARITY_SECOND = 6,
+}
+
+namespace libtextclassifier3.EntityData_.Datetime_.DatetimeComponent_;
+enum ComponentType : int {
+ UNSPECIFIED = 0,
+ YEAR = 1,
+ MONTH = 2,
+ WEEK = 3,
+ DAY_OF_WEEK = 4,
+ DAY_OF_MONTH = 5,
+ HOUR = 6,
+ MINUTE = 7,
+ SECOND = 8,
+ MERIDIEM = 9,
+ ZONE_OFFSET = 10,
+ DST_OFFSET = 11,
+}
+
+// Enum to identify if the datetime component are relative or absolute.
+namespace libtextclassifier3.EntityData_.Datetime_.DatetimeComponent_;
+enum RelationType : int {
+ RELATION_UNSPECIFIED = 0,
+
+ // Absolute represents the datetime component that need no further
+ // calculation e.g. in a datetime span "21-03-2019" components
+ // year=2019, month=3 and day=21 is explicitly mentioned in the span
+ ABSOLUTE = 1,
+
+ // Identify datetime component where datetime expressions are relative.
+ // e.g. "three days ago", "2 days after March 1st", "next monday",
+ // "last Mondays".
+ RELATIVE = 2,
+}
+
+namespace libtextclassifier3.EntityData_.Datetime_;
+table DatetimeComponent {
+ component_type:DatetimeComponent_.ComponentType = UNSPECIFIED;
+ absolute_value:int;
+ relative_count:int;
+ relation_type:DatetimeComponent_.RelationType = RELATION_UNSPECIFIED;
+}
+
+namespace libtextclassifier3.EntityData_;
+table Datetime {
+ time_ms_utc:long;
+ granularity:Datetime_.Granularity = GRANULARITY_UNKNOWN;
+ datetime_component:[Datetime_.DatetimeComponent];
+}
+
+namespace libtextclassifier3.EntityData_;
+table Contact {
+ name:string (shared);
+ given_name:string (shared);
+ nickname:string (shared);
+ email_address:string (shared);
+ phone_number:string (shared);
+ contact_id:string (shared);
+}
+
+namespace libtextclassifier3.EntityData_;
+table App {
+ name:string (shared);
+ package_name:string (shared);
+}
+
+// The issuer/network of the payment card.
+namespace libtextclassifier3.EntityData_.PaymentCard_;
+enum CardNetwork : int {
+ UNKNOWN_CARD_NETWORK = 0,
+ AMEX = 1,
+ DINERS_CLUB = 2,
+ DISCOVER = 3,
+ INTER_PAYMENT = 4,
+ JCB = 5,
+ MAESTRO = 6,
+ MASTERCARD = 7,
+ MIR = 8,
+ TROY = 9,
+ UNIONPAY = 10,
+ VISA = 11,
+}
+
+// Details about a payment card.
+namespace libtextclassifier3.EntityData_;
+table PaymentCard {
+ card_network:PaymentCard_.CardNetwork;
+
+ // The card number.
+ card_number:string (shared);
+}
+
+// Details about a flight number.
+namespace libtextclassifier3.EntityData_;
+table Flight {
+ // The IATA or ICAO airline code of the flight number.
+ airline_code:string (shared);
+
+ // The flight number.
+ flight_number:string (shared);
+}
+
+// Details about an ISBN number.
+namespace libtextclassifier3.EntityData_;
+table Isbn {
+ // The (normalized) number.
+ number:string (shared);
+}
+
+// Details about an IBAN number.
+namespace libtextclassifier3.EntityData_;
+table Iban {
+ // The (normalized) number.
+ number:string (shared);
+
+ // The country code.
+ country_code:string (shared);
+}
+
+// The issuer/network of the package tracking number.
+namespace libtextclassifier3.EntityData_.ParcelTracking_;
+enum Carrier : int {
+ UNKNOWN_CARRIER = 0,
+ FEDEX = 1,
+ UPS = 2,
+ DHL = 3,
+ USPS = 4,
+ ONTRAC = 5,
+ LASERSHIP = 6,
+ ISRAEL_POST = 7,
+ SWISS_POST = 8,
+ MSC = 9,
+ AMAZON = 10,
+ I_PARCEL = 11,
+}
+
+// Details about a tracking number.
+namespace libtextclassifier3.EntityData_;
+table ParcelTracking {
+ carrier:ParcelTracking_.Carrier;
+ tracking_number:string (shared);
+}
+
+// Parsed money amount.
+namespace libtextclassifier3.EntityData_;
+table Money {
+ // String representation of currency, unnormalized.
+ unnormalized_currency:string (shared);
+
+ // Whole part of the amount (e.g. 123 from "CHF 123.45").
+ amount_whole_part:int;
+
+ // Decimal part of the amount (e.g. 45 from "CHF 123.45").
+ amount_decimal_part:int;
+
+ // Money amount (e.g. 123.45 from "CHF 123.45").
+ unnormalized_amount:string (shared);
+}
+
+namespace libtextclassifier3.EntityData_.Translate_;
+table LanguagePredictionResult {
+ // BCP 47 tag for the language prediction result.
+ language_tag:string (shared);
+
+ // Confidence score for the language prediction result.
+ confidence_score:float;
+}
+
+// Details about detected foreign text.
+namespace libtextclassifier3.EntityData_;
+table Translate {
+ language_prediction_results:[Translate_.LanguagePredictionResult];
+}
+
+// Represents an entity annotated in text.
+namespace libtextclassifier3;
+table EntityData {
+ // Codepoint indices of the annotation, start is inclusive, end is
+ // exclusive.
+ start:int;
+
+ end:int;
+
+ // The entity type, as in the TextClassifier APIs.
+ type:string (shared);
+
+ datetime:EntityData_.Datetime;
+ reserved_5:int (deprecated);
+ contact:EntityData_.Contact;
+ app:EntityData_.App;
+ payment_card:EntityData_.PaymentCard;
+ flight:EntityData_.Flight;
+ isbn:EntityData_.Isbn;
+ iban:EntityData_.Iban;
+ parcel:EntityData_.ParcelTracking;
+ money:EntityData_.Money;
+ translate:EntityData_.Translate;
+}
+
+root_type libtextclassifier3.EntityData;
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
new file mode 100644
index 0000000..389aae1
--- /dev/null
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+class ExperimentalAnnotator {
+ public:
+ // This is the dummy implementation of ExperimentalAnnotator and so it's
+ // always disabled;
+ static constexpr bool IsEnabled() { return false; }
+
+ explicit ExperimentalAnnotator(const ExperimentalModel* model,
+ const FeatureProcessor& feature_processor,
+ const UniLib& unilib) {}
+
+ bool Annotate(const UnicodeText& context,
+ std::vector<AnnotatedSpan>* candidates) const {
+ return false;
+ }
+
+ AnnotatedSpan SuggestSelection(const UnicodeText& context,
+ CodepointSpan click) const {
+ return {click, {}};
+ }
+
+ bool ClassifyText(const UnicodeText& context, CodepointSpan click,
+ ClassificationResult* result) const {
+ return false;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_DUMMY_H_
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
new file mode 100755
index 0000000..6e15d04
--- /dev/null
+++ b/native/annotator/experimental/experimental.fbs
@@ -0,0 +1,20 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3;
+table ExperimentalModel {
+}
+
diff --git a/native/annotator/experimental/experimental.h b/native/annotator/experimental/experimental.h
new file mode 100644
index 0000000..8144996
--- /dev/null
+++ b/native/annotator/experimental/experimental.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_H_
+
+#include "annotator/experimental/experimental-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_EXPERIMENTAL_EXPERIMENTAL_H_
diff --git a/native/annotator/feature-processor.cc b/native/annotator/feature-processor.cc
new file mode 100644
index 0000000..8d08574
--- /dev/null
+++ b/native/annotator/feature-processor.cc
@@ -0,0 +1,929 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/feature-processor.h"
+
+#include <iterator>
+#include <set>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+namespace internal {
+
+Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
+ const UniLib* unilib) {
+ std::vector<const TokenizationCodepointRange*> codepoint_config;
+ if (options->tokenization_codepoint_config() != nullptr) {
+ codepoint_config.insert(codepoint_config.end(),
+ options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end());
+ }
+ std::vector<const CodepointRange*> internal_codepoint_config;
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ internal_codepoint_config.insert(
+ internal_codepoint_config.end(),
+ options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end());
+ }
+ const bool tokenize_on_script_change =
+ options->tokenization_codepoint_config() != nullptr &&
+ options->tokenize_on_script_change();
+ return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
+ internal_codepoint_config, tokenize_on_script_change,
+ options->icu_preserve_whitespace_tokens());
+}
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions* const options) {
+ TokenFeatureExtractorOptions extractor_options;
+
+ extractor_options.num_buckets = options->num_buckets();
+ if (options->chargram_orders() != nullptr) {
+ for (int order : *options->chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
+ }
+ extractor_options.max_word_length = options->max_word_length();
+ extractor_options.extract_case_feature = options->extract_case_feature();
+ extractor_options.unicode_aware_features = options->unicode_aware_features();
+ extractor_options.extract_selection_mask_feature =
+ options->extract_selection_mask_feature();
+ if (options->regexp_feature() != nullptr) {
+ for (const auto& regexp_feauture : *options->regexp_feature()) {
+ extractor_options.regexp_features.push_back(regexp_feauture->str());
+ }
+ }
+ extractor_options.remap_digits = options->remap_digits();
+ extractor_options.lowercase_tokens = options->lowercase_tokens();
+
+ if (options->allowed_chargrams() != nullptr) {
+ for (const auto& chargram : *options->allowed_chargrams()) {
+ extractor_options.allowed_chargrams.insert(chargram->str());
+ }
+ }
+ return extractor_options;
+}
+
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens) {
+ for (auto it = tokens->begin(); it != tokens->end(); ++it) {
+ const UnicodeText token_word =
+ UTF8ToUnicodeText(it->value, /*do_copy=*/false);
+
+ auto last_start = token_word.begin();
+ int last_start_index = it->start;
+ std::vector<UnicodeText::const_iterator> split_points;
+
+ // Selection start split point.
+ if (selection.first > it->start && selection.first < it->end) {
+ std::advance(last_start, selection.first - last_start_index);
+ split_points.push_back(last_start);
+ last_start_index = selection.first;
+ }
+
+ // Selection end split point.
+ if (selection.second > it->start && selection.second < it->end) {
+ std::advance(last_start, selection.second - last_start_index);
+ split_points.push_back(last_start);
+ }
+
+ if (!split_points.empty()) {
+ // Add a final split for the rest of the token unless it's been all
+ // consumed already.
+ if (split_points.back() != token_word.end()) {
+ split_points.push_back(token_word.end());
+ }
+
+ std::vector<Token> replacement_tokens;
+ last_start = token_word.begin();
+ int current_pos = it->start;
+ for (const auto& split_point : split_points) {
+ Token new_token(token_word.UTF8Substring(last_start, split_point),
+ current_pos,
+ current_pos + std::distance(last_start, split_point));
+
+ last_start = split_point;
+ current_pos = new_token.end;
+
+ replacement_tokens.push_back(new_token);
+ }
+
+ it = tokens->erase(it);
+ it = tokens->insert(it, replacement_tokens.begin(),
+ replacement_tokens.end());
+ std::advance(it, replacement_tokens.size() - 1);
+ }
+ }
+}
+
+} // namespace internal
+
+void FeatureProcessor::StripTokensFromOtherLines(
+ const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const {
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ StripTokensFromOtherLines(context_unicode, span, tokens);
+}
+
+void FeatureProcessor::StripTokensFromOtherLines(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ std::vector<Token>* tokens) const {
+ std::vector<UnicodeTextRange> lines =
+ SplitContext(context_unicode, options_->use_pipe_character_for_newline());
+
+ auto span_start = context_unicode.begin();
+ if (span.first > 0) {
+ std::advance(span_start, span.first);
+ }
+ auto span_end = context_unicode.begin();
+ if (span.second > 0) {
+ std::advance(span_end, span.second);
+ }
+ for (const UnicodeTextRange& line : lines) {
+ // Find the line that completely contains the span.
+ if (line.first <= span_start && line.second >= span_end) {
+ const CodepointIndex last_line_begin_index =
+ std::distance(context_unicode.begin(), line.first);
+ const CodepointIndex last_line_end_index =
+ last_line_begin_index + std::distance(line.first, line.second);
+
+ for (auto token = tokens->begin(); token != tokens->end();) {
+ if (token->start >= last_line_begin_index &&
+ token->end <= last_line_end_index) {
+ ++token;
+ } else {
+ token = tokens->erase(token);
+ }
+ }
+ }
+ }
+}
+
+std::string FeatureProcessor::GetDefaultCollection() const {
+ if (options_->default_collection() < 0 ||
+ options_->collections() == nullptr ||
+ options_->default_collection() >= options_->collections()->size()) {
+ TC3_LOG(ERROR)
+ << "Invalid or missing default collection. Returning empty string.";
+ return "";
+ }
+ return (*options_->collections())[options_->default_collection()]->str();
+}
+
+std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
+ return tokenizer_.Tokenize(text);
+}
+
+std::vector<Token> FeatureProcessor::Tokenize(
+ const UnicodeText& text_unicode) const {
+ return tokenizer_.Tokenize(text_unicode);
+}
+
+bool FeatureProcessor::LabelToSpan(
+ const int label, const VectorSpan<Token>& tokens,
+ std::pair<CodepointIndex, CodepointIndex>* span) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ TokenSpan token_span;
+ if (!LabelToTokenSpan(label, &token_span)) {
+ return false;
+ }
+
+ const int result_begin_token_index = token_span.first;
+ const Token& result_begin_token =
+ tokens[options_->context_size() - result_begin_token_index];
+ const int result_begin_codepoint = result_begin_token.start;
+ const int result_end_token_index = token_span.second;
+ const Token& result_end_token =
+ tokens[options_->context_size() + result_end_token_index];
+ const int result_end_codepoint = result_end_token.end;
+
+ if (result_begin_codepoint == kInvalidIndex ||
+ result_end_codepoint == kInvalidIndex) {
+ *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+ } else {
+ const UnicodeText token_begin_unicode =
+ UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
+ const UnicodeText token_end_unicode =
+ UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
+ UnicodeText::const_iterator token_end = token_end_unicode.end();
+
+ const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
+ token_begin, token_begin_unicode.end(),
+ /*count_from_beginning=*/true);
+ const int end_ignored =
+ CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
+ /*count_from_beginning=*/false);
+ // In case everything would be stripped, set the span to the original
+ // beginning and zero length.
+ if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
+ *span = {result_begin_codepoint, result_begin_codepoint};
+ } else {
+ *span = CodepointSpan({result_begin_codepoint + begin_ignored,
+ result_end_codepoint - end_ignored});
+ }
+ }
+ return true;
+}
+
+bool FeatureProcessor::LabelToTokenSpan(const int label,
+ TokenSpan* token_span) const {
+ if (label >= 0 && label < label_to_selection_.size()) {
+ *token_span = label_to_selection_[label];
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool FeatureProcessor::SpanToLabel(
+ const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& tokens, int* label) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ const int click_position =
+ options_->context_size(); // Click is always in the middle.
+ const int padding = options_->context_size() - options_->max_selection_span();
+
+ int span_left = 0;
+ for (int i = click_position - 1; i >= padding; i--) {
+ if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
+ ++span_left;
+ } else {
+ break;
+ }
+ }
+
+ int span_right = 0;
+ for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
+ if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
+ ++span_right;
+ } else {
+ break;
+ }
+ }
+
+ // Check that the spanned tokens cover the whole span.
+ bool tokens_match_span;
+ const CodepointIndex tokens_start = tokens[click_position - span_left].start;
+ const CodepointIndex tokens_end = tokens[click_position + span_right].end;
+ if (options_->snap_label_span_boundaries_to_containing_tokens()) {
+ tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
+ } else {
+ const UnicodeText token_left_unicode = UTF8ToUnicodeText(
+ tokens[click_position - span_left].value, /*do_copy=*/false);
+ const UnicodeText token_right_unicode = UTF8ToUnicodeText(
+ tokens[click_position + span_right].value, /*do_copy=*/false);
+
+ UnicodeText::const_iterator span_begin = token_left_unicode.begin();
+ UnicodeText::const_iterator span_end = token_right_unicode.end();
+
+ const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
+ const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
+ token_right_unicode.begin(), span_end,
+ /*count_from_beginning=*/false);
+
+ tokens_match_span = tokens_start <= span.first &&
+ tokens_start + num_punctuation_start >= span.first &&
+ tokens_end >= span.second &&
+ tokens_end - num_punctuation_end <= span.second;
+ }
+
+ if (tokens_match_span) {
+ *label = TokenSpanToLabel({span_left, span_right});
+ } else {
+ *label = kInvalidLabel;
+ }
+
+ return true;
+}
+
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
+ auto it = selection_to_label_.find(span);
+ if (it != selection_to_label_.end()) {
+ return it->second;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ TokenIndex start_token = kInvalidIndex;
+ TokenIndex end_token = kInvalidIndex;
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ bool is_token_in_span;
+ if (snap_boundaries_to_containing_tokens) {
+ is_token_in_span = codepoint_start < selectable_tokens[i].end &&
+ codepoint_end > selectable_tokens[i].start;
+ } else {
+ is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
+ codepoint_end >= selectable_tokens[i].end;
+ }
+ if (is_token_in_span && !selectable_tokens[i].is_padding) {
+ if (start_token == kInvalidIndex) {
+ start_token = i;
+ }
+ end_token = i + 1;
+ }
+ }
+ return {start_token, end_token};
+}
+
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
+ return {selectable_tokens[token_span.first].start,
+ selectable_tokens[token_span.second - 1].end};
+}
+
+namespace {
+
+// Finds a single token that completely contains the given span.
+int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ if (codepoint_start >= selectable_tokens[i].start &&
+ codepoint_end <= selectable_tokens[i].end) {
+ return i;
+ }
+ }
+ return kInvalidIndex;
+}
+
+} // namespace
+
+namespace internal {
+
+int CenterTokenFromClick(CodepointSpan span,
+ const std::vector<Token>& selectable_tokens) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) =
+ CodepointSpanToTokenSpan(selectable_tokens, span);
+
+ // If no exact match was found, try finding a token that completely contains
+ // the click span. This is useful e.g. when Android builds the selection
+ // using ICU tokenization, and ends up with only a portion of our space-
+ // separated token. E.g. for "(857)" Android would select "857".
+ if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
+ int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
+ if (token_index != kInvalidIndex) {
+ range_begin = token_index;
+ range_end = token_index + 1;
+ }
+ }
+
+ // We only allow clicks that are exactly 1 selectable token.
+ if (range_end - range_begin == 1) {
+ return range_begin;
+ } else {
+ return kInvalidIndex;
+ }
+}
+
+int CenterTokenFromMiddleOfSelection(
+ CodepointSpan span, const std::vector<Token>& selectable_tokens) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) =
+ CodepointSpanToTokenSpan(selectable_tokens, span);
+
+ // Center the clicked token in the selection range.
+ if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
+ return (range_begin + range_end - 1) / 2;
+ } else {
+ return kInvalidIndex;
+ }
+}
+
+} // namespace internal
+
+int FeatureProcessor::FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const {
+ if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
+ return internal::CenterTokenFromClick(span, tokens);
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
+ return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+ } else if (options_->center_token_selection_method() ==
+ FeatureProcessorOptions_::
+ CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
+ // TODO(zilka): Remove once we have new models on the device.
+ // It uses the fact that sharing model use
+ // split_tokens_on_selection_boundaries and selection not. So depending on
+ // this we select the right way of finding the click location.
+ if (!options_->split_tokens_on_selection_boundaries()) {
+ // SmartSelection model.
+ return internal::CenterTokenFromClick(span, tokens);
+ } else {
+ // SmartSharing model.
+ return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+ }
+ } else {
+ TC3_LOG(ERROR) << "Invalid center token selection method.";
+ return kInvalidIndex;
+ }
+}
+
+bool FeatureProcessor::SelectionLabelSpans(
+ const VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const {
+ for (int i = 0; i < label_to_selection_.size(); ++i) {
+ CodepointSpan span;
+ if (!LabelToSpan(i, tokens, &span)) {
+ TC3_LOG(ERROR) << "Could not convert label to span: " << i;
+ return false;
+ }
+ selection_label_spans->push_back(span);
+ }
+ return true;
+}
+
+void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
+ if (options_->ignored_span_boundary_codepoints() != nullptr) {
+ for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
+ ignored_span_boundary_codepoints_.insert(codepoint);
+ }
+ }
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const {
+ return CountIgnoredSpanBoundaryCodepoints(span_start, span_end,
+ count_from_beginning,
+ ignored_span_boundary_codepoints_);
+}
+
+int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end, bool count_from_beginning,
+ const std::unordered_set<int>& ignored_span_boundary_codepoints) const {
+ if (span_start == span_end) {
+ return 0;
+ }
+
+ UnicodeText::const_iterator it;
+ UnicodeText::const_iterator it_last;
+ if (count_from_beginning) {
+ it = span_start;
+ it_last = span_end;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it_last;
+ } else {
+ it = span_end;
+ it_last = span_start;
+ // We can assume that the string is non-zero length because of the check
+ // above, thus the decrement is always valid here.
+ --it;
+ }
+
+ // Move until we encounter a non-ignored character.
+ int num_ignored = 0;
+ while (ignored_span_boundary_codepoints.find(*it) !=
+ ignored_span_boundary_codepoints.end()) {
+ ++num_ignored;
+
+ if (it == it_last) {
+ break;
+ }
+
+ if (count_from_beginning) {
+ ++it;
+ } else {
+ --it;
+ }
+ }
+
+ return num_ignored;
+}
+
+namespace {
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+ std::vector<UnicodeTextRange>* ranges) {
+ UnicodeText::const_iterator start = t.begin();
+ UnicodeText::const_iterator curr = start;
+ UnicodeText::const_iterator end = t.end();
+ for (; curr != end; ++curr) {
+ if (codepoints.find(*curr) != codepoints.end()) {
+ if (start != curr) {
+ ranges->push_back(std::make_pair(start, curr));
+ }
+ start = curr;
+ ++start;
+ }
+ }
+ if (start != end) {
+ ranges->push_back(std::make_pair(start, end));
+ }
+}
+
+} // namespace
+
+std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
+ const UnicodeText& context_unicode,
+ const bool use_pipe_character_for_newline) const {
+ std::vector<UnicodeTextRange> lines;
+ std::set<char32> codepoints{'\n'};
+ if (use_pipe_character_for_newline) {
+ codepoints.insert('|');
+ }
+ FindSubstrings(context_unicode, codepoints, &lines);
+ return lines;
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span) const {
+ return StripBoundaryCodepoints(context, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ return StripBoundaryCodepoints(context_unicode, span,
+ ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span) const {
+ return StripBoundaryCodepoints(context_unicode, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
+ if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
+ return span;
+ }
+
+ UnicodeText::const_iterator span_begin = context_unicode.begin();
+ std::advance(span_begin, span.first);
+ UnicodeText::const_iterator span_end = context_unicode.begin();
+ std::advance(span_end, span.second);
+
+ return StripBoundaryCodepoints(span_begin, span_end, span,
+ ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
+ return StripBoundaryCodepoints(span_begin, span_end, span,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
+ if (!ValidNonEmptySpan(span) || span_begin == span_end) {
+ return span;
+ }
+
+ const int start_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/true,
+ ignored_prefix_span_boundary_codepoints);
+ const int end_offset = CountIgnoredSpanBoundaryCodepoints(
+ span_begin, span_end, /*count_from_beginning=*/false,
+ ignored_suffix_span_boundary_codepoints);
+
+ if (span.first + start_offset < span.second - end_offset) {
+ return {span.first + start_offset, span.second - end_offset};
+ } else {
+ return {span.first, span.first};
+ }
+}
+
+float FeatureProcessor::SupportedCodepointsRatio(
+ const TokenSpan& token_span, const std::vector<Token>& tokens) const {
+ int num_supported = 0;
+ int num_total = 0;
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ const UnicodeText value =
+ UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
+ for (auto codepoint : value) {
+ if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
+ ++num_supported;
+ }
+ ++num_total;
+ }
+ }
+ // Avoid division by zero.
+ if (num_total == 0) {
+ return 0.0;
+ }
+ return static_cast<float>(num_supported) / static_cast<float>(num_total);
+}
+
+const std::string& FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer) const {
+ return StripBoundaryCodepoints(value, buffer,
+ ignored_span_boundary_codepoints_,
+ ignored_span_boundary_codepoints_);
+}
+
+const std::string& FeatureProcessor::StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const {
+ const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
+ const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
+ const CodepointSpan stripped_span = StripBoundaryCodepoints(
+ value_unicode, initial_span, ignored_prefix_span_boundary_codepoints,
+ ignored_suffix_span_boundary_codepoints);
+
+ if (initial_span != stripped_span) {
+ const UnicodeText stripped_token_value =
+ UnicodeText::Substring(value_unicode, stripped_span.first,
+ stripped_span.second, /*do_copy=*/false);
+ *buffer = stripped_token_value.ToUTF8String();
+ return *buffer;
+ }
+ return value;
+}
+
+int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
+ const auto it = collection_to_label_.find(collection);
+ if (it == collection_to_label_.end()) {
+ return options_->default_collection();
+ } else {
+ return it->second;
+ }
+}
+
+std::string FeatureProcessor::LabelToCollection(int label) const {
+ if (label >= 0 && label < collection_to_label_.size()) {
+ return (*options_->collections())[label]->str();
+ } else {
+ return GetDefaultCollection();
+ }
+}
+
+void FeatureProcessor::MakeLabelMaps() {
+ if (options_->collections() != nullptr) {
+ for (int i = 0; i < options_->collections()->size(); ++i) {
+ collection_to_label_[(*options_->collections())[i]->str()] = i;
+ }
+ }
+
+ int selection_label_id = 0;
+ for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
+ for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
+ if (!options_->selection_reduced_output_space() ||
+ r + l <= options_->max_selection_span()) {
+ TokenSpan token_span{l, r};
+ selection_to_label_[token_span] = selection_label_id;
+ label_to_selection_.push_back(token_span);
+ ++selection_label_id;
+ }
+ }
+ }
+}
+
+void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens,
+ int* click_pos) const {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
+ tokens, click_pos);
+}
+
+void FeatureProcessor::RetokenizeAndFindClick(
+ const UnicodeText& context_unicode, CodepointSpan input_span,
+ bool only_use_line_with_click, std::vector<Token>* tokens,
+ int* click_pos) const {
+ TC3_CHECK(tokens != nullptr);
+
+ if (options_->split_tokens_on_selection_boundaries()) {
+ internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
+ }
+
+ if (only_use_line_with_click) {
+ StripTokensFromOtherLines(context_unicode, input_span, tokens);
+ }
+
+ int local_click_pos;
+ if (click_pos == nullptr) {
+ click_pos = &local_click_pos;
+ }
+ *click_pos = FindCenterToken(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
+ // If the default click method failed, let's try to do sub-token matching
+ // before we fail.
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ }
+}
+
+namespace internal {
+
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos) {
+ int right_context_needed = relative_click_span.second + context_size;
+ if (*click_pos + right_context_needed + 1 >= tokens->size()) {
+ // Pad max the context size.
+ const int num_pad_tokens = std::min(
+ context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
+ tokens->size()));
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+ } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos + right_context_needed + 1);
+ tokens->erase(it, tokens->end());
+ }
+
+ int left_context_needed = relative_click_span.first + context_size;
+ if (*click_pos < left_context_needed) {
+ // Pad max the context size.
+ const int num_pad_tokens =
+ std::min(context_size, left_context_needed - *click_pos);
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+ *click_pos += num_pad_tokens;
+ } else if (*click_pos > left_context_needed) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos - left_context_needed);
+ *click_pos -= it - tokens->begin();
+ tokens->erase(tokens->begin(), it);
+ }
+}
+
+} // namespace internal
+
+bool FeatureProcessor::HasEnoughSupportedCodepoints(
+ const std::vector<Token>& tokens, TokenSpan token_span) const {
+ if (options_->min_supported_codepoint_ratio() > 0) {
+ const float supported_codepoint_ratio =
+ SupportedCodepointsRatio(token_span, tokens);
+ if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
+ TC3_VLOG(1) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
+ return false;
+ }
+ }
+ return true;
+}
+
+bool FeatureProcessor::ExtractFeatures(
+ const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const {
+ std::unique_ptr<std::vector<float>> features(new std::vector<float>());
+ features->reserve(feature_vector_size * TokenSpanSize(token_span));
+ for (int i = token_span.first; i < token_span.second; ++i) {
+ if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ features.get())) {
+ TC3_LOG(ERROR) << "Could not get token features.";
+ return false;
+ }
+ }
+
+ std::unique_ptr<std::vector<float>> padding_features(
+ new std::vector<float>());
+ padding_features->reserve(feature_vector_size);
+ if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
+ embedding_executor, embedding_cache,
+ padding_features.get())) {
+ TC3_LOG(ERROR) << "Count not get padding token features.";
+ return false;
+ }
+
+ *cached_features = CachedFeatures::Create(token_span, std::move(features),
+ std::move(padding_features),
+ options_, feature_vector_size);
+ if (!*cached_features) {
+ TC3_LOG(ERROR) << "Cound not create cached features.";
+ return false;
+ }
+
+ return true;
+}
+
+bool FeatureProcessor::AppendTokenFeaturesWithCache(
+ const Token& token, CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const {
+ // Look for the embedded features for the token in the cache, if there is one.
+ if (embedding_cache) {
+ const auto it = embedding_cache->find({token.start, token.end});
+ if (it != embedding_cache->end()) {
+ // The embedded features were found in the cache, extract only the dense
+ // features.
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ /*sparse_features=*/nullptr, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's dense features.";
+ return false;
+ }
+
+ // Append both embedded and dense features to the output and return.
+ output_features->insert(output_features->end(), it->second.begin(),
+ it->second.end());
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+ }
+ }
+
+ // Extract the sparse and dense features.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ if (!feature_extractor_.Extract(
+ token, token.IsContainedInSpan(selection_span_for_feature),
+ &sparse_features, &dense_features)) {
+ TC3_LOG(ERROR) << "Could not extract token's features.";
+ return false;
+ }
+
+ // Embed the sparse features, appending them directly to the output.
+ const int embedding_size = GetOptions()->embedding_size();
+ output_features->resize(output_features->size() + embedding_size);
+ float* output_features_end =
+ output_features->data() + output_features->size();
+ if (!embedding_executor->AddEmbedding(
+ TensorView<int>(sparse_features.data(),
+ {static_cast<int>(sparse_features.size())}),
+ /*dest=*/output_features_end - embedding_size,
+ /*dest_size=*/embedding_size)) {
+ TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
+ return false;
+ }
+
+ // If there is a cache, the embedded features for the token were not in it,
+ // so insert them.
+ if (embedding_cache) {
+ (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
+ output_features_end - embedding_size, output_features_end);
+ }
+
+ // Append the dense features to the output.
+ output_features->insert(output_features->end(), dense_features.begin(),
+ dense_features.end());
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/feature-processor.h b/native/annotator/feature-processor.h
new file mode 100644
index 0000000..78dbbce
--- /dev/null
+++ b/native/annotator/feature-processor.h
@@ -0,0 +1,327 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature processing for FFModel (feed-forward SmartSelection model).
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
+
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "annotator/cached-features.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/token-feature-extractor.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+constexpr int kInvalidLabel = -1;
+
+namespace internal {
+
+Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
+ const UniLib* unilib);
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions* options);
+
+// Splits tokens that contain the selection boundary inside them.
+// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens);
+
+// Returns the index of token that corresponds to the codepoint span.
+int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+
+// Returns the index of token that corresponds to the middle of the codepoint
+// span.
+int CenterTokenFromMiddleOfSelection(
+ CodepointSpan span, const std::vector<Token>& selectable_tokens);
+
+// Strips the tokens from the tokens vector that are not used for feature
+// extraction because they are out of scope, or pads them so that there is
+// enough tokens in the required context_size for all inferences with a click
+// in relative_click_span.
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos);
+
+} // namespace internal
+
+// Converts a codepoint span to a token span in the given list of tokens.
+// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
+// token to overlap with the codepoint range to be considered part of it.
+// Otherwise it must be fully included in the range.
+TokenSpan CodepointSpanToTokenSpan(
+ const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens = false);
+
+// Converts a token span to a codepoint span in the given list of tokens.
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+
+// Takes care of preparing features for the span prediction model.
+class FeatureProcessor {
+ public:
+ // A cache mapping codepoint spans to embedded tokens features. An instance
+ // can be provided to multiple calls to ExtractFeatures() operating on the
+ // same context (the same codepoint spans corresponding to the same tokens),
+ // as an optimization. Note that the tokenizations do not have to be
+ // identical.
+ typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
+
+ explicit FeatureProcessor(const FeatureProcessorOptions* options,
+ const UniLib* unilib)
+ : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
+ unilib),
+ options_(options),
+ tokenizer_(internal::BuildTokenizer(options, unilib)) {
+ MakeLabelMaps();
+ if (options->supported_codepoint_ranges() != nullptr) {
+ SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
+ options->supported_codepoint_ranges()->end()},
+ &supported_codepoint_ranges_);
+ }
+ PrepareIgnoredSpanBoundaryCodepoints();
+ }
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
+
+ // Converts a label into a token span.
+ bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+
+ // Gets the total number of selection labels.
+ int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
+ // Gets the string value for given collection label.
+ std::string LabelToCollection(int label) const;
+
+ // Gets the total number of collections of the model.
+ int NumCollections() const { return collection_to_label_.size(); }
+
+ // Gets the name of the default collection.
+ std::string GetDefaultCollection() const;
+
+ const FeatureProcessorOptions* GetOptions() const { return options_; }
+
+ // Retokenizes the context and input span, and finds the click position.
+ // Depending on the options, might modify tokens (split them or remove them).
+ void RetokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Same as above but takes UnicodeText.
+ void RetokenizeAndFindClick(const UnicodeText& context_unicode,
+ CodepointSpan input_span,
+ bool only_use_line_with_click,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Returns true if the token span has enough supported codepoints (as defined
+ // in the model config) or not and model should not run.
+ bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
+ TokenSpan token_span) const;
+
+ // Extracts features as a CachedFeatures object that can be used for repeated
+ // inference over token spans in the given context.
+ bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache, int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const;
+
+ // Fills selection_label_spans with CodepointSpans that correspond to the
+ // selection labels. The CodepointSpans are based on the codepoint ranges of
+ // given tokens.
+ bool SelectionLabelSpans(
+ VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const;
+
+ int DenseFeaturesCount() const {
+ return feature_extractor_.DenseFeaturesCount();
+ }
+
+ int EmbeddingSize() const { return options_->embedding_size(); }
+
+ // Splits context to several segments.
+ std::vector<UnicodeTextRange> SplitContext(
+ const UnicodeText& context_unicode,
+ const bool use_pipe_character_for_newline) const;
+
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const std::string& context, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
+ // Same as above but takes UnicodeText.
+ CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
+ CodepointSpan span) const;
+
+ // Same as the previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText& context_unicode, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
+ // Same as above but takes a pair of iterators for the span, for efficiency.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span) const;
+
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ CodepointSpan StripBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
+ // Same as above, but takes an optional buffer for saving the modified value.
+ // As an optimization, returns pointer to 'value' if nothing was stripped, or
+ // pointer to 'buffer' if something was stripped.
+ const std::string& StripBoundaryCodepoints(const std::string& value,
+ std::string* buffer) const;
+
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ const std::string& StripBoundaryCodepoints(
+ const std::string& value, std::string* buffer,
+ const std::unordered_set<int>& ignored_prefix_span_boundary_codepoints,
+ const std::unordered_set<int>& ignored_suffix_span_boundary_codepoints)
+ const;
+
+ protected:
+ // Returns the class id corresponding to the given string collection
+ // identifier. There is a catch-all class id that the function returns for
+ // unknown collections.
+ int CollectionToLabel(const std::string& collection) const;
+
+ // Prepares mapping from collection names to labels.
+ void MakeLabelMaps();
+
+ // Gets the number of spannable tokens for the model.
+ //
+ // Spannable tokens are those tokens of context, which the model predicts
+ // selection spans over (i.e., there is 1:1 correspondence between the output
+ // classes of the model and each of the spannable tokens).
+ int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
+
+ // Converts a label into a span of codepoint indices corresponding to it
+ // given output_tokens.
+ bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
+ CodepointSpan* span) const;
+
+ // Converts a span to the corresponding label given output_tokens.
+ bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& output_tokens, int* label) const;
+
+ // Converts a token span to the corresponding label.
+ int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+
+ // Returns the ratio of supported codepoints to total number of codepoints in
+ // the given token span.
+ float SupportedCodepointsRatio(const TokenSpan& token_span,
+ const std::vector<Token>& tokens) const;
+
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
+ // Same as previous, but also takes the ignored span boundary codepoints.
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end, bool count_from_beginning,
+ const std::unordered_set<int>& ignored_span_boundary_codepoints) const;
+
+ // Finds the center token index in tokens vector, using the method defined
+ // in options_.
+ int FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const;
+
+ // Removes all tokens from tokens that are not on a line (defined by calling
+ // SplitContext on the context) to which span points.
+ void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
+ // Same as above but takes UnicodeText.
+ void StripTokensFromOtherLines(const UnicodeText& context_unicode,
+ CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
+ // Extracts the features of a token and appends them to the output vector.
+ // Uses the embedding cache to to avoid re-extracting the re-embedding the
+ // sparse features for the same token.
+ bool AppendTokenFeaturesWithCache(const Token& token,
+ CodepointSpan selection_span_for_feature,
+ const EmbeddingExecutor* embedding_executor,
+ EmbeddingCache* embedding_cache,
+ std::vector<float>* output_features) const;
+
+ protected:
+ const TokenFeatureExtractor feature_extractor_;
+
+ // Codepoint ranges that define what codepoints are supported by the model.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
+
+ private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::unordered_set<int32> ignored_span_boundary_codepoints_;
+
+ const FeatureProcessorOptions* const options_;
+
+ // Mapping between token selection spans and labels ids.
+ std::map<TokenSpan, int> selection_to_label_;
+ std::vector<TokenSpan> label_to_selection_;
+
+ // Mapping between collections and labels.
+ std::map<std::string, int> collection_to_label_;
+
+ Tokenizer tokenizer_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-options.h b/native/annotator/grammar/dates/annotations/annotation-options.h
new file mode 100755
index 0000000..29e9939
--- /dev/null
+++ b/native/annotator/grammar/dates/annotations/annotation-options.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+// Options for date/datetime/date range annotations.
+struct DateAnnotationOptions {
+ // If enabled, extract special day offset like today, yesterday, etc.
+ bool enable_special_day_offset;
+
+ // If true, merge the adjacent day of week, time and date. e.g.
+ // "20/2/2016 at 8pm" is extracted as a single instance instead of two
+ // instance: "20/2/2016" and "8pm".
+ bool merge_adjacent_components;
+
+ // List the extra id of requested dates.
+ std::vector<std::string> extra_requested_dates;
+
+ // If true, try to include preposition to the extracted annotation. e.g.
+ // "at 6pm". if it's false, only 6pm is included. offline-actions has special
+ // requirements to include preposition.
+ bool include_preposition;
+
+ // The base timestamp (milliseconds) which used to convert relative time to
+ // absolute time.
+ // e.g.:
+ // base timestamp is 2016/4/25, then tomorrow will be converted to
+ // 2016/4/26.
+ // base timestamp is 2016/4/25 10:30:20am, then 1 days, 2 hours, 10 minutes
+ // and 5 seconds ago will be converted to 2016/4/24 08:20:15am
+ int64 base_timestamp_millis;
+
+ // If enabled, extract range in date annotator.
+ // input: Monday, 5-6pm
+ // If the flag is true, The extracted annotation only contains 1 range
+ // instance which is from Monday 5pm to 6pm.
+ // If the flag is false, The extracted annotation contains two date
+ // instance: "Monday" and "6pm".
+ bool enable_date_range;
+
+ // Timezone in which the input text was written
+ std::string reference_timezone;
+ // Localization params.
+ // The format of the locale lists should be "<lang_code-<county_code>"
+ // comma-separated list of two-character language/country pairs.
+ std::string locales;
+
+ // If enabled, the annotation/rule_match priority score is used to set the and
+ // priority score of the annotation.
+ // In case of false the annotation priority score are set from
+ // GrammarDatetimeModel's priority_score
+ bool use_rule_priority_score;
+
+ // If enabled, annotator will try to resolve the ambiguity by generating
+ // possible alternative interpretations of the input text
+ // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
+ bool generate_alternative_interpretations_when_ambiguous;
+
+ // List the ignored span in the date string e.g. 12 March @12PM, here '@'
+ // can be ignored tokens.
+ std::vector<std::string> ignored_spans;
+
+ // Default Constructor
+ DateAnnotationOptions()
+ : enable_special_day_offset(true),
+ merge_adjacent_components(true),
+ include_preposition(false),
+ base_timestamp_millis(0),
+ enable_date_range(false),
+ use_rule_priority_score(false),
+ generate_alternative_interpretations_when_ambiguous(false) {}
+};
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.cc b/native/annotator/grammar/dates/annotations/annotation-util.cc
new file mode 100644
index 0000000..438206f
--- /dev/null
+++ b/native/annotator/grammar/dates/annotations/annotation-util.cc
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/annotations/annotation-util.h"
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data) {
+ for (int i = 0; i < annotation_data.properties.size(); ++i) {
+ if (annotation_data.properties[i].name == name.ToString()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+int GetPropertyIndex(StringPiece name, const Annotation& annotation) {
+ return GetPropertyIndex(name, annotation.data);
+}
+
+int GetIntProperty(StringPiece name, const Annotation& annotation) {
+ return GetIntProperty(name, annotation.data);
+}
+
+int GetIntProperty(StringPiece name, const AnnotationData& annotation_data) {
+ const int index = GetPropertyIndex(name, annotation_data);
+ if (index < 0) {
+ TC3_DCHECK_GE(index, 0)
+ << "No property with name " << name.ToString() << ".";
+ return 0;
+ }
+
+ if (annotation_data.properties.at(index).int_values.size() != 1) {
+ TC3_DCHECK_EQ(annotation_data.properties[index].int_values.size(), 1);
+ return 0;
+ }
+
+ return annotation_data.properties.at(index).int_values.at(0);
+}
+
+int AddIntProperty(StringPiece name, int value, Annotation* annotation) {
+ return AddRepeatedIntProperty(name, &value, 1, annotation);
+}
+
+int AddIntProperty(StringPiece name, int value,
+ AnnotationData* annotation_data) {
+ return AddRepeatedIntProperty(name, &value, 1, annotation_data);
+}
+
+int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
+ Annotation* annotation) {
+ return AddRepeatedIntProperty(name, start, size, &annotation->data);
+}
+
+int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
+ AnnotationData* annotation_data) {
+ Property property;
+ property.name = name.ToString();
+ auto first = start;
+ auto last = start + size;
+ while (first != last) {
+ property.int_values.push_back(*first);
+ first++;
+ }
+ annotation_data->properties.push_back(property);
+ return annotation_data->properties.size() - 1;
+}
+
+int AddAnnotationDataProperty(const std::string& key,
+ const AnnotationData& value,
+ AnnotationData* annotation_data) {
+ Property property;
+ property.name = key;
+ property.annotation_data_values.push_back(value);
+ annotation_data->properties.push_back(property);
+ return annotation_data->properties.size() - 1;
+}
+
+int AddAnnotationDataProperty(const std::string& key,
+ const AnnotationData& value,
+ Annotation* annotation) {
+ return AddAnnotationDataProperty(key, value, &annotation->data);
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation-util.h b/native/annotator/grammar/dates/annotations/annotation-util.h
new file mode 100644
index 0000000..e4afbfe
--- /dev/null
+++ b/native/annotator/grammar/dates/annotations/annotation-util.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
+
+#include "annotator/grammar/dates/annotations/annotation.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Return the index of property in annotation.data().properties().
+// Return -1 if the property does not exist.
+int GetPropertyIndex(StringPiece name, const Annotation& annotation);
+
+// Return the index of property in thing.properties().
+// Return -1 if the property does not exist.
+int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data);
+
+// Return the single int value for property 'name' of the annotation.
+// Returns 0 if the property does not exist or does not contain a single int
+// value.
+int GetIntProperty(StringPiece name, const Annotation& annotation);
+
+// Return the single float value for property 'name' of the annotation.
+// Returns 0 if the property does not exist or does not contain a single int
+// value.
+int GetIntProperty(StringPiece name, const AnnotationData& annotation_data);
+
+// Add a new property with a single int value to an Annotation instance.
+// Return the index of the property.
+int AddIntProperty(StringPiece name, const int value, Annotation* annotation);
+
+// Add a new property with a single int value to a Thing instance.
+// Return the index of the property.
+int AddIntProperty(StringPiece name, const int value,
+ AnnotationData* annotation_data);
+
+// Add a new property with repeated int values to an Annotation instance.
+// Return the index of the property.
+int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
+ Annotation* annotation);
+
+// Add a new property with repeated int values to a Thing instance.
+// Return the index of the property.
+int AddRepeatedIntProperty(StringPiece name, const int* start, int size,
+ AnnotationData* annotation_data);
+
+// Add a new property with Thing value.
+// Return the index of the property.
+int AddAnnotationDataProperty(const std::string& key,
+ const AnnotationData& value,
+ Annotation* annotation);
+
+// Add a new property with Thing value.
+// Return the index of the property.
+int AddAnnotationDataProperty(const std::string& key,
+ const AnnotationData& value,
+ AnnotationData* annotation_data);
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_
diff --git a/native/annotator/grammar/dates/annotations/annotation-util_test.cc b/native/annotator/grammar/dates/annotations/annotation-util_test.cc
new file mode 100644
index 0000000..6d25d64
--- /dev/null
+++ b/native/annotator/grammar/dates/annotations/annotation-util_test.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/annotations/annotation-util.h"
+
+#include "annotator/grammar/dates/annotations/annotation.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(AnnotationUtilTest, VerifyIntFunctions) {
+ Annotation annotation;
+
+ int index_key1 = AddIntProperty("key1", 1, &annotation);
+ int index_key2 = AddIntProperty("key2", 2, &annotation);
+
+ static const int kValuesKey3[] = {3, 4, 5};
+ int index_key3 =
+ AddRepeatedIntProperty("key3", kValuesKey3, /*size=*/3, &annotation);
+
+ EXPECT_EQ(2, GetIntProperty("key2", annotation));
+ EXPECT_EQ(1, GetIntProperty("key1", annotation));
+
+ EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
+ EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
+ EXPECT_EQ(index_key3, GetPropertyIndex("key3", annotation));
+ EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
+}
+
+TEST(AnnotationUtilTest, VerifyAnnotationDataFunctions) {
+ Annotation annotation;
+
+ AnnotationData true_annotation_data;
+ Property true_property;
+ true_property.bool_values.push_back(true);
+ true_annotation_data.properties.push_back(true_property);
+ int index_key1 =
+ AddAnnotationDataProperty("key1", true_annotation_data, &annotation);
+
+ AnnotationData false_annotation_data;
+ Property false_property;
+ false_property.bool_values.push_back(false);
+ true_annotation_data.properties.push_back(false_property);
+ int index_key2 =
+ AddAnnotationDataProperty("key2", false_annotation_data, &annotation);
+
+ EXPECT_EQ(index_key1, GetPropertyIndex("key1", annotation));
+ EXPECT_EQ(index_key2, GetPropertyIndex("key2", annotation));
+ EXPECT_EQ(-1, GetPropertyIndex("invalid_key", annotation));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/annotations/annotation.h b/native/annotator/grammar/dates/annotations/annotation.h
new file mode 100644
index 0000000..e6ddb09
--- /dev/null
+++ b/native/annotator/grammar/dates/annotations/annotation.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+struct AnnotationData;
+
+// Define enum for each annotation.
+enum GrammarAnnotationType {
+ // Date&time like "May 1", "12:20pm", etc.
+ DATETIME = 0,
+ // Datetime range like "2pm - 3pm".
+ DATETIME_RANGE = 1,
+};
+
+struct Property {
+ // TODO(hassan): Replace the name with enum e.g. PropertyType.
+ std::string name;
+ // At most one of these will have any values.
+ std::vector<bool> bool_values;
+ std::vector<int64> int_values;
+ std::vector<double> double_values;
+ std::vector<std::string> string_values;
+ std::vector<AnnotationData> annotation_data_values;
+};
+
+struct AnnotationData {
+ // TODO(hassan): Replace it type with GrammarAnnotationType
+ std::string type;
+ std::vector<Property> properties;
+};
+
+// Represents an annotation instance.
+// lets call it either AnnotationDetails
+struct Annotation {
+ // Codepoint offsets into the original text specifying the substring of the
+ // text that was annotated.
+ int32 begin;
+ int32 end;
+
+ // Annotation priority score which can be used to resolve conflict between
+ // annotators.
+ float annotator_priority_score;
+
+ // Represents the details of the annotation instance, including the type of
+ // the annotation instance and its properties.
+ AnnotationData data;
+};
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.cc b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
new file mode 100644
index 0000000..99d3be0
--- /dev/null
+++ b/native/annotator/grammar/dates/cfg-datetime-annotator.cc
@@ -0,0 +1,139 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/cfg-datetime-annotator.h"
+
+#include "annotator/datetime/utils.h"
+#include "annotator/grammar/dates/annotations/annotation-options.h"
+#include "annotator/grammar/utils.h"
+#include "utils/strings/split.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::dates {
+namespace {
+
+static std::string GetReferenceLocale(const std::string& locales) {
+ std::vector<StringPiece> split_locales = strings::Split(locales, ',');
+ if (!split_locales.empty()) {
+ return split_locales[0].ToString();
+ }
+ return "";
+}
+
+static void InterpretParseData(const DatetimeParsedData& datetime_parsed_data,
+ const DateAnnotationOptions& options,
+ const CalendarLib& calendarlib,
+ int64* interpreted_time_ms_utc,
+ DatetimeGranularity* granularity) {
+ DatetimeGranularity local_granularity =
+ calendarlib.GetGranularity(datetime_parsed_data);
+ if (!calendarlib.InterpretParseData(
+ datetime_parsed_data, options.base_timestamp_millis,
+ options.reference_timezone, GetReferenceLocale(options.locales),
+ /*prefer_future_for_unspecified_date=*/true, interpreted_time_ms_utc,
+ granularity)) {
+ TC3_LOG(WARNING) << "Failed to extract time in millis and Granularity.";
+ // Fallingback to DatetimeParsedData's finest granularity
+ *granularity = local_granularity;
+ }
+}
+
+} // namespace
+
+CfgDatetimeAnnotator::CfgDatetimeAnnotator(
+ const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
+ const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
+ const float annotator_target_classification_score,
+ const float annotator_priority_score)
+ : calendar_lib_(*calendar_lib),
+ tokenizer_(BuildTokenizer(unilib, tokenizer_options)),
+ parser_(unilib, datetime_rules),
+ annotator_target_classification_score_(
+ annotator_target_classification_score),
+ annotator_priority_score_(annotator_priority_score) {}
+
+void CfgDatetimeAnnotator::Parse(
+ const std::string& input, const DateAnnotationOptions& annotation_options,
+ const std::vector<Locale>& locales,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), annotation_options,
+ locales, results);
+}
+
+void CfgDatetimeAnnotator::ProcessDatetimeParseResult(
+ const DateAnnotationOptions& annotation_options,
+ const DatetimeParseResult& datetime_parse_result,
+ std::vector<DatetimeParseResult>* results) const {
+ DatetimeParsedData datetime_parsed_data;
+ datetime_parsed_data.AddDatetimeComponents(
+ datetime_parse_result.datetime_components);
+
+ std::vector<DatetimeParsedData> interpretations;
+ if (annotation_options.generate_alternative_interpretations_when_ambiguous) {
+ FillInterpretations(datetime_parsed_data,
+ calendar_lib_.GetGranularity(datetime_parsed_data),
+ &interpretations);
+ } else {
+ interpretations.emplace_back(datetime_parsed_data);
+ }
+ for (const DatetimeParsedData& interpretation : interpretations) {
+ results->emplace_back();
+ interpretation.GetDatetimeComponents(&results->back().datetime_components);
+ InterpretParseData(interpretation, annotation_options, calendar_lib_,
+ &(results->back().time_ms_utc),
+ &(results->back().granularity));
+ std::sort(results->back().datetime_components.begin(),
+ results->back().datetime_components.end(),
+ [](const DatetimeComponent& a, const DatetimeComponent& b) {
+ return a.component_type > b.component_type;
+ });
+ }
+}
+
+void CfgDatetimeAnnotator::Parse(
+ const UnicodeText& input, const DateAnnotationOptions& annotation_options,
+ const std::vector<Locale>& locales,
+ std::vector<DatetimeParseResultSpan>* results) const {
+ std::vector<DatetimeParseResultSpan> grammar_datetime_parse_result_spans =
+ parser_.Parse(input.data(), tokenizer_.Tokenize(input), locales,
+ annotation_options);
+
+ for (const DatetimeParseResultSpan& grammar_datetime_parse_result_span :
+ grammar_datetime_parse_result_spans) {
+ DatetimeParseResultSpan datetime_parse_result_span;
+ datetime_parse_result_span.span.first =
+ grammar_datetime_parse_result_span.span.first;
+ datetime_parse_result_span.span.second =
+ grammar_datetime_parse_result_span.span.second;
+ datetime_parse_result_span.priority_score = annotator_priority_score_;
+ if (annotation_options.use_rule_priority_score) {
+ datetime_parse_result_span.priority_score =
+ grammar_datetime_parse_result_span.priority_score;
+ }
+ datetime_parse_result_span.target_classification_score =
+ annotator_target_classification_score_;
+ for (const DatetimeParseResult& grammar_datetime_parse_result :
+ grammar_datetime_parse_result_span.data) {
+ ProcessDatetimeParseResult(annotation_options,
+ grammar_datetime_parse_result,
+ &datetime_parse_result_span.data);
+ }
+ results->emplace_back(datetime_parse_result_span);
+ }
+}
+
+} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/cfg-datetime-annotator.h b/native/annotator/grammar/dates/cfg-datetime-annotator.h
new file mode 100644
index 0000000..73c9b7b
--- /dev/null
+++ b/native/annotator/grammar/dates/cfg-datetime-annotator.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
+
+#include "annotator/grammar/dates/annotations/annotation.h"
+#include "annotator/grammar/dates/dates_generated.h"
+#include "annotator/grammar/dates/parser.h"
+#include "annotator/grammar/dates/utils/annotation-keys.h"
+#include "annotator/model_generated.h"
+#include "utils/calendar/calendar.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::dates {
+
+// Helper class to convert the parsed datetime expression from AnnotationList
+// (List of annotation generated from Grammar rules) to DatetimeParseResultSpan.
+class CfgDatetimeAnnotator {
+ public:
+ explicit CfgDatetimeAnnotator(
+ const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options,
+ const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules,
+ const float annotator_target_classification_score,
+ const float annotator_priority_score);
+
+ // CfgDatetimeAnnotator is neither copyable nor movable.
+ CfgDatetimeAnnotator(const CfgDatetimeAnnotator&) = delete;
+ CfgDatetimeAnnotator& operator=(const CfgDatetimeAnnotator&) = delete;
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // Method will return false if input does not contain any datetime span.
+ void Parse(const std::string& input,
+ const DateAnnotationOptions& annotation_options,
+ const std::vector<Locale>& locales,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ // UnicodeText version of parse.
+ void Parse(const UnicodeText& input,
+ const DateAnnotationOptions& annotation_options,
+ const std::vector<Locale>& locales,
+ std::vector<DatetimeParseResultSpan>* results) const;
+
+ private:
+ void ProcessDatetimeParseResult(
+ const DateAnnotationOptions& annotation_options,
+ const DatetimeParseResult& datetime_parse_result,
+ std::vector<DatetimeParseResult>* results) const;
+
+ const CalendarLib& calendar_lib_;
+ const Tokenizer tokenizer_;
+ DateParser parser_;
+ const float annotator_target_classification_score_;
+ const float annotator_priority_score_;
+};
+
+} // namespace libtextclassifier3::dates
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_
diff --git a/native/annotator/grammar/dates/dates.fbs b/native/annotator/grammar/dates/dates.fbs
new file mode 100755
index 0000000..6d535bc
--- /dev/null
+++ b/native/annotator/grammar/dates/dates.fbs
@@ -0,0 +1,351 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "annotator/grammar/dates/timezone-code.fbs";
+include "utils/grammar/rules.fbs";
+
+// Type identifiers of all non-trivial matches.
+namespace libtextclassifier3.dates;
+enum MatchType : int {
+ UNKNOWN = 0,
+
+ // Match of a date extraction rule.
+ DATETIME_RULE = 1,
+
+ // Match of a date range extraction rule.
+ DATETIME_RANGE_RULE = 2,
+
+ // Match defined by an ExtractionRule (e.g., a single time-result that is
+ // matched by a time-rule, which is ready to be output individually, with
+ // this kind of match, we can retrieve it in range rules).
+ DATETIME = 3,
+
+ // Match defined by TermValue.
+ TERM_VALUE = 4,
+
+ // Matches defined by Nonterminal.
+ NONTERMINAL = 5,
+
+ DIGITS = 6,
+ YEAR = 7,
+ MONTH = 8,
+ DAY = 9,
+ HOUR = 10,
+ MINUTE = 11,
+ SECOND = 12,
+ FRACTION_SECOND = 13,
+ DAY_OF_WEEK = 14,
+ TIME_VALUE = 15,
+ TIME_SPAN = 16,
+ TIME_ZONE_NAME = 17,
+ TIME_ZONE_OFFSET = 18,
+ TIME_PERIOD = 19,
+ RELATIVE_DATE = 20,
+ COMBINED_DIGITS = 21,
+}
+
+namespace libtextclassifier3.dates;
+enum BCAD : int {
+ BCAD_NONE = -1,
+ BC = 0,
+ AD = 1,
+}
+
+namespace libtextclassifier3.dates;
+enum DayOfWeek : int {
+ DOW_NONE = -1,
+ SUNDAY = 1,
+ MONDAY = 2,
+ TUESDAY = 3,
+ WEDNESDAY = 4,
+ THURSDAY = 5,
+ FRIDAY = 6,
+ SATURDAY = 7,
+}
+
+namespace libtextclassifier3.dates;
+enum TimespanCode : int {
+ TIMESPAN_CODE_NONE = -1,
+ AM = 0,
+ PM = 1,
+ NOON = 2,
+ MIDNIGHT = 3,
+
+ // English "tonight".
+ TONIGHT = 11,
+}
+
+// The datetime grammar rules.
+namespace libtextclassifier3.dates;
+table DatetimeRules {
+ // The context free grammar rules.
+ rules:grammar.RulesSet;
+
+ // Values associated with grammar rule matches.
+ extraction_rule:[ExtractionRuleParameter];
+
+ term_value:[TermValue];
+ nonterminal_value:[NonterminalValue];
+}
+
+namespace libtextclassifier3.dates;
+table TermValue {
+ value:int;
+
+ // A time segment e.g. 10AM - 12AM
+ time_span_spec:TimeSpanSpec;
+
+ // Time zone information representation
+ time_zone_name_spec:TimeZoneNameSpec;
+}
+
+// Define nonterms from terms or other nonterms.
+namespace libtextclassifier3.dates;
+table NonterminalValue {
+ // Mapping value.
+ value:TermValue;
+
+ // Parameter describing formatting choices for nonterminal messages
+ nonterminal_parameter:NonterminalParameter;
+
+ // Parameter interpreting past/future dates (e.g. "last year")
+ relative_parameter:RelativeParameter;
+
+ // Format info for nonterminals representing times.
+ time_value_parameter:TimeValueParameter;
+
+ // Parameter describing the format of time-zone info - e.g. "UTC-8"
+ time_zone_offset_parameter:TimeZoneOffsetParameter;
+}
+
+namespace libtextclassifier3.dates.RelativeParameter_;
+enum RelativeType : int {
+ NONE = 0,
+ YEAR = 1,
+ MONTH = 2,
+ DAY = 3,
+ WEEK = 4,
+ HOUR = 5,
+ MINUTE = 6,
+ SECOND = 7,
+}
+
+namespace libtextclassifier3.dates.RelativeParameter_;
+enum Period : int {
+ PERIOD_UNKNOWN = 0,
+ PERIOD_PAST = 1,
+ PERIOD_FUTURE = 2,
+}
+
+// Relative interpretation.
+// Indicates which day the day of week could be, for example "next Friday"
+// could means the Friday which is the closest Friday or the Friday in the
+// next week.
+namespace libtextclassifier3.dates.RelativeParameter_;
+enum Interpretation : int {
+ UNKNOWN = 0,
+
+ // The closest X in the past.
+ NEAREST_LAST = 1,
+
+ // The X before the closest X in the past.
+ SECOND_LAST = 2,
+
+ // The closest X in the future.
+ NEAREST_NEXT = 3,
+
+ // The X after the closest X in the future.
+ SECOND_NEXT = 4,
+
+ // X in the previous one.
+ PREVIOUS = 5,
+
+ // X in the coming one.
+ COMING = 6,
+
+ // X in current one, it can be both past and future.
+ CURRENT = 7,
+
+ // Some X.
+ SOME = 8,
+
+ // The closest X, it can be both past and future.
+ NEAREST = 9,
+}
+
+namespace libtextclassifier3.dates;
+table RelativeParameter {
+ type:RelativeParameter_.RelativeType = NONE;
+ period:RelativeParameter_.Period = PERIOD_UNKNOWN;
+ day_of_week_interpretation:[RelativeParameter_.Interpretation];
+}
+
+namespace libtextclassifier3.dates.NonterminalParameter_;
+enum Flag : int {
+ IS_SPELLED = 1,
+}
+
+namespace libtextclassifier3.dates;
+table NonterminalParameter {
+ // Bit-wise OR Flag.
+ flag:uint = 0;
+
+ combined_digits_format:string (shared);
+}
+
+namespace libtextclassifier3.dates.TimeValueParameter_;
+enum TimeValueValidation : int {
+ // Allow extra spaces between sub-components in time-value.
+ ALLOW_EXTRA_SPACE = 1,
+ // 1 << 0
+
+ // Disallow colon- or dot-context with digits for time-value.
+ DISALLOW_COLON_DOT_CONTEXT = 2,
+ // 1 << 1
+}
+
+namespace libtextclassifier3.dates;
+table TimeValueParameter {
+ validation:uint = 0;
+ // Bitwise-OR
+
+ flag:uint = 0;
+ // Bitwise-OR
+}
+
+namespace libtextclassifier3.dates.TimeZoneOffsetParameter_;
+enum Format : int {
+ // Offset is in an uncategorized format.
+ FORMAT_UNKNOWN = 0,
+
+ // Offset contains 1-digit hour only, e.g. "UTC-8".
+ FORMAT_H = 1,
+
+ // Offset contains 2-digit hour only, e.g. "UTC-08".
+ FORMAT_HH = 2,
+
+ // Offset contains 1-digit hour and minute, e.g. "UTC-8:00".
+ FORMAT_H_MM = 3,
+
+ // Offset contains 2-digit hour and minute, e.g. "UTC-08:00".
+ FORMAT_HH_MM = 4,
+
+ // Offset contains 3-digit hour-and-minute, e.g. "UTC-800".
+ FORMAT_HMM = 5,
+
+ // Offset contains 4-digit hour-and-minute, e.g. "UTC-0800".
+ FORMAT_HHMM = 6,
+}
+
+namespace libtextclassifier3.dates;
+table TimeZoneOffsetParameter {
+ format:TimeZoneOffsetParameter_.Format = FORMAT_UNKNOWN;
+}
+
+namespace libtextclassifier3.dates.ExtractionRuleParameter_;
+enum ExtractionValidation : int {
+ // Boundary checking for final match.
+ LEFT_BOUND = 1,
+
+ RIGHT_BOUND = 2,
+ SPELLED_YEAR = 4,
+ SPELLED_MONTH = 8,
+ SPELLED_DAY = 16,
+
+ // Without this validation-flag set, unconfident time-zone expression
+ // are discarded in the output-callback, e.g. "-08:00, +8".
+ ALLOW_UNCONFIDENT_TIME_ZONE = 32,
+}
+
+// Parameter info for extraction rule, help rule explanation.
+namespace libtextclassifier3.dates;
+table ExtractionRuleParameter {
+ // Bit-wise OR Validation.
+ validation:uint = 0;
+
+ priority_delta:int;
+ id:string (shared);
+
+ // The score reflects the confidence score of the date/time match, which is
+ // set while creating grammar rules.
+ // e.g. given we have the rule which detect "22.33" as a HH.MM then because
+ // of ambiguity the confidence of this match maybe relatively less.
+ annotator_priority_score:float;
+}
+
+// Internal structure used to describe an hour-mapping segment.
+namespace libtextclassifier3.dates.TimeSpanSpec_;
+table Segment {
+ // From 0 to 24, the beginning hour of the segment, always included.
+ begin:int;
+
+ // From 0 to 24, the ending hour of the segment, not included if the
+ // segment is not closed. The value 0 means the beginning of the next
+ // day, the same value as "begin" means a time-point.
+ end:int;
+
+ // From -24 to 24, the mapping offset in hours from spanned expressions
+ // to 24-hour expressions. The value 0 means identical mapping.
+ offset:int;
+
+ // True if the segment is a closed one instead of a half-open one.
+ // Always set it to true when describing time-points.
+ is_closed:bool = false;
+
+ // True if a strict check should be performed onto the segment which
+ // disallows already-offset hours to be used in spanned expressions,
+ // e.g. 15:30PM.
+ is_strict:bool = false;
+
+ // True if the time-span can be used without an explicitly specified
+ // hour value, then it can generate an exact time point (the "begin"
+ // o'clock sharp, like "noon") or a time range, like "Tonight".
+ is_stand_alone:bool = false;
+}
+
+namespace libtextclassifier3.dates;
+table TimeSpanSpec {
+ code:TimespanCode;
+ segment:[TimeSpanSpec_.Segment];
+}
+
+namespace libtextclassifier3.dates.TimeZoneNameSpec_;
+enum TimeZoneType : int {
+ // The corresponding name might represent a standard or daylight-saving
+ // time-zone, depending on some external information, e.g. the date.
+ AMBIGUOUS = 0,
+
+ // The corresponding name represents a standard time-zone.
+ STANDARD = 1,
+
+ // The corresponding name represents a daylight-saving time-zone.
+ DAYLIGHT = 2,
+}
+
+namespace libtextclassifier3.dates;
+table TimeZoneNameSpec {
+ code:TimezoneCode;
+ type:TimeZoneNameSpec_.TimeZoneType = AMBIGUOUS;
+
+ // Set to true if the corresponding name is internationally used as an
+ // abbreviation (or expression) of UTC. For example, "GMT" and "Z".
+ is_utc:bool = false;
+
+ // Set to false if the corresponding name is not an abbreviation. For example,
+ // "Pacific Time" and "China Standard Time".
+ is_abbreviation:bool = true;
+}
+
diff --git a/native/annotator/grammar/dates/extractor.cc b/native/annotator/grammar/dates/extractor.cc
new file mode 100644
index 0000000..d2db23e
--- /dev/null
+++ b/native/annotator/grammar/dates/extractor.cc
@@ -0,0 +1,913 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/extractor.h"
+
+#include <initializer_list>
+#include <map>
+
+#include "annotator/grammar/dates/utils/date-match.h"
+#include "annotator/grammar/dates/utils/date-utils.h"
+#include "utils/base/casts.h"
+#include "utils/base/logging.h"
+#include "utils/strings/numbers.h"
+
+namespace libtextclassifier3::dates {
+namespace {
+
+// Helper struct for time-related components.
+// Extracts all subnodes of a specified type.
+struct MatchComponents {
+ MatchComponents(const grammar::Match* root,
+ std::initializer_list<int16> types)
+ : root(root),
+ components(grammar::SelectAll(
+ root, [root, &types](const grammar::Match* node) {
+ if (node == root || node->type == grammar::Match::kUnknownType) {
+ return false;
+ }
+ for (const int64 type : types) {
+ if (node->type == type) {
+ return true;
+ }
+ }
+ return false;
+ })) {}
+
+ // Returns the index of the first submatch of the specified type or -1 if not
+ // found.
+ int IndexOf(const int16 type, const int start_index = 0) const {
+ for (int i = start_index; i < components.size(); i++) {
+ if (components[i]->type == type) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ // Returns the first submatch of the specified type, or nullptr if not found.
+ template <typename T>
+ const T* SubmatchOf(const int16 type, const int start_index = 0) const {
+ return SubmatchAt<T>(IndexOf(type, start_index));
+ }
+
+ template <typename T>
+ const T* SubmatchAt(const int index) const {
+ if (index < 0) {
+ return nullptr;
+ }
+ return static_cast<const T*>(components[index]);
+ }
+
+ const grammar::Match* root;
+ std::vector<const grammar::Match*> components;
+};
+
+// Helper method to check whether a time value has valid components.
+bool IsValidTimeValue(const TimeValueMatch& time_value) {
+ // Can only specify seconds if minutes are present.
+ if (time_value.minute == NO_VAL && time_value.second != NO_VAL) {
+ return false;
+ }
+ // Can only specify fraction of seconds if seconds are present.
+ if (time_value.second == NO_VAL && time_value.fraction_second >= 0.0) {
+ return false;
+ }
+
+ const int8 h = time_value.hour;
+ const int8 m = (time_value.minute < 0 ? 0 : time_value.minute);
+ const int8 s = (time_value.second < 0 ? 0 : time_value.second);
+ const double f =
+ (time_value.fraction_second < 0.0 ? 0.0 : time_value.fraction_second);
+
+ // Check value bounds.
+ if (h == NO_VAL || h > 24 || m > 59 || s > 60) {
+ return false;
+ }
+ if (h == 24 && (m != 0 || s != 0 || f > 0.0)) {
+ return false;
+ }
+ if (s == 60 && m != 59) {
+ return false;
+ }
+ return true;
+}
+
+int ParseLeadingDec32Value(const char* c_str) {
+ int value;
+ if (ParseInt32(c_str, &value)) {
+ return value;
+ }
+ return NO_VAL;
+}
+
+double ParseLeadingDoubleValue(const char* c_str) {
+ double value;
+ if (ParseDouble(c_str, &value)) {
+ return value;
+ }
+ return NO_VAL;
+}
+
+// Extracts digits as an integer and adds a typed match accordingly.
+template <typename T>
+void CheckDigits(const grammar::Match* match,
+ const NonterminalValue* nonterminal, StringPiece match_text,
+ grammar::Matcher* matcher) {
+ TC3_CHECK(match->IsUnaryRule());
+ const int value = ParseLeadingDec32Value(match_text.ToString().c_str());
+ if (!T::IsValid(value)) {
+ return;
+ }
+ const int num_digits = match_text.size();
+ T* result = matcher->AllocateAndInitMatch<T>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->value = value;
+ result->count_of_digits = num_digits;
+ result->is_zero_prefixed = (num_digits >= 2 && match_text[0] == '0');
+ matcher->AddMatch(result);
+}
+
+// Extracts digits as a decimal (as fraction, as if a "0." is prefixed) and
+// adds a typed match to the `er accordingly.
+template <typename T>
+void CheckDigitsAsFraction(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ StringPiece match_text, grammar::Matcher* matcher) {
+ TC3_CHECK(match->IsUnaryRule());
+ // TODO(smillius): Should should be achievable in a more straight-forward way.
+ const double value =
+ ParseLeadingDoubleValue(("0." + match_text.ToString()).data());
+ if (!T::IsValid(value)) {
+ return;
+ }
+ T* result = matcher->AllocateAndInitMatch<T>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->value = value;
+ result->count_of_digits = match_text.size();
+ matcher->AddMatch(result);
+}
+
+// Extracts consecutive digits as multiple integers according to a format and
+// adds a type match to the matcher accordingly.
+template <typename T>
+void CheckCombinedDigits(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ StringPiece match_text, grammar::Matcher* matcher) {
+ TC3_CHECK(match->IsUnaryRule());
+ const std::string& format =
+ nonterminal->nonterminal_parameter()->combined_digits_format()->str();
+ if (match_text.size() != format.size()) {
+ return;
+ }
+
+ static std::map<char, CombinedDigitsMatch::Index>& kCombinedDigitsMatchIndex =
+ *[]() {
+ return new std::map<char, CombinedDigitsMatch::Index>{
+ {'Y', CombinedDigitsMatch::INDEX_YEAR},
+ {'M', CombinedDigitsMatch::INDEX_MONTH},
+ {'D', CombinedDigitsMatch::INDEX_DAY},
+ {'h', CombinedDigitsMatch::INDEX_HOUR},
+ {'m', CombinedDigitsMatch::INDEX_MINUTE},
+ {'s', CombinedDigitsMatch::INDEX_SECOND}};
+ }();
+
+ struct Segment {
+ const int index;
+ const int length;
+ const int value;
+ };
+ std::vector<Segment> segments;
+ int slice_start = 0;
+ while (slice_start < format.size()) {
+ int slice_end = slice_start + 1;
+ // Advace right as long as we have the same format character.
+ while (slice_end < format.size() &&
+ format[slice_start] == format[slice_end]) {
+ slice_end++;
+ }
+
+ const int slice_length = slice_end - slice_start;
+ const int value = ParseLeadingDec32Value(
+ std::string(match_text.data() + slice_start, slice_length).c_str());
+
+ auto index = kCombinedDigitsMatchIndex.find(format[slice_start]);
+ if (index == kCombinedDigitsMatchIndex.end()) {
+ return;
+ }
+ if (!T::IsValid(index->second, value)) {
+ return;
+ }
+ segments.push_back(Segment{index->second, slice_length, value});
+ slice_start = slice_end;
+ }
+ T* result = matcher->AllocateAndInitMatch<T>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ for (const Segment& segment : segments) {
+ result->values[segment.index] = segment.value;
+ }
+ result->count_of_digits = match_text.size();
+ result->is_zero_prefixed =
+ (match_text[0] == '0' && segments.front().length >= 2);
+ matcher->AddMatch(result);
+}
+
+// Retrieves the corresponding value from an associated term-value mapping for
+// the nonterminal and adds a typed match to the matcher accordingly.
+template <typename T>
+void CheckMappedValue(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ const TermValueMatch* term =
+ grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
+ if (term == nullptr) {
+ return;
+ }
+ const int value = term->term_value->value();
+ if (!T::IsValid(value)) {
+ return;
+ }
+ T* result = matcher->AllocateAndInitMatch<T>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->value = value;
+ matcher->AddMatch(result);
+}
+
+// Checks if there is an associated value in the corresponding nonterminal and
+// adds a typed match to the matcher accordingly.
+template <typename T>
+void CheckDirectValue(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ const int value = nonterminal->value()->value();
+ if (!T::IsValid(value)) {
+ return;
+ }
+ T* result = matcher->AllocateAndInitMatch<T>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->value = value;
+ matcher->AddMatch(result);
+}
+
+template <typename T>
+void CheckAndAddDirectOrMappedValue(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ if (nonterminal->value() != nullptr) {
+ CheckDirectValue<T>(match, nonterminal, matcher);
+ } else {
+ CheckMappedValue<T>(match, nonterminal, matcher);
+ }
+}
+
+template <typename T>
+void CheckAndAddNumericValue(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ StringPiece match_text,
+ grammar::Matcher* matcher) {
+ if (nonterminal->nonterminal_parameter() != nullptr &&
+ nonterminal->nonterminal_parameter()->flag() &
+ NonterminalParameter_::Flag_IS_SPELLED) {
+ CheckMappedValue<T>(match, nonterminal, matcher);
+ } else {
+ CheckDigits<T>(match, nonterminal, match_text, matcher);
+ }
+}
+
+// Tries to parse as digital time value.
+bool ParseDigitalTimeValue(const std::vector<UnicodeText::const_iterator>& text,
+ const MatchComponents& components,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ // Required fields.
+ const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
+ if (hour == nullptr || hour->count_of_digits == 0) {
+ return false;
+ }
+
+ // Optional fields.
+ const MinuteMatch* minute =
+ components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
+ if (minute != nullptr && minute->count_of_digits == 0) {
+ return false;
+ }
+ const SecondMatch* second =
+ components.SubmatchOf<SecondMatch>(MatchType_SECOND);
+ if (second != nullptr && second->count_of_digits == 0) {
+ return false;
+ }
+ const FractionSecondMatch* fraction_second =
+ components.SubmatchOf<FractionSecondMatch>(MatchType_FRACTION_SECOND);
+ if (fraction_second != nullptr && fraction_second->count_of_digits == 0) {
+ return false;
+ }
+
+ // Validation.
+ uint32 validation = nonterminal->time_value_parameter()->validation();
+ const grammar::Match* end = hour;
+ if (minute != nullptr) {
+ if (second != nullptr) {
+ if (fraction_second != nullptr) {
+ end = fraction_second;
+ } else {
+ end = second;
+ }
+ } else {
+ end = minute;
+ }
+ }
+
+ // Check if there is any extra space between h m s f.
+ if ((validation &
+ TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
+ // Check whether there is whitespace between token.
+ if (minute != nullptr && minute->HasLeadingWhitespace()) {
+ return false;
+ }
+ if (second != nullptr && second->HasLeadingWhitespace()) {
+ return false;
+ }
+ if (fraction_second != nullptr && fraction_second->HasLeadingWhitespace()) {
+ return false;
+ }
+ }
+
+ // Check if there is any ':' or '.' as a prefix or suffix.
+ if (validation &
+ TimeValueParameter_::TimeValueValidation_DISALLOW_COLON_DOT_CONTEXT) {
+ const int begin_pos = hour->codepoint_span.first;
+ const int end_pos = end->codepoint_span.second;
+ if (begin_pos > 1 &&
+ (*text[begin_pos - 1] == ':' || *text[begin_pos - 1] == '.') &&
+ isdigit(*text[begin_pos - 2])) {
+ return false;
+ }
+ // Last valid codepoint is at text.size() - 2 as we added the end position
+ // of text for easier span extraction.
+ if (end_pos < text.size() - 2 &&
+ (*text[end_pos] == ':' || *text[end_pos] == '.') &&
+ isdigit(*text[end_pos + 1])) {
+ return false;
+ }
+ }
+
+ TimeValueMatch time_value;
+ time_value.Init(components.root->lhs, components.root->codepoint_span,
+ components.root->match_offset);
+ time_value.Reset();
+ time_value.hour_match = hour;
+ time_value.minute_match = minute;
+ time_value.second_match = second;
+ time_value.fraction_second_match = fraction_second;
+ time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
+ time_value.is_minute_one_digit =
+ (minute != nullptr && minute->count_of_digits == 1);
+ time_value.is_second_one_digit =
+ (second != nullptr && second->count_of_digits == 1);
+ time_value.hour = hour->value;
+ time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
+ time_value.second = (second != nullptr ? second->value : NO_VAL);
+ time_value.fraction_second =
+ (fraction_second != nullptr ? fraction_second->value : NO_VAL);
+
+ if (!IsValidTimeValue(time_value)) {
+ return false;
+ }
+
+ TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
+ *result = time_value;
+ matcher->AddMatch(result);
+ return true;
+}
+
+// Tries to parsing a time from spelled out time components.
+bool ParseSpelledTimeValue(const MatchComponents& components,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ // Required fields.
+ const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR);
+ if (hour == nullptr || hour->count_of_digits != 0) {
+ return false;
+ }
+ // Optional fields.
+ const MinuteMatch* minute =
+ components.SubmatchOf<MinuteMatch>(MatchType_MINUTE);
+ if (minute != nullptr && minute->count_of_digits != 0) {
+ return false;
+ }
+ const SecondMatch* second =
+ components.SubmatchOf<SecondMatch>(MatchType_SECOND);
+ if (second != nullptr && second->count_of_digits != 0) {
+ return false;
+ }
+
+ uint32 validation = nonterminal->time_value_parameter()->validation();
+ // Check if there is any extra space between h m s.
+ if ((validation &
+ TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) {
+ // Check whether there is whitespace between token.
+ if (minute != nullptr && minute->HasLeadingWhitespace()) {
+ return false;
+ }
+ if (second != nullptr && second->HasLeadingWhitespace()) {
+ return false;
+ }
+ }
+
+ TimeValueMatch time_value;
+ time_value.Init(components.root->lhs, components.root->codepoint_span,
+ components.root->match_offset);
+ time_value.Reset();
+ time_value.hour_match = hour;
+ time_value.minute_match = minute;
+ time_value.second_match = second;
+ time_value.is_hour_zero_prefixed = hour->is_zero_prefixed;
+ time_value.is_minute_one_digit =
+ (minute != nullptr && minute->count_of_digits == 1);
+ time_value.is_second_one_digit =
+ (second != nullptr && second->count_of_digits == 1);
+ time_value.hour = hour->value;
+ time_value.minute = (minute != nullptr ? minute->value : NO_VAL);
+ time_value.second = (second != nullptr ? second->value : NO_VAL);
+
+ if (!IsValidTimeValue(time_value)) {
+ return false;
+ }
+
+ TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>();
+ *result = time_value;
+ matcher->AddMatch(result);
+ return true;
+}
+
+// Reconstructs and validates a time value from a match.
+void CheckTimeValue(const std::vector<UnicodeText::const_iterator>& text,
+ const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ MatchComponents components(
+ match, {MatchType_HOUR, MatchType_MINUTE, MatchType_SECOND,
+ MatchType_FRACTION_SECOND});
+ if (ParseDigitalTimeValue(text, components, nonterminal, matcher)) {
+ return;
+ }
+ if (ParseSpelledTimeValue(components, nonterminal, matcher)) {
+ return;
+ }
+}
+
+// Validates a time span match.
+void CheckTimeSpan(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ const TermValueMatch* ts_name =
+ grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE);
+ const TermValue* term_value = ts_name->term_value;
+ TC3_CHECK(term_value != nullptr);
+ TC3_CHECK(term_value->time_span_spec() != nullptr);
+ const TimeSpanSpec* ts_spec = term_value->time_span_spec();
+ TimeSpanMatch* time_span = matcher->AllocateAndInitMatch<TimeSpanMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ time_span->Reset();
+ time_span->nonterminal = nonterminal;
+ time_span->time_span_spec = ts_spec;
+ time_span->time_span_code = ts_spec->code();
+ matcher->AddMatch(time_span);
+}
+
+// Validates a time period match.
+void CheckTimePeriod(const std::vector<UnicodeText::const_iterator>& text,
+ const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ int period_value = NO_VAL;
+
+ // If a value mapping exists, use it.
+ if (nonterminal->value() != nullptr) {
+ period_value = nonterminal->value()->value();
+ } else if (const TermValueMatch* term =
+ grammar::SelectFirstOfType<TermValueMatch>(
+ match, MatchType_TERM_VALUE)) {
+ period_value = term->term_value->value();
+ } else if (const grammar::Match* digits =
+ grammar::SelectFirstOfType<grammar::Match>(
+ match, grammar::Match::kDigitsType)) {
+ period_value = ParseLeadingDec32Value(
+ std::string(text[digits->codepoint_span.first].utf8_data(),
+ text[digits->codepoint_span.second].utf8_data() -
+ text[digits->codepoint_span.first].utf8_data())
+ .c_str());
+ }
+
+ if (period_value <= NO_VAL) {
+ return;
+ }
+
+ TimePeriodMatch* result = matcher->AllocateAndInitMatch<TimePeriodMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->value = period_value;
+ matcher->AddMatch(result);
+}
+
+// Reconstructs a date from a relative date rule match.
+void CheckRelativeDate(const DateAnnotationOptions& options,
+ const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ if (!options.enable_special_day_offset &&
+ grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE) !=
+ nullptr) {
+ // Special day offsets, like "Today", "Tomorrow" etc. are not enabled.
+ return;
+ }
+
+ RelativeMatch* relative_match = matcher->AllocateAndInitMatch<RelativeMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ relative_match->Reset();
+ relative_match->nonterminal = nonterminal;
+
+ // Fill relative date information from individual components.
+ grammar::Traverse(match, [match, relative_match](const grammar::Match* node) {
+ // Ignore the current match.
+ if (node == match || node->type == grammar::Match::kUnknownType) {
+ return true;
+ }
+
+ if (node->type == MatchType_TERM_VALUE) {
+ const int value =
+ static_cast<const TermValueMatch*>(node)->term_value->value();
+ relative_match->day = abs(value);
+ if (value >= 0) {
+ // Marks "today" as in the future.
+ relative_match->is_future_date = true;
+ }
+ relative_match->existing |=
+ (RelativeMatch::HAS_DAY | RelativeMatch::HAS_IS_FUTURE);
+ return false;
+ }
+
+ // Parse info from nonterminal.
+ const NonterminalValue* nonterminal =
+ static_cast<const NonterminalMatch*>(node)->nonterminal;
+ if (nonterminal != nullptr &&
+ nonterminal->relative_parameter() != nullptr) {
+ const RelativeParameter* relative_parameter =
+ nonterminal->relative_parameter();
+ if (relative_parameter->period() !=
+ RelativeParameter_::Period_PERIOD_UNKNOWN) {
+ relative_match->is_future_date =
+ (relative_parameter->period() ==
+ RelativeParameter_::Period_PERIOD_FUTURE);
+ relative_match->existing |= RelativeMatch::HAS_IS_FUTURE;
+ }
+ if (relative_parameter->day_of_week_interpretation() != nullptr) {
+ relative_match->day_of_week_nonterminal = nonterminal;
+ relative_match->existing |= RelativeMatch::HAS_DAY_OF_WEEK;
+ }
+ }
+
+ // Relative day of week.
+ if (node->type == MatchType_DAY_OF_WEEK) {
+ relative_match->day_of_week =
+ static_cast<const DayOfWeekMatch*>(node)->value;
+ return false;
+ }
+
+ if (node->type != MatchType_TIME_PERIOD) {
+ return true;
+ }
+
+ const TimePeriodMatch* period = static_cast<const TimePeriodMatch*>(node);
+ switch (nonterminal->relative_parameter()->type()) {
+ case RelativeParameter_::RelativeType_YEAR: {
+ relative_match->year = period->value;
+ relative_match->existing |= RelativeMatch::HAS_YEAR;
+ break;
+ }
+ case RelativeParameter_::RelativeType_MONTH: {
+ relative_match->month = period->value;
+ relative_match->existing |= RelativeMatch::HAS_MONTH;
+ break;
+ }
+ case RelativeParameter_::RelativeType_WEEK: {
+ relative_match->week = period->value;
+ relative_match->existing |= RelativeMatch::HAS_WEEK;
+ break;
+ }
+ case RelativeParameter_::RelativeType_DAY: {
+ relative_match->day = period->value;
+ relative_match->existing |= RelativeMatch::HAS_DAY;
+ break;
+ }
+ case RelativeParameter_::RelativeType_HOUR: {
+ relative_match->hour = period->value;
+ relative_match->existing |= RelativeMatch::HAS_HOUR;
+ break;
+ }
+ case RelativeParameter_::RelativeType_MINUTE: {
+ relative_match->minute = period->value;
+ relative_match->existing |= RelativeMatch::HAS_MINUTE;
+ break;
+ }
+ case RelativeParameter_::RelativeType_SECOND: {
+ relative_match->second = period->value;
+ relative_match->existing |= RelativeMatch::HAS_SECOND;
+ break;
+ }
+ default:
+ break;
+ }
+
+ return true;
+ });
+ matcher->AddMatch(relative_match);
+}
+
+bool IsValidTimeZoneOffset(const int time_zone_offset) {
+ return (time_zone_offset >= -720 && time_zone_offset <= 840 &&
+ time_zone_offset % 15 == 0);
+}
+
+// Parses, validates and adds a time zone offset match.
+void CheckTimeZoneOffset(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ MatchComponents components(
+ match, {MatchType_DIGITS, MatchType_TERM_VALUE, MatchType_NONTERMINAL});
+ const TermValueMatch* tz_sign =
+ components.SubmatchOf<TermValueMatch>(MatchType_TERM_VALUE);
+ if (tz_sign == nullptr) {
+ return;
+ }
+ const int sign = tz_sign->term_value->value();
+ TC3_CHECK(sign == -1 || sign == 1);
+
+ const int tz_digits_index = components.IndexOf(MatchType_DIGITS);
+ if (tz_digits_index < 0) {
+ return;
+ }
+ const DigitsMatch* tz_digits =
+ components.SubmatchAt<DigitsMatch>(tz_digits_index);
+ if (tz_digits == nullptr) {
+ return;
+ }
+
+ int offset;
+ if (tz_digits->count_of_digits >= 3) {
+ offset = (tz_digits->value / 100) * 60 + (tz_digits->value % 100);
+ } else {
+ offset = tz_digits->value * 60;
+ if (const DigitsMatch* tz_digits_extra = components.SubmatchOf<DigitsMatch>(
+ MatchType_DIGITS, /*start_index=*/tz_digits_index + 1)) {
+ offset += tz_digits_extra->value;
+ }
+ }
+
+ const NonterminalMatch* tz_offset =
+ components.SubmatchOf<NonterminalMatch>(MatchType_NONTERMINAL);
+ if (tz_offset == nullptr) {
+ return;
+ }
+
+ const int time_zone_offset = sign * offset;
+ if (!IsValidTimeZoneOffset(time_zone_offset)) {
+ return;
+ }
+
+ TimeZoneOffsetMatch* result =
+ matcher->AllocateAndInitMatch<TimeZoneOffsetMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->time_zone_offset_param =
+ tz_offset->nonterminal->time_zone_offset_parameter();
+ result->time_zone_offset = time_zone_offset;
+ matcher->AddMatch(result);
+}
+
+// Validates and adds a time zone name match.
+void CheckTimeZoneName(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ TC3_CHECK(match->IsUnaryRule());
+ const TermValueMatch* tz_name =
+ static_cast<const TermValueMatch*>(match->unary_rule_rhs());
+ if (tz_name == nullptr) {
+ return;
+ }
+ const TimeZoneNameSpec* tz_name_spec =
+ tz_name->term_value->time_zone_name_spec();
+ TimeZoneNameMatch* result = matcher->AllocateAndInitMatch<TimeZoneNameMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ result->time_zone_name_spec = tz_name_spec;
+ result->time_zone_code = tz_name_spec->code();
+ matcher->AddMatch(result);
+}
+
+// Adds a mapped term value match containing its value.
+void AddTermValue(const grammar::Match* match, const TermValue* term_value,
+ grammar::Matcher* matcher) {
+ TermValueMatch* term_match = matcher->AllocateAndInitMatch<TermValueMatch>(
+ match->lhs, match->codepoint_span, match->match_offset);
+ term_match->Reset();
+ term_match->term_value = term_value;
+ matcher->AddMatch(term_match);
+}
+
+// Adds a match for a nonterminal.
+void AddNonterminal(const grammar::Match* match,
+ const NonterminalValue* nonterminal,
+ grammar::Matcher* matcher) {
+ NonterminalMatch* result =
+ matcher->AllocateAndInitMatch<NonterminalMatch>(*match);
+ result->Reset();
+ result->nonterminal = nonterminal;
+ matcher->AddMatch(result);
+}
+
+// Adds a match for an extraction rule that is potentially used in a date range
+// rule.
+void AddExtractionRuleMatch(const grammar::Match* match,
+ const ExtractionRuleParameter* rule,
+ grammar::Matcher* matcher) {
+ ExtractionMatch* result =
+ matcher->AllocateAndInitMatch<ExtractionMatch>(*match);
+ result->Reset();
+ result->extraction_rule = rule;
+ matcher->AddMatch(result);
+}
+
+} // namespace
+
+void DateExtractor::HandleExtractionRuleMatch(
+ const ExtractionRuleParameter* rule, const grammar::Match* match,
+ grammar::Matcher* matcher) {
+ if (rule->id() != nullptr) {
+ const std::string rule_id = rule->id()->str();
+ bool keep = false;
+ for (const std::string& extra_requested_dates_id :
+ options_.extra_requested_dates) {
+ if (extra_requested_dates_id == rule_id) {
+ keep = true;
+ break;
+ }
+ }
+ if (!keep) {
+ return;
+ }
+ }
+ output_.push_back(
+ Output{rule, matcher->AllocateAndInitMatch<grammar::Match>(*match)});
+}
+
+void DateExtractor::HandleRangeExtractionRuleMatch(const grammar::Match* match,
+ grammar::Matcher* matcher) {
+ // Collect the two datetime roots that make up the range.
+ std::vector<const grammar::Match*> parts;
+ grammar::Traverse(match, [match, &parts](const grammar::Match* node) {
+ if (node == match || node->type == grammar::Match::kUnknownType) {
+ // Just continue traversing the match.
+ return true;
+ }
+
+ // Collect, but don't expand the individual datetime nodes.
+ parts.push_back(node);
+ return false;
+ });
+ TC3_CHECK_EQ(parts.size(), 2);
+ range_output_.push_back(
+ RangeOutput{matcher->AllocateAndInitMatch<grammar::Match>(*match),
+ /*from=*/parts[0], /*to=*/parts[1]});
+}
+
+void DateExtractor::MatchFound(const grammar::Match* match,
+ const grammar::CallbackId type,
+ const int64 value, grammar::Matcher* matcher) {
+ switch (type) {
+ case MatchType_DATETIME_RULE: {
+ HandleExtractionRuleMatch(
+ /*rule=*/
+ datetime_rules_->extraction_rule()->Get(value), match, matcher);
+ return;
+ }
+ case MatchType_DATETIME_RANGE_RULE: {
+ HandleRangeExtractionRuleMatch(match, matcher);
+ return;
+ }
+ case MatchType_DATETIME: {
+ // If an extraction rule is also part of a range extraction rule, then the
+ // extraction rule is treated as a rule match and nonterminal match.
+ // This type is used to match the rule as non terminal.
+ AddExtractionRuleMatch(
+ match, datetime_rules_->extraction_rule()->Get(value), matcher);
+ return;
+ }
+ case MatchType_TERM_VALUE: {
+ // Handle mapped terms.
+ AddTermValue(match, datetime_rules_->term_value()->Get(value), matcher);
+ return;
+ }
+ default:
+ break;
+ }
+
+ // Handle non-terminals.
+ const NonterminalValue* nonterminal =
+ datetime_rules_->nonterminal_value()->Get(value);
+ StringPiece match_text =
+ StringPiece(text_[match->codepoint_span.first].utf8_data(),
+ text_[match->codepoint_span.second].utf8_data() -
+ text_[match->codepoint_span.first].utf8_data());
+ switch (type) {
+ case MatchType_NONTERMINAL:
+ AddNonterminal(match, nonterminal, matcher);
+ break;
+ case MatchType_DIGITS:
+ CheckDigits<DigitsMatch>(match, nonterminal, match_text, matcher);
+ break;
+ case MatchType_YEAR:
+ CheckDigits<YearMatch>(match, nonterminal, match_text, matcher);
+ break;
+ case MatchType_MONTH:
+ CheckAndAddNumericValue<MonthMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_DAY:
+ CheckAndAddNumericValue<DayMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_DAY_OF_WEEK:
+ CheckAndAddDirectOrMappedValue<DayOfWeekMatch>(match, nonterminal,
+ matcher);
+ break;
+ case MatchType_HOUR:
+ CheckAndAddNumericValue<HourMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_MINUTE:
+ CheckAndAddNumericValue<MinuteMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_SECOND:
+ CheckAndAddNumericValue<SecondMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_FRACTION_SECOND:
+ CheckDigitsAsFraction<FractionSecondMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ case MatchType_TIME_VALUE:
+ CheckTimeValue(text_, match, nonterminal, matcher);
+ break;
+ case MatchType_TIME_SPAN:
+ CheckTimeSpan(match, nonterminal, matcher);
+ break;
+ case MatchType_TIME_ZONE_NAME:
+ CheckTimeZoneName(match, nonterminal, matcher);
+ break;
+ case MatchType_TIME_ZONE_OFFSET:
+ CheckTimeZoneOffset(match, nonterminal, matcher);
+ break;
+ case MatchType_TIME_PERIOD:
+ CheckTimePeriod(text_, match, nonterminal, matcher);
+ break;
+ case MatchType_RELATIVE_DATE:
+ CheckRelativeDate(options_, match, nonterminal, matcher);
+ break;
+ case MatchType_COMBINED_DIGITS:
+ CheckCombinedDigits<CombinedDigitsMatch>(match, nonterminal, match_text,
+ matcher);
+ break;
+ default:
+ TC3_VLOG(ERROR) << "Unhandled match type: " << type;
+ }
+}
+
+} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/extractor.h b/native/annotator/grammar/dates/extractor.h
new file mode 100644
index 0000000..58c8880
--- /dev/null
+++ b/native/annotator/grammar/dates/extractor.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
+
+#include <vector>
+
+#include "annotator/grammar/dates/annotations/annotation-options.h"
+#include "annotator/grammar/dates/dates_generated.h"
+#include "utils/base/integral_types.h"
+#include "utils/grammar/callback-delegate.h"
+#include "utils/grammar/match.h"
+#include "utils/grammar/matcher.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3::dates {
+
+// A helper class for the datetime parser that extracts structured data from
+// the datetime grammar matches.
+// It handles simple sanity checking of the rule matches and interacts with the
+// grammar matcher to extract all datetime occurrences in a text.
+class DateExtractor : public grammar::CallbackDelegate {
+ public:
+ // Represents a date match for an extraction rule.
+ struct Output {
+ const ExtractionRuleParameter* rule = nullptr;
+ const grammar::Match* match = nullptr;
+ };
+
+ // Represents a date match from a range extraction rule.
+ struct RangeOutput {
+ const grammar::Match* match = nullptr;
+ const grammar::Match* from = nullptr;
+ const grammar::Match* to = nullptr;
+ };
+
+ DateExtractor(const std::vector<UnicodeText::const_iterator>& text,
+ const DateAnnotationOptions& options,
+ const DatetimeRules* datetime_rules)
+ : text_(text), options_(options), datetime_rules_(datetime_rules) {}
+
+ // Handle a rule match in the date time grammar.
+ // This checks the type of the match and does type dependent checks.
+ void MatchFound(const grammar::Match* match, grammar::CallbackId type,
+ int64 value, grammar::Matcher* matcher) override;
+
+ const std::vector<Output>& output() const { return output_; }
+ const std::vector<RangeOutput>& range_output() const { return range_output_; }
+
+ private:
+ // Extracts a date from a root rule match.
+ void HandleExtractionRuleMatch(const ExtractionRuleParameter* rule,
+ const grammar::Match* match,
+ grammar::Matcher* matcher);
+
+ // Extracts a date range from a root rule match.
+ void HandleRangeExtractionRuleMatch(const grammar::Match* match,
+ grammar::Matcher* matcher);
+
+ const std::vector<UnicodeText::const_iterator>& text_;
+ const DateAnnotationOptions& options_;
+ const DatetimeRules* datetime_rules_;
+
+ // Extraction results.
+ std::vector<Output> output_;
+ std::vector<RangeOutput> range_output_;
+};
+
+} // namespace libtextclassifier3::dates
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_
diff --git a/native/annotator/grammar/dates/parser.cc b/native/annotator/grammar/dates/parser.cc
new file mode 100644
index 0000000..37e65fc
--- /dev/null
+++ b/native/annotator/grammar/dates/parser.cc
@@ -0,0 +1,794 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/parser.h"
+
+#include "annotator/grammar/dates/extractor.h"
+#include "annotator/grammar/dates/utils/date-match.h"
+#include "annotator/grammar/dates/utils/date-utils.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/grammar/lexer.h"
+#include "utils/grammar/matcher.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::dates {
+namespace {
+
+// Helper methods to validate individual components from a date match.
+
+// Checks the validation requirement of a rule against a match.
+// For example if the rule asks for `SPELLED_MONTH`, then we check that the
+// match has the right flag.
+bool CheckMatchValidationAndFlag(
+ const grammar::Match* match, const ExtractionRuleParameter* rule,
+ const ExtractionRuleParameter_::ExtractionValidation validation,
+ const NonterminalParameter_::Flag flag) {
+ if (rule == nullptr || (rule->validation() & validation) == 0) {
+ // No validation requirement.
+ return true;
+ }
+ const NonterminalParameter* nonterminal_parameter =
+ static_cast<const NonterminalMatch*>(match)
+ ->nonterminal->nonterminal_parameter();
+ return (nonterminal_parameter != nullptr &&
+ (nonterminal_parameter->flag() & flag) != 0);
+}
+
+bool GenerateDate(const ExtractionRuleParameter* rule,
+ const grammar::Match* match, DateMatch* date) {
+ bool is_valid = true;
+
+ // Post check and assign date components.
+ grammar::Traverse(match, [rule, date, &is_valid](const grammar::Match* node) {
+ switch (node->type) {
+ case MatchType_YEAR: {
+ if (CheckMatchValidationAndFlag(
+ node, rule,
+ ExtractionRuleParameter_::ExtractionValidation_SPELLED_YEAR,
+ NonterminalParameter_::Flag_IS_SPELLED)) {
+ date->year_match = static_cast<const YearMatch*>(node);
+ date->year = date->year_match->value;
+ } else {
+ is_valid = false;
+ }
+ break;
+ }
+ case MatchType_MONTH: {
+ if (CheckMatchValidationAndFlag(
+ node, rule,
+ ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH,
+ NonterminalParameter_::Flag_IS_SPELLED)) {
+ date->month_match = static_cast<const MonthMatch*>(node);
+ date->month = date->month_match->value;
+ } else {
+ is_valid = false;
+ }
+ break;
+ }
+ case MatchType_DAY: {
+ if (CheckMatchValidationAndFlag(
+ node, rule,
+ ExtractionRuleParameter_::ExtractionValidation_SPELLED_DAY,
+ NonterminalParameter_::Flag_IS_SPELLED)) {
+ date->day_match = static_cast<const DayMatch*>(node);
+ date->day = date->day_match->value;
+ } else {
+ is_valid = false;
+ }
+ break;
+ }
+ case MatchType_DAY_OF_WEEK: {
+ date->day_of_week_match = static_cast<const DayOfWeekMatch*>(node);
+ date->day_of_week =
+ static_cast<DayOfWeek>(date->day_of_week_match->value);
+ break;
+ }
+ case MatchType_TIME_VALUE: {
+ date->time_value_match = static_cast<const TimeValueMatch*>(node);
+ date->hour = date->time_value_match->hour;
+ date->minute = date->time_value_match->minute;
+ date->second = date->time_value_match->second;
+ date->fraction_second = date->time_value_match->fraction_second;
+ return false;
+ }
+ case MatchType_TIME_SPAN: {
+ date->time_span_match = static_cast<const TimeSpanMatch*>(node);
+ date->time_span_code = date->time_span_match->time_span_code;
+ return false;
+ }
+ case MatchType_TIME_ZONE_NAME: {
+ date->time_zone_name_match =
+ static_cast<const TimeZoneNameMatch*>(node);
+ date->time_zone_code = date->time_zone_name_match->time_zone_code;
+ return false;
+ }
+ case MatchType_TIME_ZONE_OFFSET: {
+ date->time_zone_offset_match =
+ static_cast<const TimeZoneOffsetMatch*>(node);
+ date->time_zone_offset = date->time_zone_offset_match->time_zone_offset;
+ return false;
+ }
+ case MatchType_RELATIVE_DATE: {
+ date->relative_match = static_cast<const RelativeMatch*>(node);
+ return false;
+ }
+ case MatchType_COMBINED_DIGITS: {
+ date->combined_digits_match =
+ static_cast<const CombinedDigitsMatch*>(node);
+ if (date->combined_digits_match->HasYear()) {
+ date->year = date->combined_digits_match->GetYear();
+ }
+ if (date->combined_digits_match->HasMonth()) {
+ date->month = date->combined_digits_match->GetMonth();
+ }
+ if (date->combined_digits_match->HasDay()) {
+ date->day = date->combined_digits_match->GetDay();
+ }
+ if (date->combined_digits_match->HasHour()) {
+ date->hour = date->combined_digits_match->GetHour();
+ }
+ if (date->combined_digits_match->HasMinute()) {
+ date->minute = date->combined_digits_match->GetMinute();
+ }
+ if (date->combined_digits_match->HasSecond()) {
+ date->second = date->combined_digits_match->GetSecond();
+ }
+ return false;
+ }
+ default:
+ // Expand node further.
+ return true;
+ }
+
+ return false;
+ });
+
+ if (is_valid) {
+ date->begin = match->codepoint_span.first;
+ date->end = match->codepoint_span.second;
+ date->priority = rule ? rule->priority_delta() : 0;
+ date->annotator_priority_score =
+ rule ? rule->annotator_priority_score() : 0.0;
+ }
+ return is_valid;
+}
+
+bool GenerateFromOrToDateRange(const grammar::Match* match, DateMatch* date) {
+ return GenerateDate(
+ /*rule=*/(
+ match->type == MatchType_DATETIME
+ ? static_cast<const ExtractionMatch*>(match)->extraction_rule
+ : nullptr),
+ match, date);
+}
+
+bool GenerateDateRange(const grammar::Match* match, const grammar::Match* from,
+ const grammar::Match* to, DateRangeMatch* date_range) {
+ if (!GenerateFromOrToDateRange(from, &date_range->from)) {
+ TC3_LOG(WARNING) << "Failed to generate date for `from`.";
+ return false;
+ }
+ if (!GenerateFromOrToDateRange(to, &date_range->to)) {
+ TC3_LOG(WARNING) << "Failed to generate date for `to`.";
+ return false;
+ }
+ date_range->begin = match->codepoint_span.first;
+ date_range->end = match->codepoint_span.second;
+ return true;
+}
+
+bool NormalizeHour(DateMatch* date) {
+ if (date->time_span_match == nullptr) {
+ // Nothing to do.
+ return true;
+ }
+ return NormalizeHourByTimeSpan(date->time_span_match->time_span_spec, date);
+}
+
+void CheckAndSetAmbiguousHour(DateMatch* date) {
+ if (date->HasHour()) {
+ // Use am-pm ambiguity as default.
+ if (!date->HasTimeSpanCode() && date->hour >= 1 && date->hour <= 12 &&
+ !(date->time_value_match != nullptr &&
+ date->time_value_match->hour_match != nullptr &&
+ date->time_value_match->hour_match->is_zero_prefixed)) {
+ date->SetAmbiguousHourProperties(2, 12);
+ }
+ }
+}
+
+// Normalizes a date candidate.
+// Returns whether the candidate was successfully normalized.
+bool NormalizeDate(DateMatch* date) {
+ // Normalize hour.
+ if (!NormalizeHour(date)) {
+ TC3_VLOG(ERROR) << "Hour normalization (according to time-span) failed."
+ << date->DebugString();
+ return false;
+ }
+ CheckAndSetAmbiguousHour(date);
+ if (!date->IsValid()) {
+ TC3_VLOG(ERROR) << "Fields inside date instance are ill-formed "
+ << date->DebugString();
+ }
+ return true;
+}
+
+// Copies the field from one DateMatch to another whose field is null. for
+// example: if the from is "May 1, 8pm", and the to is "9pm", "May 1" will be
+// copied to "to". Now we only copy fields for date range requirement.fv
+void CopyFieldsForDateMatch(const DateMatch& from, DateMatch* to) {
+ if (from.time_span_match != nullptr && to->time_span_match == nullptr) {
+ to->time_span_match = from.time_span_match;
+ to->time_span_code = from.time_span_code;
+ }
+ if (from.month_match != nullptr && to->month_match == nullptr) {
+ to->month_match = from.month_match;
+ to->month = from.month;
+ }
+}
+
+// Normalizes a date range candidate.
+// Returns whether the date range was successfully normalized.
+bool NormalizeDateRange(DateRangeMatch* date_range) {
+ CopyFieldsForDateMatch(date_range->from, &date_range->to);
+ CopyFieldsForDateMatch(date_range->to, &date_range->from);
+ return (NormalizeDate(&date_range->from) && NormalizeDate(&date_range->to));
+}
+
+bool CheckDate(const DateMatch& date, const ExtractionRuleParameter* rule) {
+ // It's possible that "time_zone_name_match == NULL" when
+ // "HasTimeZoneCode() == true", or "time_zone_offset_match == NULL" when
+ // "HasTimeZoneOffset() == true" due to inference between endpoints, so we
+ // must check if they really exist before using them.
+ if (date.HasTimeZoneOffset()) {
+ if (date.HasTimeZoneCode()) {
+ if (date.time_zone_name_match != nullptr) {
+ TC3_CHECK(date.time_zone_name_match->time_zone_name_spec != nullptr);
+ const TimeZoneNameSpec* spec =
+ date.time_zone_name_match->time_zone_name_spec;
+ if (!spec->is_utc()) {
+ return false;
+ }
+ if (!spec->is_abbreviation()) {
+ return false;
+ }
+ }
+ } else if (date.time_zone_offset_match != nullptr) {
+ TC3_CHECK(date.time_zone_offset_match->time_zone_offset_param != nullptr);
+ const TimeZoneOffsetParameter* param =
+ date.time_zone_offset_match->time_zone_offset_param;
+ if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H ||
+ param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH) {
+ return false;
+ }
+ if (!(rule->validation() &
+ ExtractionRuleParameter_::
+ ExtractionValidation_ALLOW_UNCONFIDENT_TIME_ZONE)) {
+ if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H_MM ||
+ param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH_MM ||
+ param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HMM) {
+ return false;
+ }
+ }
+ }
+ }
+
+ // Case: 1 April could be extracted as year 1, month april.
+ // We simply remove this case.
+ if (!date.HasBcAd() && date.year_match != nullptr && date.year < 1000) {
+ // We allow case like 11/5/01
+ if (date.HasMonth() && date.HasDay() &&
+ date.year_match->count_of_digits == 2) {
+ } else {
+ return false;
+ }
+ }
+
+ // Ignore the date if the year is larger than 9999 (The maximum number of 4
+ // digits).
+ if (date.year_match != nullptr && date.year > 9999) {
+ TC3_VLOG(ERROR) << "Year is greater than 9999.";
+ return false;
+ }
+
+ // Case: spelled may could be month 5, it also used very common as modal
+ // verbs. We ignore spelled may as month.
+ if ((rule->validation() &
+ ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH) &&
+ date.month == 5 && !date.HasYear() && !date.HasDay()) {
+ return false;
+ }
+
+ return true;
+}
+
+bool CheckContext(const std::vector<UnicodeText::const_iterator>& text,
+ const DateExtractor::Output& output) {
+ const uint32 validation = output.rule->validation();
+
+ // Nothing to check if we don't have any validation requirements for the
+ // span boundaries.
+ if ((validation &
+ (ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND |
+ ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND)) == 0) {
+ return true;
+ }
+
+ const int begin = output.match->codepoint_span.first;
+ const int end = output.match->codepoint_span.second;
+
+ // So far, we only check that the adjacent character cannot be a separator,
+ // like /, - or .
+ if ((validation &
+ ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND) != 0) {
+ if (begin > 0 && (*text[begin - 1] == '/' || *text[begin - 1] == '-' ||
+ *text[begin - 1] == ':')) {
+ return false;
+ }
+ }
+ if ((validation &
+ ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND) != 0) {
+ // Last valid codepoint is at text.size() - 2 as we added the end position
+ // of text for easier span extraction.
+ if (end < text.size() - 1 &&
+ (*text[end] == '/' || *text[end] == '-' || *text[end] == ':')) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Validates a date match. Returns true if the candidate is valid.
+bool ValidateDate(const std::vector<UnicodeText::const_iterator>& text,
+ const DateExtractor::Output& output, const DateMatch& date) {
+ if (!CheckDate(date, output.rule)) {
+ return false;
+ }
+ if (!CheckContext(text, output)) {
+ return false;
+ }
+ return true;
+}
+
+// Builds matched date instances from the grammar output.
+std::vector<DateMatch> BuildDateMatches(
+ const std::vector<UnicodeText::const_iterator>& text,
+ const std::vector<DateExtractor::Output>& outputs) {
+ std::vector<DateMatch> result;
+ for (const DateExtractor::Output& output : outputs) {
+ DateMatch date;
+ if (GenerateDate(output.rule, output.match, &date)) {
+ if (!NormalizeDate(&date)) {
+ continue;
+ }
+ if (!ValidateDate(text, output, date)) {
+ continue;
+ }
+ result.push_back(date);
+ }
+ }
+ return result;
+}
+
+// Builds matched date range instances from the grammar output.
+std::vector<DateRangeMatch> BuildDateRangeMatches(
+ const std::vector<UnicodeText::const_iterator>& text,
+ const std::vector<DateExtractor::RangeOutput>& range_outputs) {
+ std::vector<DateRangeMatch> result;
+ for (const DateExtractor::RangeOutput& range_output : range_outputs) {
+ DateRangeMatch date_range;
+ if (GenerateDateRange(range_output.match, range_output.from,
+ range_output.to, &date_range)) {
+ if (!NormalizeDateRange(&date_range)) {
+ continue;
+ }
+ result.push_back(date_range);
+ }
+ }
+ return result;
+}
+
+template <typename T>
+void RemoveDeletedMatches(const std::vector<bool>& removed,
+ std::vector<T>* matches) {
+ int input = 0;
+ for (int next = 0; next < matches->size(); ++next) {
+ if (removed[next]) {
+ continue;
+ }
+ if (input != next) {
+ (*matches)[input] = (*matches)[next];
+ }
+ input++;
+ }
+ matches->resize(input);
+}
+
+// Removes duplicated date or date range instances.
+// Overlapping date and date ranges are not considered here.
+template <typename T>
+void RemoveDuplicatedDates(std::vector<T>* matches) {
+ // Assumption: matches are sorted ascending by (begin, end).
+ std::vector<bool> removed(matches->size(), false);
+ for (int i = 0; i < matches->size(); i++) {
+ if (removed[i]) {
+ continue;
+ }
+ const T& candidate = matches->at(i);
+ for (int j = i + 1; j < matches->size(); j++) {
+ if (removed[j]) {
+ continue;
+ }
+ const T& next = matches->at(j);
+
+ // Not overlapping.
+ if (next.begin >= candidate.end) {
+ break;
+ }
+
+ // If matching the same span of text, then check the priority.
+ if (candidate.begin == next.begin && candidate.end == next.end) {
+ if (candidate.GetPriority() < next.GetPriority()) {
+ removed[i] = true;
+ break;
+ } else {
+ removed[j] = true;
+ continue;
+ }
+ }
+
+ // Checks if `next` is fully covered by fields of `candidate`.
+ if (next.end <= candidate.end) {
+ removed[j] = true;
+ continue;
+ }
+
+ // Checks whether `candidate`/`next` is a refinement.
+ if (IsRefinement(candidate, next)) {
+ removed[j] = true;
+ continue;
+ } else if (IsRefinement(next, candidate)) {
+ removed[i] = true;
+ break;
+ }
+ }
+ }
+ RemoveDeletedMatches(removed, matches);
+}
+
+// Filters out simple overtriggering simple matches.
+bool IsBlacklistedDate(const UniLib& unilib,
+ const std::vector<UnicodeText::const_iterator>& text,
+ const DateMatch& match) {
+ const int begin = match.begin;
+ const int end = match.end;
+ if (end - begin != 3) {
+ return false;
+ }
+
+ std::string text_lower =
+ unilib
+ .ToLowerText(
+ UTF8ToUnicodeText(text[begin].utf8_data(),
+ text[end].utf8_data() - text[begin].utf8_data(),
+ /*do_copy=*/false))
+ .ToUTF8String();
+
+ // "sun" is not a good abbreviation for a standalone day of the week.
+ if (match.IsStandaloneRelativeDayOfWeek() &&
+ (text_lower == "sun" || text_lower == "mon")) {
+ return true;
+ }
+
+ // "mar" is not a good abbreviation for single month.
+ if (match.HasMonth() && text_lower == "mar") {
+ return true;
+ }
+
+ return false;
+}
+
+// Checks if two date matches are adjacent and mergeable.
+bool AreDateMatchesAdjacentAndMergeable(
+ const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
+ const std::vector<std::string>& ignored_spans, const DateMatch& prev,
+ const DateMatch& next) {
+ // Check the context between the two matches.
+ if (next.begin <= prev.end) {
+ // The two matches are not adjacent.
+ return false;
+ }
+ UnicodeText span;
+ for (int i = prev.end; i < next.begin; i++) {
+ const char32 codepoint = *text[i];
+ if (unilib.IsWhitespace(codepoint)) {
+ continue;
+ }
+ span.push_back(unilib.ToLower(codepoint));
+ }
+ if (span.empty()) {
+ return true;
+ }
+ const std::string span_text = span.ToUTF8String();
+ bool matched = false;
+ for (const std::string& ignored_span : ignored_spans) {
+ if (span_text == ignored_span) {
+ matched = true;
+ break;
+ }
+ }
+ if (!matched) {
+ return false;
+ }
+ return IsDateMatchMergeable(prev, next);
+}
+
+// Merges adjacent date and date range.
+// For e.g. Monday, 5-10pm, the date "Monday" and the time range "5-10pm" will
+// be merged
+void MergeDateRangeAndDate(const UniLib& unilib,
+ const std::vector<UnicodeText::const_iterator>& text,
+ const std::vector<std::string>& ignored_spans,
+ const std::vector<DateMatch>& dates,
+ std::vector<DateRangeMatch>* date_ranges) {
+ // For each range, check the date before or after the it to see if they could
+ // be merged. Both the range and date array are sorted, so we only need to
+ // scan the date array once.
+ int next_date = 0;
+ for (int i = 0; i < date_ranges->size(); i++) {
+ DateRangeMatch* date_range = &date_ranges->at(i);
+ // So far we only merge time range with a date.
+ if (!date_range->from.HasHour()) {
+ continue;
+ }
+
+ for (; next_date < dates.size(); next_date++) {
+ const DateMatch& date = dates[next_date];
+
+ // If the range is before the date, we check whether `date_range->to` can
+ // be merged with the date.
+ if (date_range->end <= date.begin) {
+ DateMatch merged_date = date;
+ if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
+ date_range->to, date)) {
+ MergeDateMatch(date_range->to, &merged_date, /*update_span=*/true);
+ date_range->to = merged_date;
+ date_range->end = date_range->to.end;
+ MergeDateMatch(date, &date_range->from, /*update_span=*/false);
+ next_date++;
+
+ // Check the second date after the range to see if it could be merged
+ // further. For example: 10-11pm, Monday, May 15. 10-11pm is merged
+ // with Monday and then we check that it could be merged with May 15
+ // as well.
+ if (next_date < dates.size()) {
+ DateMatch next_match = dates[next_date];
+ if (AreDateMatchesAdjacentAndMergeable(
+ unilib, text, ignored_spans, date_range->to, next_match)) {
+ MergeDateMatch(date_range->to, &next_match, /*update_span=*/true);
+ date_range->to = next_match;
+ date_range->end = date_range->to.end;
+ MergeDateMatch(dates[next_date], &date_range->from,
+ /*update_span=*/false);
+ next_date++;
+ }
+ }
+ }
+ // Since the range is before the date, we try to check if the next range
+ // could be merged with the current date.
+ break;
+ } else if (date_range->end > date.end && date_range->begin > date.begin) {
+ // If the range is after the date, we check if `date_range.from` can be
+ // merged with the date. Here is a special case, the date before range
+ // could be partially overlapped. This is because the range.from could
+ // be extracted as year in date. For example: March 3, 10-11pm is
+ // extracted as date March 3, 2010 and the range 10-11pm. In this
+ // case, we simply clear the year from date.
+ DateMatch merged_date = date;
+ if (date.HasYear() &&
+ date.year_match->codepoint_span.second > date_range->begin) {
+ merged_date.year_match = nullptr;
+ merged_date.year = NO_VAL;
+ merged_date.end = date.year_match->match_offset;
+ }
+ // Check and merge the range and the date before the range.
+ if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans,
+ merged_date, date_range->from)) {
+ MergeDateMatch(merged_date, &date_range->from, /*update_span=*/true);
+ date_range->begin = date_range->from.begin;
+ MergeDateMatch(merged_date, &date_range->to, /*update_span=*/false);
+
+ // Check if the second date before the range can be merged as well.
+ if (next_date > 0) {
+ DateMatch prev_match = dates[next_date - 1];
+ if (prev_match.end <= date_range->from.begin) {
+ if (AreDateMatchesAdjacentAndMergeable(unilib, text,
+ ignored_spans, prev_match,
+ date_range->from)) {
+ MergeDateMatch(prev_match, &date_range->from,
+ /*update_span=*/true);
+ date_range->begin = date_range->from.begin;
+ MergeDateMatch(prev_match, &date_range->to,
+ /*update_span=*/false);
+ }
+ }
+ }
+ next_date++;
+ break;
+ } else {
+ // Since the date is before the date range, we move to the next date
+ // to check if it could be merged with the current range.
+ continue;
+ }
+ } else {
+ // The date is either fully overlapped by the date range or the date
+ // span end is after the date range. Move to the next date in both
+ // cases.
+ }
+ }
+ }
+}
+
+// Removes the dates which are part of a range. e.g. in "May 1 - 3", the date
+// "May 1" is fully contained in the range.
+void RemoveOverlappedDateByRange(const std::vector<DateRangeMatch>& ranges,
+ std::vector<DateMatch>* dates) {
+ int next_date = 0;
+ std::vector<bool> removed(dates->size(), false);
+ for (int i = 0; i < ranges.size(); ++i) {
+ const auto& range = ranges[i];
+ for (; next_date < dates->size(); ++next_date) {
+ const auto& date = dates->at(next_date);
+ // So far we don't touch the partially overlapped case.
+ if (date.begin >= range.begin && date.end <= range.end) {
+ // Fully contained.
+ removed[next_date] = true;
+ } else if (date.end <= range.begin) {
+ continue; // date is behind range, go to next date
+ } else if (date.begin >= range.end) {
+ break; // range is behind date, go to next range
+ }
+ }
+ }
+ RemoveDeletedMatches(removed, dates);
+}
+
+// Converts candidate dates and date ranges.
+void FillDateInstances(
+ const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text,
+ const DateAnnotationOptions& options, std::vector<DateMatch>* date_matches,
+ std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
+ int i = 0;
+ for (int j = 1; j < date_matches->size(); j++) {
+ if (options.merge_adjacent_components &&
+ AreDateMatchesAdjacentAndMergeable(unilib, text, options.ignored_spans,
+ date_matches->at(i),
+ date_matches->at(j))) {
+ MergeDateMatch(date_matches->at(i), &date_matches->at(j), true);
+ } else {
+ if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
+ DatetimeParseResultSpan datetime_parse_result_span;
+ FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
+ datetime_parse_result_spans->push_back(datetime_parse_result_span);
+ }
+ }
+ i = j;
+ }
+ if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) {
+ DatetimeParseResultSpan datetime_parse_result_span;
+ FillDateInstance(date_matches->at(i), &datetime_parse_result_span);
+ datetime_parse_result_spans->push_back(datetime_parse_result_span);
+ }
+}
+
+void FillDateRangeInstances(
+ const std::vector<DateRangeMatch>& date_range_matches,
+ std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) {
+ for (const DateRangeMatch& date_range_match : date_range_matches) {
+ DatetimeParseResultSpan datetime_parse_result_span;
+ FillDateRangeInstance(date_range_match, &datetime_parse_result_span);
+ datetime_parse_result_spans->push_back(datetime_parse_result_span);
+ }
+}
+
+// Fills `DatetimeParseResultSpan` from `DateMatch` and `DateRangeMatch`
+// instances.
+std::vector<DatetimeParseResultSpan> GetOutputAsAnnotationList(
+ const UniLib& unilib, const DateExtractor& extractor,
+ const std::vector<UnicodeText::const_iterator>& text,
+ const DateAnnotationOptions& options) {
+ std::vector<DatetimeParseResultSpan> datetime_parse_result_spans;
+ std::vector<DateMatch> date_matches =
+ BuildDateMatches(text, extractor.output());
+
+ std::sort(
+ date_matches.begin(), date_matches.end(),
+ // Order by increasing begin, and decreasing end (decreasing length).
+ [](const DateMatch& a, const DateMatch& b) {
+ return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
+ });
+
+ if (!date_matches.empty()) {
+ RemoveDuplicatedDates(&date_matches);
+ }
+
+ if (options.enable_date_range) {
+ std::vector<DateRangeMatch> date_range_matches =
+ BuildDateRangeMatches(text, extractor.range_output());
+
+ if (!date_range_matches.empty()) {
+ std::sort(
+ date_range_matches.begin(), date_range_matches.end(),
+ // Order by increasing begin, and decreasing end (decreasing length).
+ [](const DateRangeMatch& a, const DateRangeMatch& b) {
+ return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end));
+ });
+ RemoveDuplicatedDates(&date_range_matches);
+ }
+
+ if (!date_matches.empty()) {
+ MergeDateRangeAndDate(unilib, text, options.ignored_spans, date_matches,
+ &date_range_matches);
+ RemoveOverlappedDateByRange(date_range_matches, &date_matches);
+ }
+ FillDateRangeInstances(date_range_matches, &datetime_parse_result_spans);
+ }
+
+ if (!date_matches.empty()) {
+ FillDateInstances(unilib, text, options, &date_matches,
+ &datetime_parse_result_spans);
+ }
+ return datetime_parse_result_spans;
+}
+
+} // namespace
+
+std::vector<DatetimeParseResultSpan> DateParser::Parse(
+ StringPiece text, const std::vector<Token>& tokens,
+ const std::vector<Locale>& locales,
+ const DateAnnotationOptions& options) const {
+ std::vector<UnicodeText::const_iterator> codepoint_offsets;
+ const UnicodeText text_unicode = UTF8ToUnicodeText(text,
+ /*do_copy=*/false);
+ for (auto it = text_unicode.begin(); it != text_unicode.end(); it++) {
+ codepoint_offsets.push_back(it);
+ }
+ codepoint_offsets.push_back(text_unicode.end());
+ DateExtractor extractor(codepoint_offsets, options, datetime_rules_);
+ // Select locale matching rules.
+ // Only use a shard if locales match or the shard doesn't specify a locale
+ // restriction.
+ std::vector<const grammar::RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(datetime_rules_->rules(), rules_locales_,
+ locales);
+ if (locale_rules.empty()) {
+ return {};
+ }
+ grammar::Matcher matcher(&unilib_, datetime_rules_->rules(), locale_rules,
+ &extractor);
+ lexer_.Process(text_unicode, tokens, /*annotations=*/nullptr, &matcher);
+ return GetOutputAsAnnotationList(unilib_, extractor, codepoint_offsets,
+ options);
+}
+
+} // namespace libtextclassifier3::dates
diff --git a/native/annotator/grammar/dates/parser.h b/native/annotator/grammar/dates/parser.h
new file mode 100644
index 0000000..be919df
--- /dev/null
+++ b/native/annotator/grammar/dates/parser.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
+
+#include <vector>
+
+#include "annotator/grammar/dates/annotations/annotation-options.h"
+#include "annotator/grammar/dates/annotations/annotation.h"
+#include "annotator/grammar/dates/dates_generated.h"
+#include "annotator/grammar/dates/utils/date-match.h"
+#include "utils/grammar/lexer.h"
+#include "utils/grammar/rules-utils.h"
+#include "utils/i18n/locale.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::dates {
+
+// Parses datetime expressions in the input with the datetime grammar and
+// constructs, validates, deduplicates and normalizes date time annotations.
+class DateParser {
+ public:
+ explicit DateParser(const UniLib* unilib, const DatetimeRules* datetime_rules)
+ : unilib_(*unilib),
+ lexer_(unilib, datetime_rules->rules()),
+ datetime_rules_(datetime_rules),
+ rules_locales_(ParseRulesLocales(datetime_rules->rules())) {}
+
+ // Parses the dates in the input. Makes sure that the results do not
+ // overlap.
+ std::vector<DatetimeParseResultSpan> Parse(
+ StringPiece text, const std::vector<Token>& tokens,
+ const std::vector<Locale>& locales,
+ const DateAnnotationOptions& options) const;
+
+ private:
+ const UniLib& unilib_;
+ const grammar::Lexer lexer_;
+
+ // The datetime grammar.
+ const DatetimeRules* datetime_rules_;
+
+ // Pre-parsed locales of the rules.
+ const std::vector<std::vector<Locale>> rules_locales_;
+};
+
+} // namespace libtextclassifier3::dates
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_
diff --git a/native/annotator/grammar/dates/timezone-code.fbs b/native/annotator/grammar/dates/timezone-code.fbs
new file mode 100755
index 0000000..ff615ee
--- /dev/null
+++ b/native/annotator/grammar/dates/timezone-code.fbs
@@ -0,0 +1,593 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.dates;
+enum TimezoneCode : int {
+ TIMEZONE_CODE_NONE = -1,
+ ETC_UNKNOWN = 0,
+ PST8PDT = 1,
+ // Delegate.
+
+ AFRICA_ABIDJAN = 2,
+ AFRICA_ACCRA = 3,
+ AFRICA_ADDIS_ABABA = 4,
+ AFRICA_ALGIERS = 5,
+ AFRICA_ASMARA = 6,
+ AFRICA_BAMAKO = 7,
+ // Delegate.
+
+ AFRICA_BANGUI = 8,
+ AFRICA_BANJUL = 9,
+ AFRICA_BISSAU = 10,
+ AFRICA_BLANTYRE = 11,
+ AFRICA_BRAZZAVILLE = 12,
+ AFRICA_BUJUMBURA = 13,
+ EGYPT = 14,
+ // Delegate.
+
+ AFRICA_CASABLANCA = 15,
+ AFRICA_CEUTA = 16,
+ AFRICA_CONAKRY = 17,
+ AFRICA_DAKAR = 18,
+ AFRICA_DAR_ES_SALAAM = 19,
+ AFRICA_DJIBOUTI = 20,
+ AFRICA_DOUALA = 21,
+ AFRICA_EL_AAIUN = 22,
+ AFRICA_FREETOWN = 23,
+ AFRICA_GABORONE = 24,
+ AFRICA_HARARE = 25,
+ AFRICA_JOHANNESBURG = 26,
+ AFRICA_KAMPALA = 27,
+ AFRICA_KHARTOUM = 28,
+ AFRICA_KIGALI = 29,
+ AFRICA_KINSHASA = 30,
+ AFRICA_LAGOS = 31,
+ AFRICA_LIBREVILLE = 32,
+ AFRICA_LOME = 33,
+ AFRICA_LUANDA = 34,
+ AFRICA_LUBUMBASHI = 35,
+ AFRICA_LUSAKA = 36,
+ AFRICA_MALABO = 37,
+ AFRICA_MAPUTO = 38,
+ AFRICA_MASERU = 39,
+ AFRICA_MBABANE = 40,
+ AFRICA_MOGADISHU = 41,
+ AFRICA_MONROVIA = 42,
+ AFRICA_NAIROBI = 43,
+ AFRICA_NDJAMENA = 44,
+ AFRICA_NIAMEY = 45,
+ AFRICA_NOUAKCHOTT = 46,
+ AFRICA_OUAGADOUGOU = 47,
+ AFRICA_PORTO_NOVO = 48,
+ AFRICA_SAO_TOME = 49,
+ LIBYA = 51,
+ // Delegate.
+
+ AFRICA_TUNIS = 52,
+ AFRICA_WINDHOEK = 53,
+ US_ALEUTIAN = 54,
+ // Delegate.
+
+ US_ALASKA = 55,
+ // Delegate.
+
+ AMERICA_ANGUILLA = 56,
+ AMERICA_ANTIGUA = 57,
+ AMERICA_ARAGUAINA = 58,
+ AMERICA_BUENOS_AIRES = 59,
+ AMERICA_CATAMARCA = 60,
+ AMERICA_CORDOBA = 62,
+ AMERICA_JUJUY = 63,
+ AMERICA_ARGENTINA_LA_RIOJA = 64,
+ AMERICA_MENDOZA = 65,
+ AMERICA_ARGENTINA_RIO_GALLEGOS = 66,
+ AMERICA_ARGENTINA_SAN_JUAN = 67,
+ AMERICA_ARGENTINA_TUCUMAN = 68,
+ AMERICA_ARGENTINA_USHUAIA = 69,
+ AMERICA_ARUBA = 70,
+ AMERICA_ASUNCION = 71,
+ AMERICA_BAHIA = 72,
+ AMERICA_BARBADOS = 73,
+ AMERICA_BELEM = 74,
+ AMERICA_BELIZE = 75,
+ AMERICA_BOA_VISTA = 76,
+ AMERICA_BOGOTA = 77,
+ AMERICA_BOISE = 78,
+ AMERICA_CAMBRIDGE_BAY = 79,
+ AMERICA_CAMPO_GRANDE = 80,
+ AMERICA_CANCUN = 81,
+ AMERICA_CARACAS = 82,
+ AMERICA_CAYENNE = 83,
+ AMERICA_CAYMAN = 84,
+ CST6CDT = 85,
+ // Delegate.
+
+ AMERICA_CHIHUAHUA = 86,
+ AMERICA_COSTA_RICA = 87,
+ AMERICA_CUIABA = 88,
+ AMERICA_CURACAO = 89,
+ AMERICA_DANMARKSHAVN = 90,
+ AMERICA_DAWSON = 91,
+ AMERICA_DAWSON_CREEK = 92,
+ NAVAJO = 93,
+ // Delegate.
+
+ US_MICHIGAN = 94,
+ // Delegate.
+
+ AMERICA_DOMINICA = 95,
+ CANADA_MOUNTAIN = 96,
+ // Delegate.
+
+ AMERICA_EIRUNEPE = 97,
+ AMERICA_EL_SALVADOR = 98,
+ AMERICA_FORTALEZA = 99,
+ AMERICA_GLACE_BAY = 100,
+ AMERICA_GODTHAB = 101,
+ AMERICA_GOOSE_BAY = 102,
+ AMERICA_GRAND_TURK = 103,
+ AMERICA_GRENADA = 104,
+ AMERICA_GUADELOUPE = 105,
+ AMERICA_GUATEMALA = 106,
+ AMERICA_GUAYAQUIL = 107,
+ AMERICA_GUYANA = 108,
+ AMERICA_HALIFAX = 109,
+ // Delegate.
+
+ CUBA = 110,
+ // Delegate.
+
+ AMERICA_HERMOSILLO = 111,
+ AMERICA_KNOX_IN = 113,
+ // Delegate.
+
+ AMERICA_INDIANA_MARENGO = 114,
+ US_EAST_INDIANA = 115,
+ AMERICA_INDIANA_VEVAY = 116,
+ AMERICA_INUVIK = 117,
+ AMERICA_IQALUIT = 118,
+ JAMAICA = 119,
+ // Delegate.
+
+ AMERICA_JUNEAU = 120,
+ AMERICA_KENTUCKY_MONTICELLO = 122,
+ AMERICA_LA_PAZ = 123,
+ AMERICA_LIMA = 124,
+ AMERICA_LOUISVILLE = 125,
+ AMERICA_MACEIO = 126,
+ AMERICA_MANAGUA = 127,
+ BRAZIL_WEST = 128,
+ // Delegate.
+
+ AMERICA_MARTINIQUE = 129,
+ MEXICO_BAJASUR = 130,
+ // Delegate.
+
+ AMERICA_MENOMINEE = 131,
+ AMERICA_MERIDA = 132,
+ MEXICO_GENERAL = 133,
+ // Delegate.
+
+ AMERICA_MIQUELON = 134,
+ AMERICA_MONTERREY = 135,
+ AMERICA_MONTEVIDEO = 136,
+ AMERICA_MONTREAL = 137,
+ AMERICA_MONTSERRAT = 138,
+ AMERICA_NASSAU = 139,
+ EST5EDT = 140,
+ // Delegate.
+
+ AMERICA_NIPIGON = 141,
+ AMERICA_NOME = 142,
+ AMERICA_NORONHA = 143,
+ // Delegate.
+
+ AMERICA_NORTH_DAKOTA_CENTER = 144,
+ AMERICA_PANAMA = 145,
+ AMERICA_PANGNIRTUNG = 146,
+ AMERICA_PARAMARIBO = 147,
+ US_ARIZONA = 148,
+ // Delegate.
+
+ AMERICA_PORT_AU_PRINCE = 149,
+ AMERICA_PORT_OF_SPAIN = 150,
+ AMERICA_PORTO_VELHO = 151,
+ AMERICA_PUERTO_RICO = 152,
+ AMERICA_RAINY_RIVER = 153,
+ AMERICA_RANKIN_INLET = 154,
+ AMERICA_RECIFE = 155,
+ AMERICA_REGINA = 156,
+ // Delegate.
+
+ BRAZIL_ACRE = 157,
+ AMERICA_SANTIAGO = 158,
+ // Delegate.
+
+ AMERICA_SANTO_DOMINGO = 159,
+ BRAZIL_EAST = 160,
+ // Delegate.
+
+ AMERICA_SCORESBYSUND = 161,
+ AMERICA_ST_JOHNS = 163,
+ // Delegate.
+
+ AMERICA_ST_KITTS = 164,
+ AMERICA_ST_LUCIA = 165,
+ AMERICA_VIRGIN = 166,
+ // Delegate.
+
+ AMERICA_ST_VINCENT = 167,
+ AMERICA_SWIFT_CURRENT = 168,
+ AMERICA_TEGUCIGALPA = 169,
+ AMERICA_THULE = 170,
+ AMERICA_THUNDER_BAY = 171,
+ AMERICA_TIJUANA = 172,
+ CANADA_EASTERN = 173,
+ // Delegate.
+
+ AMERICA_TORTOLA = 174,
+ CANADA_PACIFIC = 175,
+ // Delegate.
+
+ CANADA_YUKON = 176,
+ // Delegate.
+
+ CANADA_CENTRAL = 177,
+ // Delegate.
+
+ AMERICA_YAKUTAT = 178,
+ AMERICA_YELLOWKNIFE = 179,
+ ANTARCTICA_CASEY = 180,
+ ANTARCTICA_DAVIS = 181,
+ ANTARCTICA_DUMONTDURVILLE = 182,
+ ANTARCTICA_MAWSON = 183,
+ ANTARCTICA_MCMURDO = 184,
+ ANTARCTICA_PALMER = 185,
+ ANTARCTICA_ROTHERA = 186,
+ ANTARCTICA_SYOWA = 188,
+ ANTARCTICA_VOSTOK = 189,
+ ATLANTIC_JAN_MAYEN = 190,
+ // Delegate.
+
+ ASIA_ADEN = 191,
+ ASIA_ALMATY = 192,
+ ASIA_AMMAN = 193,
+ ASIA_ANADYR = 194,
+ ASIA_AQTAU = 195,
+ ASIA_AQTOBE = 196,
+ ASIA_ASHGABAT = 197,
+ // Delegate.
+
+ ASIA_BAGHDAD = 198,
+ ASIA_BAHRAIN = 199,
+ ASIA_BAKU = 200,
+ ASIA_BANGKOK = 201,
+ ASIA_BEIRUT = 202,
+ ASIA_BISHKEK = 203,
+ ASIA_BRUNEI = 204,
+ ASIA_KOLKATA = 205,
+ // Delegate.
+
+ ASIA_CHOIBALSAN = 206,
+ ASIA_COLOMBO = 208,
+ ASIA_DAMASCUS = 209,
+ ASIA_DACCA = 210,
+ ASIA_DILI = 211,
+ ASIA_DUBAI = 212,
+ ASIA_DUSHANBE = 213,
+ ASIA_GAZA = 214,
+ HONGKONG = 216,
+ // Delegate.
+
+ ASIA_HOVD = 217,
+ ASIA_IRKUTSK = 218,
+ ASIA_JAKARTA = 220,
+ ASIA_JAYAPURA = 221,
+ ISRAEL = 222,
+ // Delegate.
+
+ ASIA_KABUL = 223,
+ ASIA_KAMCHATKA = 224,
+ ASIA_KARACHI = 225,
+ ASIA_KATMANDU = 227,
+ ASIA_KRASNOYARSK = 228,
+ ASIA_KUALA_LUMPUR = 229,
+ ASIA_KUCHING = 230,
+ ASIA_KUWAIT = 231,
+ ASIA_MACAO = 232,
+ ASIA_MAGADAN = 233,
+ ASIA_MAKASSAR = 234,
+ // Delegate.
+
+ ASIA_MANILA = 235,
+ ASIA_MUSCAT = 236,
+ ASIA_NICOSIA = 237,
+ // Delegate.
+
+ ASIA_NOVOSIBIRSK = 238,
+ ASIA_OMSK = 239,
+ ASIA_ORAL = 240,
+ ASIA_PHNOM_PENH = 241,
+ ASIA_PONTIANAK = 242,
+ ASIA_PYONGYANG = 243,
+ ASIA_QATAR = 244,
+ ASIA_QYZYLORDA = 245,
+ ASIA_RANGOON = 246,
+ ASIA_RIYADH = 247,
+ ASIA_SAIGON = 248,
+ ASIA_SAKHALIN = 249,
+ ASIA_SAMARKAND = 250,
+ ROK = 251,
+ // Delegate.
+
+ PRC = 252,
+ SINGAPORE = 253,
+ // Delegate.
+
+ ROC = 254,
+ // Delegate.
+
+ ASIA_TASHKENT = 255,
+ ASIA_TBILISI = 256,
+ IRAN = 257,
+ // Delegate.
+
+ ASIA_THIMBU = 258,
+ JAPAN = 259,
+ // Delegate.
+
+ ASIA_ULAN_BATOR = 260,
+ // Delegate.
+
+ ASIA_URUMQI = 261,
+ ASIA_VIENTIANE = 262,
+ ASIA_VLADIVOSTOK = 263,
+ ASIA_YAKUTSK = 264,
+ ASIA_YEKATERINBURG = 265,
+ ASIA_YEREVAN = 266,
+ ATLANTIC_AZORES = 267,
+ ATLANTIC_BERMUDA = 268,
+ ATLANTIC_CANARY = 269,
+ ATLANTIC_CAPE_VERDE = 270,
+ ATLANTIC_FAROE = 271,
+ // Delegate.
+
+ ATLANTIC_MADEIRA = 273,
+ ICELAND = 274,
+ // Delegate.
+
+ ATLANTIC_SOUTH_GEORGIA = 275,
+ ATLANTIC_STANLEY = 276,
+ ATLANTIC_ST_HELENA = 277,
+ AUSTRALIA_SOUTH = 278,
+ // Delegate.
+
+ AUSTRALIA_BRISBANE = 279,
+ // Delegate.
+
+ AUSTRALIA_YANCOWINNA = 280,
+ // Delegate.
+
+ AUSTRALIA_NORTH = 281,
+ // Delegate.
+
+ AUSTRALIA_HOBART = 282,
+ // Delegate.
+
+ AUSTRALIA_LINDEMAN = 283,
+ AUSTRALIA_LHI = 284,
+ AUSTRALIA_VICTORIA = 285,
+ // Delegate.
+
+ AUSTRALIA_WEST = 286,
+ // Delegate.
+
+ AUSTRALIA_ACT = 287,
+ EUROPE_AMSTERDAM = 288,
+ EUROPE_ANDORRA = 289,
+ EUROPE_ATHENS = 290,
+ EUROPE_BELGRADE = 292,
+ EUROPE_BERLIN = 293,
+ EUROPE_BRATISLAVA = 294,
+ EUROPE_BRUSSELS = 295,
+ EUROPE_BUCHAREST = 296,
+ EUROPE_BUDAPEST = 297,
+ EUROPE_CHISINAU = 298,
+ // Delegate.
+
+ EUROPE_COPENHAGEN = 299,
+ EIRE = 300,
+ EUROPE_GIBRALTAR = 301,
+ EUROPE_HELSINKI = 302,
+ TURKEY = 303,
+ EUROPE_KALININGRAD = 304,
+ EUROPE_KIEV = 305,
+ PORTUGAL = 306,
+ // Delegate.
+
+ EUROPE_LJUBLJANA = 307,
+ GB = 308,
+ EUROPE_LUXEMBOURG = 309,
+ EUROPE_MADRID = 310,
+ EUROPE_MALTA = 311,
+ EUROPE_MARIEHAMN = 312,
+ EUROPE_MINSK = 313,
+ EUROPE_MONACO = 314,
+ W_SU = 315,
+ // Delegate.
+
+ EUROPE_OSLO = 317,
+ EUROPE_PARIS = 318,
+ EUROPE_PRAGUE = 319,
+ EUROPE_RIGA = 320,
+ EUROPE_ROME = 321,
+ EUROPE_SAMARA = 322,
+ EUROPE_SAN_MARINO = 323,
+ EUROPE_SARAJEVO = 324,
+ EUROPE_SIMFEROPOL = 325,
+ EUROPE_SKOPJE = 326,
+ EUROPE_SOFIA = 327,
+ EUROPE_STOCKHOLM = 328,
+ EUROPE_TALLINN = 329,
+ EUROPE_TIRANE = 330,
+ EUROPE_UZHGOROD = 331,
+ EUROPE_VADUZ = 332,
+ EUROPE_VATICAN = 333,
+ EUROPE_VIENNA = 334,
+ EUROPE_VILNIUS = 335,
+ POLAND = 336,
+ // Delegate.
+
+ EUROPE_ZAGREB = 337,
+ EUROPE_ZAPOROZHYE = 338,
+ EUROPE_ZURICH = 339,
+ INDIAN_ANTANANARIVO = 340,
+ INDIAN_CHAGOS = 341,
+ INDIAN_CHRISTMAS = 342,
+ INDIAN_COCOS = 343,
+ INDIAN_COMORO = 344,
+ INDIAN_KERGUELEN = 345,
+ INDIAN_MAHE = 346,
+ INDIAN_MALDIVES = 347,
+ INDIAN_MAURITIUS = 348,
+ INDIAN_MAYOTTE = 349,
+ INDIAN_REUNION = 350,
+ PACIFIC_APIA = 351,
+ NZ = 352,
+ NZ_CHAT = 353,
+ PACIFIC_EASTER = 354,
+ PACIFIC_EFATE = 355,
+ PACIFIC_ENDERBURY = 356,
+ PACIFIC_FAKAOFO = 357,
+ PACIFIC_FIJI = 358,
+ PACIFIC_FUNAFUTI = 359,
+ PACIFIC_GALAPAGOS = 360,
+ PACIFIC_GAMBIER = 361,
+ PACIFIC_GUADALCANAL = 362,
+ PACIFIC_GUAM = 363,
+ US_HAWAII = 364,
+ // Delegate.
+
+ PACIFIC_JOHNSTON = 365,
+ PACIFIC_KIRITIMATI = 366,
+ PACIFIC_KOSRAE = 367,
+ KWAJALEIN = 368,
+ PACIFIC_MAJURO = 369,
+ PACIFIC_MARQUESAS = 370,
+ PACIFIC_MIDWAY = 371,
+ PACIFIC_NAURU = 372,
+ PACIFIC_NIUE = 373,
+ PACIFIC_NORFOLK = 374,
+ PACIFIC_NOUMEA = 375,
+ US_SAMOA = 376,
+ // Delegate.
+
+ PACIFIC_PALAU = 377,
+ PACIFIC_PITCAIRN = 378,
+ PACIFIC_PONAPE = 379,
+ PACIFIC_PORT_MORESBY = 380,
+ PACIFIC_RAROTONGA = 381,
+ PACIFIC_SAIPAN = 382,
+ PACIFIC_TAHITI = 383,
+ PACIFIC_TARAWA = 384,
+ PACIFIC_TONGATAPU = 385,
+ PACIFIC_YAP = 386,
+ PACIFIC_WAKE = 387,
+ PACIFIC_WALLIS = 388,
+ AMERICA_ATIKOKAN = 390,
+ AUSTRALIA_CURRIE = 391,
+ ETC_GMT_EAST_14 = 392,
+ ETC_GMT_EAST_13 = 393,
+ ETC_GMT_EAST_12 = 394,
+ ETC_GMT_EAST_11 = 395,
+ ETC_GMT_EAST_10 = 396,
+ ETC_GMT_EAST_9 = 397,
+ ETC_GMT_EAST_8 = 398,
+ ETC_GMT_EAST_7 = 399,
+ ETC_GMT_EAST_6 = 400,
+ ETC_GMT_EAST_5 = 401,
+ ETC_GMT_EAST_4 = 402,
+ ETC_GMT_EAST_3 = 403,
+ ETC_GMT_EAST_2 = 404,
+ ETC_GMT_EAST_1 = 405,
+ GMT = 406,
+ // Delegate.
+
+ ETC_GMT_WEST_1 = 407,
+ ETC_GMT_WEST_2 = 408,
+ ETC_GMT_WEST_3 = 409,
+ SYSTEMV_AST4 = 410,
+ // Delegate.
+
+ EST = 411,
+ SYSTEMV_CST6 = 412,
+ // Delegate.
+
+ MST = 413,
+ // Delegate.
+
+ SYSTEMV_PST8 = 414,
+ // Delegate.
+
+ SYSTEMV_YST9 = 415,
+ // Delegate.
+
+ HST = 416,
+ // Delegate.
+
+ ETC_GMT_WEST_11 = 417,
+ ETC_GMT_WEST_12 = 418,
+ AMERICA_NORTH_DAKOTA_NEW_SALEM = 419,
+ AMERICA_INDIANA_PETERSBURG = 420,
+ AMERICA_INDIANA_VINCENNES = 421,
+ AMERICA_MONCTON = 422,
+ AMERICA_BLANC_SABLON = 423,
+ EUROPE_GUERNSEY = 424,
+ EUROPE_ISLE_OF_MAN = 425,
+ EUROPE_JERSEY = 426,
+ EUROPE_PODGORICA = 427,
+ EUROPE_VOLGOGRAD = 428,
+ AMERICA_INDIANA_WINAMAC = 429,
+ AUSTRALIA_EUCLA = 430,
+ AMERICA_INDIANA_TELL_CITY = 431,
+ AMERICA_RESOLUTE = 432,
+ AMERICA_ARGENTINA_SAN_LUIS = 433,
+ AMERICA_SANTAREM = 434,
+ AMERICA_ARGENTINA_SALTA = 435,
+ AMERICA_BAHIA_BANDERAS = 436,
+ AMERICA_MARIGOT = 437,
+ AMERICA_MATAMOROS = 438,
+ AMERICA_OJINAGA = 439,
+ AMERICA_SANTA_ISABEL = 440,
+ AMERICA_ST_BARTHELEMY = 441,
+ ANTARCTICA_MACQUARIE = 442,
+ ASIA_NOVOKUZNETSK = 443,
+ AFRICA_JUBA = 444,
+ AMERICA_METLAKATLA = 445,
+ AMERICA_NORTH_DAKOTA_BEULAH = 446,
+ AMERICA_SITKA = 447,
+ ASIA_HEBRON = 448,
+ AMERICA_CRESTON = 449,
+ AMERICA_KRALENDIJK = 450,
+ AMERICA_LOWER_PRINCES = 451,
+ ANTARCTICA_TROLL = 452,
+ ASIA_KHANDYGA = 453,
+ ASIA_UST_NERA = 454,
+ EUROPE_BUSINGEN = 455,
+ ASIA_CHITA = 456,
+ ASIA_SREDNEKOLYMSK = 457,
+}
+
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.cc b/native/annotator/grammar/dates/utils/annotation-keys.cc
new file mode 100644
index 0000000..3438c6d
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/annotation-keys.cc
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/utils/annotation-keys.h"
+
+namespace libtextclassifier3 {
+namespace dates {
+const char* const kDateTimeType = "dateTime";
+const char* const kDateTimeRangeType = "dateTimeRange";
+const char* const kDateTime = "dateTime";
+const char* const kDateTimeSupplementary = "dateTimeSupplementary";
+const char* const kDateTimeRelative = "dateTimeRelative";
+const char* const kDateTimeRangeFrom = "dateTimeRangeFrom";
+const char* const kDateTimeRangeTo = "dateTimeRangeTo";
+} // namespace dates
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/annotation-keys.h b/native/annotator/grammar/dates/utils/annotation-keys.h
new file mode 100644
index 0000000..f970a51
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/annotation-keys.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
+
+namespace libtextclassifier3 {
+namespace dates {
+
+// Date time specific constants not defined in standard schemas.
+//
+// Date annotator output two type of annotation. One is date&time like "May 1",
+// "12:20pm", etc. Another is range like "2pm - 3pm". The two string identify
+// the type of annotation and are used as type in Thing proto.
+extern const char* const kDateTimeType;
+extern const char* const kDateTimeRangeType;
+
+// kDateTime contains most common field for date time. It's integer array and
+// the format is (year, month, day, hour, minute, second, fraction_sec,
+// day_of_week). All eight fields must be provided. If the field is not
+// extracted, the value is -1 in the array.
+extern const char* const kDateTime;
+
+// kDateTimeSupplementary contains uncommon field like timespan, timezone. It's
+// integer array and the format is (bc_ad, timespan_code, timezone_code,
+// timezone_offset). Al four fields must be provided. If the field is not
+// extracted, the value is -1 in the array.
+extern const char* const kDateTimeSupplementary;
+
+// kDateTimeRelative contains fields for relative date time. It's integer
+// array and the format is (is_future, year, month, day, week, hour, minute,
+// second, day_of_week, dow_interpretation*). The first nine fields must be
+// provided and dow_interpretation could have zero or multiple values.
+// If the field is not extracted, the value is -1 in the array.
+extern const char* const kDateTimeRelative;
+
+// Date time range specific constants not defined in standard schemas.
+// kDateTimeRangeFrom and kDateTimeRangeTo define the from/to of a date/time
+// range. The value is thing object which contains a date time.
+extern const char* const kDateTimeRangeFrom;
+extern const char* const kDateTimeRangeTo;
+
+} // namespace dates
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_
diff --git a/native/annotator/grammar/dates/utils/date-match.cc b/native/annotator/grammar/dates/utils/date-match.cc
new file mode 100644
index 0000000..d9fca52
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/date-match.cc
@@ -0,0 +1,440 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/utils/date-match.h"
+
+#include <algorithm>
+
+#include "annotator/grammar/dates/utils/date-utils.h"
+#include "annotator/types.h"
+#include "utils/strings/append.h"
+
+static const int kAM = 0;
+static const int kPM = 1;
+
+namespace libtextclassifier3 {
+namespace dates {
+
+namespace {
+static int GetMeridiemValue(const TimespanCode& timespan_code) {
+ switch (timespan_code) {
+ case TimespanCode_AM:
+ case TimespanCode_MIDNIGHT:
+ // MIDNIGHT [3] -> AM
+ return kAM;
+ case TimespanCode_TONIGHT:
+ // TONIGHT [11] -> PM
+ case TimespanCode_NOON:
+ // NOON [2] -> PM
+ case TimespanCode_PM:
+ return kPM;
+ case TimespanCode_TIMESPAN_CODE_NONE:
+ default:
+ TC3_LOG(WARNING) << "Failed to extract time span code.";
+ }
+ return NO_VAL;
+}
+
+static int GetRelativeCount(const RelativeParameter* relative_parameter) {
+ for (const int interpretation :
+ *relative_parameter->day_of_week_interpretation()) {
+ switch (interpretation) {
+ case RelativeParameter_::Interpretation_NEAREST_LAST:
+ case RelativeParameter_::Interpretation_PREVIOUS:
+ return -1;
+ case RelativeParameter_::Interpretation_SECOND_LAST:
+ return -2;
+ case RelativeParameter_::Interpretation_SECOND_NEXT:
+ return 2;
+ case RelativeParameter_::Interpretation_COMING:
+ case RelativeParameter_::Interpretation_SOME:
+ case RelativeParameter_::Interpretation_NEAREST:
+ case RelativeParameter_::Interpretation_NEAREST_NEXT:
+ return 1;
+ case RelativeParameter_::Interpretation_CURRENT:
+ return 0;
+ }
+ }
+ return 0;
+}
+} // namespace
+
+using strings::JoinStrings;
+using strings::SStringAppendF;
+
+std::string DateMatch::DebugString() const {
+ std::string res;
+#if !defined(NDEBUG)
+ if (begin >= 0 && end >= 0) {
+ SStringAppendF(&res, 0, "[%u,%u)", begin, end);
+ }
+
+ if (HasDayOfWeek()) {
+ SStringAppendF(&res, 0, "%u", day_of_week);
+ }
+
+ if (HasYear()) {
+ int year_output = year;
+ if (HasBcAd() && bc_ad == BCAD_BC) {
+ year_output = -year;
+ }
+ SStringAppendF(&res, 0, "%u/", year_output);
+ } else {
+ SStringAppendF(&res, 0, "____/");
+ }
+
+ if (HasMonth()) {
+ SStringAppendF(&res, 0, "%u/", month);
+ } else {
+ SStringAppendF(&res, 0, "__/");
+ }
+
+ if (HasDay()) {
+ SStringAppendF(&res, 0, "%u ", day);
+ } else {
+ SStringAppendF(&res, 0, "__ ");
+ }
+
+ if (HasHour()) {
+ SStringAppendF(&res, 0, "%u:", hour);
+ } else {
+ SStringAppendF(&res, 0, "__:");
+ }
+
+ if (HasMinute()) {
+ SStringAppendF(&res, 0, "%u:", minute);
+ } else {
+ SStringAppendF(&res, 0, "__:");
+ }
+
+ if (HasSecond()) {
+ if (HasFractionSecond()) {
+ SStringAppendF(&res, 0, "%u.%lf ", second, fraction_second);
+ } else {
+ SStringAppendF(&res, 0, "%u ", second);
+ }
+ } else {
+ SStringAppendF(&res, 0, "__ ");
+ }
+
+ if (HasTimeSpanCode() && TimespanCode_TIMESPAN_CODE_NONE < time_span_code &&
+ time_span_code <= TimespanCode_MAX) {
+ SStringAppendF(&res, 0, "TS=%u ", time_span_code);
+ }
+
+ if (HasTimeZoneCode() && time_zone_code != -1) {
+ SStringAppendF(&res, 0, "TZ= %u ", time_zone_code);
+ }
+
+ if (HasTimeZoneOffset()) {
+ SStringAppendF(&res, 0, "TZO=%u ", time_zone_offset);
+ }
+
+ if (HasRelativeDate()) {
+ const RelativeMatch* rm = relative_match;
+ SStringAppendF(&res, 0, (rm->is_future_date ? "future " : "past "));
+ if (rm->day_of_week != NO_VAL) {
+ SStringAppendF(&res, 0, "DOW:%d ", rm->day_of_week);
+ }
+ if (rm->year != NO_VAL) {
+ SStringAppendF(&res, 0, "Y:%d ", rm->year);
+ }
+ if (rm->month != NO_VAL) {
+ SStringAppendF(&res, 0, "M:%d ", rm->month);
+ }
+ if (rm->day != NO_VAL) {
+ SStringAppendF(&res, 0, "D:%d ", rm->day);
+ }
+ if (rm->week != NO_VAL) {
+ SStringAppendF(&res, 0, "W:%d ", rm->week);
+ }
+ if (rm->hour != NO_VAL) {
+ SStringAppendF(&res, 0, "H:%d ", rm->hour);
+ }
+ if (rm->minute != NO_VAL) {
+ SStringAppendF(&res, 0, "M:%d ", rm->minute);
+ }
+ if (rm->second != NO_VAL) {
+ SStringAppendF(&res, 0, "S:%d ", rm->second);
+ }
+ }
+
+ SStringAppendF(&res, 0, "prio=%d ", priority);
+ SStringAppendF(&res, 0, "conf-score=%lf ", annotator_priority_score);
+
+ if (IsHourAmbiguous()) {
+ std::vector<int8> values;
+ GetPossibleHourValues(&values);
+ std::string str_values;
+
+ for (unsigned int i = 0; i < values.size(); ++i) {
+ SStringAppendF(&str_values, 0, "%u,", values[i]);
+ }
+ SStringAppendF(&res, 0, "amb=%s ", str_values.c_str());
+ }
+
+ std::vector<std::string> tags;
+ if (is_inferred) {
+ tags.push_back("inferred");
+ }
+ if (!tags.empty()) {
+ SStringAppendF(&res, 0, "tag=%s ", JoinStrings(",", tags).c_str());
+ }
+#endif // !defined(NDEBUG)
+ return res;
+}
+
+void DateMatch::GetPossibleHourValues(std::vector<int8>* values) const {
+ TC3_CHECK(values != nullptr);
+ values->clear();
+ if (HasHour()) {
+ int8 possible_hour = hour;
+ values->push_back(possible_hour);
+ for (int count = 1; count < ambiguous_hour_count; ++count) {
+ possible_hour += ambiguous_hour_interval;
+ if (possible_hour >= 24) {
+ possible_hour -= 24;
+ }
+ values->push_back(possible_hour);
+ }
+ }
+}
+
+DatetimeComponent::RelativeQualifier DateMatch::GetRelativeQualifier() const {
+ if (HasRelativeDate()) {
+ if (relative_match->existing & RelativeMatch::HAS_IS_FUTURE) {
+ if (!relative_match->is_future_date) {
+ return DatetimeComponent::RelativeQualifier::PAST;
+ }
+ }
+ return DatetimeComponent::RelativeQualifier::FUTURE;
+ }
+ return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+}
+
+// Embed RelativeQualifier information of DatetimeComponent as a sign of
+// relative counter field of datetime component i.e. relative counter is
+// negative when relative qualifier RelativeQualifier::PAST.
+int GetAdjustedRelativeCounter(
+ const DatetimeComponent::RelativeQualifier& relative_qualifier,
+ const int relative_counter) {
+ if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) {
+ return -relative_counter;
+ }
+ return relative_counter;
+}
+
+Optional<DatetimeComponent> CreateDatetimeComponent(
+ const DatetimeComponent::ComponentType& component_type,
+ const DatetimeComponent::RelativeQualifier& relative_qualifier,
+ const int absolute_value, const int relative_value) {
+ if (absolute_value == NO_VAL && relative_value == NO_VAL) {
+ return Optional<DatetimeComponent>();
+ }
+ return Optional<DatetimeComponent>(DatetimeComponent(
+ component_type,
+ (relative_value != NO_VAL)
+ ? relative_qualifier
+ : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
+ (absolute_value != NO_VAL) ? absolute_value : 0,
+ (relative_value != NO_VAL)
+ ? GetAdjustedRelativeCounter(relative_qualifier, relative_value)
+ : 0));
+}
+
+Optional<DatetimeComponent> CreateDayOfWeekComponent(
+ const RelativeMatch* relative_match,
+ const DatetimeComponent::RelativeQualifier& relative_qualifier,
+ const DayOfWeek& absolute_day_of_week) {
+ DatetimeComponent::RelativeQualifier updated_relative_qualifier =
+ relative_qualifier;
+ int absolute_value = absolute_day_of_week;
+ int relative_value = NO_VAL;
+ if (relative_match) {
+ relative_value = relative_match->day_of_week;
+ if (relative_match->existing & RelativeMatch::HAS_DAY_OF_WEEK) {
+ if (relative_match->IsStandaloneRelativeDayOfWeek() &&
+ absolute_day_of_week == DayOfWeek_DOW_NONE) {
+ absolute_value = relative_match->day_of_week;
+ }
+ // Check if the relative date has day of week with week period.
+ if (relative_match->existing & RelativeMatch::HAS_WEEK) {
+ relative_value = 1;
+ } else {
+ const NonterminalValue* nonterminal =
+ relative_match->day_of_week_nonterminal;
+ TC3_CHECK(nonterminal != nullptr);
+ TC3_CHECK(nonterminal->relative_parameter());
+ const RelativeParameter* rp = nonterminal->relative_parameter();
+ if (rp->day_of_week_interpretation()) {
+ relative_value = GetRelativeCount(rp);
+ if (relative_value < 0) {
+ relative_value = abs(relative_value);
+ updated_relative_qualifier =
+ DatetimeComponent::RelativeQualifier::PAST;
+ } else if (relative_value > 0) {
+ updated_relative_qualifier =
+ DatetimeComponent::RelativeQualifier::FUTURE;
+ }
+ }
+ }
+ }
+ }
+ return CreateDatetimeComponent(DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ updated_relative_qualifier, absolute_value,
+ relative_value);
+}
+
+// Resolve the year’s ambiguity.
+// If the year in the date has 4 digits i.e. DD/MM/YYYY then there is no
+// ambiguity, the year value is YYYY but certain format i.e. MM/DD/YY is
+// ambiguous e.g. in {April/23/15} year value can be 15 or 1915 or 2015.
+// Following heuristic is used to resolve the ambiguity.
+// - For YYYY there is nothing to resolve.
+// - For all YY years
+// - Value less than 50 will be resolved to 20YY
+// - Value greater or equal 50 will be resolved to 19YY
+static int InterpretYear(int parsed_year) {
+ if (parsed_year == NO_VAL) {
+ return parsed_year;
+ }
+ if (parsed_year < 100) {
+ if (parsed_year < 50) {
+ return parsed_year + 2000;
+ }
+ return parsed_year + 1900;
+ }
+ return parsed_year;
+}
+
+Optional<DatetimeComponent> DateMatch::GetDatetimeComponent(
+ const DatetimeComponent::ComponentType& component_type) const {
+ switch (component_type) {
+ case DatetimeComponent::ComponentType::YEAR:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), InterpretYear(year),
+ (relative_match != nullptr) ? relative_match->year : NO_VAL);
+ case DatetimeComponent::ComponentType::MONTH:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), month,
+ (relative_match != nullptr) ? relative_match->month : NO_VAL);
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), day,
+ (relative_match != nullptr) ? relative_match->day : NO_VAL);
+ case DatetimeComponent::ComponentType::HOUR:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), hour,
+ (relative_match != nullptr) ? relative_match->hour : NO_VAL);
+ case DatetimeComponent::ComponentType::MINUTE:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), minute,
+ (relative_match != nullptr) ? relative_match->minute : NO_VAL);
+ case DatetimeComponent::ComponentType::SECOND:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), second,
+ (relative_match != nullptr) ? relative_match->second : NO_VAL);
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ return CreateDayOfWeekComponent(relative_match, GetRelativeQualifier(),
+ day_of_week);
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ return CreateDatetimeComponent(component_type, GetRelativeQualifier(),
+ GetMeridiemValue(time_span_code), NO_VAL);
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ if (HasTimeZoneOffset()) {
+ return Optional<DatetimeComponent>(DatetimeComponent(
+ component_type, DatetimeComponent::RelativeQualifier::UNSPECIFIED,
+ time_zone_offset, /*arg_relative_count=*/0));
+ }
+ return Optional<DatetimeComponent>();
+ case DatetimeComponent::ComponentType::WEEK:
+ return CreateDatetimeComponent(
+ component_type, GetRelativeQualifier(), NO_VAL,
+ HasRelativeDate() ? relative_match->week : NO_VAL);
+ default:
+ return Optional<DatetimeComponent>();
+ }
+}
+
+bool DateMatch::IsValid() const {
+ if (!HasYear() && HasBcAd()) {
+ return false;
+ }
+ if (!HasMonth() && HasYear() && (HasDay() || HasDayOfWeek())) {
+ return false;
+ }
+ if (!HasDay() && HasDayOfWeek() && (HasYear() || HasMonth())) {
+ return false;
+ }
+ if (!HasDay() && !HasDayOfWeek() && HasHour() && (HasYear() || HasMonth())) {
+ return false;
+ }
+ if (!HasHour() && (HasMinute() || HasSecond() || HasFractionSecond())) {
+ return false;
+ }
+ if (!HasMinute() && (HasSecond() || HasFractionSecond())) {
+ return false;
+ }
+ if (!HasSecond() && HasFractionSecond()) {
+ return false;
+ }
+ // Check whether day exists in a month, to exclude cases like "April 31".
+ if (HasDay() && HasMonth() && day > GetLastDayOfMonth(year, month)) {
+ return false;
+ }
+ return (HasDateFields() || HasTimeFields() || HasRelativeDate());
+}
+
+void DateMatch::FillDatetimeComponents(
+ std::vector<DatetimeComponent>* datetime_component) const {
+ static const std::vector<DatetimeComponent::ComponentType>*
+ kDatetimeComponents = new std::vector<DatetimeComponent::ComponentType>{
+ DatetimeComponent::ComponentType::ZONE_OFFSET,
+ DatetimeComponent::ComponentType::MERIDIEM,
+ DatetimeComponent::ComponentType::SECOND,
+ DatetimeComponent::ComponentType::MINUTE,
+ DatetimeComponent::ComponentType::HOUR,
+ DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ DatetimeComponent::ComponentType::DAY_OF_WEEK,
+ DatetimeComponent::ComponentType::WEEK,
+ DatetimeComponent::ComponentType::MONTH,
+ DatetimeComponent::ComponentType::YEAR};
+
+ for (const DatetimeComponent::ComponentType& component_type :
+ *kDatetimeComponents) {
+ Optional<DatetimeComponent> date_time =
+ GetDatetimeComponent(component_type);
+ if (date_time.has_value()) {
+ datetime_component->emplace_back(date_time.value());
+ }
+ }
+}
+
+std::string DateRangeMatch::DebugString() const {
+ std::string res;
+ // The method is only called for debugging purposes.
+#if !defined(NDEBUG)
+ if (begin >= 0 && end >= 0) {
+ SStringAppendF(&res, 0, "[%u,%u)\n", begin, end);
+ }
+ SStringAppendF(&res, 0, "from: %s \n", from.DebugString().c_str());
+ SStringAppendF(&res, 0, "to: %s\n", to.DebugString().c_str());
+#endif // !defined(NDEBUG)
+ return res;
+}
+
+} // namespace dates
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-match.h b/native/annotator/grammar/dates/utils/date-match.h
new file mode 100644
index 0000000..285e9b3
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/date-match.h
@@ -0,0 +1,537 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <algorithm>
+#include <vector>
+
+#include "annotator/grammar/dates/dates_generated.h"
+#include "annotator/grammar/dates/timezone-code_generated.h"
+#include "utils/grammar/match.h"
+
+namespace libtextclassifier3 {
+namespace dates {
+
+static constexpr int NO_VAL = -1;
+
+// POD match data structure.
+struct MatchBase : public grammar::Match {
+ void Reset() { type = MatchType::MatchType_UNKNOWN; }
+};
+
+struct ExtractionMatch : public MatchBase {
+ const ExtractionRuleParameter* extraction_rule;
+
+ void Reset() {
+ MatchBase::Reset();
+ type = MatchType::MatchType_DATETIME_RULE;
+ extraction_rule = nullptr;
+ }
+};
+
+struct TermValueMatch : public MatchBase {
+ const TermValue* term_value;
+
+ void Reset() {
+ MatchBase::Reset();
+ type = MatchType::MatchType_TERM_VALUE;
+ term_value = nullptr;
+ }
+};
+
+struct NonterminalMatch : public MatchBase {
+ const NonterminalValue* nonterminal;
+
+ void Reset() {
+ MatchBase::Reset();
+ type = MatchType::MatchType_NONTERMINAL;
+ nonterminal = nullptr;
+ }
+};
+
+struct IntegerMatch : public NonterminalMatch {
+ int value;
+ int8 count_of_digits; // When expression is in digits format.
+ bool is_zero_prefixed; // When expression is in digits format.
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ value = NO_VAL;
+ count_of_digits = 0;
+ is_zero_prefixed = false;
+ }
+};
+
+struct DigitsMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_DIGITS;
+ }
+
+ static bool IsValid(int x) { return true; }
+};
+
+struct YearMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_YEAR;
+ }
+
+ static bool IsValid(int x) { return x >= 1; }
+};
+
+struct MonthMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_MONTH;
+ }
+
+ static bool IsValid(int x) { return (x >= 1 && x <= 12); }
+};
+
+struct DayMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_DAY;
+ }
+
+ static bool IsValid(int x) { return (x >= 1 && x <= 31); }
+};
+
+struct HourMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_HOUR;
+ }
+
+ static bool IsValid(int x) { return (x >= 0 && x <= 24); }
+};
+
+struct MinuteMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_MINUTE;
+ }
+
+ static bool IsValid(int x) { return (x >= 0 && x <= 59); }
+};
+
+struct SecondMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_SECOND;
+ }
+
+ static bool IsValid(int x) { return (x >= 0 && x <= 60); }
+};
+
+struct DecimalMatch : public NonterminalMatch {
+ double value;
+ int8 count_of_digits; // When expression is in digits format.
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ value = NO_VAL;
+ count_of_digits = 0;
+ }
+};
+
+struct FractionSecondMatch : public DecimalMatch {
+ void Reset() {
+ DecimalMatch::Reset();
+ type = MatchType::MatchType_FRACTION_SECOND;
+ }
+
+ static bool IsValid(double x) { return (x >= 0.0 && x < 1.0); }
+};
+
+// CombinedIntegersMatch<N> is used for expressions containing multiple (up
+// to N) matches of integers without delimeters between them (because
+// CFG-grammar is based on tokenizer, it could not split a token into several
+// pieces like using regular-expression). For example, "1130" contains "11"
+// and "30" meaning November 30.
+template <int N>
+struct CombinedIntegersMatch : public NonterminalMatch {
+ enum {
+ SIZE = N,
+ };
+
+ int values[SIZE];
+ int8 count_of_digits; // When expression is in digits format.
+ bool is_zero_prefixed; // When expression is in digits format.
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ for (int i = 0; i < SIZE; ++i) {
+ values[i] = NO_VAL;
+ }
+ count_of_digits = 0;
+ is_zero_prefixed = false;
+ }
+};
+
+struct CombinedDigitsMatch : public CombinedIntegersMatch<6> {
+ enum Index {
+ INDEX_YEAR = 0,
+ INDEX_MONTH = 1,
+ INDEX_DAY = 2,
+ INDEX_HOUR = 3,
+ INDEX_MINUTE = 4,
+ INDEX_SECOND = 5,
+ };
+
+ bool HasYear() const { return values[INDEX_YEAR] != NO_VAL; }
+ bool HasMonth() const { return values[INDEX_MONTH] != NO_VAL; }
+ bool HasDay() const { return values[INDEX_DAY] != NO_VAL; }
+ bool HasHour() const { return values[INDEX_HOUR] != NO_VAL; }
+ bool HasMinute() const { return values[INDEX_MINUTE] != NO_VAL; }
+ bool HasSecond() const { return values[INDEX_SECOND] != NO_VAL; }
+
+ int GetYear() const { return values[INDEX_YEAR]; }
+ int GetMonth() const { return values[INDEX_MONTH]; }
+ int GetDay() const { return values[INDEX_DAY]; }
+ int GetHour() const { return values[INDEX_HOUR]; }
+ int GetMinute() const { return values[INDEX_MINUTE]; }
+ int GetSecond() const { return values[INDEX_SECOND]; }
+
+ void Reset() {
+ CombinedIntegersMatch<SIZE>::Reset();
+ type = MatchType::MatchType_COMBINED_DIGITS;
+ }
+
+ static bool IsValid(int i, int x) {
+ switch (i) {
+ case INDEX_YEAR:
+ return YearMatch::IsValid(x);
+ case INDEX_MONTH:
+ return MonthMatch::IsValid(x);
+ case INDEX_DAY:
+ return DayMatch::IsValid(x);
+ case INDEX_HOUR:
+ return HourMatch::IsValid(x);
+ case INDEX_MINUTE:
+ return MinuteMatch::IsValid(x);
+ case INDEX_SECOND:
+ return SecondMatch::IsValid(x);
+ default:
+ return false;
+ }
+ }
+};
+
+struct TimeValueMatch : public NonterminalMatch {
+ const HourMatch* hour_match;
+ const MinuteMatch* minute_match;
+ const SecondMatch* second_match;
+ const FractionSecondMatch* fraction_second_match;
+
+ bool is_hour_zero_prefixed : 1;
+ bool is_minute_one_digit : 1;
+ bool is_second_one_digit : 1;
+
+ int8 hour;
+ int8 minute;
+ int8 second;
+ double fraction_second;
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_TIME_VALUE;
+ hour_match = nullptr;
+ minute_match = nullptr;
+ second_match = nullptr;
+ fraction_second_match = nullptr;
+ is_hour_zero_prefixed = false;
+ is_minute_one_digit = false;
+ is_second_one_digit = false;
+ hour = NO_VAL;
+ minute = NO_VAL;
+ second = NO_VAL;
+ fraction_second = NO_VAL;
+ }
+};
+
+struct TimeSpanMatch : public NonterminalMatch {
+ const TimeSpanSpec* time_span_spec;
+ TimespanCode time_span_code;
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_TIME_SPAN;
+ time_span_spec = nullptr;
+ time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
+ }
+};
+
+struct TimeZoneNameMatch : public NonterminalMatch {
+ const TimeZoneNameSpec* time_zone_name_spec;
+ TimezoneCode time_zone_code;
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_TIME_ZONE_NAME;
+ time_zone_name_spec = nullptr;
+ time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
+ }
+};
+
+struct TimeZoneOffsetMatch : public NonterminalMatch {
+ const TimeZoneOffsetParameter* time_zone_offset_param;
+ int16 time_zone_offset;
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_TIME_ZONE_OFFSET;
+ time_zone_offset_param = nullptr;
+ time_zone_offset = 0;
+ }
+};
+
+struct DayOfWeekMatch : public IntegerMatch {
+ void Reset() {
+ IntegerMatch::Reset();
+ type = MatchType::MatchType_DAY_OF_WEEK;
+ }
+
+ static bool IsValid(int x) {
+ return (x > DayOfWeek_DOW_NONE && x <= DayOfWeek_MAX);
+ }
+};
+
+struct TimePeriodMatch : public NonterminalMatch {
+ int value;
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_TIME_PERIOD;
+ value = NO_VAL;
+ }
+};
+
+struct RelativeMatch : public NonterminalMatch {
+ enum {
+ HAS_NONE = 0,
+ HAS_YEAR = 1 << 0,
+ HAS_MONTH = 1 << 1,
+ HAS_DAY = 1 << 2,
+ HAS_WEEK = 1 << 3,
+ HAS_HOUR = 1 << 4,
+ HAS_MINUTE = 1 << 5,
+ HAS_SECOND = 1 << 6,
+ HAS_DAY_OF_WEEK = 1 << 7,
+ HAS_IS_FUTURE = 1 << 31,
+ };
+ uint32 existing;
+
+ int year;
+ int month;
+ int day;
+ int week;
+ int hour;
+ int minute;
+ int second;
+ const NonterminalValue* day_of_week_nonterminal;
+ int8 day_of_week;
+ bool is_future_date;
+
+ bool HasDay() const { return existing & HAS_DAY; }
+
+ bool HasDayFields() const { return existing & (HAS_DAY | HAS_DAY_OF_WEEK); }
+
+ bool HasTimeValueFields() const {
+ return existing & (HAS_HOUR | HAS_MINUTE | HAS_SECOND);
+ }
+
+ bool IsStandaloneRelativeDayOfWeek() const {
+ return (existing & HAS_DAY_OF_WEEK) && (existing & ~HAS_DAY_OF_WEEK) == 0;
+ }
+
+ void Reset() {
+ NonterminalMatch::Reset();
+ type = MatchType::MatchType_RELATIVE_DATE;
+ existing = HAS_NONE;
+ year = NO_VAL;
+ month = NO_VAL;
+ day = NO_VAL;
+ week = NO_VAL;
+ hour = NO_VAL;
+ minute = NO_VAL;
+ second = NO_VAL;
+ day_of_week = NO_VAL;
+ is_future_date = false;
+ }
+};
+
+// This is not necessarily POD, it is used to keep the final matched result.
+struct DateMatch {
+ // Sub-matches in the date match.
+ const YearMatch* year_match = nullptr;
+ const MonthMatch* month_match = nullptr;
+ const DayMatch* day_match = nullptr;
+ const DayOfWeekMatch* day_of_week_match = nullptr;
+ const TimeValueMatch* time_value_match = nullptr;
+ const TimeSpanMatch* time_span_match = nullptr;
+ const TimeZoneNameMatch* time_zone_name_match = nullptr;
+ const TimeZoneOffsetMatch* time_zone_offset_match = nullptr;
+ const RelativeMatch* relative_match = nullptr;
+ const CombinedDigitsMatch* combined_digits_match = nullptr;
+
+ // [begin, end) indicates the Document position where the date or date range
+ // was found.
+ int begin = -1;
+ int end = -1;
+ int priority = 0;
+ float annotator_priority_score = 0.0;
+
+ int year = NO_VAL;
+ int8 month = NO_VAL;
+ int8 day = NO_VAL;
+ DayOfWeek day_of_week = DayOfWeek_DOW_NONE;
+ BCAD bc_ad = BCAD_BCAD_NONE;
+ int8 hour = NO_VAL;
+ int8 minute = NO_VAL;
+ int8 second = NO_VAL;
+ double fraction_second = NO_VAL;
+ TimespanCode time_span_code = TimespanCode_TIMESPAN_CODE_NONE;
+ int time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE;
+ int16 time_zone_offset = std::numeric_limits<int16>::min();
+
+ // Fields about ambiguous hours. These fields are used to interpret the
+ // possible values of ambiguous hours. Since all kinds of known ambiguities
+ // are in the form of arithmetic progression (starting from .hour field),
+ // we can use "ambiguous_hour_count" to denote the count of ambiguous hours,
+ // and use "ambiguous_hour_interval" to denote the distance between a pair
+ // of adjacent possible hours. Values in the arithmetic progression are
+ // shrunk into [0, 23] (MOD 24). One can use the GetPossibleHourValues()
+ // method for the complete list of possible hours.
+ uint8 ambiguous_hour_count = 0;
+ uint8 ambiguous_hour_interval = 0;
+
+ bool is_inferred = false;
+
+ // This field is set in function PerformRefinements to remove some DateMatch
+ // like overlapped, duplicated, etc.
+ bool is_removed = false;
+
+ std::string DebugString() const;
+
+ bool HasYear() const { return year != NO_VAL; }
+ bool HasMonth() const { return month != NO_VAL; }
+ bool HasDay() const { return day != NO_VAL; }
+ bool HasDayOfWeek() const { return day_of_week != DayOfWeek_DOW_NONE; }
+ bool HasBcAd() const { return bc_ad != BCAD_BCAD_NONE; }
+ bool HasHour() const { return hour != NO_VAL; }
+ bool HasMinute() const { return minute != NO_VAL; }
+ bool HasSecond() const { return second != NO_VAL; }
+ bool HasFractionSecond() const { return fraction_second != NO_VAL; }
+ bool HasTimeSpanCode() const {
+ return time_span_code != TimespanCode_TIMESPAN_CODE_NONE;
+ }
+ bool HasTimeZoneCode() const {
+ return time_zone_code != TimezoneCode_TIMEZONE_CODE_NONE;
+ }
+ bool HasTimeZoneOffset() const {
+ return time_zone_offset != std::numeric_limits<int16>::min();
+ }
+
+ bool HasRelativeDate() const { return relative_match != nullptr; }
+
+ bool IsHourAmbiguous() const { return ambiguous_hour_count >= 2; }
+
+ bool IsStandaloneTime() const {
+ return (HasHour() || HasMinute()) && !HasDayOfWeek() && !HasDay() &&
+ !HasMonth() && !HasYear();
+ }
+
+ void SetAmbiguousHourProperties(uint8 count, uint8 interval) {
+ ambiguous_hour_count = count;
+ ambiguous_hour_interval = interval;
+ }
+
+ // Outputs all the possible hour values. If current DateMatch does not
+ // contain an hour, nothing will be output. If the hour is not ambiguous,
+ // only one value (= .hour) will be output. This method clears the vector
+ // "values" first, and it is not guaranteed that the values in the vector
+ // are in a sorted order.
+ void GetPossibleHourValues(std::vector<int8>* values) const;
+
+ int GetPriority() const { return priority; }
+
+ float GetAnnotatorPriorityScore() const { return annotator_priority_score; }
+
+ bool IsStandaloneRelativeDayOfWeek() const {
+ return (HasRelativeDate() &&
+ relative_match->IsStandaloneRelativeDayOfWeek() &&
+ !HasDateFields() && !HasTimeFields() && !HasTimeSpanCode());
+ }
+
+ bool HasDateFields() const {
+ return (HasYear() || HasMonth() || HasDay() || HasDayOfWeek() || HasBcAd());
+ }
+ bool HasTimeValueFields() const {
+ return (HasHour() || HasMinute() || HasSecond() || HasFractionSecond());
+ }
+ bool HasTimeSpanFields() const { return HasTimeSpanCode(); }
+ bool HasTimeZoneFields() const {
+ return (HasTimeZoneCode() || HasTimeZoneOffset());
+ }
+ bool HasTimeFields() const {
+ return (HasTimeValueFields() || HasTimeSpanFields() || HasTimeZoneFields());
+ }
+
+ bool IsValid() const;
+
+ // Overall relative qualifier of the DateMatch e.g. 2 year ago is 'PAST' and
+ // next week is 'FUTURE'.
+ DatetimeComponent::RelativeQualifier GetRelativeQualifier() const;
+
+ // Getter method to get the 'DatetimeComponent' of given 'ComponentType'.
+ Optional<DatetimeComponent> GetDatetimeComponent(
+ const DatetimeComponent::ComponentType& component_type) const;
+
+ void FillDatetimeComponents(
+ std::vector<DatetimeComponent>* datetime_component) const;
+};
+
+// Represent a matched date range which includes the from and to matched date.
+struct DateRangeMatch {
+ int begin = -1;
+ int end = -1;
+
+ DateMatch from;
+ DateMatch to;
+
+ std::string DebugString() const;
+
+ int GetPriority() const {
+ return std::max(from.GetPriority(), to.GetPriority());
+ }
+
+ float GetAnnotatorPriorityScore() const {
+ return std::max(from.GetAnnotatorPriorityScore(),
+ to.GetAnnotatorPriorityScore());
+ }
+};
+
+} // namespace dates
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_
diff --git a/native/annotator/grammar/dates/utils/date-match_test.cc b/native/annotator/grammar/dates/utils/date-match_test.cc
new file mode 100644
index 0000000..f10f32a
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/date-match_test.cc
@@ -0,0 +1,397 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/utils/date-match.h"
+
+#include <stdint.h>
+
+#include <string>
+
+#include "annotator/grammar/dates/dates_generated.h"
+#include "annotator/grammar/dates/timezone-code_generated.h"
+#include "annotator/grammar/dates/utils/date-utils.h"
+#include "utils/strings/append.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace dates {
+namespace {
+
+class DateMatchTest : public ::testing::Test {
+ protected:
+ enum {
+ X = NO_VAL,
+ };
+
+ static DayOfWeek DOW_X() { return DayOfWeek_DOW_NONE; }
+ static DayOfWeek SUN() { return DayOfWeek_SUNDAY; }
+
+ static BCAD BCAD_X() { return BCAD_BCAD_NONE; }
+ static BCAD BC() { return BCAD_BC; }
+
+ DateMatch& SetDate(DateMatch* date, int year, int8 month, int8 day,
+ DayOfWeek day_of_week = DOW_X(), BCAD bc_ad = BCAD_X()) {
+ date->year = year;
+ date->month = month;
+ date->day = day;
+ date->day_of_week = day_of_week;
+ date->bc_ad = bc_ad;
+ return *date;
+ }
+
+ DateMatch& SetTimeValue(DateMatch* date, int8 hour, int8 minute = X,
+ int8 second = X, double fraction_second = X) {
+ date->hour = hour;
+ date->minute = minute;
+ date->second = second;
+ date->fraction_second = fraction_second;
+ return *date;
+ }
+
+ DateMatch& SetTimeSpan(DateMatch* date, TimespanCode time_span_code) {
+ date->time_span_code = time_span_code;
+ return *date;
+ }
+
+ DateMatch& SetTimeZone(DateMatch* date, TimezoneCode time_zone_code,
+ int16 time_zone_offset = INT16_MIN) {
+ date->time_zone_code = time_zone_code;
+ date->time_zone_offset = time_zone_offset;
+ return *date;
+ }
+
+ bool SameDate(const DateMatch& a, const DateMatch& b) {
+ return (a.day == b.day && a.month == b.month && a.year == b.year &&
+ a.day_of_week == b.day_of_week);
+ }
+
+ DateMatch& SetDayOfWeek(DateMatch* date, DayOfWeek dow) {
+ date->day_of_week = dow;
+ return *date;
+ }
+};
+
+TEST_F(DateMatchTest, BitFieldWidth) {
+ // For DateMatch::day_of_week (:8).
+ EXPECT_GE(DayOfWeek_MIN, INT8_MIN);
+ EXPECT_LE(DayOfWeek_MAX, INT8_MAX);
+
+ // For DateMatch::bc_ad (:8).
+ EXPECT_GE(BCAD_MIN, INT8_MIN);
+ EXPECT_LE(BCAD_MAX, INT8_MAX);
+
+ // For DateMatch::time_span_code (:16).
+ EXPECT_GE(TimespanCode_MIN, INT16_MIN);
+ EXPECT_LE(TimespanCode_MAX, INT16_MAX);
+}
+
+TEST_F(DateMatchTest, IsValid) {
+ // Valid: dates.
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, 26);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, X);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, X, X);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, 26);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, X);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, X, 26);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, 26, SUN());
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, 26, SUN());
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, X, 26, SUN());
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, 26, DOW_X(), BC());
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ // Valid: times.
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, 30, 59, 0.99);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, 30, 59);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, 30);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ // Valid: mixed.
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, 26);
+ SetTimeValue(&d, 12, 30, 59, 0.99);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, 26);
+ SetTimeValue(&d, 12, 30, 59);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, X, X, SUN());
+ SetTimeValue(&d, 12, 30);
+ EXPECT_TRUE(d.IsValid()) << d.DebugString();
+ }
+ // Invalid: dates.
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, 26, DOW_X(), BC());
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, X, 26);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, 2014, X, X, SUN());
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetDate(&d, X, 1, X, SUN());
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ // Invalid: times.
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, X, 59);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, X, X, 0.99);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, 12, 30, X, 0.99);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ {
+ DateMatch d;
+ SetTimeValue(&d, X, 30);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ // Invalid: mixed.
+ {
+ DateMatch d;
+ SetDate(&d, 2014, 1, X);
+ SetTimeValue(&d, 12);
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+ // Invalid: empty.
+ {
+ DateMatch d;
+ EXPECT_FALSE(d.IsValid()) << d.DebugString();
+ }
+}
+
+std::string DebugStrings(const std::vector<DateMatch>& instances) {
+ std::string res;
+ for (int i = 0; i < instances.size(); ++i) {
+ ::libtextclassifier3::strings::SStringAppendF(
+ &res, 0, "[%d] == %s\n", i, instances[i].DebugString().c_str());
+ }
+ return res;
+}
+
+TEST_F(DateMatchTest, IsRefinement) {
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, X);
+ DateMatch b;
+ SetDate(&b, 2014, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ DateMatch b;
+ SetDate(&b, 2014, 2, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ DateMatch b;
+ SetDate(&b, X, 2, 24);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, 0, X);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, 0, 0);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, 0, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ SetTimeSpan(&a, TimespanCode_AM);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ SetTimeZone(&a, TimezoneCode_PST8PDT);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ a.priority += 10;
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ SetTimeValue(&b, 9, X, X);
+ EXPECT_TRUE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ DateMatch b;
+ SetDate(&b, X, 2, 24);
+ SetTimeValue(&b, 9, 0, X);
+ EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetDate(&a, X, 2, 24);
+ SetTimeValue(&a, 9, X, X);
+ DateMatch b;
+ SetDate(&b, 2014, 2, 24);
+ EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+ {
+ DateMatch a;
+ SetTimeValue(&a, 9, 0, 0);
+ DateMatch b;
+ SetTimeValue(&b, 9, X, X);
+ SetTimeSpan(&b, TimespanCode_AM);
+ EXPECT_FALSE(IsRefinement(a, b)) << DebugStrings({a, b});
+ }
+}
+
+TEST_F(DateMatchTest, FillDateInstance_AnnotatorPriorityScore) {
+ DateMatch date_match;
+ SetDate(&date_match, 2014, 2, X);
+ date_match.annotator_priority_score = 0.5;
+ DatetimeParseResultSpan datetime_parse_result_span;
+ FillDateInstance(date_match, &datetime_parse_result_span);
+ EXPECT_FLOAT_EQ(datetime_parse_result_span.priority_score, 0.5)
+ << DebugStrings({date_match});
+}
+
+TEST_F(DateMatchTest, MergeDateMatch_AnnotatorPriorityScore) {
+ DateMatch a;
+ SetDate(&a, 2014, 2, 4);
+ a.annotator_priority_score = 0.5;
+
+ DateMatch b;
+ SetTimeValue(&b, 10, 45, 23);
+ b.annotator_priority_score = 1.0;
+
+ MergeDateMatch(b, &a, false);
+ EXPECT_FLOAT_EQ(a.annotator_priority_score, 1.0);
+}
+
+} // namespace
+} // namespace dates
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.cc b/native/annotator/grammar/dates/utils/date-utils.cc
new file mode 100644
index 0000000..ea8015d
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/date-utils.cc
@@ -0,0 +1,399 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/dates/utils/date-utils.h"
+
+#include <algorithm>
+#include <ctime>
+
+#include "annotator/grammar/dates/annotations/annotation-util.h"
+#include "annotator/grammar/dates/dates_generated.h"
+#include "annotator/grammar/dates/utils/annotation-keys.h"
+#include "annotator/grammar/dates/utils/date-match.h"
+#include "annotator/types.h"
+#include "utils/base/macros.h"
+
+namespace libtextclassifier3 {
+namespace dates {
+
+bool IsLeapYear(int year) {
+ // For the sake of completeness, we want to be able to decide
+ // whether a year is a leap year all the way back to 0 Julian, or
+ // 4714 BCE. But we don't want to take the modulus of a negative
+ // number, because this may not be very well-defined or portable. So
+ // we increment the year by some large multiple of 400, which is the
+ // periodicity of this leap-year calculation.
+ if (year < 0) {
+ year += 8000;
+ }
+ return ((year) % 4 == 0 && ((year) % 100 != 0 || (year) % 400 == 0));
+}
+
+namespace {
+#define SECSPERMIN (60)
+#define MINSPERHOUR (60)
+#define HOURSPERDAY (24)
+#define DAYSPERWEEK (7)
+#define DAYSPERNYEAR (365)
+#define DAYSPERLYEAR (366)
+#define MONSPERYEAR (12)
+
+const int8 kDaysPerMonth[2][1 + MONSPERYEAR] = {
+ {-1, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
+ {-1, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
+};
+} // namespace
+
+int8 GetLastDayOfMonth(int year, int month) {
+ if (year == 0) { // No year specified
+ return kDaysPerMonth[1][month];
+ }
+ return kDaysPerMonth[IsLeapYear(year)][month];
+}
+
+namespace {
+inline bool IsHourInSegment(const TimeSpanSpec_::Segment* segment, int8 hour,
+ bool is_exact) {
+ return (hour >= segment->begin() &&
+ (hour < segment->end() ||
+ (hour == segment->end() && is_exact && segment->is_closed())));
+}
+
+Property* FindOrCreateDefaultDateTime(AnnotationData* inst) {
+ // Refer comments for kDateTime in annotation-keys.h to see the format.
+ static constexpr int kDefault[] = {-1, -1, -1, -1, -1, -1, -1, -1};
+
+ int idx = GetPropertyIndex(kDateTime, *inst);
+ if (idx < 0) {
+ idx = AddRepeatedIntProperty(kDateTime, kDefault, TC3_ARRAYSIZE(kDefault),
+ inst);
+ }
+ return &inst->properties[idx];
+}
+
+void IncrementDayOfWeek(DayOfWeek* dow) {
+ static const DayOfWeek dow_ring[] = {DayOfWeek_MONDAY, DayOfWeek_TUESDAY,
+ DayOfWeek_WEDNESDAY, DayOfWeek_THURSDAY,
+ DayOfWeek_FRIDAY, DayOfWeek_SATURDAY,
+ DayOfWeek_SUNDAY, DayOfWeek_MONDAY};
+ const auto& cur_dow =
+ std::find(std::begin(dow_ring), std::end(dow_ring), *dow);
+ if (cur_dow != std::end(dow_ring)) {
+ *dow = *std::next(cur_dow);
+ }
+}
+} // namespace
+
+bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date) {
+ if (ts_spec->segment() == nullptr) {
+ return false;
+ }
+ if (date->HasHour()) {
+ const bool is_exact =
+ (!date->HasMinute() ||
+ (date->minute == 0 &&
+ (!date->HasSecond() ||
+ (date->second == 0 &&
+ (!date->HasFractionSecond() || date->fraction_second == 0.0)))));
+ for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
+ if (IsHourInSegment(segment, date->hour + segment->offset(), is_exact)) {
+ date->hour += segment->offset();
+ return true;
+ }
+ if (!segment->is_strict() &&
+ IsHourInSegment(segment, date->hour, is_exact)) {
+ return true;
+ }
+ }
+ } else {
+ for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) {
+ if (segment->is_stand_alone()) {
+ if (segment->begin() == segment->end()) {
+ date->hour = segment->begin();
+ }
+ // Allow stand-alone time-span points and ranges.
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool IsRefinement(const DateMatch& a, const DateMatch& b) {
+ int count = 0;
+ if (b.HasBcAd()) {
+ if (!a.HasBcAd() || a.bc_ad != b.bc_ad) return false;
+ } else if (a.HasBcAd()) {
+ if (a.bc_ad == BCAD_BC) return false;
+ ++count;
+ }
+ if (b.HasYear()) {
+ if (!a.HasYear() || a.year != b.year) return false;
+ } else if (a.HasYear()) {
+ ++count;
+ }
+ if (b.HasMonth()) {
+ if (!a.HasMonth() || a.month != b.month) return false;
+ } else if (a.HasMonth()) {
+ ++count;
+ }
+ if (b.HasDay()) {
+ if (!a.HasDay() || a.day != b.day) return false;
+ } else if (a.HasDay()) {
+ ++count;
+ }
+ if (b.HasDayOfWeek()) {
+ if (!a.HasDayOfWeek() || a.day_of_week != b.day_of_week) return false;
+ } else if (a.HasDayOfWeek()) {
+ ++count;
+ }
+ if (b.HasHour()) {
+ if (!a.HasHour()) return false;
+ std::vector<int8> possible_hours;
+ b.GetPossibleHourValues(&possible_hours);
+ if (std::find(possible_hours.begin(), possible_hours.end(), a.hour) ==
+ possible_hours.end()) {
+ return false;
+ }
+ } else if (a.HasHour()) {
+ ++count;
+ }
+ if (b.HasMinute()) {
+ if (!a.HasMinute() || a.minute != b.minute) return false;
+ } else if (a.HasMinute()) {
+ ++count;
+ }
+ if (b.HasSecond()) {
+ if (!a.HasSecond() || a.second != b.second) return false;
+ } else if (a.HasSecond()) {
+ ++count;
+ }
+ if (b.HasFractionSecond()) {
+ if (!a.HasFractionSecond() || a.fraction_second != b.fraction_second)
+ return false;
+ } else if (a.HasFractionSecond()) {
+ ++count;
+ }
+ if (b.HasTimeSpanCode()) {
+ if (!a.HasTimeSpanCode() || a.time_span_code != b.time_span_code)
+ return false;
+ } else if (a.HasTimeSpanCode()) {
+ ++count;
+ }
+ if (b.HasTimeZoneCode()) {
+ if (!a.HasTimeZoneCode() || a.time_zone_code != b.time_zone_code)
+ return false;
+ } else if (a.HasTimeZoneCode()) {
+ ++count;
+ }
+ if (b.HasTimeZoneOffset()) {
+ if (!a.HasTimeZoneOffset() || a.time_zone_offset != b.time_zone_offset)
+ return false;
+ } else if (a.HasTimeZoneOffset()) {
+ ++count;
+ }
+ return (count > 0 || a.priority >= b.priority);
+}
+
+bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b) {
+ return false;
+}
+
+bool IsPrecedent(const DateMatch& a, const DateMatch& b) {
+ if (a.HasYear() && b.HasYear()) {
+ if (a.year < b.year) return true;
+ if (a.year > b.year) return false;
+ }
+
+ if (a.HasMonth() && b.HasMonth()) {
+ if (a.month < b.month) return true;
+ if (a.month > b.month) return false;
+ }
+
+ if (a.HasDay() && b.HasDay()) {
+ if (a.day < b.day) return true;
+ if (a.day > b.day) return false;
+ }
+
+ if (a.HasHour() && b.HasHour()) {
+ if (a.hour < b.hour) return true;
+ if (a.hour > b.hour) return false;
+ }
+
+ if (a.HasMinute() && b.HasHour()) {
+ if (a.minute < b.hour) return true;
+ if (a.minute > b.hour) return false;
+ }
+
+ if (a.HasSecond() && b.HasSecond()) {
+ if (a.second < b.hour) return true;
+ if (a.second > b.hour) return false;
+ }
+
+ return false;
+}
+
+void FillDateInstance(const DateMatch& date,
+ DatetimeParseResultSpan* instance) {
+ instance->span.first = date.begin;
+ instance->span.second = date.end;
+ instance->priority_score = date.GetAnnotatorPriorityScore();
+ DatetimeParseResult datetime_parse_result;
+ date.FillDatetimeComponents(&datetime_parse_result.datetime_components);
+ instance->data.emplace_back(datetime_parse_result);
+}
+
+void FillDateRangeInstance(const DateRangeMatch& range,
+ DatetimeParseResultSpan* instance) {
+ instance->span.first = range.begin;
+ instance->span.second = range.end;
+ instance->priority_score = range.GetAnnotatorPriorityScore();
+
+ // Filling from DatetimeParseResult.
+ instance->data.emplace_back();
+ range.from.FillDatetimeComponents(&instance->data.back().datetime_components);
+
+ // Filling to DatetimeParseResult.
+ instance->data.emplace_back();
+ range.to.FillDatetimeComponents(&instance->data.back().datetime_components);
+}
+
+namespace {
+bool AnyOverlappedField(const DateMatch& prev, const DateMatch& next) {
+#define Field(f) \
+ if (prev.f && next.f) return true
+ Field(year_match);
+ Field(month_match);
+ Field(day_match);
+ Field(day_of_week_match);
+ Field(time_value_match);
+ Field(time_span_match);
+ Field(time_zone_name_match);
+ Field(time_zone_offset_match);
+ Field(relative_match);
+ Field(combined_digits_match);
+#undef Field
+ return false;
+}
+
+void MergeDateMatchImpl(const DateMatch& prev, DateMatch* next,
+ bool update_span) {
+#define RM(f) \
+ if (!next->f) next->f = prev.f
+ RM(year_match);
+ RM(month_match);
+ RM(day_match);
+ RM(day_of_week_match);
+ RM(time_value_match);
+ RM(time_span_match);
+ RM(time_zone_name_match);
+ RM(time_zone_offset_match);
+ RM(relative_match);
+ RM(combined_digits_match);
+#undef RM
+
+#define RV(f) \
+ if (next->f == NO_VAL) next->f = prev.f
+ RV(year);
+ RV(month);
+ RV(day);
+ RV(hour);
+ RV(minute);
+ RV(second);
+ RV(fraction_second);
+#undef RV
+
+#define RE(f, v) \
+ if (next->f == v) next->f = prev.f
+ RE(day_of_week, DayOfWeek_DOW_NONE);
+ RE(bc_ad, BCAD_BCAD_NONE);
+ RE(time_span_code, TimespanCode_TIMESPAN_CODE_NONE);
+ RE(time_zone_code, TimezoneCode_TIMEZONE_CODE_NONE);
+#undef RE
+
+ if (next->time_zone_offset == std::numeric_limits<int16>::min()) {
+ next->time_zone_offset = prev.time_zone_offset;
+ }
+
+ next->priority = std::max(next->priority, prev.priority);
+ next->annotator_priority_score =
+ std::max(next->annotator_priority_score, prev.annotator_priority_score);
+ if (update_span) {
+ next->begin = std::min(next->begin, prev.begin);
+ next->end = std::max(next->end, prev.end);
+ }
+}
+} // namespace
+
+bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next) {
+ // Do not merge if they share the same field.
+ if (AnyOverlappedField(prev, next)) {
+ return false;
+ }
+
+ // It's impossible that both prev and next have relative date since it's
+ // excluded by overlapping check before.
+ if (prev.HasRelativeDate() || next.HasRelativeDate()) {
+ // If one of them is relative date, then we merge:
+ // - if relative match shouldn't have time, and always has DOW or day.
+ // - if not both relative match and non relative match has day.
+ // - if non relative match has time or day.
+ const DateMatch* rm = &prev;
+ const DateMatch* non_rm = &prev;
+ if (prev.HasRelativeDate()) {
+ non_rm = &next;
+ } else {
+ rm = &next;
+ }
+
+ const RelativeMatch* relative_match = rm->relative_match;
+ // Relative Match should have day or DOW but no time.
+ if (!relative_match->HasDayFields() ||
+ relative_match->HasTimeValueFields()) {
+ return false;
+ }
+ // Check if both relative match and non relative match has day.
+ if (non_rm->HasDateFields() && relative_match->HasDay()) {
+ return false;
+ }
+ // Non relative match should have either hour (time) or day (date).
+ if (!non_rm->HasHour() && !non_rm->HasDay()) {
+ return false;
+ }
+ } else {
+ // Only one match has date and another has time.
+ if ((prev.HasDateFields() && next.HasDateFields()) ||
+ (prev.HasTimeFields() && next.HasTimeFields())) {
+ return false;
+ }
+ // DOW never be extracted as a single DateMatch except in RelativeMatch. So
+ // here, we always merge one with day and another one with hour.
+ if (!(prev.HasDay() || next.HasDay()) ||
+ !(prev.HasHour() || next.HasHour())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span) {
+ if (IsDateMatchMergeable(prev, *next)) {
+ MergeDateMatchImpl(prev, next, update_span);
+ }
+}
+
+} // namespace dates
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/dates/utils/date-utils.h b/native/annotator/grammar/dates/utils/date-utils.h
new file mode 100644
index 0000000..2fcda92
--- /dev/null
+++ b/native/annotator/grammar/dates/utils/date-utils.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <ctime>
+#include <vector>
+
+#include "annotator/grammar/dates/annotations/annotation.h"
+#include "annotator/grammar/dates/utils/date-match.h"
+#include "utils/base/casts.h"
+
+namespace libtextclassifier3 {
+namespace dates {
+
+bool IsLeapYear(int year);
+
+int8 GetLastDayOfMonth(int year, int month);
+
+// Normalizes hour value of the specified date using the specified time-span
+// specification. Returns true if the original hour value (can be no-value)
+// is compatible with the time-span and gets normalized successfully, or
+// false otherwise.
+bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date);
+
+// Returns true iff "a" is considered as a refinement of "b". For example,
+// besides fully compatible fields, having more fields or higher priority.
+bool IsRefinement(const DateMatch& a, const DateMatch& b);
+bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b);
+
+// Returns true iff "a" occurs strictly before "b"
+bool IsPrecedent(const DateMatch& a, const DateMatch& b);
+
+// Fill DatetimeParseResult based on DateMatch object which is created from
+// matched rule. The matched string is extracted from tokenizer which provides
+// an interface to access the clean text based on the matched range.
+void FillDateInstance(const DateMatch& date, DatetimeParseResult* instance);
+
+// Fill DatetimeParseResultSpan based on DateMatch object which is created from
+// matched rule. The matched string is extracted from tokenizer which provides
+// an interface to access the clean text based on the matched range.
+void FillDateInstance(const DateMatch& date, DatetimeParseResultSpan* instance);
+
+// Fill DatetimeParseResultSpan based on DateRangeMatch object which i screated
+// from matched rule.
+void FillDateRangeInstance(const DateRangeMatch& range,
+ DatetimeParseResultSpan* instance);
+
+// Merge the fields in DateMatch prev to next if there is no overlapped field.
+// If update_span is true, the span of next is also updated.
+// e.g.: prev is 11am, next is: May 1, then the merged next is May 1, 11am
+void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span);
+
+// If DateMatches have no overlapped field, then they could be merged as the
+// following rules:
+// -- If both don't have relative match and one DateMatch has day but another
+// DateMatch has hour.
+// -- If one have relative match then follow the rules in code.
+// It's impossible to get DateMatch which only has DOW and not in relative
+// match according to current rules.
+bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next);
+} // namespace dates
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_
diff --git a/native/annotator/grammar/grammar-annotator.cc b/native/annotator/grammar/grammar-annotator.cc
new file mode 100644
index 0000000..baa3fac
--- /dev/null
+++ b/native/annotator/grammar/grammar-annotator.cc
@@ -0,0 +1,479 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/grammar-annotator.h"
+
+#include "annotator/feature-processor.h"
+#include "annotator/grammar/utils.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/grammar/callback-delegate.h"
+#include "utils/grammar/match.h"
+#include "utils/grammar/matcher.h"
+#include "utils/grammar/rules-utils.h"
+#include "utils/grammar/types.h"
+#include "utils/normalization.h"
+#include "utils/optional.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Returns the unicode codepoint offsets in a utf8 encoded text.
+std::vector<UnicodeText::const_iterator> UnicodeCodepointOffsets(
+ const UnicodeText& text) {
+ std::vector<UnicodeText::const_iterator> offsets;
+ for (auto it = text.begin(); it != text.end(); it++) {
+ offsets.push_back(it);
+ }
+ offsets.push_back(text.end());
+ return offsets;
+}
+
+} // namespace
+
+class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate {
+ public:
+ explicit GrammarAnnotatorCallbackDelegate(
+ const UniLib* unilib, const GrammarModel* model,
+ const ReflectiveFlatbufferBuilder* entity_data_builder,
+ const ModeFlag mode)
+ : unilib_(*unilib),
+ model_(model),
+ entity_data_builder_(entity_data_builder),
+ mode_(mode) {}
+
+ // Handles a grammar rule match in the annotator grammar.
+ void MatchFound(const grammar::Match* match, grammar::CallbackId type,
+ int64 value, grammar::Matcher* matcher) override {
+ switch (static_cast<GrammarAnnotator::Callback>(type)) {
+ case GrammarAnnotator::Callback::kRuleMatch: {
+ HandleRuleMatch(match, /*rule_id=*/value);
+ return;
+ }
+ default:
+ grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
+ }
+ }
+
+ // Deduplicate and populate annotations from grammar matches.
+ bool GetAnnotations(const std::vector<UnicodeText::const_iterator>& text,
+ std::vector<AnnotatedSpan>* annotations) const {
+ for (const grammar::Derivation& candidate :
+ grammar::DeduplicateDerivations(candidates_)) {
+ // Check that assertions are fulfilled.
+ if (!grammar::VerifyAssertions(candidate.match)) {
+ continue;
+ }
+ if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool GetTextSelection(const std::vector<UnicodeText::const_iterator>& text,
+ const CodepointSpan& selection, AnnotatedSpan* result) {
+ std::vector<grammar::Derivation> selection_candidates;
+ // Deduplicate and verify matches.
+ auto maybe_interpretation = GetBestValidInterpretation(
+ grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
+ selection, candidates_, /*only_exact_overlap=*/false)));
+ if (!maybe_interpretation.has_value()) {
+ return false;
+ }
+ const GrammarModel_::RuleClassificationResult* interpretation;
+ const grammar::Match* match;
+ std::tie(interpretation, match) = maybe_interpretation.value();
+ return InstantiateAnnotatedSpanFromInterpretation(text, interpretation,
+ match, result);
+ }
+
+ // Provides a classification results from the grammar matches.
+ bool GetClassification(const std::vector<UnicodeText::const_iterator>& text,
+ const CodepointSpan& selection,
+ ClassificationResult* classification) const {
+ // Deduplicate and verify matches.
+ auto maybe_interpretation = GetBestValidInterpretation(
+ grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
+ selection, candidates_, /*only_exact_overlap=*/true)));
+ if (!maybe_interpretation.has_value()) {
+ return false;
+ }
+
+ // Instantiate result.
+ const GrammarModel_::RuleClassificationResult* interpretation;
+ const grammar::Match* match;
+ std::tie(interpretation, match) = maybe_interpretation.value();
+ return InstantiateClassificationInterpretation(text, interpretation, match,
+ classification);
+ }
+
+ private:
+ // Handles annotation/selection/classification rule matches.
+ void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
+ if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() &
+ mode_) != 0) {
+ candidates_.push_back(grammar::Derivation{match, rule_id});
+ }
+ }
+
+ // Computes the selection boundaries from a grammar match.
+ CodepointSpan MatchSelectionBoundaries(
+ const grammar::Match* match,
+ const GrammarModel_::RuleClassificationResult* classification) const {
+ if (classification->capturing_group() == nullptr) {
+ // Use full match as selection span.
+ return match->codepoint_span;
+ }
+
+ // Set information from capturing matches.
+ CodepointSpan span{kInvalidIndex, kInvalidIndex};
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::Match*> capturing_matches;
+ for (const grammar::MappingMatch* match :
+ grammar::SelectAllOfType<grammar::MappingMatch>(
+ match, grammar::Match::kMappingMatch)) {
+ capturing_matches[match->id] = match;
+ }
+
+ // Compute span boundaries.
+ for (int i = 0; i < classification->capturing_group()->size(); i++) {
+ auto it = capturing_matches.find(i);
+ if (it == capturing_matches.end()) {
+ // Capturing group is not active, skip.
+ continue;
+ }
+ const CapturingGroup* group = classification->capturing_group()->Get(i);
+ if (group->extend_selection()) {
+ if (span.first == kInvalidIndex) {
+ span = it->second->codepoint_span;
+ } else {
+ span.first = std::min(span.first, it->second->codepoint_span.first);
+ span.second =
+ std::max(span.second, it->second->codepoint_span.second);
+ }
+ }
+ }
+ return span;
+ }
+
+ // Filters out results that do not overlap with a reference span.
+ std::vector<grammar::Derivation> GetOverlappingRuleMatches(
+ const CodepointSpan& selection,
+ const std::vector<grammar::Derivation>& candidates,
+ const bool only_exact_overlap) const {
+ std::vector<grammar::Derivation> result;
+ for (const grammar::Derivation& candidate : candidates) {
+ // Discard matches that do not match the selection.
+ // Simple check.
+ if (!SpansOverlap(selection, candidate.match->codepoint_span)) {
+ continue;
+ }
+
+ // Compute exact selection boundaries (without assertions and
+ // non-capturing parts).
+ const CodepointSpan span = MatchSelectionBoundaries(
+ candidate.match,
+ model_->rule_classification_result()->Get(candidate.rule_id));
+ if (!SpansOverlap(selection, span) ||
+ (only_exact_overlap && span != selection)) {
+ continue;
+ }
+ result.push_back(candidate);
+ }
+ return result;
+ }
+
+ // Returns the best valid interpretation of a set of candidate matches.
+ Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
+ const grammar::Match*>>
+ GetBestValidInterpretation(
+ const std::vector<grammar::Derivation>& candidates) const {
+ const GrammarModel_::RuleClassificationResult* best_interpretation =
+ nullptr;
+ const grammar::Match* best_match = nullptr;
+ for (const grammar::Derivation& candidate : candidates) {
+ if (!grammar::VerifyAssertions(candidate.match)) {
+ continue;
+ }
+ const GrammarModel_::RuleClassificationResult*
+ rule_classification_result =
+ model_->rule_classification_result()->Get(candidate.rule_id);
+ if (best_interpretation == nullptr ||
+ best_interpretation->priority_score() <
+ rule_classification_result->priority_score()) {
+ best_interpretation = rule_classification_result;
+ best_match = candidate.match;
+ }
+ }
+
+ // No valid interpretation found.
+ Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
+ const grammar::Match*>>
+ result;
+ if (best_interpretation != nullptr) {
+ result = {best_interpretation, best_match};
+ }
+ return result;
+ }
+
+ // Instantiates an annotated span from a rule match and appends it to the
+ // result.
+ bool AddAnnotatedSpanFromMatch(
+ const std::vector<UnicodeText::const_iterator>& text,
+ const grammar::Derivation& candidate,
+ std::vector<AnnotatedSpan>* result) const {
+ if (candidate.rule_id < 0 ||
+ candidate.rule_id >= model_->rule_classification_result()->size()) {
+ TC3_LOG(INFO) << "Invalid rule id.";
+ return false;
+ }
+ const GrammarModel_::RuleClassificationResult* interpretation =
+ model_->rule_classification_result()->Get(candidate.rule_id);
+ result->emplace_back();
+ return InstantiateAnnotatedSpanFromInterpretation(
+ text, interpretation, candidate.match, &result->back());
+ }
+
+ bool InstantiateAnnotatedSpanFromInterpretation(
+ const std::vector<UnicodeText::const_iterator>& text,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ const grammar::Match* match, AnnotatedSpan* result) const {
+ result->span = MatchSelectionBoundaries(match, interpretation);
+ ClassificationResult classification;
+ if (!InstantiateClassificationInterpretation(text, interpretation, match,
+ &classification)) {
+ return false;
+ }
+ result->classification.push_back(classification);
+ return true;
+ }
+
+ // Instantiates a classification result from a rule match.
+ bool InstantiateClassificationInterpretation(
+ const std::vector<UnicodeText::const_iterator>& text,
+ const GrammarModel_::RuleClassificationResult* interpretation,
+ const grammar::Match* match, ClassificationResult* classification) const {
+ classification->collection = interpretation->collection_name()->str();
+ classification->score = interpretation->target_classification_score();
+ classification->priority_score = interpretation->priority_score();
+
+ // Assemble entity data.
+ if (entity_data_builder_ == nullptr) {
+ return true;
+ }
+ std::unique_ptr<ReflectiveFlatbuffer> entity_data =
+ entity_data_builder_->NewRoot();
+ if (interpretation->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+ if (interpretation->entity_data() != nullptr) {
+ entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
+ interpretation->entity_data()));
+ }
+
+ // Populate entity data from the capturing matches.
+ if (interpretation->capturing_group() != nullptr) {
+ // Gather active capturing matches.
+ std::unordered_map<uint16, const grammar::Match*> capturing_matches;
+ for (const grammar::MappingMatch* match :
+ grammar::SelectAllOfType<grammar::MappingMatch>(
+ match, grammar::Match::kMappingMatch)) {
+ capturing_matches[match->id] = match;
+ }
+ for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
+ auto it = capturing_matches.find(i);
+ if (it == capturing_matches.end()) {
+ // Capturing group is not active, skip.
+ continue;
+ }
+ const CapturingGroup* group = interpretation->capturing_group()->Get(i);
+
+ // Add static entity data.
+ if (group->serialized_entity_data() != nullptr) {
+ entity_data->MergeFromSerializedFlatbuffer(
+ StringPiece(interpretation->serialized_entity_data()->data(),
+ interpretation->serialized_entity_data()->size()));
+ }
+
+ // Set entity field from captured text.
+ if (group->entity_field_path() != nullptr) {
+ const grammar::Match* capturing_match = it->second;
+ StringPiece group_text = StringPiece(
+ text[capturing_match->codepoint_span.first].utf8_data(),
+ text[capturing_match->codepoint_span.second].utf8_data() -
+ text[capturing_match->codepoint_span.first].utf8_data());
+ UnicodeText normalized_group_text =
+ UTF8ToUnicodeText(group_text, /*do_copy=*/false);
+ if (group->normalization_options() != nullptr) {
+ normalized_group_text = NormalizeText(
+ unilib_, group->normalization_options(), normalized_group_text);
+ }
+ if (!entity_data->ParseAndSet(group->entity_field_path(),
+ normalized_group_text.ToUTF8String())) {
+ TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
+ return false;
+ }
+ }
+ }
+ }
+
+ if (entity_data && entity_data->HasExplicitlySetFields()) {
+ classification->serialized_entity_data = entity_data->Serialize();
+ }
+ return true;
+ }
+
+ const UniLib& unilib_;
+ const GrammarModel* model_;
+ const ReflectiveFlatbufferBuilder* entity_data_builder_;
+ const ModeFlag mode_;
+
+ // All annotation/selection/classification rule match candidates.
+ // Grammar rule matches are recorded, deduplicated and then instantiated.
+ std::vector<grammar::Derivation> candidates_;
+};
+
+GrammarAnnotator::GrammarAnnotator(
+ const UniLib* unilib, const GrammarModel* model,
+ const ReflectiveFlatbufferBuilder* entity_data_builder)
+ : unilib_(*unilib),
+ model_(model),
+ lexer_(unilib, model->rules()),
+ tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
+ entity_data_builder_(entity_data_builder),
+ rules_locales_(grammar::ParseRulesLocales(model->rules())) {}
+
+bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
+ const UnicodeText& text,
+ std::vector<AnnotatedSpan>* result) const {
+ if (model_ == nullptr || model_->rules() == nullptr) {
+ // Nothing to do.
+ return true;
+ }
+
+ // Select locale matching rules.
+ std::vector<const grammar::RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ // Run the grammar.
+ GrammarAnnotatorCallbackDelegate callback_handler(
+ &unilib_, model_, entity_data_builder_,
+ /*mode=*/ModeFlag_ANNOTATION);
+ grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
+ &callback_handler);
+ lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
+ &matcher);
+
+ // Populate results.
+ return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result);
+}
+
+bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
+ const UnicodeText& text,
+ const CodepointSpan& selection,
+ AnnotatedSpan* result) const {
+ if (model_ == nullptr || model_->rules() == nullptr ||
+ selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
+ // Nothing to do.
+ return false;
+ }
+
+ // Select locale matching rules.
+ std::vector<const grammar::RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ // Run the grammar.
+ GrammarAnnotatorCallbackDelegate callback_handler(
+ &unilib_, model_, entity_data_builder_,
+ /*mode=*/ModeFlag_SELECTION);
+ grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
+ &callback_handler);
+ lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
+ &matcher);
+
+ // Populate the result.
+ return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text),
+ selection, result);
+}
+
+bool GrammarAnnotator::ClassifyText(
+ const std::vector<Locale>& locales, const UnicodeText& text,
+ const CodepointSpan& selection,
+ ClassificationResult* classification_result) const {
+ if (model_ == nullptr || model_->rules() == nullptr ||
+ selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
+ // Nothing to do.
+ return false;
+ }
+
+ // Select locale matching rules.
+ std::vector<const grammar::RulesSet_::Rules*> locale_rules =
+ SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
+ if (locale_rules.empty()) {
+ // Nothing to do.
+ return false;
+ }
+
+ // Run the grammar.
+ GrammarAnnotatorCallbackDelegate callback_handler(
+ &unilib_, model_, entity_data_builder_,
+ /*mode=*/ModeFlag_CLASSIFICATION);
+ grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
+ &callback_handler);
+
+ const std::vector<Token> tokens = tokenizer_.Tokenize(text);
+ if (model_->context_left_num_tokens() == -1 &&
+ model_->context_right_num_tokens() == -1) {
+ // Use all tokens.
+ lexer_.Process(text, tokens, /*annotations=*/{}, &matcher);
+ } else {
+ TokenSpan context_span = CodepointSpanToTokenSpan(
+ tokens, selection, /*snap_boundaries_to_containing_tokens=*/true);
+ std::vector<Token>::const_iterator begin = tokens.begin();
+ std::vector<Token>::const_iterator end = tokens.begin();
+ if (model_->context_left_num_tokens() != -1) {
+ std::advance(begin, std::max(0, context_span.first -
+ model_->context_left_num_tokens()));
+ }
+ if (model_->context_right_num_tokens() == -1) {
+ end = tokens.end();
+ } else {
+ std::advance(end, std::min(static_cast<int>(tokens.size()),
+ context_span.second +
+ model_->context_right_num_tokens()));
+ }
+ lexer_.Process(text, begin, end,
+ /*annotations=*/nullptr, &matcher);
+ }
+
+ // Populate result.
+ return callback_handler.GetClassification(UnicodeCodepointOffsets(text),
+ selection, classification_result);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/grammar-annotator.h b/native/annotator/grammar/grammar-annotator.h
new file mode 100644
index 0000000..365bb44
--- /dev/null
+++ b/native/annotator/grammar/grammar-annotator.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_GRAMMAR_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_GRAMMAR_ANNOTATOR_H_
+
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/grammar/lexer.h"
+#include "utils/i18n/locale.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Grammar backed annotator.
+class GrammarAnnotator {
+ public:
+ enum class Callback : grammar::CallbackId {
+ kRuleMatch = 1,
+ };
+
+ explicit GrammarAnnotator(
+ const UniLib* unilib, const GrammarModel* model,
+ const ReflectiveFlatbufferBuilder* entity_data_builder);
+
+ // Annotates a given text.
+ // Returns true if the text was successfully annotated.
+ bool Annotate(const std::vector<Locale>& locales, const UnicodeText& text,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Classifies a span in a text.
+ // Returns true if the span was classified by a grammar rule.
+ bool ClassifyText(const std::vector<Locale>& locales, const UnicodeText& text,
+ const CodepointSpan& selection,
+ ClassificationResult* classification_result) const;
+
+ // Suggests text selections in a text.
+ // Returns true if a span was suggested by a grammar rule.
+ bool SuggestSelection(const std::vector<Locale>& locales,
+ const UnicodeText& text, const CodepointSpan& selection,
+ AnnotatedSpan* result) const;
+
+ private:
+ const UniLib& unilib_;
+ const GrammarModel* model_;
+ const grammar::Lexer lexer_;
+ const Tokenizer tokenizer_;
+ const ReflectiveFlatbufferBuilder* entity_data_builder_;
+
+ // Pre-parsed locales of the rules.
+ const std::vector<std::vector<Locale>> rules_locales_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_GRAMMAR_ANNOTATOR_H_
diff --git a/native/annotator/grammar/utils.cc b/native/annotator/grammar/utils.cc
new file mode 100644
index 0000000..8b9363d
--- /dev/null
+++ b/native/annotator/grammar/utils.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/grammar/utils.h"
+
+#include "utils/grammar/utils/rules.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using ::libtextclassifier3::GrammarModel_::RuleClassificationResultT;
+
+} // namespace
+
+Tokenizer BuildTokenizer(const UniLib* unilib,
+ const GrammarTokenizerOptions* options) {
+ TC3_CHECK(options != nullptr);
+
+ std::vector<const TokenizationCodepointRange*> codepoint_config;
+ if (options->tokenization_codepoint_config() != nullptr) {
+ codepoint_config.insert(codepoint_config.end(),
+ options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end());
+ }
+ std::vector<const CodepointRange*> internal_codepoint_config;
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ internal_codepoint_config.insert(
+ internal_codepoint_config.end(),
+ options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end());
+ }
+
+ const bool tokenize_on_script_change =
+ options->tokenization_codepoint_config() != nullptr &&
+ options->tokenize_on_script_change();
+ return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
+ internal_codepoint_config, tokenize_on_script_change,
+ /*icu_preserve_whitespace_tokens=*/false);
+}
+
+int AddRuleClassificationResult(const std::string& collection,
+ const ModeFlag& enabled_modes,
+ GrammarModelT* model) {
+ const int result_id = model->rule_classification_result.size();
+ model->rule_classification_result.emplace_back(new RuleClassificationResultT);
+ RuleClassificationResultT* result =
+ model->rule_classification_result.back().get();
+ result->collection_name = collection;
+ result->enabled_modes = enabled_modes;
+ return result_id;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/grammar/utils.h b/native/annotator/grammar/utils.h
new file mode 100644
index 0000000..4d870fd
--- /dev/null
+++ b/native/annotator/grammar/utils.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common utility functions for grammar annotators.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_UTILS_H_
+
+#include "annotator/model_generated.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Builds a tokenizer instance from options.
+Tokenizer BuildTokenizer(const UniLib* unilib,
+ const GrammarTokenizerOptions* options);
+
+// Adds a rule classification result to the |model|.
+// collection: the classification entity detected.
+// enabled_modes: the target to apply the given rule.
+// Returns the ID associated with the created classification rule.
+int AddRuleClassificationResult(const std::string& collection,
+ const ModeFlag& enabled_modes,
+ GrammarModelT* model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_UTILS_H_
diff --git a/annotator/installed_app/installed-app-engine-dummy.h b/native/annotator/installed_app/installed-app-engine-dummy.h
similarity index 100%
rename from annotator/installed_app/installed-app-engine-dummy.h
rename to native/annotator/installed_app/installed-app-engine-dummy.h
diff --git a/annotator/installed_app/installed-app-engine.h b/native/annotator/installed_app/installed-app-engine.h
similarity index 100%
rename from annotator/installed_app/installed-app-engine.h
rename to native/annotator/installed_app/installed-app-engine.h
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
new file mode 100644
index 0000000..e9f688a
--- /dev/null
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/status.h"
+#include "utils/optional.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the knowledge engine.
+class KnowledgeEngine {
+ public:
+ bool Initialize(const std::string& serialized_config, const UniLib* unilib) {
+ return true;
+ }
+
+ void SetPriorityScore(float priority_score) {}
+
+ bool ClassifyText(const std::string& text, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const std::string& text, AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+
+ Status ChunkMultipleSpans(
+ const std::vector<std::string>& text_fragments,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions,
+ std::vector<std::vector<AnnotatedSpan>>* results) const {
+ return Status::OK;
+ }
+
+ bool LookUpEntity(const std::string& id,
+ std::string* serialized_knowledge_result) const {
+ return false;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_DUMMY_H_
diff --git a/annotator/knowledge/knowledge-engine.h b/native/annotator/knowledge/knowledge-engine.h
similarity index 100%
rename from annotator/knowledge/knowledge-engine.h
rename to native/annotator/knowledge/knowledge-engine.h
diff --git a/annotator/model-executor.cc b/native/annotator/model-executor.cc
similarity index 100%
rename from annotator/model-executor.cc
rename to native/annotator/model-executor.cc
diff --git a/native/annotator/model-executor.h b/native/annotator/model-executor.h
new file mode 100644
index 0000000..5d6c4a7
--- /dev/null
+++ b/native/annotator/model-executor.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains classes that can execute different models/parts of a model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
+
+#include <memory>
+
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/tensor-view.h"
+#include "utils/tflite-model-executor.h"
+
+namespace libtextclassifier3 {
+
+// Executor for the text selection prediction and classification models.
+class ModelExecutor : public TfLiteModelExecutor {
+ public:
+ static std::unique_ptr<ModelExecutor> FromModelSpec(
+ const tflite::Model* model_spec) {
+ auto model = TfLiteModelFromModelSpec(model_spec);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
+ }
+
+ static std::unique_ptr<ModelExecutor> FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
+ auto model = TfLiteModelFromBuffer(model_spec_buffer);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
+ }
+
+ TensorView<float> ComputeLogits(const TensorView<float>& features,
+ tflite::Interpreter* interpreter) const;
+
+ protected:
+ explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
+ : TfLiteModelExecutor(std::move(model)) {}
+
+ static constexpr int kInputIndexFeatures = 0;
+ static constexpr int kOutputIndexLogits = 0;
+};
+
+// Executor for embedding sparse features into a dense vector.
+class EmbeddingExecutor {
+ public:
+ virtual ~EmbeddingExecutor() {}
+
+ // Embeds the sparse_features into a dense embedding and adds (+) it
+ // element-wise to the dest vector.
+ virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const = 0;
+
+ // Returns true when the model is ready to be used, false otherwise.
+ virtual bool IsReady() const { return true; }
+};
+
+class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
+ public:
+ static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
+
+ // Embeds the sparse_features into a dense embedding and adds (+) it
+ // element-wise to the dest vector.
+ bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
+ int dest_size) const;
+
+ // Auxiliary function for computing prefixes used in implementation of
+ // efficient mask indexing data structure.
+ void ComputePrefixCounts();
+
+ // Function implementing mask indexing based on efficient data structure
+ int PruneBucketId(int bucket_id) const;
+
+ protected:
+ explicit TFLiteEmbeddingExecutor(
+ std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
+ int num_buckets, int bytes_per_embedding, int output_embedding_size,
+ const TfLiteTensor* scales, const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter,
+ const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
+
+ std::unique_ptr<TfLiteModelExecutor> executor_;
+
+ int quantization_bits_;
+ int num_buckets_ = -1;
+ int bytes_per_embedding_ = -1;
+ int output_embedding_size_ = -1;
+ const TfLiteTensor* scales_ = nullptr;
+ const TfLiteTensor* embeddings_ = nullptr;
+
+ // NOTE: This interpreter is used in a read-only way (as a storage for the
+ // model params), thus is still thread-safe.
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+
+ std::vector<uint64> pruning_mask_;
+ std::vector<uint16> prefix_counts_;
+ int full_num_buckets_ = -1;
+
+ // Index of row of embedding table corresponding to all pruned buckets.
+ int pruned_row_bucket_id_ = -1;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
new file mode 100755
index 0000000..bdb7a17
--- /dev/null
+++ b/native/annotator/model.fbs
@@ -0,0 +1,988 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "annotator/entity-data.fbs";
+include "annotator/experimental/experimental.fbs";
+include "annotator/grammar/dates/dates.fbs";
+include "utils/codepoint-range.fbs";
+include "utils/flatbuffers.fbs";
+include "utils/grammar/rules.fbs";
+include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
+include "utils/resources.fbs";
+include "utils/tokenizer.fbs";
+include "utils/zlib/buffer.fbs";
+
+file_identifier "TC2 ";
+
+// The possible model modes, represents a bit field.
+namespace libtextclassifier3;
+enum ModeFlag : int {
+ NONE = 0,
+ ANNOTATION = 1,
+ CLASSIFICATION = 2,
+ ANNOTATION_AND_CLASSIFICATION = 3,
+ SELECTION = 4,
+ ANNOTATION_AND_SELECTION = 5,
+ CLASSIFICATION_AND_SELECTION = 6,
+ ALL = 7,
+}
+
+// Enum for specifying the annotation usecase.
+namespace libtextclassifier3;
+enum AnnotationUsecase : int {
+ // Results are optimized for Smart{Select,Share,Linkify}.
+ ANNOTATION_USECASE_SMART = 0,
+ // Smart{Select,Share,Linkify}
+
+ // Results are optimized for using TextClassifier as an infrastructure that
+ // annotates as much as possible.
+ ANNOTATION_USECASE_RAW = 1,
+}
+
+namespace libtextclassifier3;
+enum DatetimeExtractorType : int {
+ UNKNOWN_DATETIME_EXTRACTOR_TYPE = 0,
+ AM = 1,
+ PM = 2,
+ JANUARY = 3,
+ FEBRUARY = 4,
+ MARCH = 5,
+ APRIL = 6,
+ MAY = 7,
+ JUNE = 8,
+ JULY = 9,
+ AUGUST = 10,
+ SEPTEMBER = 11,
+ OCTOBER = 12,
+ NOVEMBER = 13,
+ DECEMBER = 14,
+ NEXT = 15,
+ NEXT_OR_SAME = 16,
+ LAST = 17,
+ NOW = 18,
+ TOMORROW = 19,
+ YESTERDAY = 20,
+ PAST = 21,
+ FUTURE = 22,
+ DAY = 23,
+ WEEK = 24,
+ MONTH = 25,
+ YEAR = 26,
+ MONDAY = 27,
+ TUESDAY = 28,
+ WEDNESDAY = 29,
+ THURSDAY = 30,
+ FRIDAY = 31,
+ SATURDAY = 32,
+ SUNDAY = 33,
+ DAYS = 34,
+ WEEKS = 35,
+ MONTHS = 36,
+
+ // TODO(zilka): Make the following 3 values singular for consistency.
+ HOURS = 37,
+
+ MINUTES = 38,
+ SECONDS = 39,
+ YEARS = 40,
+ DIGITS = 41,
+ SIGNEDDIGITS = 42,
+ ZERO = 43,
+ ONE = 44,
+ TWO = 45,
+ THREE = 46,
+ FOUR = 47,
+ FIVE = 48,
+ SIX = 49,
+ SEVEN = 50,
+ EIGHT = 51,
+ NINE = 52,
+ TEN = 53,
+ ELEVEN = 54,
+ TWELVE = 55,
+ THIRTEEN = 56,
+ FOURTEEN = 57,
+ FIFTEEN = 58,
+ SIXTEEN = 59,
+ SEVENTEEN = 60,
+ EIGHTEEN = 61,
+ NINETEEN = 62,
+ TWENTY = 63,
+ THIRTY = 64,
+ FORTY = 65,
+ FIFTY = 66,
+ SIXTY = 67,
+ SEVENTY = 68,
+ EIGHTY = 69,
+ NINETY = 70,
+ HUNDRED = 71,
+ THOUSAND = 72,
+}
+
+namespace libtextclassifier3;
+enum DatetimeGroupType : int {
+ GROUP_UNKNOWN = 0,
+ GROUP_UNUSED = 1,
+ GROUP_YEAR = 2,
+ GROUP_MONTH = 3,
+ GROUP_DAY = 4,
+ GROUP_HOUR = 5,
+ GROUP_MINUTE = 6,
+ GROUP_SECOND = 7,
+ GROUP_AMPM = 8,
+ GROUP_RELATIONDISTANCE = 9,
+ GROUP_RELATION = 10,
+ GROUP_RELATIONTYPE = 11,
+
+ // Dummy groups serve just as an inflator of the selection. E.g. we might want
+ // to select more text than was contained in an envelope of all extractor
+ // spans.
+ GROUP_DUMMY1 = 12,
+
+ GROUP_DUMMY2 = 13,
+}
+
+// Options for the model that predicts text selection.
+namespace libtextclassifier3;
+table SelectionModelOptions {
+ // If true, before the selection is returned, the unpaired brackets contained
+ // in the predicted selection are stripped from the both selection ends.
+ // The bracket codepoints are defined in the Unicode standard:
+ // http://www.unicode.org/Public/UNIDATA/BidiBrackets.txt
+ strip_unpaired_brackets:bool = true;
+
+ // Number of hypothetical click positions on either side of the actual click
+ // to consider in order to enforce symmetry.
+ symmetry_context_size:int;
+
+ // Number of examples to bundle in one batch for inference.
+ batch_size:int = 1024;
+
+ // Whether to always classify a suggested selection or only on demand.
+ always_classify_suggested_selection:bool = false;
+}
+
+// Options for the model that classifies a text selection.
+namespace libtextclassifier3;
+table ClassificationModelOptions {
+ // Limits for phone numbers.
+ phone_min_num_digits:int = 7;
+
+ phone_max_num_digits:int = 15;
+
+ // Limits for addresses.
+ address_min_num_tokens:int;
+
+ // Maximum number of tokens to attempt a classification (-1 is unlimited).
+ max_num_tokens:int = -1;
+}
+
+// Options for post-checks, checksums and verification to apply on a match.
+namespace libtextclassifier3;
+table VerificationOptions {
+ verify_luhn_checksum:bool = false;
+
+ // Lua verifier to use.
+ // Index of the lua verifier in the model.
+ lua_verifier:int = -1;
+}
+
+// Behaviour of rule capturing groups.
+// This specifies how the text and span of a capturing group, in a regular
+// expression or from a capturing match in a grammar rule, should be handled.
+namespace libtextclassifier3;
+table CapturingGroup {
+ // If true, the span of the capturing group will be used to
+ // extend the selection.
+ extend_selection:bool = true;
+
+ // If set, the text of the capturing group will be used to set a field in
+ // the classfication result entity data.
+ entity_field_path:FlatbufferFieldPath;
+
+ // If set, the flatbuffer entity data will be merged with the
+ // classification result entity data.
+ serialized_entity_data:string (shared);
+
+ // If set, normalization to apply before text is used in entity data.
+ normalization_options:NormalizationOptions;
+
+ entity_data:EntityData;
+}
+
+// List of regular expression matchers to check.
+namespace libtextclassifier3.RegexModel_;
+table Pattern {
+ // The name of the collection of a match.
+ collection_name:string (shared);
+
+ // The pattern to check.
+ pattern:string (shared);
+
+ // The modes for which to apply the patterns.
+ enabled_modes:ModeFlag = ALL;
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // If true, will use an approximate matching implementation implemented
+ // using Find() instead of the true Match(). This approximate matching will
+ // use the first Find() result and then check that it spans the whole input.
+ use_approximate_matching:bool = false;
+
+ compressed_pattern:CompressedBuffer;
+
+ // Verification to apply on a match.
+ verification_options:VerificationOptions;
+
+ capturing_group:[CapturingGroup];
+
+ // Entity data to set for a match.
+ serialized_entity_data:string (shared);
+
+ entity_data:EntityData;
+}
+
+namespace libtextclassifier3;
+table RegexModel {
+ patterns:[RegexModel_.Pattern];
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
+
+ // Lua scripts for match verification.
+ // The verifier can access:
+ // * `context`: The context as a string.
+ // * `match`: The groups of the regex match as an array, each group gives
+ // * `begin`: span start
+ // * `end`: span end
+ // * `text`: the text
+ // The verifier is expected to return a boolean, indicating whether the
+ // verification succeeded or not.
+ lua_verifier:[string];
+}
+
+// List of regex patterns.
+namespace libtextclassifier3.DatetimeModelPattern_;
+table Regex {
+ pattern:string (shared);
+
+ // The ith entry specifies the type of the ith capturing group.
+ // This is used to decide how the matched content has to be parsed.
+ groups:[DatetimeGroupType];
+
+ compressed_pattern:CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table DatetimeModelPattern {
+ regexes:[DatetimeModelPattern_.Regex];
+
+ // List of locale indices in DatetimeModel that represent the locales that
+ // these patterns should be used for. If empty, can be used for all locales.
+ locales:[int];
+
+ // The final score to assign to the results of this pattern.
+ target_classification_score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes for which to apply the patterns.
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to apply the patterns.
+ // This is a flag field for values of AnnotationUsecase.
+ enabled_annotation_usecases:uint = 4294967295;
+}
+
+namespace libtextclassifier3;
+table DatetimeModelExtractor {
+ extractor:DatetimeExtractorType;
+ pattern:string (shared);
+ locales:[int];
+ compressed_pattern:CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table DatetimeModel {
+ // List of BCP 47 locale strings representing all locales supported by the
+ // model. The individual patterns refer back to them using an index.
+ locales:[string];
+
+ patterns:[DatetimeModelPattern];
+ extractors:[DatetimeModelExtractor];
+
+ // If true, will use the extractors for determining the match location as
+ // opposed to using the location where the global pattern matched.
+ use_extractors_for_locating:bool = true;
+
+ // List of locale ids, rules of whose are always run, after the requested
+ // ones.
+ default_locales:[int];
+
+ // If true, will generate the alternative interpretations for ambiguous
+ // datetime expressions.
+ generate_alternative_interpretations_when_ambiguous:bool = false;
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool = true;
+
+ // If true, will give only future dates (when the day is not specified).
+ prefer_future_for_unspecified_date:bool = false;
+}
+
+// Configuration for the tokenizer.
+namespace libtextclassifier3;
+table GrammarTokenizerOptions {
+ tokenization_type:TokenizationType = ICU;
+
+ // If true, white space tokens will be kept when using the icu tokenizer.
+ icu_preserve_whitespace_tokens:bool = false;
+
+ // Codepoint ranges that determine what role the different codepoints play
+ // during tokenized. The ranges must not overlap.
+ tokenization_codepoint_config:[TokenizationCodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[CodepointRange];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = false;
+}
+
+// Options for grammar date/datetime/date range annotations.
+namespace libtextclassifier3.GrammarDatetimeModel_;
+table AnnotationOptions {
+ // If enabled, extract special day offset like today, yesterday, etc.
+ enable_special_day_offset:bool = true;
+
+ // If true, merge the adjacent day of week, time and date. e.g.
+ // "20/2/2016 at 8pm" is extracted as a single instance instead of two
+ // instance: "20/2/2016" and "8pm".
+ merge_adjacent_components:bool = true;
+
+ // List the extra id of requested dates.
+ extra_requested_dates:[string];
+
+ // If true, try to include preposition to the extracted annotation. e.g.
+ // "at 6pm". if it's false, only 6pm is included. offline-actions has
+ // special requirements to include preposition.
+ include_preposition:bool = true;
+
+ // If enabled, extract range in date annotator.
+ // input: Monday, 5-6pm
+ // If the flag is true, The extracted annotation only contains 1 range
+ // instance which is from Monday 5pm to 6pm.
+ // If the flag is false, The extracted annotation contains two date
+ // instance: "Monday" and "6pm".
+ enable_date_range:bool = true;
+ reserved_6:int16 (deprecated);
+
+ // If enabled, the rule priority score is used to set the priority score of
+ // the annotation.
+ // In case of false the annotation priority score is set from
+ // GrammarDatetimeModel's priority_score
+ use_rule_priority_score:bool = false;
+
+ // If enabled, annotator will try to resolve the ambiguity by generating
+ // possible alternative interpretations of the input text
+ // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'.
+ generate_alternative_interpretations_when_ambiguous:bool;
+
+ // List of spans which grammar will ignore during the match e.g. if
+ // “@” is in the allowed span list and input is “12 March @ 12PM” then “@”
+ // will be ignored and 12 March @ 12PM will be translate to
+ // {Day:12 Month: March Hour: 12 MERIDIAN: PM}.
+ // This can also be achieved by adding additional rules e.g.
+ // <Digit_Day> <Month> <Time>
+ // <Digit_Day> <Month> @ <Time>
+ // Though this is doable in the grammar but requires multiple rules, this
+ // list enables the rule to represent multiple rules.
+ ignored_spans:[string];
+}
+
+namespace libtextclassifier3;
+table GrammarDatetimeModel {
+ // List of BCP 47 locale strings representing all locales supported by the
+ // model.
+ locales:[string];
+
+ // If true, will give only future dates (when the day is not specified).
+ prefer_future_for_unspecified_date:bool = false;
+
+ // Grammar specific tokenizer options.
+ grammar_tokenizer_options:GrammarTokenizerOptions;
+
+ // The modes for which to apply the grammars.
+ enabled_modes:ModeFlag = ALL;
+
+ // The datetime grammar rules.
+ datetime_rules:dates.DatetimeRules;
+
+ // The final score to assign to the results of grammar model
+ target_classification_score:float = 1;
+
+ // The priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // Options for grammar annotations.
+ annotation_options:GrammarDatetimeModel_.AnnotationOptions;
+}
+
+namespace libtextclassifier3.DatetimeModelLibrary_;
+table Item {
+ key:string (shared);
+ value:DatetimeModel;
+}
+
+// A set of named DateTime models.
+namespace libtextclassifier3;
+table DatetimeModelLibrary {
+ models:[DatetimeModelLibrary_.Item];
+}
+
+// Classification result to instantiate for a rule match.
+namespace libtextclassifier3.GrammarModel_;
+table RuleClassificationResult {
+ // The name of the collection.
+ collection_name:string (shared);
+
+ // The score.
+ target_classification_score:float = 1;
+
+ // The priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // Behaviour of capturing matches.
+ capturing_group:[CapturingGroup];
+
+ // Entity data to set for a match.
+ serialized_entity_data:string (shared);
+
+ // Enabled modes.
+ enabled_modes:ModeFlag = ALL;
+
+ entity_data:EntityData;
+}
+
+// Configuration for grammar based annotators.
+namespace libtextclassifier3;
+table GrammarModel {
+ // The grammar rules.
+ rules:grammar.RulesSet;
+
+ rule_classification_result:[GrammarModel_.RuleClassificationResult];
+
+ // Number of tokens in the context to use for classification and text
+ // selection suggestion.
+ // A value -1 uses the full context.
+ context_left_num_tokens:int;
+
+ context_right_num_tokens:int;
+
+ // Grammar specific tokenizer options.
+ tokenizer_options:GrammarTokenizerOptions;
+}
+
+namespace libtextclassifier3;
+table MoneyParsingOptions {
+ // Separators (codepoints) marking decimal or thousand in the money amount.
+ separators:[int];
+}
+
+namespace libtextclassifier3.ModelTriggeringOptions_;
+table CollectionToPriorityEntry {
+ key:string (key, shared);
+ value:float;
+}
+
+// Options controlling the output of the Tensorflow Lite models.
+namespace libtextclassifier3;
+table ModelTriggeringOptions {
+ // Lower bound threshold for filtering annotation model outputs.
+ min_annotate_confidence:float = 0;
+
+ // The modes for which to enable the models.
+ enabled_modes:ModeFlag = ALL;
+
+ // Comma-separated list of locales (BCP 47 tags) that dictionary
+ // classification supports.
+ dictionary_locales:string (shared);
+
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ locales:string (shared);
+
+ // Priority score assigned to the "other" class from ML model.
+ other_collection_priority_score:float = -1000;
+
+ // Priority score assigned to knowledge engine annotations.
+ knowledge_priority_score:float = 0;
+ reserved_7:int16 (deprecated);
+
+ // Apply a factor to the priority score for entities that are added to this
+ // map. Key: collection type e.g. "address", "phone"..., Value: float number.
+ // NOTE: The entries here need to be sorted since we use LookupByKey.
+ collection_to_priority:[ModelTriggeringOptions_.CollectionToPriorityEntry];
+}
+
+// Options controlling the output of the classifier.
+namespace libtextclassifier3;
+table OutputOptions {
+ // Lists of collection names that will be filtered out at the output:
+ // - For annotation, the spans of given collection are simply dropped.
+ // - For classification, the result is mapped to the class "other".
+ // - For selection, the spans of given class are returned as
+ // single-selection.
+ filtered_collections_annotation:[string];
+
+ filtered_collections_classification:[string];
+ filtered_collections_selection:[string];
+}
+
+namespace libtextclassifier3.Model_;
+table EmbeddingPruningMask {
+ // If true, use pruning mask. In this case, we use mask
+ // pruning_mask to determine the mapping of hashed-charactergrams.
+ enabled:bool;
+
+ // Packing of the binary pruning mask into uint64 values.
+ pruning_mask:[ulong] (force_align: 16);
+
+ // Number of buckets before pruning.
+ full_num_buckets:int;
+
+ // Index of row of compressed embedding matrix to which all pruned buckets
+ // are mapped.
+ pruned_row_bucket_id:int;
+}
+
+namespace libtextclassifier3.Model_;
+table ConflictResolutionOptions {
+ // If true, will prioritize the longest annotation during conflict
+ // resolution.
+ prioritize_longest_annotation:bool = false;
+
+ // If true, the annotator will perform conflict resolution between the
+ // different sub-annotators also in the RAW mode. If false, no conflict
+ // resolution will be performed in RAW mode.
+ do_conflict_resolution_in_raw_mode:bool = true;
+}
+
+namespace libtextclassifier3;
+table Model {
+ // Comma-separated list of locales supported by the model as BCP 47 tags.
+ locales:string (shared);
+
+ version:int;
+
+ // A name for the model that can be used for e.g. logging.
+ name:string (shared);
+
+ selection_feature_options:FeatureProcessorOptions;
+ classification_feature_options:FeatureProcessorOptions;
+
+ // Tensorflow Lite models.
+ selection_model:[ubyte] (force_align: 16);
+
+ classification_model:[ubyte] (force_align: 16);
+ embedding_model:[ubyte] (force_align: 16);
+
+ // Options for the different models.
+ selection_options:SelectionModelOptions;
+
+ classification_options:ClassificationModelOptions;
+ regex_model:RegexModel;
+ datetime_model:DatetimeModel;
+
+ // Options controlling the output of the models.
+ triggering_options:ModelTriggeringOptions;
+
+ // Global switch that controls if SuggestSelection(), ClassifyText() and
+ // Annotate() will run. If a mode is disabled it returns empty/no-op results.
+ enabled_modes:ModeFlag = ALL;
+
+ // If true, will snap the selections that consist only of whitespaces to the
+ // containing suggested span. Otherwise, no suggestion is proposed, since the
+ // selections are not part of any token.
+ snap_whitespace_selections:bool = true;
+
+ // Global configuration for the output of SuggestSelection(), ClassifyText()
+ // and Annotate().
+ output_options:OutputOptions;
+
+ // Configures how Intents should be generated on Android.
+ android_intent_options:AndroidIntentFactoryOptions;
+
+ intent_options:IntentFactoryModel;
+
+ // Model resources.
+ resources:ResourcePool;
+
+ // Schema data for handling entity data.
+ entity_data_schema:[ubyte];
+
+ number_annotator_options:NumberAnnotatorOptions;
+ duration_annotator_options:DurationAnnotatorOptions;
+
+ // Comma-separated list of locales (BCP 47 tags) that the model supports, that
+ // are used to prevent triggering on input in unsupported languages. If
+ // empty, the model will trigger on all inputs.
+ triggering_locales:string (shared);
+
+ embedding_pruning_mask:Model_.EmbeddingPruningMask;
+ grammar_datetime_model:GrammarDatetimeModel;
+ contact_annotator_options:ContactAnnotatorOptions;
+ money_parsing_options:MoneyParsingOptions;
+ translate_annotator_options:TranslateAnnotatorOptions;
+ grammar_model:GrammarModel;
+ conflict_resolution_options:Model_.ConflictResolutionOptions;
+ experimental_model:ExperimentalModel;
+}
+
+// Method for selecting the center token.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+enum CenterTokenSelectionMethod : int {
+ DEFAULT_CENTER_TOKEN_METHOD = 0,
+ // Invalid option.
+
+ // Use click indices to determine the center token.
+ CENTER_TOKEN_FROM_CLICK = 1,
+
+ // Use selection indices to get a token range, and select the middle of it
+ // as the center token.
+ CENTER_TOKEN_MIDDLE_OF_SELECTION = 2,
+}
+
+// Bounds-sensitive feature extraction configuration.
+namespace libtextclassifier3.FeatureProcessorOptions_;
+table BoundsSensitiveFeatures {
+ // Enables the extraction of bounds-sensitive features, instead of the click
+ // context features.
+ enabled:bool;
+
+ // The numbers of tokens to extract in specific locations relative to the
+ // bounds.
+ // Immediately before the span.
+ num_tokens_before:int;
+
+ // Inside the span, aligned with the beginning.
+ num_tokens_inside_left:int;
+
+ // Inside the span, aligned with the end.
+ num_tokens_inside_right:int;
+
+ // Immediately after the span.
+ num_tokens_after:int;
+
+ // If true, also extracts the tokens of the entire span and adds up their
+ // features forming one "token" to include in the extracted features.
+ include_inside_bag:bool;
+
+ // If true, includes the selection length (in the number of tokens) as a
+ // feature.
+ include_inside_length:bool;
+
+ // If true, for selection, single token spans are not run through the model
+ // and their score is assumed to be zero.
+ score_single_token_spans_as_zero:bool;
+}
+
+namespace libtextclassifier3;
+table FeatureProcessorOptions {
+ // Number of buckets used for hashing charactergrams.
+ num_buckets:int = -1;
+
+ // Size of the embedding.
+ embedding_size:int = -1;
+
+ // Number of bits for quantization for embeddings.
+ embedding_quantization_bits:int = 8;
+
+ // Context size defines the number of words to the left and to the right of
+ // the selected word to be used as context. For example, if context size is
+ // N, then we take N words to the left and N words to the right of the
+ // selected word as its context.
+ context_size:int = -1;
+
+ // Maximum number of words of the context to select in total.
+ max_selection_span:int = -1;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ chargram_orders:[int];
+
+ // Maximum length of a word, in codepoints.
+ max_word_length:int = 20;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ unicode_aware_features:bool = false;
+
+ // Whether to extract the token case feature.
+ extract_case_feature:bool = false;
+
+ // Whether to extract the selection mask feature.
+ extract_selection_mask_feature:bool = false;
+
+ // List of regexps to run over each token. For each regexp, if there is a
+ // match, a dense feature of 1.0 is emitted. Otherwise -1.0 is used.
+ regexp_feature:[string];
+
+ // Whether to remap all digits to a single number.
+ remap_digits:bool = false;
+
+ // Whether to lower-case each token before generating hashgrams.
+ lowercase_tokens:bool;
+
+ // If true, the selection classifier output will contain only the selections
+ // that are feasible (e.g., those that are shorter than max_selection_span),
+ // if false, the output will be a complete cross-product of possible
+ // selections to the left and possible selections to the right, including the
+ // infeasible ones.
+ // NOTE: Exists mainly for compatibility with older models that were trained
+ // with the non-reduced output space.
+ selection_reduced_output_space:bool = true;
+
+ // Collection names.
+ collections:[string];
+
+ // An index of collection in collections to be used if a collection name can't
+ // be mapped to an id.
+ default_collection:int = -1;
+
+ // If true, will split the input by lines, and only use the line that contains
+ // the clicked token.
+ only_use_line_with_click:bool = false;
+
+ // If true, will split tokens that contain the selection boundary, at the
+ // position of the boundary.
+ // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+ split_tokens_on_selection_boundaries:bool = false;
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ tokenization_codepoint_config:[TokenizationCodepointRange];
+
+ center_token_selection_method:FeatureProcessorOptions_.CenterTokenSelectionMethod;
+
+ // If true, span boundaries will be snapped to containing tokens and not
+ // required to exactly match token boundaries.
+ snap_label_span_boundaries_to_containing_tokens:bool;
+
+ // A set of codepoint ranges supported by the model.
+ supported_codepoint_ranges:[CodepointRange];
+
+ // A set of codepoint ranges to use in the mixed tokenization mode to identify
+ // stretches of tokens to re-tokenize using the internal tokenizer.
+ internal_tokenizer_codepoint_ranges:[CodepointRange];
+
+ // Minimum ratio of supported codepoints in the input context. If the ratio
+ // is lower than this, the feature computation will fail.
+ min_supported_codepoint_ratio:float = 0;
+
+ // Used for versioning the format of features the model expects.
+ // - feature_version == 0:
+ // For each token the features consist of:
+ // - chargram embeddings
+ // - dense features
+ // Chargram embeddings for tokens are concatenated first together,
+ // and at the end, the dense features for the tokens are concatenated
+ // to it. So the resulting feature vector has two regions.
+ feature_version:int = 0;
+
+ tokenization_type:TokenizationType = INTERNAL_TOKENIZER;
+ icu_preserve_whitespace_tokens:bool = false;
+
+ // List of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ ignored_span_boundary_codepoints:[int];
+
+ bounds_sensitive_features:FeatureProcessorOptions_.BoundsSensitiveFeatures;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ // The field is typed as bytes type to allow non-UTF8 chargrams.
+ allowed_chargrams:[string];
+
+ // If true, tokens will be also split when the codepoint's script_id changes
+ // as defined in TokenizationCodepointRange.
+ tokenize_on_script_change:bool = false;
+
+ // If true, the pipe character '|' will be used as a newline character when
+ // splitting lines.
+ use_pipe_character_for_newline:bool = true;
+}
+
+namespace libtextclassifier3;
+table NumberAnnotatorOptions {
+ // If true, number and percentage annotations will be produced.
+ enabled:bool = false;
+
+ // Score to assign to the annotated numbers and percentages in the annotator.
+ score:float = 1;
+
+ // Number priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes in which to enable number and percentage annotations.
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to produce number annotations.
+ // This is a flag field for values of AnnotationUsecase.
+ enabled_annotation_usecases:uint = 4294967295;
+
+ // [Deprecated] A list of codepoints that can form a prefix of a valid number.
+ allowed_prefix_codepoints:[int];
+
+ // [Deprecated] A list of codepoints that can form a suffix of a valid number.
+ allowed_suffix_codepoints:[int];
+
+ // [Deprecated] List of codepoints that will be stripped from beginning of
+ // predicted spans.
+ ignored_prefix_span_boundary_codepoints:[int];
+
+ // [Deprecated] List of codepoints that will be stripped from end of predicted
+ // spans.
+ ignored_suffix_span_boundary_codepoints:[int];
+
+ // [Deprecated] If true, percent annotations will be produced.
+ enable_percentage:bool = false;
+
+ // Zero separated and ordered list of suffixes that mark a percent.
+ percentage_pieces_string:string (shared);
+
+ // [Deprecated] List of suffixes offsets in the percent_pieces_string string.
+ percentage_pieces_offsets:[int];
+
+ // Priority score for the percentage annotation.
+ percentage_priority_score:float = 1;
+
+ // Float number priority score used for conflict resolution with the other
+ // models.
+ float_number_priority_score:float = 0;
+
+ // The maximum number of digits an annotated number can have. Requirement:
+ // the value should be less or equal to 20.
+ max_number_of_digits:int = 20;
+
+ // The annotation usecases for which to produce percentage annotations.
+ // This is a flag field for values of AnnotationUsecase.
+ percentage_annotation_usecases:uint = 2;
+}
+
+// DurationAnnotator is so far tailored for English and Japanese only.
+namespace libtextclassifier3;
+table DurationAnnotatorOptions {
+ // If true, duration annotations will be produced.
+ enabled:bool = false;
+
+ // Score to assign to the annotated durations from the annotator.
+ score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float = 0;
+
+ // The modes in which to enable duration annotations.
+ enabled_modes:ModeFlag = ALL;
+
+ // The annotation usecases for which to produce duration annotations.
+ enabled_annotation_usecases:uint = 4294967295;
+
+ // Durations typically look like XX hours and XX minutes etc... The list of
+ // strings below enumerate variants of "hours", "minutes", etc. in these
+ // expressions. These are verbatim strings that are matched against tokens in
+ // the input.
+ week_expressions:[string];
+
+ day_expressions:[string];
+ hour_expressions:[string];
+ minute_expressions:[string];
+ second_expressions:[string];
+
+ // List of expressions that doesn't break a duration expression (can become
+ // a part of it) but has not semantic meaning.
+ filler_expressions:[string];
+
+ // List of expressions that mean half of a unit of duration (e.g. "half an
+ // hour").
+ half_expressions:[string];
+
+ // Set of condepoints that can split the Annotator tokens to sub-tokens for
+ // sub-token matching.
+ sub_token_separator_codepoints:[int];
+
+ // If this is true, unit must be associated with quantity. For example, a
+ // phrase "minute" is not parsed as one minute duration if this is true.
+ require_quantity:bool;
+
+ // If this is true, dangling quantity is included in the annotation. For
+ // example, "10 minutes 20" is interpreted as 10 minutes and 20 seconds.
+ enable_dangling_quantity_interpretation:bool = true;
+}
+
+namespace libtextclassifier3;
+table ContactAnnotatorOptions {
+ // Supported for English genitives only so far.
+ enable_declension:bool;
+
+ // For each language there is a customized list of supported declensions.
+ language:string (shared);
+}
+
+namespace libtextclassifier3.TranslateAnnotatorOptions_;
+enum Algorithm : int {
+ DEFAULT_ALGORITHM = 0,
+ BACKOFF = 1,
+}
+
+// Backoff is the algorithm shipped with Android Q.
+namespace libtextclassifier3.TranslateAnnotatorOptions_;
+table BackoffOptions {
+ // The minimum size of text to prefer for detection (in codepoints).
+ min_text_size:int = 20;
+
+ // For reducing the score when text is less than the preferred size.
+ penalize_ratio:float = 1;
+
+ // Original detection score to surrounding text detection score ratios.
+ subject_text_score_ratio:float = 0.4;
+}
+
+namespace libtextclassifier3;
+table TranslateAnnotatorOptions {
+ enabled:bool = false;
+
+ // Score to assign to the classification results.
+ score:float = 1;
+
+ // Priority score used for conflict resolution with the other models.
+ priority_score:float;
+
+ algorithm:TranslateAnnotatorOptions_.Algorithm;
+ backoff_options:TranslateAnnotatorOptions_.BackoffOptions;
+}
+
+root_type libtextclassifier3.Model;
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
new file mode 100644
index 0000000..3be6ad8
--- /dev/null
+++ b/native/annotator/number/number.cc
@@ -0,0 +1,314 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/number/number.h"
+
+#include <climits>
+#include <cstdlib>
+#include <string>
+
+#include "annotator/collections.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/split.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+bool NumberAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const {
+ TC3_CHECK(classification_result != nullptr);
+
+ const UnicodeText substring_selected = UnicodeText::Substring(
+ context, selection_indices.first, selection_indices.second);
+
+ std::vector<AnnotatedSpan> results;
+ if (!FindAll(substring_selected, annotation_usecase, &results)) {
+ return false;
+ }
+
+ for (const AnnotatedSpan& result : results) {
+ if (result.classification.empty()) {
+ continue;
+ }
+
+ // We make sure that the result span is equal to the stripped selection span
+ // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
+ // anyway only find valid numbers and percentages and a given selection with
+ // more than two tokens won't pass this check.
+ if (result.span.first + selection_indices.first ==
+ selection_indices.first &&
+ result.span.second + selection_indices.first ==
+ selection_indices.second) {
+ *classification_result = result.classification[0];
+ return true;
+ }
+ }
+ return false;
+}
+
+bool NumberAnnotator::IsCJTterm(UnicodeText::const_iterator token_begin_it,
+ const int token_length) const {
+ auto token_end_it = token_begin_it;
+ std::advance(token_end_it, token_length);
+ for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
+ if (!unilib_->IsCJTletter(*char_it)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool NumberAnnotator::TokensAreValidStart(const std::vector<Token>& tokens,
+ const int start_index) const {
+ if (start_index < 0 || tokens[start_index].is_whitespace) {
+ return true;
+ }
+ return false;
+}
+
+bool NumberAnnotator::TokensAreValidNumberPrefix(
+ const std::vector<Token>& tokens, const int prefix_end_index) const {
+ if (TokensAreValidStart(tokens, prefix_end_index)) {
+ return true;
+ }
+
+ auto prefix_begin_it =
+ UTF8ToUnicodeText(tokens[prefix_end_index].value, /*do_copy=*/false)
+ .begin();
+ const int token_length =
+ tokens[prefix_end_index].end - tokens[prefix_end_index].start;
+ if (token_length == 1 && unilib_->IsOpeningBracket(*prefix_begin_it) &&
+ TokensAreValidStart(tokens, prefix_end_index - 1)) {
+ return true;
+ }
+ if (token_length == 1 && unilib_->IsNumberSign(*prefix_begin_it) &&
+ TokensAreValidStart(tokens, prefix_end_index - 1)) {
+ return true;
+ }
+ if (token_length == 1 && unilib_->IsSlash(*prefix_begin_it) &&
+ prefix_end_index >= 1 &&
+ TokensAreValidStart(tokens, prefix_end_index - 2)) {
+ int64 int_val;
+ double double_val;
+ return TryParseNumber(UTF8ToUnicodeText(tokens[prefix_end_index - 1].value,
+ /*do_copy=*/false),
+ false, &int_val, &double_val);
+ }
+ if (IsCJTterm(prefix_begin_it, token_length)) {
+ return true;
+ }
+
+ return false;
+}
+
+bool NumberAnnotator::TokensAreValidEnding(const std::vector<Token>& tokens,
+ const int ending_index) const {
+ if (ending_index >= tokens.size() || tokens[ending_index].is_whitespace) {
+ return true;
+ }
+
+ auto ending_begin_it =
+ UTF8ToUnicodeText(tokens[ending_index].value, /*do_copy=*/false).begin();
+ if (ending_index == tokens.size() - 1 &&
+ tokens[ending_index].end - tokens[ending_index].start == 1 &&
+ unilib_->IsPunctuation(*ending_begin_it)) {
+ return true;
+ }
+ if (ending_index < tokens.size() - 1 &&
+ tokens[ending_index].end - tokens[ending_index].start == 1 &&
+ unilib_->IsPunctuation(*ending_begin_it) &&
+ tokens[ending_index + 1].is_whitespace) {
+ return true;
+ }
+
+ return false;
+}
+
+bool NumberAnnotator::TokensAreValidNumberSuffix(
+ const std::vector<Token>& tokens, const int suffix_start_index) const {
+ if (TokensAreValidEnding(tokens, suffix_start_index)) {
+ return true;
+ }
+
+ auto suffix_begin_it =
+ UTF8ToUnicodeText(tokens[suffix_start_index].value, /*do_copy=*/false)
+ .begin();
+
+ if (percent_suffixes_.find(tokens[suffix_start_index].value) !=
+ percent_suffixes_.end() &&
+ TokensAreValidEnding(tokens, suffix_start_index + 1)) {
+ return true;
+ }
+
+ const int token_length =
+ tokens[suffix_start_index].end - tokens[suffix_start_index].start;
+ if (token_length == 1 && unilib_->IsSlash(*suffix_begin_it) &&
+ suffix_start_index <= tokens.size() - 2 &&
+ TokensAreValidEnding(tokens, suffix_start_index + 2)) {
+ int64 int_val;
+ double double_val;
+ return TryParseNumber(
+ UTF8ToUnicodeText(tokens[suffix_start_index + 1].value,
+ /*do_copy=*/false),
+ false, &int_val, &double_val);
+ }
+ if (IsCJTterm(suffix_begin_it, token_length)) {
+ return true;
+ }
+
+ return false;
+}
+
+int NumberAnnotator::FindPercentSuffixEndCodepoint(
+ const std::vector<Token>& tokens,
+ const int suffix_token_start_index) const {
+ if (suffix_token_start_index >= tokens.size()) {
+ return -1;
+ }
+
+ if (percent_suffixes_.find(tokens[suffix_token_start_index].value) !=
+ percent_suffixes_.end() &&
+ TokensAreValidEnding(tokens, suffix_token_start_index + 1)) {
+ return tokens[suffix_token_start_index].end;
+ }
+ if (tokens[suffix_token_start_index].is_whitespace) {
+ return FindPercentSuffixEndCodepoint(tokens, suffix_token_start_index + 1);
+ }
+
+ return -1;
+}
+
+bool NumberAnnotator::TryParseNumber(const UnicodeText& token_text,
+ const bool is_negative,
+ int64* parsed_int_value,
+ double* parsed_double_value) const {
+ if (token_text.ToUTF8String().size() >= max_number_of_digits_) {
+ return false;
+ }
+ const bool is_double = unilib_->ParseDouble(token_text, parsed_double_value);
+ if (!is_double) {
+ return false;
+ }
+ *parsed_int_value = std::trunc(*parsed_double_value);
+ if (is_negative) {
+ *parsed_int_value *= -1;
+ *parsed_double_value *= -1;
+ }
+
+ return true;
+}
+
+bool NumberAnnotator::FindAll(const UnicodeText& context,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const {
+ if (!options_->enabled()) {
+ return true;
+ }
+
+ const std::vector<Token> tokens = tokenizer_.Tokenize(context);
+ for (int i = 0; i < tokens.size(); ++i) {
+ const Token token = tokens[i];
+ if (tokens[i].value.empty() ||
+ !unilib_->IsDigit(
+ *UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false).begin())) {
+ continue;
+ }
+
+ const UnicodeText token_text =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ int64 parsed_int_value;
+ double parsed_double_value;
+ bool is_negative =
+ (i > 0) &&
+ unilib_->IsMinus(
+ *UTF8ToUnicodeText(tokens[i - 1].value, /*do_copy=*/false).begin());
+ if (!TryParseNumber(token_text, is_negative, &parsed_int_value,
+ &parsed_double_value)) {
+ continue;
+ }
+ if (!TokensAreValidNumberPrefix(tokens, is_negative ? i - 2 : i - 1) ||
+ !TokensAreValidNumberSuffix(tokens, i + 1)) {
+ continue;
+ }
+
+ const bool has_decimal = !(parsed_int_value == parsed_double_value);
+ const int new_start_codepoint = is_negative ? token.start - 1 : token.start;
+
+ if (((1 << annotation_usecase) & options_->enabled_annotation_usecases()) !=
+ 0) {
+ result->push_back(CreateAnnotatedSpan(
+ new_start_codepoint, token.end, parsed_int_value, parsed_double_value,
+ Collections::Number(), options_->score(),
+ /*priority_score=*/
+ has_decimal ? options_->float_number_priority_score()
+ : options_->priority_score()));
+ }
+
+ const int percent_end_codepoint =
+ FindPercentSuffixEndCodepoint(tokens, i + 1);
+ if (percent_end_codepoint != -1 &&
+ ((1 << annotation_usecase) &
+ options_->percentage_annotation_usecases()) != 0) {
+ result->push_back(CreateAnnotatedSpan(
+ new_start_codepoint, percent_end_codepoint, parsed_int_value,
+ parsed_double_value, Collections::Percentage(), options_->score(),
+ options_->percentage_priority_score()));
+ }
+ }
+
+ return true;
+}
+
+AnnotatedSpan NumberAnnotator::CreateAnnotatedSpan(
+ const int start, const int end, const int int_value,
+ const double double_value, const std::string collection, const float score,
+ const float priority_score) const {
+ ClassificationResult classification{collection, score};
+ classification.numeric_value = int_value;
+ classification.numeric_double_value = double_value;
+ classification.priority_score = priority_score;
+
+ AnnotatedSpan annotated_span;
+ annotated_span.span = {start, end};
+ annotated_span.classification.push_back(classification);
+ return annotated_span;
+}
+
+std::unordered_set<std::string>
+NumberAnnotator::FromFlatbufferStringToUnordredSet(
+ const flatbuffers::String* flatbuffer_percent_strings) {
+ std::unordered_set<std::string> strings_set;
+ if (flatbuffer_percent_strings == nullptr) {
+ return strings_set;
+ }
+
+ const std::string percent_strings = flatbuffer_percent_strings->str();
+ for (StringPiece suffix : strings::Split(percent_strings, '\0')) {
+ std::string percent_suffix = suffix.ToString();
+ percent_suffix.erase(
+ std::remove_if(percent_suffix.begin(), percent_suffix.end(),
+ [](unsigned char x) { return std::isspace(x); }),
+ percent_suffix.end());
+ strings_set.insert(percent_suffix);
+ }
+
+ return strings_set;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
new file mode 100644
index 0000000..d83bea0
--- /dev/null
+++ b/native/annotator/number/number.h
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/container/sorted-strings-table.h"
+#include "utils/tokenizer.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// Annotator of numbers in text.
+//
+// Integer supported values are in range [-1 000 000 000, 1 000 000 000].
+// Doble supposted values are in range [-999999999.999999999,
+// 999999999.999999999].
+class NumberAnnotator {
+ public:
+ explicit NumberAnnotator(const NumberAnnotatorOptions* options,
+ const UniLib* unilib)
+ : options_(options),
+ unilib_(unilib),
+ tokenizer_(Tokenizer(TokenizationType_LETTER_DIGIT, unilib,
+ /*codepoint_ranges=*/{},
+ /*internal_tokenizer_codepoint_ranges=*/{},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/true)),
+ percent_suffixes_(FromFlatbufferStringToUnordredSet(
+ options_->percentage_pieces_string())),
+ max_number_of_digits_(options->max_number_of_digits()) {}
+
+ // Classifies given text, and if it is a number, it passes the result in
+ // 'classification_result' and returns true, otherwise returns false.
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ AnnotationUsecase annotation_usecase,
+ ClassificationResult* classification_result) const;
+
+ // Finds all number instances in the input text. Returns true in any case.
+ bool FindAll(const UnicodeText& context_unicode,
+ AnnotationUsecase annotation_usecase,
+ std::vector<AnnotatedSpan>* result) const;
+
+ private:
+ // Converts a Flatbuffer string containing zero-separated percent suffixes
+ // to an unordered set.
+ static std::unordered_set<std::string> FromFlatbufferStringToUnordredSet(
+ const flatbuffers::String* flatbuffer_percent_strings);
+
+ // Checks if the annotated numbers from the context represent percentages.
+ // If yes, replaces the collection type and the annotation boundary in the
+ // result.
+ void FindPercentages(const UnicodeText& context,
+ std::vector<AnnotatedSpan>* result) const;
+
+ // Checks if the tokens from in the interval [start_index-2, start_index] are
+ // valid characters that can preced a number context.
+ bool TokensAreValidStart(const std::vector<Token>& tokens,
+ int start_index) const;
+
+ // Checks if the tokens in the interval (..., prefix_end_index] are a valid
+ // number prefix.
+ bool TokensAreValidNumberPrefix(const std::vector<Token>& tokens,
+ int prefix_end_index) const;
+
+ // Checks if the tokens from in the interval [ending_index, ending_index+2]
+ // are valid characters that can follow a number context.
+ bool TokensAreValidEnding(const std::vector<Token>& tokens,
+ int ending_index) const;
+
+ // Checks if the tokens in the interval [suffix_start_index, ...) are a valid
+ // number suffix.
+ bool TokensAreValidNumberSuffix(const std::vector<Token>& tokens,
+ int suffix_start_index) const;
+
+ // Checks if the tokens in the interval [suffix_start_index, ...) are a valid
+ // percent suffix. If false, returns -1, else returns the end codepoint.
+ int FindPercentSuffixEndCodepoint(const std::vector<Token>& tokens,
+ int suffix_token_start_index) const;
+
+ // Checks if the given text represents a number (either int or double).
+ bool TryParseNumber(const UnicodeText& token_text, bool is_negative,
+ int64* parsed_int_value,
+ double* parsed_double_value) const;
+
+ // Checks if a word contains only CJT characters.
+ bool IsCJTterm(UnicodeText::const_iterator token_begin_it,
+ int token_length) const;
+
+ AnnotatedSpan CreateAnnotatedSpan(int start, int end, int int_value,
+ double double_value,
+ const std::string collection, float score,
+ float priority_score) const;
+
+ const NumberAnnotatorOptions* options_;
+ const UniLib* unilib_;
+ const Tokenizer tokenizer_;
+ const std::unordered_set<std::string> percent_suffixes_;
+ const int max_number_of_digits_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_NUMBER_NUMBER_H_
diff --git a/native/annotator/person_name/person-name-engine-dummy.h b/native/annotator/person_name/person-name-engine-dummy.h
new file mode 100644
index 0000000..9c83241
--- /dev/null
+++ b/native/annotator/person_name/person-name-engine-dummy.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/feature-processor.h"
+#include "annotator/person_name/person_name_model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the person name engine.
+class PersonNameEngine {
+ public:
+ explicit PersonNameEngine(const FeatureProcessor* feature_processor,
+ const UniLib* unilib) {}
+
+ bool Initialize(const PersonNameModel* model) {
+ TC3_LOG(ERROR) << "No person name engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_DUMMY_H_
diff --git a/native/annotator/person_name/person-name-engine.h b/native/annotator/person_name/person-name-engine.h
new file mode 100644
index 0000000..988fce3
--- /dev/null
+++ b/native/annotator/person_name/person-name-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_H_
+
+#include "annotator/person_name/person-name-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_PERSON_NAME_PERSON_NAME_ENGINE_H_
diff --git a/native/annotator/person_name/person_name_model.fbs b/native/annotator/person_name/person_name_model.fbs
new file mode 100755
index 0000000..b15543f
--- /dev/null
+++ b/native/annotator/person_name/person_name_model.fbs
@@ -0,0 +1,57 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+file_identifier "TC2 ";
+
+// Next ID: 2
+namespace libtextclassifier3.PersonNameModel_;
+table PersonName {
+ // Person name which is considered by the person name annotator. This
+ // attribute should contain 'atomic' person names, e.g., 'John' and 'Doe'
+ // should be two separate person names.
+ // required
+ person_name:string (shared);
+}
+
+// Next ID: 6
+namespace libtextclassifier3;
+table PersonNameModel {
+ // Decides if the person name annotator is enabled.
+ // required
+ enabled:bool;
+
+ // List of all person names which are considered by the person name annotator.
+ person_names:[PersonNameModel_.PersonName];
+
+ // Decides if the English genitive ending 's is stripped, e.g., if Peter's is
+ // stripped to Peter before looking for the name in the dictionary. required
+ strip_english_genitive_ending:bool;
+
+ // List of codepoints that are considered as 'end of person name' indicator in
+ // the heuristic to find the longest person name match.
+ // required
+ end_of_person_name_indicators:[int];
+
+ // Decides if only capitalized names should be annotated. In general, a
+ // capitalized name starts with an uppercase character and continues with
+ // lower case characters. In order to capture names such as O'Conell and
+ // McFee, this heursistic considers names as capitalized if they start with an
+ // upper case character and have at least one lower case character.
+ // required
+ annotate_capitalized_names_only:bool;
+}
+
+root_type libtextclassifier3.PersonNameModel;
diff --git a/annotator/quantization.cc b/native/annotator/quantization.cc
similarity index 100%
rename from annotator/quantization.cc
rename to native/annotator/quantization.cc
diff --git a/annotator/quantization.h b/native/annotator/quantization.h
similarity index 100%
rename from annotator/quantization.h
rename to native/annotator/quantization.h
diff --git a/annotator/quantization_test.cc b/native/annotator/quantization_test.cc
similarity index 100%
rename from annotator/quantization_test.cc
rename to native/annotator/quantization_test.cc
diff --git a/annotator/strip-unpaired-brackets.cc b/native/annotator/strip-unpaired-brackets.cc
similarity index 100%
rename from annotator/strip-unpaired-brackets.cc
rename to native/annotator/strip-unpaired-brackets.cc
diff --git a/annotator/strip-unpaired-brackets.h b/native/annotator/strip-unpaired-brackets.h
similarity index 100%
rename from annotator/strip-unpaired-brackets.h
rename to native/annotator/strip-unpaired-brackets.h
diff --git a/annotator/strip-unpaired-brackets_test.cc b/native/annotator/strip-unpaired-brackets_test.cc
similarity index 100%
rename from annotator/strip-unpaired-brackets_test.cc
rename to native/annotator/strip-unpaired-brackets_test.cc
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
new file mode 100644
index 0000000..640ceec
--- /dev/null
+++ b/native/annotator/translate/translate.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/translate/translate.h"
+
+#include <memory>
+
+#include "annotator/collections.h"
+#include "annotator/entity-data_generated.h"
+#include "annotator/types.h"
+#include "lang_id/lang-id-wrapper.h"
+#include "utils/base/logging.h"
+#include "utils/i18n/locale.h"
+#include "utils/utf8/unicodetext.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+bool TranslateAnnotator::ClassifyText(
+ const UnicodeText& context, CodepointSpan selection_indices,
+ const std::string& user_familiar_language_tags,
+ ClassificationResult* classification_result) const {
+ std::vector<TranslateAnnotator::LanguageConfidence> confidences;
+ if (options_->algorithm() ==
+ TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
+ if (options_->backoff_options() == nullptr) {
+ TC3_LOG(WARNING) << "No backoff options specified. Returning.";
+ return false;
+ }
+ confidences = BackoffDetectLanguages(context, selection_indices);
+ }
+
+ if (confidences.empty()) {
+ return false;
+ }
+
+ std::vector<Locale> user_familiar_languages;
+ if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
+ TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
+ return false;
+ }
+ if (user_familiar_languages.empty()) {
+ TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
+ "translate action.";
+ return false;
+ }
+ bool user_can_understand_language_of_text = false;
+ for (const Locale& locale : user_familiar_languages) {
+ if (locale.Language() == confidences[0].language) {
+ user_can_understand_language_of_text = true;
+ break;
+ }
+ }
+
+ if (!user_can_understand_language_of_text) {
+ classification_result->collection = Collections::Translate();
+ classification_result->score = options_->score();
+ classification_result->priority_score = options_->priority_score();
+ classification_result->serialized_entity_data =
+ CreateSerializedEntityData(confidences);
+ return true;
+ }
+
+ return false;
+}
+
+std::string TranslateAnnotator::CreateSerializedEntityData(
+ const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
+ const {
+ EntityDataT entity_data;
+ entity_data.translate.reset(new EntityData_::TranslateT());
+
+ for (const LanguageConfidence& confidence : confidences) {
+ EntityData_::Translate_::LanguagePredictionResultT*
+ language_prediction_result =
+ new EntityData_::Translate_::LanguagePredictionResultT();
+ language_prediction_result->language_tag = confidence.language;
+ language_prediction_result->confidence_score = confidence.confidence;
+ entity_data.translate->language_prediction_results.emplace_back(
+ language_prediction_result);
+ }
+ flatbuffers::FlatBufferBuilder builder;
+ FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+std::vector<TranslateAnnotator::LanguageConfidence>
+TranslateAnnotator::BackoffDetectLanguages(
+ const UnicodeText& context, CodepointSpan selection_indices) const {
+ const float penalize_ratio = options_->backoff_options()->penalize_ratio();
+ const int min_text_size = options_->backoff_options()->min_text_size();
+ if (selection_indices.second - selection_indices.first < min_text_size &&
+ penalize_ratio <= 0) {
+ return {};
+ }
+
+ const UnicodeText entity =
+ UnicodeText::Substring(context, selection_indices.first,
+ selection_indices.second, /*do_copy=*/false);
+ const std::vector<std::pair<std::string, float>> lang_id_result =
+ langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
+
+ const float more_text_score_ratio =
+ 1.0f - options_->backoff_options()->subject_text_score_ratio();
+ std::vector<std::pair<std::string, float>> more_lang_id_results;
+ if (more_text_score_ratio >= 0) {
+ const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
+ context, selection_indices, min_text_size);
+ more_lang_id_results =
+ langid::GetPredictions(langid_model_, entity_with_context.data(),
+ entity_with_context.size_bytes());
+ }
+
+ const float subject_text_score_ratio =
+ options_->backoff_options()->subject_text_score_ratio();
+
+ std::map<std::string, float> result_map;
+ for (const auto& [language, score] : lang_id_result) {
+ result_map[language] = subject_text_score_ratio * score;
+ }
+ for (const auto& [language, score] : more_lang_id_results) {
+ result_map[language] += more_text_score_ratio * score * penalize_ratio;
+ }
+
+ std::vector<TranslateAnnotator::LanguageConfidence> result;
+ result.reserve(result_map.size());
+ for (const auto& [key, value] : result_map) {
+ result.push_back({key, value});
+ }
+
+ std::sort(result.begin(), result.end(),
+ [](TranslateAnnotator::LanguageConfidence& a,
+ TranslateAnnotator::LanguageConfidence& b) {
+ return a.confidence > b.confidence;
+ });
+ return result;
+}
+
+UnicodeText::const_iterator
+TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
+ const UnicodeText& text, int start_index, int direction) const {
+ TC3_CHECK(direction == 1 || direction == -1);
+ auto it = text.begin();
+ std::advance(it, start_index);
+ while (it > text.begin() && it < text.end()) {
+ if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
+ break;
+ }
+ std::advance(it, direction);
+ }
+ return it;
+}
+
+UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
+ const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
+ const int text_size_codepoints = text.size_codepoints();
+ if (text_size_codepoints < minimum_length) {
+ return UnicodeText(text, /*do_copy=*/false);
+ }
+
+ const int start = indices.first;
+ const int end = indices.second;
+ const int length = end - start;
+ if (length >= minimum_length) {
+ return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
+ }
+
+ const int offset = (minimum_length - length) / 2;
+ const int iter_start = std::max(
+ 0, std::min(start - offset, text_size_codepoints - minimum_length));
+ const int iter_end =
+ std::min(text_size_codepoints, iter_start + minimum_length);
+
+ auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
+ const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
+
+ // The it_start now points to whitespace/punctuation (unless it reached the
+ // beginning of the string). So we'll move it one position forward to point to
+ // the actual text.
+ if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
+ std::advance(it_start, 1);
+ }
+
+ return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/translate/translate.h b/native/annotator/translate/translate.h
new file mode 100644
index 0000000..97e994d
--- /dev/null
+++ b/native/annotator/translate/translate.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TRANSLATE_TRANSLATE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TRANSLATE_TRANSLATE_H_
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+// Returns classification with "translate" when the input text is in a language
+// not understood by the user.
+class TranslateAnnotator {
+ public:
+ TranslateAnnotator(const TranslateAnnotatorOptions* options,
+ const libtextclassifier3::mobile::lang_id::LangId* langid_model,
+ const UniLib* unilib)
+ : options_(options), langid_model_(langid_model), unilib_(unilib) {}
+
+ // Returns true if a classification_result was filled with "translate"
+ // classification.
+ bool ClassifyText(const UnicodeText& context, CodepointSpan selection_indices,
+ const std::string& user_familiar_language_tags,
+ ClassificationResult* classification_result) const;
+
+ protected:
+ struct LanguageConfidence {
+ std::string language;
+ float confidence = -1.0;
+ };
+
+ // Detects language of the selection in given context using the "Backoff
+ // algorithm", sorted by the score descendingly. It is based on several
+ // heuristics, see the code. This is the same algorithm that TextClassifier
+ // uses in Android Q.
+ std::vector<LanguageConfidence> BackoffDetectLanguages(
+ const UnicodeText& context, CodepointSpan selection_indices) const;
+
+ // Returns the iterator of the next whitespace/punctuation character in given
+ // text, starting from given position and going forward (iff direction == 1),
+ // and backward (iff direction == -1).
+ UnicodeText::const_iterator FindIndexOfNextWhitespaceOrPunctuation(
+ const UnicodeText& text, int start_index, int direction) const;
+
+ // Returns substring from given text, centered around the specified indices,
+ // of certain minimum length. The substring is token aligned, so it is
+ // guaranteed that the words won't be broken down.
+ UnicodeText TokenAlignedSubstringAroundSpan(const UnicodeText& text,
+ CodepointSpan indices,
+ int minimum_length) const;
+
+ private:
+ std::string CreateSerializedEntityData(
+ const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
+ const;
+
+ const TranslateAnnotatorOptions* options_;
+ const libtextclassifier3::mobile::lang_id::LangId* langid_model_;
+ const UniLib* unilib_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TRANSLATE_TRANSLATE_H_
diff --git a/native/annotator/types-test-util.h b/native/annotator/types-test-util.h
new file mode 100644
index 0000000..1d018a1
--- /dev/null
+++ b/native/annotator/types-test-util.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
+
+#include <ostream>
+
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+#define TC3_DECLARE_PRINT_OPERATOR(TYPE_NAME) \
+ inline std::ostream& operator<<(std::ostream& stream, \
+ const TYPE_NAME& value) { \
+ logging::LoggingStringStream tmp_stream; \
+ tmp_stream << value; \
+ return stream << tmp_stream.message; \
+ }
+
+TC3_DECLARE_PRINT_OPERATOR(AnnotatedSpan)
+TC3_DECLARE_PRINT_OPERATOR(ClassificationResult)
+TC3_DECLARE_PRINT_OPERATOR(DatetimeParsedData)
+TC3_DECLARE_PRINT_OPERATOR(DatetimeParseResultSpan)
+TC3_DECLARE_PRINT_OPERATOR(Token)
+
+#undef TC3_DECLARE_PRINT_OPERATOR
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_TEST_UTIL_H_
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
new file mode 100644
index 0000000..be542d3
--- /dev/null
+++ b/native/annotator/types.cc
@@ -0,0 +1,419 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/types.h"
+
+#include <vector>
+
+#include "utils/optional.h"
+
+namespace libtextclassifier3 {
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token) {
+ if (!token.is_padding) {
+ return stream << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ")";
+ } else {
+ return stream << "Token()";
+ }
+}
+
+bool DatetimeComponent::ShouldRoundToGranularity() const {
+ // Don't round to the granularity for relative expressions that specify the
+ // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
+ // 10:35:03.
+ if (relative_qualifier == RelativeQualifier::UNSPECIFIED) {
+ return false;
+ }
+ if (relative_qualifier == RelativeQualifier::NEXT ||
+ relative_qualifier == RelativeQualifier::TOMORROW ||
+ relative_qualifier == RelativeQualifier::YESTERDAY ||
+ relative_qualifier == RelativeQualifier::LAST ||
+ relative_qualifier == RelativeQualifier::THIS ||
+ relative_qualifier == RelativeQualifier::NOW) {
+ return true;
+ }
+ return false;
+}
+
+namespace {
+std::string FormatMillis(int64 time_ms_utc) {
+ long time_seconds = time_ms_utc / 1000; // NOLINT
+ char buffer[512];
+ strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
+ localtime(&time_seconds));
+ return std::string(buffer);
+}
+} // namespace
+
+std::string ComponentTypeToString(
+ const DatetimeComponent::ComponentType& component_type) {
+ switch (component_type) {
+ case DatetimeComponent::ComponentType::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::ComponentType::YEAR:
+ return "YEAR";
+ case DatetimeComponent::ComponentType::MONTH:
+ return "MONTH";
+ case DatetimeComponent::ComponentType::WEEK:
+ return "WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ return "DAY_OF_WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ return "DAY_OF_MONTH";
+ case DatetimeComponent::ComponentType::HOUR:
+ return "HOUR";
+ case DatetimeComponent::ComponentType::MINUTE:
+ return "MINUTE";
+ case DatetimeComponent::ComponentType::SECOND:
+ return "SECOND";
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ return "MERIDIEM";
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ return "ZONE_OFFSET";
+ case DatetimeComponent::ComponentType::DST_OFFSET:
+ return "DST_OFFSET";
+ default:
+ return "";
+ }
+}
+
+std::string RelativeQualifierToString(
+ const DatetimeComponent::RelativeQualifier& relative_qualifier) {
+ switch (relative_qualifier) {
+ case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::RelativeQualifier::NEXT:
+ return "NEXT";
+ case DatetimeComponent::RelativeQualifier::THIS:
+ return "THIS";
+ case DatetimeComponent::RelativeQualifier::LAST:
+ return "LAST";
+ case DatetimeComponent::RelativeQualifier::NOW:
+ return "NOW";
+ case DatetimeComponent::RelativeQualifier::TOMORROW:
+ return "TOMORROW";
+ case DatetimeComponent::RelativeQualifier::YESTERDAY:
+ return "YESTERDAY";
+ case DatetimeComponent::RelativeQualifier::PAST:
+ return "PAST";
+ case DatetimeComponent::RelativeQualifier::FUTURE:
+ return "FUTURE";
+ default:
+ return "";
+ }
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value) {
+ stream << "DatetimeParseResultSpan({" << value.span.first << ", "
+ << value.span.second << "}, "
+ << "/*target_classification_score=*/ "
+ << value.target_classification_score << "/*priority_score=*/"
+ << value.priority_score << " {";
+ for (const DatetimeParseResult& data : value.data) {
+ stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
+ << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
+ << data.granularity << ", /*datetime_components=*/ ";
+ for (const DatetimeComponent& datetime_comp : data.datetime_components) {
+ stream << "{/*component_type=*/ "
+ << ComponentTypeToString(datetime_comp.component_type)
+ << " /*relative_qualifier=*/ "
+ << RelativeQualifierToString(datetime_comp.relative_qualifier)
+ << " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ "
+ << datetime_comp.relative_count << "}, ";
+ }
+ stream << "}, ";
+ }
+ stream << "})";
+ return stream;
+}
+
+bool ClassificationResult::operator==(const ClassificationResult& other) const {
+ return ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
+ *this, other) &&
+ fabs(score - other.score) < 0.001 &&
+ fabs(priority_score - other.priority_score) < 0.001 &&
+ serialized_entity_data == other.serialized_entity_data;
+}
+
+bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
+ const ClassificationResult& a, const ClassificationResult& b) {
+ return a.collection == b.collection &&
+ a.datetime_parse_result == b.datetime_parse_result &&
+ a.serialized_knowledge_result == b.serialized_knowledge_result &&
+ a.contact_pointer == b.contact_pointer &&
+ a.contact_name == b.contact_name &&
+ a.contact_given_name == b.contact_given_name &&
+ a.contact_family_name == b.contact_family_name &&
+ a.contact_nickname == b.contact_nickname &&
+ a.contact_email_address == b.contact_email_address &&
+ a.contact_phone_number == b.contact_phone_number &&
+ a.contact_id == b.contact_id &&
+ a.app_package_name == b.app_package_name &&
+ a.numeric_value == b.numeric_value &&
+ fabs(a.numeric_double_value - b.numeric_double_value) < 0.001 &&
+ a.duration_ms == b.duration_ms;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result) {
+ return stream << "ClassificationResult(" << result.collection
+ << ", /*score=*/ " << result.score << ", /*priority_score=*/ "
+ << result.priority_score << ")";
+}
+
+logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results) {
+ stream = stream << "{\n";
+ for (const ClassificationResult& result : results) {
+ stream = stream << " " << result << "\n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].collection;
+ best_score = span.classification[0].score;
+ }
+ return stream << "Span(" << span.span.first << ", " << span.span.second
+ << ", " << best_class << ", " << best_score << ")";
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParsedData& data) {
+ std::vector<DatetimeComponent> date_time_components;
+ data.GetDatetimeComponents(&date_time_components);
+ stream = stream << "DatetimeParsedData { \n";
+ for (const DatetimeComponent& c : date_time_components) {
+ stream = stream << " DatetimeComponent { \n";
+ stream = stream << " Component Type:" << static_cast<int>(c.component_type)
+ << "\n";
+ stream = stream << " Value:" << c.value << "\n";
+ stream = stream << " Relative Qualifier:"
+ << static_cast<int>(c.relative_qualifier) << "\n";
+ stream = stream << " Relative Count:" << c.relative_count << "\n";
+ stream = stream << " } \n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
+void DatetimeParsedData::SetAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type, int value) {
+ GetOrCreateDatetimeComponent(field_type).value = value;
+}
+
+void DatetimeParsedData::SetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ const DatetimeComponent::RelativeQualifier& relative_value) {
+ GetOrCreateDatetimeComponent(field_type).relative_qualifier = relative_value;
+}
+
+void DatetimeParsedData::SetRelativeCount(
+ const DatetimeComponent::ComponentType& field_type, int relative_count) {
+ GetOrCreateDatetimeComponent(field_type).relative_count = relative_count;
+}
+
+void DatetimeParsedData::AddDatetimeComponents(
+ const std::vector<DatetimeComponent>& datetime_components) {
+ for (const DatetimeComponent& datetime_component : datetime_components) {
+ date_time_components_.insert(
+ {datetime_component.component_type, datetime_component});
+ }
+}
+
+bool DatetimeParsedData::HasFieldType(
+ const DatetimeComponent::ComponentType& field_type) const {
+ if (date_time_components_.find(field_type) == date_time_components_.end()) {
+ return false;
+ }
+ return true;
+}
+
+bool DatetimeParsedData::GetFieldValue(
+ const DatetimeComponent::ComponentType& field_type,
+ int* field_value) const {
+ if (HasFieldType(field_type)) {
+ *field_value = date_time_components_.at(field_type).value;
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::GetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ DatetimeComponent::RelativeQualifier* relative_value) const {
+ if (HasFieldType(field_type)) {
+ *relative_value = date_time_components_.at(field_type).relative_qualifier;
+ return true;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::HasRelativeValue(
+ const DatetimeComponent::ComponentType& field_type) const {
+ if (HasFieldType(field_type)) {
+ return date_time_components_.at(field_type).relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED;
+ }
+ return false;
+}
+
+bool DatetimeParsedData::HasAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type) const {
+ return HasFieldType(field_type) && !HasRelativeValue(field_type);
+}
+
+bool DatetimeParsedData::IsEmpty() const {
+ return date_time_components_.empty();
+}
+
+void DatetimeParsedData::GetRelativeDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const {
+ for (auto it = date_time_components_.begin();
+ it != date_time_components_.end(); it++) {
+ if (it->second.relative_qualifier !=
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
+ date_time_components->push_back(it->second);
+ }
+ }
+}
+
+void DatetimeParsedData::GetDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const {
+ for (auto it = date_time_components_.begin();
+ it != date_time_components_.end(); it++) {
+ date_time_components->push_back(it->second);
+ }
+}
+
+DatetimeComponent& DatetimeParsedData::GetOrCreateDatetimeComponent(
+ const DatetimeComponent::ComponentType& component_type) {
+ auto result =
+ date_time_components_
+ .insert(
+ {component_type,
+ DatetimeComponent(
+ component_type,
+ DatetimeComponent::RelativeQualifier::UNSPECIFIED, 0, 0)})
+ .first;
+ return result->second;
+}
+
+namespace {
+DatetimeGranularity GetFinestGranularityFromComponentTypes(
+ const std::vector<DatetimeComponent::ComponentType>&
+ datetime_component_types) {
+ DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_UNKNOWN;
+ for (const auto& component_type : datetime_component_types) {
+ switch (component_type) {
+ case DatetimeComponent::ComponentType::YEAR:
+ if (granularity < DatetimeGranularity::GRANULARITY_YEAR) {
+ granularity = DatetimeGranularity::GRANULARITY_YEAR;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MONTH:
+ if (granularity < DatetimeGranularity::GRANULARITY_MONTH) {
+ granularity = DatetimeGranularity::GRANULARITY_MONTH;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::WEEK:
+ if (granularity < DatetimeGranularity::GRANULARITY_WEEK) {
+ granularity = DatetimeGranularity::GRANULARITY_WEEK;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ if (granularity < DatetimeGranularity::GRANULARITY_DAY) {
+ granularity = DatetimeGranularity::GRANULARITY_DAY;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::HOUR:
+ if (granularity < DatetimeGranularity::GRANULARITY_HOUR) {
+ granularity = DatetimeGranularity::GRANULARITY_HOUR;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MINUTE:
+ if (granularity < DatetimeGranularity::GRANULARITY_MINUTE) {
+ granularity = DatetimeGranularity::GRANULARITY_MINUTE;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::SECOND:
+ if (granularity < DatetimeGranularity::GRANULARITY_SECOND) {
+ granularity = DatetimeGranularity::GRANULARITY_SECOND;
+ }
+ break;
+
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ case DatetimeComponent::ComponentType::DST_OFFSET:
+ default:
+ break;
+ }
+ }
+ return granularity;
+}
+} // namespace
+
+DatetimeGranularity DatetimeParsedData::GetFinestGranularity() const {
+ std::vector<DatetimeComponent::ComponentType> component_types;
+ std::transform(date_time_components_.begin(), date_time_components_.end(),
+ std::back_inserter(component_types),
+ [](const std::map<DatetimeComponent::ComponentType,
+ DatetimeComponent>::value_type& pair) {
+ return pair.first;
+ });
+ return GetFinestGranularityFromComponentTypes(component_types);
+}
+
+Optional<DatetimeComponent> GetDatetimeComponent(
+ const std::vector<DatetimeComponent>& datetime_components,
+ const DatetimeComponent::ComponentType& component_type) {
+ for (auto datetime_component : datetime_components) {
+ if (datetime_component.component_type == component_type) {
+ return Optional<DatetimeComponent>(datetime_component);
+ }
+ }
+ return Optional<DatetimeComponent>();
+}
+
+// Returns the granularity of the DatetimeComponents.
+DatetimeGranularity GetFinestGranularity(
+ const std::vector<DatetimeComponent>& datetime_component) {
+ std::vector<DatetimeComponent::ComponentType> component_types;
+ std::transform(datetime_component.begin(), datetime_component.end(),
+ std::back_inserter(component_types),
+ [](const DatetimeComponent& component) {
+ return component.component_type;
+ });
+ return GetFinestGranularityFromComponentTypes(component_types);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/types.h b/native/annotator/types.h
new file mode 100644
index 0000000..665d4b6
--- /dev/null
+++ b/native/annotator/types.h
@@ -0,0 +1,692 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+
+#include <time.h>
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <map>
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "annotator/entity-data_generated.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "utils/optional.h"
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+constexpr int kInvalidIndex = -1;
+constexpr int kSunday = 1;
+constexpr int kMonday = 2;
+constexpr int kTuesday = 3;
+constexpr int kWednesday = 4;
+constexpr int kThursday = 5;
+constexpr int kFriday = 6;
+constexpr int kSaturday = 7;
+
+// Index for a 0-based array of tokens.
+using TokenIndex = int;
+
+// Index for a 0-based array of codepoints.
+using CodepointIndex = int;
+
+// Marks a span in a sequence of codepoints. The first element is the index of
+// the first codepoint of the span, and the second element is the index of the
+// codepoint one past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+
+inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
+ return a.first < b.second && b.first < a.second;
+}
+
+inline bool ValidNonEmptySpan(const CodepointSpan& span) {
+ return span.first < span.second && span.first >= 0 && span.second >= 0;
+}
+
+template <typename T>
+bool DoesCandidateConflict(
+ const int considered_candidate, const std::vector<T>& candidates,
+ const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
+ if (chosen_indices_set.empty()) {
+ return false;
+ }
+
+ auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
+ // Check conflict on the right.
+ if (conflicting_it != chosen_indices_set.end() &&
+ SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return true;
+ }
+
+ // Check conflict on the left.
+ // If we can't go more left, there can't be a conflict:
+ if (conflicting_it == chosen_indices_set.begin()) {
+ return false;
+ }
+ // Otherwise move one span left and insert if it doesn't overlap with the
+ // candidate.
+ --conflicting_it;
+ if (!SpansOverlap(candidates[considered_candidate].span,
+ candidates[*conflicting_it].span)) {
+ return false;
+ }
+
+ return true;
+}
+
+// Marks a span in a sequence of tokens. The first element is the index of the
+// first token in the span, and the second element is the index of the token one
+// past the end of the span.
+// TODO(b/71982294): Make it a struct.
+using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+
+// Returns the size of the token span. Assumes that the span is valid.
+inline int TokenSpanSize(const TokenSpan& token_span) {
+ return token_span.second - token_span.first;
+}
+
+// Returns a token span consisting of one token.
+inline TokenSpan SingleTokenSpan(int token_index) {
+ return {token_index, token_index + 1};
+}
+
+// Returns an intersection of two token spans. Assumes that both spans are valid
+// and overlapping.
+inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
+ const TokenSpan& token_span2) {
+ return {std::max(token_span1.first, token_span2.first),
+ std::min(token_span1.second, token_span2.second)};
+}
+
+// Returns and expanded token span by adding a certain number of tokens on its
+// left and on its right.
+inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
+ int num_tokens_left, int num_tokens_right) {
+ return {token_span.first - num_tokens_left,
+ token_span.second + num_tokens_right};
+}
+
+// Token holds a token, its position in the original string and whether it was
+// part of the input span.
+struct Token {
+ std::string value;
+ CodepointIndex start;
+ CodepointIndex end;
+
+ // Whether the token is a padding token.
+ bool is_padding;
+
+ // Whether the token contains only white characters.
+ bool is_whitespace;
+
+ // Default constructor constructs the padding-token.
+ Token()
+ : Token(/*arg_value=*/"", /*arg_start=*/kInvalidIndex,
+ /*arg_end=*/kInvalidIndex, /*is_padding=*/true,
+ /*is_whitespace=*/false) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end)
+ : Token(/*arg_value=*/arg_value, /*arg_start=*/arg_start,
+ /*arg_end=*/arg_end, /*is_padding=*/false,
+ /*is_whitespace=*/false) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end, bool is_padding, bool is_whitespace)
+ : value(arg_value),
+ start(arg_start),
+ end(arg_end),
+ is_padding(is_padding),
+ is_whitespace(is_whitespace) {}
+
+ bool operator==(const Token& other) const {
+ return value == other.value && start == other.start && end == other.end &&
+ is_padding == other.is_padding;
+ }
+
+ bool IsContainedInSpan(CodepointSpan span) const {
+ return start >= span.first && end <= span.second;
+ }
+};
+
+// Pretty-printing function for Token.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token);
+
+enum DatetimeGranularity {
+ GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
+ // structure being uninitialized.
+ GRANULARITY_YEAR = 0,
+ GRANULARITY_MONTH = 1,
+ GRANULARITY_WEEK = 2,
+ GRANULARITY_DAY = 3,
+ GRANULARITY_HOUR = 4,
+ GRANULARITY_MINUTE = 5,
+ GRANULARITY_SECOND = 6
+};
+
+// This struct represents a unit of date and time expression.
+// Examples include:
+// - In {March 21, 2019} datetime components are month: {March},
+// day of month: {21} and year: {2019}.
+// - {8:00 am} contains hour: {8}, minutes: {0} and am/pm: {am}
+struct DatetimeComponent {
+ enum class ComponentType {
+ UNSPECIFIED = 0,
+ // Year of the date seen in the text match.
+ YEAR = 1,
+ // Month of the year starting with January = 1.
+ MONTH = 2,
+ // Week (7 days).
+ WEEK = 3,
+ // Day of week, start of the week is Sunday & its value is 1.
+ DAY_OF_WEEK = 4,
+ // Day of the month starting with 1.
+ DAY_OF_MONTH = 5,
+ // Hour of the day with a range of 0-23,
+ // values less than 12 need the AMPM field below or heuristics
+ // to definitively determine the time.
+ HOUR = 6,
+ // Minute of the hour with a range of 0-59.
+ MINUTE = 7,
+ // Seconds of the minute with a range of 0-59.
+ SECOND = 8,
+ // Meridiem field where 0 == AM, 1 == PM.
+ MERIDIEM = 9,
+ // Number of hours offset from UTC this date time is in.
+ ZONE_OFFSET = 10,
+ // Number of hours offest for DST.
+ DST_OFFSET = 11,
+ };
+
+ // TODO(hassan): Remove RelativeQualifier as in the presence of relative
+ // count RelativeQualifier is redundant.
+ // Enum to represent the relative DateTimeComponent e.g. "next Monday",
+ // "the following day", "tomorrow".
+ enum class RelativeQualifier {
+ UNSPECIFIED = 0,
+ NEXT = 1,
+ THIS = 2,
+ LAST = 3,
+ NOW = 4,
+ TOMORROW = 5,
+ YESTERDAY = 6,
+ PAST = 7,
+ FUTURE = 8
+ };
+
+ bool operator==(const DatetimeComponent& other) const {
+ return component_type == other.component_type &&
+ relative_qualifier == other.relative_qualifier &&
+ relative_count == other.relative_count && value == other.value;
+ }
+
+ bool ShouldRoundToGranularity() const;
+
+ ComponentType component_type = ComponentType::UNSPECIFIED;
+ RelativeQualifier relative_qualifier = RelativeQualifier::UNSPECIFIED;
+
+ // Represents the absolute value of DateTime components.
+ int value = 0;
+ // The number of units of change present in the relative DateTimeComponent.
+ int relative_count = 0;
+
+ DatetimeComponent() = default;
+
+ explicit DatetimeComponent(ComponentType arg_component_type,
+ RelativeQualifier arg_relative_qualifier,
+ int arg_value, int arg_relative_count)
+ : component_type(arg_component_type),
+ relative_qualifier(arg_relative_qualifier),
+ value(arg_value),
+ relative_count(arg_relative_count) {}
+};
+
+// Utility method to calculate Returns the finest granularity of
+// DatetimeComponents.
+DatetimeGranularity GetFinestGranularity(
+ const std::vector<DatetimeComponent>& datetime_component);
+
+// Return the 'DatetimeComponent' from collection filter by component type.
+Optional<DatetimeComponent> GetDatetimeComponent(
+ const std::vector<DatetimeComponent>& datetime_components,
+ const DatetimeComponent::ComponentType& component_type);
+
+struct DatetimeParseResult {
+ // The absolute time in milliseconds since the epoch in UTC.
+ int64 time_ms_utc;
+
+ // The precision of the estimate then in to calculating the milliseconds
+ DatetimeGranularity granularity;
+
+ // List of parsed DateTimeComponent.
+ std::vector<DatetimeComponent> datetime_components;
+
+ DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
+
+ DatetimeParseResult(int64 arg_time_ms_utc,
+ DatetimeGranularity arg_granularity,
+ std::vector<DatetimeComponent> arg_datetime__components)
+ : time_ms_utc(arg_time_ms_utc),
+ granularity(arg_granularity),
+ datetime_components(arg_datetime__components) {}
+
+ bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
+
+ bool operator==(const DatetimeParseResult& other) const {
+ return granularity == other.granularity &&
+ time_ms_utc == other.time_ms_utc &&
+ datetime_components == other.datetime_components;
+ }
+};
+
+const float kFloatCompareEpsilon = 1e-5;
+
+struct DatetimeParseResultSpan {
+ CodepointSpan span;
+ std::vector<DatetimeParseResult> data;
+ float target_classification_score;
+ float priority_score;
+
+ DatetimeParseResultSpan()
+ : target_classification_score(-1.0), priority_score(-1.0) {}
+
+ DatetimeParseResultSpan(const CodepointSpan& span,
+ const std::vector<DatetimeParseResult>& data,
+ const float target_classification_score,
+ const float priority_score) {
+ this->span = span;
+ this->data = data;
+ this->target_classification_score = target_classification_score;
+ this->priority_score = priority_score;
+ }
+
+ bool operator==(const DatetimeParseResultSpan& other) const {
+ return span == other.span && data == other.data &&
+ std::abs(target_classification_score -
+ other.target_classification_score) < kFloatCompareEpsilon &&
+ std::abs(priority_score - other.priority_score) <
+ kFloatCompareEpsilon;
+ }
+};
+
+// Pretty-printing function for DatetimeParseResultSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value);
+
+// This struct contains information intended to uniquely identify a device
+// contact. Instances are created by the Knowledge Engine, and dereferenced by
+// the Contact Engine.
+struct ContactPointer {
+ std::string focus_contact_id;
+ std::string device_id;
+ std::string device_contact_id;
+ std::string contact_name;
+ std::string contact_name_hash;
+
+ bool operator==(const ContactPointer& other) const {
+ return focus_contact_id == other.focus_contact_id &&
+ device_id == other.device_id &&
+ device_contact_id == other.device_contact_id &&
+ contact_name == other.contact_name &&
+ contact_name_hash == other.contact_name_hash;
+ }
+};
+
+struct ClassificationResult {
+ std::string collection;
+ float score;
+ DatetimeParseResult datetime_parse_result;
+ std::string serialized_knowledge_result;
+ ContactPointer contact_pointer;
+ std::string contact_name, contact_given_name, contact_family_name,
+ contact_nickname, contact_email_address, contact_phone_number, contact_id;
+ std::string app_name, app_package_name;
+ int64 numeric_value;
+ double numeric_double_value;
+
+ // Length of the parsed duration in milliseconds.
+ int64 duration_ms;
+
+ // Internal score used for conflict resolution.
+ float priority_score;
+
+
+ // Entity data information.
+ std::string serialized_entity_data;
+ const EntityData* entity_data() const {
+ return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
+ serialized_entity_data.size());
+ }
+
+ explicit ClassificationResult()
+ : score(-1.0f),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
+ priority_score(-1.0) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score)
+ : collection(arg_collection),
+ score(arg_score),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
+ priority_score(arg_score) {}
+
+ ClassificationResult(const std::string& arg_collection, float arg_score,
+ float arg_priority_score)
+ : collection(arg_collection),
+ score(arg_score),
+ numeric_value(0),
+ numeric_double_value(0.),
+ duration_ms(0),
+ priority_score(arg_priority_score) {}
+
+ bool operator!=(const ClassificationResult& other) const {
+ return !(*this == other);
+ }
+
+ bool operator==(const ClassificationResult& other) const;
+};
+
+// Aliases for long enum values.
+const AnnotationUsecase ANNOTATION_USECASE_SMART =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART;
+const AnnotationUsecase ANNOTATION_USECASE_RAW =
+ AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
+struct LocationContext {
+ // User location latitude in degrees.
+ double user_location_lat = 180.;
+
+ // User location longitude in degrees.
+ double user_location_lng = 360.;
+
+ // The estimated horizontal accuracy of the user location in meters.
+ // Analogous to android.location.Location accuracy.
+ float user_location_accuracy_meters = 0.f;
+
+ bool operator==(const LocationContext& other) const {
+ return std::fabs(this->user_location_lat - other.user_location_lat) <
+ 1e-8 &&
+ std::fabs(this->user_location_lng - other.user_location_lng) <
+ 1e-8 &&
+ std::fabs(this->user_location_accuracy_meters -
+ other.user_location_accuracy_meters) < 1e-8;
+ }
+};
+
+struct BaseOptions {
+ // Comma-separated list of locale specification for the input text (BCP 47
+ // tags).
+ std::string locales;
+
+ // Comma-separated list of BCP 47 language tags.
+ std::string detected_text_language_tags;
+
+ // Tailors the output annotations according to the specified use-case.
+ AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
+
+ // The location context passed along with each annotation.
+ Optional<LocationContext> location_context;
+
+ bool operator==(const BaseOptions& other) const {
+ bool location_context_equality = this->location_context.has_value() ==
+ other.location_context.has_value();
+ if (this->location_context.has_value() &&
+ other.location_context.has_value()) {
+ location_context_equality =
+ this->location_context.value() == other.location_context.value();
+ }
+ return this->locales == other.locales &&
+ this->annotation_usecase == other.annotation_usecase &&
+ this->detected_text_language_tags ==
+ other.detected_text_language_tags &&
+ location_context_equality;
+ }
+};
+
+struct DatetimeOptions {
+ // For parsing relative datetimes, the reference now time against which the
+ // relative datetimes get resolved.
+ // UTC milliseconds since epoch.
+ int64 reference_time_ms_utc = 0;
+
+ // Timezone in which the input text was written (format as accepted by ICU).
+ std::string reference_timezone;
+
+ bool operator==(const DatetimeOptions& other) const {
+ return this->reference_time_ms_utc == other.reference_time_ms_utc &&
+ this->reference_timezone == other.reference_timezone;
+ }
+};
+
+struct SelectionOptions : public BaseOptions {};
+
+struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
+ // Comma-separated list of language tags which the user can read and
+ // understand (BCP 47).
+ std::string user_familiar_language_tags;
+
+ bool operator==(const ClassificationOptions& other) const {
+ return this->user_familiar_language_tags ==
+ other.user_familiar_language_tags &&
+ BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
+ }
+};
+
+struct Permissions {
+ // If true the user location can be used to provide better annotations.
+ bool has_location_permission = true;
+ // If true, annotators can use personal data to provide personalized
+ // annotations.
+ bool has_personalization_permission = true;
+
+ bool operator==(const Permissions& other) const {
+ return this->has_location_permission == other.has_location_permission &&
+ this->has_personalization_permission ==
+ other.has_personalization_permission;
+ }
+};
+
+struct AnnotationOptions : public BaseOptions, public DatetimeOptions {
+ // List of entity types that should be used for annotation.
+ std::unordered_set<std::string> entity_types;
+
+ // If true, serialized_entity_data in the results is populated."
+ bool is_serialized_entity_data_enabled = false;
+
+ // Defines the permissions for the annotators.
+ Permissions permissions;
+
+ bool operator==(const AnnotationOptions& other) const {
+ return this->is_serialized_entity_data_enabled ==
+ other.is_serialized_entity_data_enabled &&
+ this->permissions == other.permissions &&
+ this->entity_types == other.entity_types &&
+ BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
+ }
+};
+
+// Returns true when ClassificationResults are euqal up to scores.
+bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
+ const ClassificationResult& a, const ClassificationResult& b);
+
+// Pretty-printing function for ClassificationResult.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result);
+
+// Pretty-printing function for std::vector<ClassificationResult>.
+logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results);
+
+// Represents a result of Annotate call.
+struct AnnotatedSpan {
+ enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
+
+ // Unicode codepoint indices in the input string.
+ CodepointSpan span = {kInvalidIndex, kInvalidIndex};
+
+ // Classification result for the span.
+ std::vector<ClassificationResult> classification;
+
+ // The source of the annotation, used in conflict resolution.
+ Source source = Source::OTHER;
+
+ AnnotatedSpan() = default;
+
+ AnnotatedSpan(CodepointSpan arg_span,
+ std::vector<ClassificationResult> arg_classification)
+ : span(arg_span), classification(std::move(arg_classification)) {}
+
+ AnnotatedSpan(CodepointSpan arg_span,
+ std::vector<ClassificationResult> arg_classification,
+ Source arg_source)
+ : span(arg_span),
+ classification(std::move(arg_classification)),
+ source(arg_source) {}
+};
+
+struct InputFragment {
+ std::string text;
+
+ // If present will override the AnnotationOptions reference time and timezone
+ // when annotating this specific string fragment.
+ Optional<DatetimeOptions> datetime_options;
+};
+
+// Pretty-printing function for AnnotatedSpan.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span);
+
+// StringPiece analogue for std::vector<T>.
+template <class T>
+class VectorSpan {
+ public:
+ VectorSpan() : begin_(), end_() {}
+ VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ : begin_(v.begin()), end_(v.end()) {}
+ VectorSpan(typename std::vector<T>::const_iterator begin,
+ typename std::vector<T>::const_iterator end)
+ : begin_(begin), end_(end) {}
+
+ const T& operator[](typename std::vector<T>::size_type i) const {
+ return *(begin_ + i);
+ }
+
+ int size() const { return end_ - begin_; }
+ typename std::vector<T>::const_iterator begin() const { return begin_; }
+ typename std::vector<T>::const_iterator end() const { return end_; }
+ const float* data() const { return &(*begin_); }
+
+ private:
+ typename std::vector<T>::const_iterator begin_;
+ typename std::vector<T>::const_iterator end_;
+};
+
+// Class to provide representation of date and time expressions
+class DatetimeParsedData {
+ public:
+ // Function to set the absolute value of DateTimeComponent for the given
+ // FieldType, if the field is not present it will create the field and set
+ // the value.
+ void SetAbsoluteValue(const DatetimeComponent::ComponentType& field_type,
+ int value);
+
+ // Function to set the relative value of DateTimeComponent, if the field is
+ // not present the function will create the field and set the relative value.
+ void SetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ const DatetimeComponent::RelativeQualifier& relative_value);
+
+ // Add collection of 'DatetimeComponent' to 'DatetimeParsedData'.
+ void AddDatetimeComponents(
+ const std::vector<DatetimeComponent>& datetime_components);
+
+ // Function to set the relative count of DateTimeComponent, if the field is
+ // not present the function will create the field and set the count.
+ void SetRelativeCount(const DatetimeComponent::ComponentType& field_type,
+ int relative_count);
+
+ // Function to populate the absolute value of the FieldType and return true.
+ // In case of no FieldType function will return false.
+ bool GetFieldValue(const DatetimeComponent::ComponentType& field_type,
+ int* field_value) const;
+
+ // Function to populate the relative value of the FieldType and return true.
+ // In case of no relative value function will return false.
+ bool GetRelativeValue(
+ const DatetimeComponent::ComponentType& field_type,
+ DatetimeComponent::RelativeQualifier* relative_value) const;
+
+ // Returns relative DateTimeComponent from the parsed DateTime span.
+ void GetRelativeDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const;
+
+ // Returns DateTimeComponent from the parsed DateTime span.
+ void GetDatetimeComponents(
+ std::vector<DatetimeComponent>* date_time_components) const;
+
+ // Represent the granularity of the Parsed DateTime span. The function will
+ // return “GRANULARITY_UNKNOWN” if no datetime field is set.
+ DatetimeGranularity GetFinestGranularity() const;
+
+ // Utility function to check if DateTimeParsedData has FieldType initialized.
+ bool HasFieldType(const DatetimeComponent::ComponentType& field_type) const;
+
+ // Function to check if DateTimeParsedData has relative DateTimeComponent for
+ // given FieldType.
+ bool HasRelativeValue(
+ const DatetimeComponent::ComponentType& field_type) const;
+
+ // Function to check if DateTimeParsedData has absolute value
+ // DateTimeComponent for given FieldType.
+ bool HasAbsoluteValue(
+ const DatetimeComponent::ComponentType& field_type) const;
+
+ // Function to check if DateTimeParsedData has any DateTimeComponent.
+ bool IsEmpty() const;
+
+ private:
+ DatetimeComponent& GetOrCreateDatetimeComponent(
+
+ const DatetimeComponent::ComponentType& component_type);
+
+ std::map<DatetimeComponent::ComponentType, DatetimeComponent>
+ date_time_components_;
+};
+
+// Pretty-printing function for DateTimeParsedData.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParsedData& data);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc
new file mode 100644
index 0000000..c3c2cf1
--- /dev/null
+++ b/native/annotator/zlib-utils.cc
@@ -0,0 +1,150 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/intents/zlib-utils.h"
+#include "utils/resources.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Compress rule fields in the model.
+bool CompressModel(ModelT* model) {
+ std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
+ if (!zlib_compressor) {
+ TC3_LOG(ERROR) << "Cannot compress model.";
+ return false;
+ }
+
+ // Compress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ pattern->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(pattern->pattern,
+ pattern->compressed_pattern.get());
+ pattern->pattern.clear();
+ }
+ }
+
+ // Compress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ regex->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(regex->pattern,
+ regex->compressed_pattern.get());
+ regex->pattern.clear();
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ extractor->compressed_pattern.reset(new CompressedBufferT);
+ zlib_compressor->Compress(extractor->pattern,
+ extractor->compressed_pattern.get());
+ extractor->pattern.clear();
+ }
+ }
+
+ // Compress resources.
+ if (model->resources != nullptr) {
+ CompressResources(model->resources.get());
+ }
+
+ // Compress intent generator.
+ if (model->intent_options != nullptr) {
+ CompressIntentModel(model->intent_options.get());
+ }
+
+ return true;
+}
+
+bool DecompressModel(ModelT* model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ // Decompress regex rules.
+ if (model->regex_model != nullptr) {
+ for (int i = 0; i < model->regex_model->patterns.size(); i++) {
+ RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
+ if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(),
+ &pattern->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ pattern->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ // Decompress date-time rules.
+ if (model->datetime_model != nullptr) {
+ for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
+ DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
+ for (int j = 0; j < pattern->regexes.size(); j++) {
+ DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
+ if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(),
+ ®ex->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
+ return false;
+ }
+ regex->compressed_pattern.reset(nullptr);
+ }
+ }
+ for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
+ DatetimeModelExtractorT* extractor =
+ model->datetime_model->extractors[i].get();
+ if (!zlib_decompressor->MaybeDecompress(
+ extractor->compressed_pattern.get(), &extractor->pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
+ return false;
+ }
+ extractor->compressed_pattern.reset(nullptr);
+ }
+ }
+
+ if (model->resources != nullptr) {
+ DecompressResources(model->resources.get());
+ }
+
+ if (model->intent_options != nullptr) {
+ DecompressIntentModel(model->intent_options.get());
+ }
+
+ return true;
+}
+
+std::string CompressSerializedModel(const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(CompressModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/zlib-utils.h b/native/annotator/zlib-utils.h
similarity index 100%
rename from annotator/zlib-utils.h
rename to native/annotator/zlib-utils.h
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
new file mode 100644
index 0000000..df33ea1
--- /dev/null
+++ b/native/annotator/zlib-utils_test.cc
@@ -0,0 +1,152 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/zlib-utils.h"
+
+#include <memory>
+
+#include "annotator/model_generated.h"
+#include "utils/zlib/zlib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+TEST(AnnotatorZlibUtilsTest, CompressModel) {
+ ModelT model;
+ model.regex_model.reset(new RegexModelT);
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a test pattern";
+ model.regex_model->patterns.emplace_back(new RegexModel_::PatternT);
+ model.regex_model->patterns.back()->pattern = "this is a second test pattern";
+
+ model.datetime_model.reset(new DatetimeModelT);
+ model.datetime_model->patterns.emplace_back(new DatetimeModelPatternT);
+ model.datetime_model->patterns.back()->regexes.emplace_back(
+ new DatetimeModelPattern_::RegexT);
+ model.datetime_model->patterns.back()->regexes.back()->pattern =
+ "an example datetime pattern";
+ model.datetime_model->extractors.emplace_back(new DatetimeModelExtractorT);
+ model.datetime_model->extractors.back()->pattern =
+ "an example datetime extractor";
+
+ model.intent_options.reset(new IntentFactoryModelT);
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator1 = "lua generator 1";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end());
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator2 = "lua generator 2";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
+
+ // NOTE: The resource strings contain some repetition, so that the compressed
+ // version is smaller than the uncompressed one. Because the compression code
+ // looks at that as well.
+ model.resources.reset(new ResourcePoolT);
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.2";
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.2";
+
+ // Compress the model.
+ EXPECT_TRUE(CompressModel(&model));
+
+ // Sanity check that uncompressed field is removed.
+ EXPECT_TRUE(model.regex_model->patterns[0]->pattern.empty());
+ EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
+ EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[0]->lua_template_generator.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[1]->lua_template_generator.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
+
+ // Pack and load the model.
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(Model::Pack(builder, &model));
+ const Model* compressed_model =
+ GetModel(reinterpret_cast<const char*>(builder.GetBufferPointer()));
+ ASSERT_TRUE(compressed_model != nullptr);
+
+ // Decompress the fields again and check that they match the original.
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
+ ASSERT_TRUE(decompressor != nullptr);
+ std::string uncompressed_pattern;
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->regex_model()->patterns()->Get(0)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a test pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(
+ compressed_model->regex_model()->patterns()->Get(1)->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "this is a second test pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
+ ->patterns()
+ ->Get(0)
+ ->regexes()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime pattern");
+ EXPECT_TRUE(decompressor->MaybeDecompress(compressed_model->datetime_model()
+ ->extractors()
+ ->Get(0)
+ ->compressed_pattern(),
+ &uncompressed_pattern));
+ EXPECT_EQ(uncompressed_pattern, "an example datetime extractor");
+
+ EXPECT_TRUE(DecompressModel(&model));
+ EXPECT_EQ(model.regex_model->patterns[0]->pattern, "this is a test pattern");
+ EXPECT_EQ(model.regex_model->patterns[1]->pattern,
+ "this is a second test pattern");
+ EXPECT_EQ(model.datetime_model->patterns[0]->regexes[0]->pattern,
+ "an example datetime pattern");
+ EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
+ "an example datetime extractor");
+ EXPECT_EQ(
+ model.intent_options->generator[0]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end()));
+ EXPECT_EQ(
+ model.intent_options->generator[1]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
+ "rrrrrrrrrrrrr1.1");
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
+ "rrrrrrrrrrrrr1.2");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
+ "rrrrrrrrrrrrr2.1");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
+ "rrrrrrrrrrrrr2.2");
+}
+
+} // namespace libtextclassifier3
diff --git a/jni.lds b/native/jni.lds
similarity index 100%
rename from jni.lds
rename to native/jni.lds
diff --git a/native/lang_id/common/embedding-feature-extractor.cc b/native/lang_id/common/embedding-feature-extractor.cc
new file mode 100644
index 0000000..a2e3cdf
--- /dev/null
+++ b/native/lang_id/common/embedding-feature-extractor.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-feature-extractor.h"
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+bool GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
+ // Don't use version to determine how to get feature FML.
+ const std::string features = context->Get(GetParamName("features"), "");
+ const std::string embedding_names =
+ context->Get(GetParamName("embedding_names"), "");
+ const std::string embedding_dims =
+ context->Get(GetParamName("embedding_dims"), "");
+
+ // NOTE: unfortunately, LiteStrSplit returns a vector of StringPieces pointing
+ // to the original string, in this case |features|, which is local to this
+ // method. We need to explicitly create new strings.
+ for (StringPiece sp : LiteStrSplit(features, ';')) {
+ embedding_fml_.emplace_back(sp);
+ }
+
+ // Same here.
+ for (StringPiece sp : LiteStrSplit(embedding_names, ';')) {
+ embedding_names_.emplace_back(sp);
+ }
+
+ std::vector<StringPiece> dim_strs = LiteStrSplit(embedding_dims, ';');
+ for (const auto &dim_str : dim_strs) {
+ int dim = 0;
+ if (!LiteAtoi(dim_str, &dim)) {
+ SAFTM_LOG(ERROR) << "Unable to parse " << dim_str;
+ return false;
+ }
+ embedding_dims_.push_back(dim);
+ }
+ return true;
+}
+
+bool GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/embedding-feature-extractor.h b/native/lang_id/common/embedding-feature-extractor.h
new file mode 100644
index 0000000..ba4f858
--- /dev/null
+++ b/native/lang_id/common/embedding-feature-extractor.h
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// An EmbeddingFeatureExtractor manages the extraction of features for
+// embedding-based models. It wraps a sequence of underlying classes of feature
+// extractors, along with associated predicate maps. Each class of feature
+// extractors is associated with a name, e.g., "words", "labels", "tags".
+//
+// The class is split between a generic abstract version,
+// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
+// signature of the ExtractFeatures method) and a typed version.
+//
+// The predicate maps must be initialized before use: they can be loaded using
+// Read() or updated via UpdateMapsForExample.
+class GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this GenericEmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit GenericEmbeddingFeatureExtractor(const std::string &arg_prefix)
+ : arg_prefix_(arg_prefix) {}
+
+ virtual ~GenericEmbeddingFeatureExtractor() {}
+
+ // Sets/inits up predicate maps and embedding space names that are common for
+ // all embedding based feature extractors.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
+
+ // Requests workspace for the underlying feature extractors. This is
+ // implemented in the typed class.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return embedding_dims_.size(); }
+
+ const std::vector<std::string> &embedding_fml() const {
+ return embedding_fml_;
+ }
+
+ // Get parameter name by concatenating the prefix and the original name.
+ std::string GetParamName(const std::string ¶m_name) const {
+ std::string full_name = arg_prefix_;
+ full_name.push_back('_');
+ full_name.append(param_name);
+ return full_name;
+ }
+
+ private:
+ // Prefix for TaskContext parameters.
+ const std::string arg_prefix_;
+
+ // Embedding space names for parameter sharing.
+ std::vector<std::string> embedding_names_;
+
+ // FML strings for each feature extractor.
+ std::vector<std::string> embedding_fml_;
+
+ // Size of each of the embedding spaces (maximum predicate id).
+ std::vector<int> embedding_sizes_;
+
+ // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
+ std::vector<int> embedding_dims_;
+};
+
+// Templated, object-specific implementation of the
+// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
+// ARGS...> class that has the appropriate FeatureTraits() to ensure that
+// locator type features work.
+//
+// Note: for backwards compatibility purposes, this always reads the FML spec
+// from "<prefix>_features".
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
+ public:
+ // Constructs this EmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit EmbeddingFeatureExtractor(const std::string &arg_prefix)
+ : GenericEmbeddingFeatureExtractor(arg_prefix) {}
+
+ // Sets up all predicate maps, feature extractors, and flags.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
+ return false;
+ }
+ feature_extractors_.resize(embedding_fml().size());
+ for (int i = 0; i < embedding_fml().size(); ++i) {
+ feature_extractors_[i].reset(new EXTRACTOR());
+ if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
+ if (!feature_extractors_[i]->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes resources needed by the feature extractors.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
+ for (auto &feature_extractor : feature_extractors_) {
+ if (!feature_extractor->Init(context)) return false;
+ }
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess().
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->RequestWorkspaces(registry);
+ }
+ }
+
+ // Must be called on the object one state for each sentence, before any
+ // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
+ void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
+ for (auto &feature_extractor : feature_extractors_) {
+ feature_extractor->Preprocess(workspaces, obj);
+ }
+ }
+
+ // Extracts features using the extractors. Note that features must already
+ // be initialized to the correct number of feature extractors. No predicate
+ // mapping is applied.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
+ ARGS... args,
+ std::vector<FeatureVector> *features) const {
+ // DCHECK(features != nullptr);
+ // DCHECK_EQ(features->size(), feature_extractors_.size());
+ for (int i = 0; i < feature_extractors_.size(); ++i) {
+ (*features)[i].clear();
+ feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
+ &(*features)[i]);
+ }
+ }
+
+ private:
+ // Templated feature extractor class.
+ std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
diff --git a/native/lang_id/common/embedding-feature-interface.h b/native/lang_id/common/embedding-feature-interface.h
new file mode 100644
index 0000000..75d0c98
--- /dev/null
+++ b/native/lang_id/common/embedding-feature-interface.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-extractor.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureInterface {
+ public:
+ // Constructs this EmbeddingFeatureInterface.
+ //
+ // |arg_prefix| is a string prefix for the TaskContext parameters, passed to
+ // |the underlying EmbeddingFeatureExtractor.
+ explicit EmbeddingFeatureInterface(const std::string &arg_prefix)
+ : feature_extractor_(arg_prefix) {}
+
+ // Sets up feature extractors and flags for processing (inference).
+ SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
+ return feature_extractor_.Setup(context);
+ }
+
+ // Initializes feature extractor resources for processing (inference)
+ // including requesting a workspace for caching extracted features.
+ SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
+ if (!feature_extractor_.Init(context)) return false;
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
+ return true;
+ }
+
+ // Preprocesses *obj using the internal workspace registry.
+ void Preprocess(WorkspaceSet *workspace, OBJ *obj) const {
+ workspace->Reset(workspace_registry_);
+ feature_extractor_.Preprocess(workspace, obj);
+ }
+
+ // Extract features from |obj|. On return, FeatureVector features[i]
+ // contains the features for the embedding space #i.
+ //
+ // This function uses the precomputed info from |workspace|. Usage pattern:
+ //
+ // EmbeddingFeatureInterface<...> feature_interface;
+ // ...
+ // OBJ obj;
+ // WorkspaceSet workspace;
+ // feature_interface.Preprocess(&workspace, &obj);
+ //
+ // // For the same obj, but with different args:
+ // std::vector<FeatureVector> features;
+ // feature_interface.GetFeatures(obj, args, workspace, &features);
+ //
+ // This pattern is useful (more efficient) if you can pre-compute some info
+ // for the entire |obj|, which is reused by the feature extraction performed
+ // for different args. If that is not the case, you can use the simpler
+ // version GetFeaturesNoCaching below.
+ void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace,
+ std::vector<FeatureVector> *features) const {
+ feature_extractor_.ExtractFeatures(workspace, obj, args..., features);
+ }
+
+ // Simpler version of GetFeatures(), for cases when there is no opportunity to
+ // reuse computation between feature extractions for the same |obj|, but with
+ // different |args|. Returns the extracted features. For more info, see the
+ // doc for GetFeatures().
+ std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj,
+ ARGS... args) const {
+ // Technically, we still use a workspace, because
+ // feature_extractor_.ExtractFeatures requires one. But there is no real
+ // caching here, as we start from scratch for each call to ExtractFeatures.
+ WorkspaceSet workspace;
+ Preprocess(&workspace, obj);
+ std::vector<FeatureVector> features(NumEmbeddings());
+ GetFeatures(*obj, args..., workspace, &features);
+ return features;
+ }
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); }
+
+ private:
+ // Typed feature extractor for embeddings.
+ EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_;
+
+ // The registry of shared workspaces in the feature extractor.
+ WorkspaceRegistry workspace_registry_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
diff --git a/native/lang_id/common/embedding-network-params.cc b/native/lang_id/common/embedding-network-params.cc
new file mode 100644
index 0000000..8b48fce
--- /dev/null
+++ b/native/lang_id/common/embedding-network-params.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/embedding-network-params.h"
+
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+QuantizationType ParseQuantizationType(const std::string &s) {
+ if (s == "NONE") {
+ return QuantizationType::NONE;
+ }
+ if (s == "UINT8") {
+ return QuantizationType::UINT8;
+ }
+ if (s == "UINT4") {
+ return QuantizationType::UINT4;
+ }
+ if (s == "FLOAT16") {
+ return QuantizationType::FLOAT16;
+ }
+ SAFTM_LOG(FATAL) << "Unsupported quantization type: " << s;
+
+ // Execution should never reach this point; just to keep the compiler happy.
+ // TODO(salcianu): implement SAFTM_LOG(FATAL) in a way that doesn't require
+ // this trick.
+ return QuantizationType::NONE;
+}
+
+} // namespace nlp_saft
diff --git a/native/lang_id/common/embedding-network-params.h b/native/lang_id/common/embedding-network-params.h
new file mode 100755
index 0000000..6ad147c
--- /dev/null
+++ b/native/lang_id/common/embedding-network-params.h
@@ -0,0 +1,316 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/float16.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+
+enum class QuantizationType {
+ NONE = 0,
+
+ // Quantization to 8 bit unsigned ints.
+ UINT8,
+
+ // Quantization to 4 bit unsigned ints.
+ UINT4,
+
+ // Quantization to 16 bit floats, the type defined in
+ // lang_id/common/float16.h
+ FLOAT16,
+
+ // NOTE: for backward compatibility, if you add a new value to this enum, add
+ // it *at the end*, such that you do not change the integer values of the
+ // existing enum values.
+};
+
+// Converts "UINT8" -> QuantizationType::UINT8, and so on.
+QuantizationType ParseQuantizationType(const std::string &s);
+
+// API for accessing parameters for a feed-forward neural network with
+// embeddings.
+//
+//
+// In fact, we provide two APIs: a high-level (and highly-recommented) API, with
+// methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
+// low-level API, using C-style names (e.g., softmax_num_cols()).
+//
+// Note: the API below is meant to allow the inference code (the class
+// libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
+// for transposing any matrix (which would require extra overhead on mobile
+// devices). Hence, as indicated by the comments for the API methods, some of
+// the matrices below are the transposes of the corresponding matrices from the
+// original proto.
+class EmbeddingNetworkParams {
+ public:
+ virtual ~EmbeddingNetworkParams() {}
+
+ // Returns true if these params are valid. False otherwise (e.g., if the
+ // underlying data is corrupted). If is_valid() returns false, clients should
+ // not call any other method on that instance of EmbeddingNetworkParams. If
+ // is_valid() returns true, then calls to the API methods below should not
+ // crash *if they are called with index parameters in bounds*. E.g., if
+ // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
+ // should not crash.
+ virtual bool is_valid() const = 0;
+
+ // **** High-level API.
+
+ // Simple representation of a matrix. This small struct that doesn't own any
+ // resource intentionally supports copy / assign, to simplify our APIs.
+ struct Matrix {
+ // Number of rows.
+ int rows = 0;
+
+ // Number of columns.
+ int cols = 0;
+
+ QuantizationType quant_type = QuantizationType::NONE;
+
+ // Pointer to matrix elements, in row-major order
+ // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
+ const void *elements = nullptr;
+
+ // Quantization scales: one scale for each row.
+ const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
+ };
+
+ // Returns i-th embedding matrix. Crashes on out of bounds indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetEmbeddingMatrix(int i) const {
+ CheckIndex(i, embeddings_size(), "embedding matrix");
+ Matrix matrix;
+ matrix.rows = embeddings_num_rows(i);
+ matrix.cols = embeddings_num_cols(i);
+ matrix.elements = embeddings_weights(i);
+ matrix.quant_type = embeddings_quant_type(i);
+ matrix.quant_scales = embeddings_quant_scales(i);
+ return matrix;
+ }
+
+ // Returns weight matrix for i-th hidden layer. Crashes on out of bounds
+ // indices.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetHiddenLayerMatrix(int i) const {
+ CheckIndex(i, hidden_size(), "hidden layer");
+ Matrix matrix;
+ matrix.rows = hidden_num_rows(i);
+ matrix.cols = hidden_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = hidden_weights_quant_type(i);
+ matrix.elements = hidden_weights(i);
+ return matrix;
+ }
+
+ // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data. Crashes
+ // on out of bounds indices.
+ Matrix GetHiddenLayerBias(int i) const {
+ CheckIndex(i, hidden_bias_size(), "hidden layer bias");
+ Matrix matrix;
+ matrix.rows = hidden_bias_num_rows(i);
+ matrix.cols = hidden_bias_num_cols(i);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = hidden_bias_weights(i);
+ return matrix;
+ }
+
+ // Returns true if a softmax layer exists.
+ bool HasSoftmax() const {
+ return softmax_size() == 1;
+ }
+
+ // Returns weight matrix for the softmax layer. Note: should be called only
+ // if HasSoftmax() is true.
+ //
+ // This is the transpose of the corresponding matrix from the original proto.
+ Matrix GetSoftmaxMatrix() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_num_rows(0);
+ matrix.cols = softmax_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = softmax_weights_quant_type(0);
+ matrix.elements = softmax_weights(0);
+ return matrix;
+ }
+
+ // Returns bias for the softmax layer. Technically a Matrix, but we expect it
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
+ // don't CHECK for that: we just provide access to underlying data.
+ Matrix GetSoftmaxBias() const {
+ SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
+ Matrix matrix;
+ matrix.rows = softmax_bias_num_rows(0);
+ matrix.cols = softmax_bias_num_cols(0);
+
+ // Quantization not supported here.
+ matrix.quant_type = QuantizationType::NONE;
+ matrix.elements = softmax_bias_weights(0);
+ return matrix;
+ }
+
+ // Updates the EmbeddingNetwork-related parameters from task_context. Returns
+ // true on success, false on error.
+ virtual bool UpdateTaskContextParameters(
+ mobile::TaskContext *task_context) = 0;
+
+ // **** Low-level API.
+ //
+ // * Most low-level API methods are documented by giving an equivalent
+ // function call on proto, the original proto (of type
+ // EmbeddingNetworkProto) which was used to generate the C++ code.
+ //
+ // * To simplify our generation code, optional proto fields of message type
+ // are treated as repeated fields with 0 or 1 instances. As such, we have
+ // *_size() methods for such optional fields: they return 0 or 1.
+ //
+ // * "transpose(M)" denotes the transpose of a matrix M.
+
+ // ** Access methods for repeated MatrixParams embeddings.
+ //
+ // Returns proto.embeddings_size().
+ virtual int embeddings_size() const = 0;
+
+ // Returns number of rows of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.embeddings(i)).
+ virtual int embeddings_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
+ // order. NOTE: for unquantized embeddings, this returns a pointer to float;
+ // for quantized embeddings, this returns a pointer to uint8.
+ virtual const void *embeddings_weights(int i) const = 0;
+
+ virtual QuantizationType embeddings_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
+ int i) const {
+ return nullptr;
+ }
+
+ // ** Access methods for repeated MatrixParams hidden.
+ //
+ // Returns embedding_network_proto.hidden_size().
+ virtual int hidden_size() const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_rows(int i) const = 0;
+
+ // Returns embedding_network_proto.hidden(i).rows().
+ virtual int hidden_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the weights of the i-th hidden layer.
+ virtual QuantizationType hidden_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to beginning of array of floats with all values from
+ // embedding_network_proto.hidden(i).
+ virtual const void *hidden_weights(int i) const = 0;
+
+ // ** Access methods for repeated MatrixParams hidden_bias.
+ //
+ // Returns proto.hidden_bias_size().
+ virtual int hidden_bias_size() const = 0;
+
+ // Returns number of rows of proto.hidden_bias(i).
+ virtual int hidden_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.hidden_bias(i).
+ virtual int hidden_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
+ virtual const void *hidden_bias_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax.
+ //
+ // Returns 1 if proto has optional field softmax, 0 otherwise.
+ virtual int softmax_size() const = 0;
+
+ // Returns number of rows of transpose(proto.softmax()).
+ virtual int softmax_num_rows(int i) const = 0;
+
+ // Returns number of columns of transpose(proto.softmax()).
+ virtual int softmax_num_cols(int i) const = 0;
+
+ // Returns quantization mode for the softmax weights.
+ virtual QuantizationType softmax_weights_quant_type(int i) const {
+ return QuantizationType::NONE;
+ }
+
+ // Returns pointer to elements of transpose(proto.softmax()), in row-major
+ // order.
+ virtual const void *softmax_weights(int i) const = 0;
+
+ // ** Access methods for optional MatrixParams softmax_bias.
+ //
+ // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
+ virtual int softmax_bias_size() const = 0;
+
+ // Returns number of rows of proto.softmax_bias().
+ virtual int softmax_bias_num_rows(int i) const = 0;
+
+ // Returns number of columns of proto.softmax_bias().
+ virtual int softmax_bias_num_cols(int i) const = 0;
+
+ // Returns pointer to elements of proto.softmax_bias(), in row-major order.
+ virtual const void *softmax_bias_weights(int i) const = 0;
+
+ // ** Access methods for repeated int32 embedding_num_features.
+ //
+ // Returns proto.embedding_num_features_size().
+ virtual int embedding_num_features_size() const = 0;
+
+ // Returns proto.embedding_num_features(i).
+ virtual int embedding_num_features(int i) const = 0;
+
+ // ** Access methods for is_precomputed
+ //
+ // Returns proto.has_is_precomputed().
+ virtual bool has_is_precomputed() const = 0;
+
+ // Returns proto.is_precomputed().
+ virtual bool is_precomputed() const = 0;
+
+ protected:
+ void CheckIndex(int index, int size, const std::string &description) const {
+ SAFTM_CHECK_GE(index, 0)
+ << "Out-of-range index for " << description << ": " << index;
+ SAFTM_CHECK_LT(index, size)
+ << "Out-of-range index for " << description << ": " << index;
+ }
+}; // class EmbeddingNetworkParams
+
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
diff --git a/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
similarity index 100%
rename from lang_id/common/embedding-network.cc
rename to native/lang_id/common/embedding-network.cc
diff --git a/lang_id/common/embedding-network.h b/native/lang_id/common/embedding-network.h
similarity index 100%
rename from lang_id/common/embedding-network.h
rename to native/lang_id/common/embedding-network.h
diff --git a/native/lang_id/common/fel/feature-descriptors.cc b/native/lang_id/common/fel/feature-descriptors.cc
new file mode 100644
index 0000000..1293399
--- /dev/null
+++ b/native/lang_id/common/fel/feature-descriptors.cc
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-descriptors.h"
+
+#include <string>
+
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+void ToFELFunction(const FeatureFunctionDescriptor &function,
+ std::string *output) {
+ LiteStrAppend(output, function.type());
+ if (function.argument() != 0 || function.parameter_size() > 0) {
+ LiteStrAppend(output, "(");
+ bool first = true;
+ if (function.argument() != 0) {
+ LiteStrAppend(output, function.argument());
+ first = false;
+ }
+ for (int i = 0; i < function.parameter_size(); ++i) {
+ if (!first) LiteStrAppend(output, ",");
+ LiteStrAppend(output, function.parameter(i).name(), "=\"",
+ function.parameter(i).value(), "\"");
+ first = false;
+ }
+ LiteStrAppend(output, ")");
+ }
+}
+
+void ToFEL(const FeatureFunctionDescriptor &function, std::string *output) {
+ ToFELFunction(function, output);
+ if (function.feature_size() == 1) {
+ LiteStrAppend(output, ".");
+ ToFEL(function.feature(0), output);
+ } else if (function.feature_size() > 1) {
+ LiteStrAppend(output, " { ");
+ for (int i = 0; i < function.feature_size(); ++i) {
+ if (i > 0) LiteStrAppend(output, " ");
+ ToFEL(function.feature(i), output);
+ }
+ LiteStrAppend(output, " } ");
+ }
+}
+
+void ToFEL(const FeatureExtractorDescriptor &extractor, std::string *output) {
+ for (int i = 0; i < extractor.feature_size(); ++i) {
+ ToFEL(extractor.feature(i), output);
+ LiteStrAppend(output, "\n");
+ }
+}
+
+std::string FeatureFunctionDescriptor::DebugString() const {
+ std::string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+std::string FeatureExtractorDescriptor::DebugString() const {
+ std::string str;
+ ToFEL(*this, &str);
+ return str;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/feature-descriptors.h b/native/lang_id/common/fel/feature-descriptors.h
new file mode 100644
index 0000000..3bdc2fa
--- /dev/null
+++ b/native/lang_id/common/fel/feature-descriptors.h
@@ -0,0 +1,160 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Named feature parameter.
+class Parameter {
+ public:
+ Parameter() {}
+
+ void set_name(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
+
+ void set_value(const std::string &value) { value_ = value; }
+ const std::string &value() const { return value_; }
+
+ private:
+ std::string name_;
+ std::string value_;
+};
+
+// Descriptor for a feature function. Used to store the results of parsing one
+// feature function.
+class FeatureFunctionDescriptor {
+ public:
+ FeatureFunctionDescriptor() {}
+
+ // Accessors for the feature function type. The function type is the string
+ // that the feature extractor code is registered under.
+ void set_type(const std::string &type) { type_ = type; }
+ const std::string &type() const { return type_; }
+
+ // Accessors for the feature function name. The function name (if available)
+ // is used for some log messages. Otherwise, a more precise, but also more
+ // verbose name based on the feature specification is used.
+ void set_name(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
+
+ // Accessors for the default (name-less) parameter.
+ void set_argument(int32 argument) { argument_ = argument; }
+ bool has_argument() const {
+ // If argument has not been specified, clients should treat it as 0. This
+ // makes the test below correct, without having a separate has_argument_
+ // bool field.
+ return argument_ != 0;
+ }
+ int32 argument() const { return argument_; }
+
+ // Accessors for the named parameters.
+ Parameter *add_parameter() {
+ parameters_.emplace_back();
+ return &(parameters_.back());
+ }
+ int parameter_size() const { return parameters_.size(); }
+ const Parameter ¶meter(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < parameter_size()));
+ return parameters_[i];
+ }
+
+ // Accessors for the sub (i.e., nested) features. Nested features: as in
+ // offset(1).label.
+ FeatureFunctionDescriptor *add_feature() {
+ sub_features_.emplace_back(new FeatureFunctionDescriptor());
+ return sub_features_.back().get();
+ }
+ int feature_size() const { return sub_features_.size(); }
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(sub_features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureFunctionDescriptor.
+ std::string DebugString() const;
+
+ private:
+ // See comments for set_type().
+ std::string type_;
+
+ // See comments for set_name().
+ std::string name_;
+
+ // See comments for set_argument().
+ int32 argument_ = 0;
+
+ // See comemnts for add_parameter().
+ std::vector<Parameter> parameters_;
+
+ // See comments for add_feature().
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> sub_features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureFunctionDescriptor);
+};
+
+// List of FeatureFunctionDescriptors. Used to store the result of parsing the
+// spec for several feature functions.
+class FeatureExtractorDescriptor {
+ public:
+ FeatureExtractorDescriptor() {}
+
+ int feature_size() const { return features_.size(); }
+
+ FeatureFunctionDescriptor *add_feature() {
+ features_.emplace_back(new FeatureFunctionDescriptor());
+ return features_.back().get();
+ }
+
+ const FeatureFunctionDescriptor &feature(int i) const {
+ SAFTM_DCHECK((i >= 0) && (i < feature_size()));
+ return *(features_[i].get());
+ }
+
+ // Returns human-readable representation of this FeatureExtractorDescriptor.
+ std::string DebugString() const;
+
+ private:
+ std::vector<std::unique_ptr<FeatureFunctionDescriptor>> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureExtractorDescriptor);
+};
+
+// Appends to |*output| the FEL representation of the top-level feature from
+// |function|, without diving into the nested features.
+void ToFELFunction(const FeatureFunctionDescriptor &function,
+ std::string *output);
+
+// Appends to |*output| the FEL representation of |function|.
+void ToFEL(const FeatureFunctionDescriptor &function, std::string *output);
+
+// Appends to |*output| the FEL representation of |extractor|.
+void ToFEL(const FeatureExtractorDescriptor &extractor, std::string *output);
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_DESCRIPTORS_H_
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
new file mode 100644
index 0000000..ab8a1a6
--- /dev/null
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -0,0 +1,141 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/feature-extractor.h"
+
+#include <string>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/fel-parser.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+constexpr FeatureValue GenericFeatureFunction::kNone;
+
+GenericFeatureExtractor::GenericFeatureExtractor() {}
+
+GenericFeatureExtractor::~GenericFeatureExtractor() {}
+
+bool GenericFeatureExtractor::Parse(const std::string &source) {
+ // Parse feature specification into descriptor.
+ FELParser parser;
+
+ if (!parser.Parse(source, mutable_descriptor())) {
+ SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
+ return false;
+ }
+
+ // Initialize feature extractor from descriptor.
+ return InitializeFeatureFunctions();
+}
+
+bool GenericFeatureExtractor::InitializeFeatureTypes() {
+ // Register all feature types.
+ GetFeatureTypes(&feature_types_);
+ for (size_t i = 0; i < feature_types_.size(); ++i) {
+ FeatureType *ft = feature_types_[i];
+ ft->set_base(i);
+
+ // Check for feature space overflow.
+ double domain_size = ft->GetDomainSize();
+ if (domain_size < 0) {
+ SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
+ << ": " << domain_size;
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string GenericFeatureFunction::GetParameter(
+ const std::string &name, const std::string &default_value) const {
+ // Find named parameter in feature descriptor.
+ for (int i = 0; i < descriptor_->parameter_size(); ++i) {
+ if (name == descriptor_->parameter(i).name()) {
+ return descriptor_->parameter(i).value();
+ }
+ }
+ return default_value;
+}
+
+GenericFeatureFunction::GenericFeatureFunction() {}
+
+GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
+
+int GenericFeatureFunction::GetIntParameter(const std::string &name,
+ int default_value) const {
+ std::string value_str = GetParameter(name, "");
+ if (value_str.empty()) {
+ // Parameter not specified, use default value for it.
+ return default_value;
+ }
+ int value = 0;
+ if (!LiteAtoi(value_str, &value)) {
+ SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
+ << "' as int for parameter " << name;
+ return default_value;
+ }
+ return value;
+}
+
+bool GenericFeatureFunction::GetBoolParameter(const std::string &name,
+ bool default_value) const {
+ std::string value = GetParameter(name, "");
+ if (value.empty()) return default_value;
+ if (value == "true") return true;
+ if (value == "false") return false;
+ SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
+ << name;
+ return default_value;
+}
+
+void GenericFeatureFunction::GetFeatureTypes(
+ std::vector<FeatureType *> *types) const {
+ if (feature_type_ != nullptr) types->push_back(feature_type_);
+}
+
+FeatureType *GenericFeatureFunction::GetFeatureType() const {
+ // If a single feature type has been registered return it.
+ if (feature_type_ != nullptr) return feature_type_;
+
+ // Get feature types for function.
+ std::vector<FeatureType *> types;
+ GetFeatureTypes(&types);
+
+ // If there is exactly one feature type return this, else return null.
+ if (types.size() == 1) return types[0];
+ return nullptr;
+}
+
+std::string GenericFeatureFunction::name() const {
+ std::string output;
+ if (descriptor_->name().empty()) {
+ if (!prefix_.empty()) {
+ output.append(prefix_);
+ output.append(".");
+ }
+ ToFEL(*descriptor_, &output);
+ } else {
+ output = descriptor_->name();
+ }
+ return output;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/feature-extractor.h b/native/lang_id/common/fel/feature-extractor.h
new file mode 100644
index 0000000..c09e1eb
--- /dev/null
+++ b/native/lang_id/common/fel/feature-extractor.h
@@ -0,0 +1,652 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic feature extractor for extracting features from objects. The feature
+// extractor can be used for extracting features from any object. The feature
+// extractor and feature function classes are template classes that have to
+// be instantiated for extracting feature from a specific object type.
+//
+// A feature extractor consists of a hierarchy of feature functions. Each
+// feature function extracts one or more feature type and value pairs from the
+// object.
+//
+// The feature extractor has a modular design where new feature functions can be
+// registered as components. The feature extractor is initialized from a
+// descriptor represented by a protocol buffer. The feature extractor can also
+// be initialized from a text-based source specification of the feature
+// extractor. Feature specification parsers can be added as components. By
+// default the feature extractor can be read from an ASCII protocol buffer or in
+// a simple feature modeling language (fml).
+
+// A feature function is invoked with a focus. Nested feature function can be
+// invoked with another focus determined by the parent feature function.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
+
+#include <stddef.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/common/registry.h"
+#include "lang_id/common/stl-util.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// A union used to represent discrete and continuous feature values.
+union FloatFeatureValue {
+ public:
+ explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
+ FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
+ FeatureValue discrete_value;
+ struct {
+ uint32 id;
+ float weight;
+ };
+};
+
+// A feature vector contains feature type and value pairs.
+class FeatureVector {
+ public:
+ FeatureVector() {}
+
+ // Adds feature type and value pair to feature vector.
+ void add(FeatureType *type, FeatureValue value) {
+ features_.emplace_back(type, value);
+ }
+
+ // Removes all elements from the feature vector.
+ void clear() { features_.clear(); }
+
+ // Returns the number of elements in the feature vector.
+ int size() const { return features_.size(); }
+
+ // Reserves space in the underlying feature vector.
+ void reserve(int n) { features_.reserve(n); }
+
+ // Returns feature type for an element in the feature vector.
+ FeatureType *type(int index) const { return features_[index].type; }
+
+ // Returns feature value for an element in the feature vector.
+ FeatureValue value(int index) const { return features_[index].value; }
+
+ private:
+ // Structure for holding feature type and value pairs.
+ struct Element {
+ Element() : type(nullptr), value(-1) {}
+ Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
+
+ FeatureType *type;
+ FeatureValue value;
+ };
+
+ // Array for storing feature vector elements.
+ std::vector<Element> features_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
+};
+
+// The generic feature extractor is the type-independent part of a feature
+// extractor. This holds the descriptor for the feature extractor and the
+// collection of feature types used in the feature extractor. The feature
+// types are not available until FeatureExtractor<>::Init() has been called.
+class GenericFeatureExtractor {
+ public:
+ GenericFeatureExtractor();
+ virtual ~GenericFeatureExtractor();
+
+ // Initializes the feature extractor from the FEL specification |source|.
+ //
+ // Returns true on success, false otherwise (e.g., FEL syntax error).
+ SAFTM_MUST_USE_RESULT bool Parse(const std::string &source);
+
+ // Returns the feature extractor descriptor.
+ const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
+ FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
+
+ // Returns the number of feature types in the feature extractor. Invalid
+ // before Init() has been called.
+ int feature_types() const { return feature_types_.size(); }
+
+ protected:
+ // Initializes the feature types used by the extractor. Called from
+ // FeatureExtractor<>::Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes();
+
+ private:
+ // Initializes the top-level feature functions.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0;
+
+ // Returns all feature types used by the extractor. The feature types are
+ // added to the result array.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
+
+ // Descriptor for the feature extractor. This is a protocol buffer that
+ // contains all the information about the feature extractor. The feature
+ // functions are initialized from the information in the descriptor.
+ FeatureExtractorDescriptor descriptor_;
+
+ // All feature types used by the feature extractor. The collection of all the
+ // feature types describes the feature space of the feature set produced by
+ // the feature extractor. Not owned.
+ std::vector<FeatureType *> feature_types_;
+};
+
+// The generic feature function is the type-independent part of a feature
+// function. Each feature function is associated with the descriptor that it is
+// instantiated from. The feature types associated with this feature function
+// will be established by the time FeatureExtractor<>::Init() completes.
+class GenericFeatureFunction {
+ public:
+ // A feature value that represents the absence of a value.
+ static constexpr FeatureValue kNone = -1;
+
+ GenericFeatureFunction();
+ virtual ~GenericFeatureFunction();
+
+ // Sets up the feature function. NB: FeatureTypes of nested functions are not
+ // guaranteed to be available until Init().
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the feature function. NB: The FeatureType of this function must
+ // be established when this method completes.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; }
+
+ // Requests workspaces from a registry to obtain indices into a WorkspaceSet
+ // for any Workspace objects used by this feature function. NB: This will be
+ // called after Init(), so it can depend on resources and arguments.
+ virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
+
+ // Appends the feature types produced by the feature function to types. The
+ // default implementation appends feature_type(), if non-null. Invalid
+ // before Init() has been called.
+ virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
+
+ // Returns the feature type for feature produced by this feature function. If
+ // the feature function produces features of different types this returns
+ // null. Invalid before Init() has been called.
+ virtual FeatureType *GetFeatureType() const;
+
+ // Returns value of parameter |name| from the feature function descriptor.
+ // If the parameter is not present, returns the indicated |default_value|.
+ std::string GetParameter(const std::string &name,
+ const std::string &default_value) const;
+
+ // Returns value of int parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value can't be parsed as an int,
+ // returns |default_value|.
+ int GetIntParameter(const std::string &name, int default_value) const;
+
+ // Returns value of bool parameter |name| from feature function descriptor.
+ // If the parameter is not present, or its value is not "true" or "false",
+ // returns |default_value|. NOTE: this method is case sensitive, it doesn't
+ // do any lower-casing.
+ bool GetBoolParameter(const std::string &name, bool default_value) const;
+
+ // Returns the FEL function description for the feature function, i.e. the
+ // name and parameters without the nested features.
+ std::string FunctionName() const {
+ std::string output;
+ ToFELFunction(*descriptor_, &output);
+ return output;
+ }
+
+ // Returns the prefix for nested feature functions. This is the prefix of this
+ // feature function concatenated with the feature function name.
+ std::string SubPrefix() const {
+ return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
+ }
+
+ // Returns/sets the feature extractor this function belongs to.
+ const GenericFeatureExtractor *extractor() const { return extractor_; }
+ void set_extractor(const GenericFeatureExtractor *extractor) {
+ extractor_ = extractor;
+ }
+
+ // Returns/sets the feature function descriptor.
+ const FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
+ void set_descriptor(const FeatureFunctionDescriptor *descriptor) {
+ descriptor_ = descriptor;
+ }
+
+ // Returns a descriptive name for the feature function. The name is taken from
+ // the descriptor for the feature function. If the name is empty or the
+ // feature function is a variable the name is the FEL representation of the
+ // feature, including the prefix.
+ std::string name() const;
+
+ // Returns the argument from the feature function descriptor. It defaults to
+ // 0 if the argument has not been specified.
+ int argument() const {
+ return descriptor_->has_argument() ? descriptor_->argument() : 0;
+ }
+
+ // Returns/sets/clears function name prefix.
+ const std::string &prefix() const { return prefix_; }
+ void set_prefix(const std::string &prefix) { prefix_ = prefix; }
+
+ protected:
+ // Returns the feature type for single-type feature functions.
+ FeatureType *feature_type() const { return feature_type_; }
+
+ // Sets the feature type for single-type feature functions. This takes
+ // ownership of feature_type. Can only be called once.
+ void set_feature_type(FeatureType *feature_type) {
+ SAFTM_CHECK_EQ(feature_type_, nullptr);
+ feature_type_ = feature_type;
+ }
+
+ private:
+ // Feature extractor this feature function belongs to. Not owned. Set to a
+ // pointer != nullptr as soon as this object is created by Instantiate().
+ // Normal methods can safely assume this is != nullptr.
+ const GenericFeatureExtractor *extractor_ = nullptr;
+
+ // Descriptor for feature function. Not owned. Set to a pointer != nullptr
+ // as soon as this object is created by Instantiate(). Normal methods can
+ // safely assume this is != nullptr.
+ const FeatureFunctionDescriptor *descriptor_ = nullptr;
+
+ // Feature type for features produced by this feature function. If the
+ // feature function produces features of multiple feature types this is null
+ // and the feature function must return it's feature types in
+ // GetFeatureTypes(). Owned.
+ FeatureType *feature_type_ = nullptr;
+
+ // Prefix used for sub-feature types of this function.
+ std::string prefix_;
+};
+
+// Feature function that can extract features from an object. Templated on
+// two type arguments:
+//
+// OBJ: The "object" from which features are extracted; e.g., a sentence. This
+// should be a plain type, rather than a reference or pointer.
+//
+// ARGS: A set of 0 or more types that are used to "index" into some part of the
+// object that should be extracted, e.g. an int token index for a sentence
+// object. This should not be a reference type.
+template <class OBJ, class... ARGS>
+class FeatureFunction
+ : public GenericFeatureFunction,
+ public RegisterableClass<FeatureFunction<OBJ, ARGS...> > {
+ public:
+ using Self = FeatureFunction<OBJ, ARGS...>;
+
+ // Preprocesses the object. This will be called prior to calling Evaluate()
+ // or Compute() on that object.
+ virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {}
+
+ // Appends features computed from the object and focus to the result. The
+ // default implementation delegates to Compute(), adding a single value if
+ // available. Multi-valued feature functions must override this method.
+ virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ FeatureValue value = Compute(workspaces, object, args...);
+ if (value != kNone) result->add(feature_type(), value);
+ }
+
+ // Returns a feature value computed from the object and focus, or kNone if no
+ // value is computed. Single-valued feature functions only need to override
+ // this method.
+ virtual FeatureValue Compute(const WorkspaceSet &workspaces,
+ const OBJ &object, ARGS... args) const {
+ return kNone;
+ }
+
+ // Instantiates a new feature function in a feature extractor from a feature
+ // descriptor.
+ //
+ // Returns a pointer to the newly-created object if everything goes well.
+ // Returns nullptr if the feature function could not be instantiated (e.g., if
+ // the function with that name is not registered; this usually happens because
+ // the relevant cc_library was not linked-in).
+ static Self *Instantiate(const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd,
+ const std::string &prefix) {
+ Self *f = Self::Create(fd->type());
+ if (f != nullptr) {
+ f->set_extractor(extractor);
+ f->set_descriptor(fd);
+ f->set_prefix(prefix);
+ }
+ return f;
+ }
+
+ private:
+ // Special feature function class for resolving variable references. The type
+ // of the feature function is used for resolving the variable reference. When
+ // evaluated it will either get the feature value(s) from the variable portion
+ // of the feature vector, if present, or otherwise it will call the referenced
+ // feature extractor function directly to extract the feature(s).
+ class Reference;
+};
+
+// Base class for features with nested feature functions. The nested functions
+// are of type NES, which may be different from the type of the parent function.
+// NB: NestedFeatureFunction will ensure that all initialization of nested
+// functions takes place during Setup() and Init() -- after the nested features
+// are initialized, the parent feature is initialized via SetupNested() and
+// InitNested(). Alternatively, a derived classes that overrides Setup() and
+// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
+//
+// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
+// Compute, since the nested functions may be of a different type.
+template <class NES, class OBJ, class... ARGS>
+class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
+ public:
+ using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
+
+ // Clean up nested functions.
+ ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
+
+ // By default, just appends the nested feature types.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK(!this->nested().empty())
+ << "Nested features require nested features to be defined.";
+ for (auto *function : nested_) function->GetFeatureTypes(types);
+ }
+
+ // Sets up the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
+ bool success = CreateNested(this->extractor(), this->descriptor(), &nested_,
+ this->SubPrefix());
+ if (!success) return false;
+ for (auto *function : nested_) {
+ if (!function->Setup(context)) return false;
+ }
+ if (!SetupNested(context)) return false;
+ return true;
+ }
+
+ // Sets up this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) {
+ return true;
+ }
+
+ // Initializes the nested features.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
+ for (auto *function : nested_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!InitNested(context)) return false;
+ return true;
+ }
+
+ // Initializes this NestedFeatureFunction specifically.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) {
+ return true;
+ }
+
+ // Gets all the workspaces needed for the nested functions.
+ void RequestWorkspaces(WorkspaceRegistry *registry) override {
+ for (auto *function : nested_) function->RequestWorkspaces(registry);
+ }
+
+ // Returns the list of nested feature functions.
+ const std::vector<NES *> &nested() const { return nested_; }
+
+ // Instantiates nested feature functions for a feature function. Creates and
+ // initializes one feature function for each sub-descriptor in the feature
+ // descriptor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT static bool CreateNested(
+ const GenericFeatureExtractor *extractor,
+ const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions,
+ const std::string &prefix) {
+ for (int i = 0; i < fd->feature_size(); ++i) {
+ const FeatureFunctionDescriptor &sub = fd->feature(i);
+ NES *f = NES::Instantiate(extractor, &sub, prefix);
+ if (f == nullptr) return false;
+ functions->push_back(f);
+ }
+ return true;
+ }
+
+ protected:
+ // The nested feature functions, if any, in order of declaration in the
+ // feature descriptor. Owned.
+ std::vector<NES *> nested_;
+};
+
+// Base class for a nested feature function that takes nested features with the
+// same signature as these features, i.e. a meta feature. For this class, we can
+// provide preprocessing of the nested features.
+template <class OBJ, class... ARGS>
+class MetaFeatureFunction
+ : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ // Preprocesses using the nested features.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+};
+
+// Template for a special type of locator: The locator of type
+// FeatureFunction<OBJ, ARGS...> calls nested functions of type
+// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
+// responsible for translating by providing the following:
+//
+// // Gets the new additional focus.
+// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
+//
+// This is useful to e.g. add a token focus to a parser state based on some
+// desired property of that state.
+template <class DER, class OBJ, class IDX, class... ARGS>
+class FeatureAddFocusLocator
+ : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ,
+ ARGS...> {
+ public:
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override {
+ for (auto *function : this->nested_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, focus, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ IDX focus =
+ static_cast<const DER *>(this)->GetFocus(workspaces, object, args...);
+ return this->nested()[0]->Compute(workspaces, object, focus, args...);
+ }
+};
+
+// CRTP feature locator class. This is a meta feature that modifies ARGS and
+// then calls the nested feature functions with the modified ARGS. Note that in
+// order for this template to work correctly, all of ARGS must be types for
+// which the reference operator & can be interpreted as a pointer to the
+// argument. The derived class DER must implement the UpdateFocus method which
+// takes pointers to the ARGS arguments:
+//
+// // Updates the current arguments.
+// void UpdateArgs(const OBJ &object, ARGS *...args) const;
+template <class DER, class OBJ, class... ARGS>
+class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
+ public:
+ // Feature locators have an additional check that there is no intrinsic type.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ SAFTM_CHECK_EQ(this->feature_type(), nullptr)
+ << "FeatureLocators should not have an intrinsic type.";
+ MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
+ }
+
+ // Evaluates the locator.
+ void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args,
+ FeatureVector *result) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ for (auto *function : this->nested()) {
+ function->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ // Returns the first nested feature's computed value.
+ FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args) const override {
+ static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+ return this->nested()[0]->Compute(workspaces, object, args...);
+ }
+};
+
+// Feature extractor for extracting features from objects of a certain class.
+// Template type parameters are as defined for FeatureFunction.
+template <class OBJ, class... ARGS>
+class FeatureExtractor : public GenericFeatureExtractor {
+ public:
+ // Feature function type for top-level functions in the feature extractor.
+ typedef FeatureFunction<OBJ, ARGS...> Function;
+ typedef FeatureExtractor<OBJ, ARGS...> Self;
+
+ // Feature locator type for the feature extractor.
+ template <class DER>
+ using Locator = FeatureLocator<DER, OBJ, ARGS...>;
+
+ // Initializes feature extractor.
+ FeatureExtractor() {}
+
+ ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
+
+ // Sets up the feature extractor. Note that only top-level functions exist
+ // until Setup() is called. This does not take ownership over the context,
+ // which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Setup(context)) return false;
+ }
+ return true;
+ }
+
+ // Initializes the feature extractor. Must be called after Setup(). This
+ // does not take ownership over the context, which must outlive this.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) {
+ for (Function *function : functions_) {
+ if (!function->Init(context)) return false;
+ }
+ if (!this->InitializeFeatureTypes()) return false;
+ return true;
+ }
+
+ // Requests workspaces from the registry. Must be called after Init(), and
+ // before Preprocess(). Does not take ownership over registry. This should be
+ // the same registry used to initialize the WorkspaceSet used in Preprocess()
+ // and ExtractFeatures(). NB: This is a different ordering from that used in
+ // SentenceFeatureRepresentation style feature computation.
+ void RequestWorkspaces(WorkspaceRegistry *registry) {
+ for (auto *function : functions_) function->RequestWorkspaces(registry);
+ }
+
+ // Preprocesses the object using feature functions for the phase. Must be
+ // called before any calls to ExtractFeatures() on that object and phase.
+ void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {
+ for (Function *function : functions_) {
+ function->Preprocess(workspaces, object);
+ }
+ }
+
+ // Extracts features from an object with a focus. This invokes all the
+ // top-level feature functions in the feature extractor. Only feature
+ // functions belonging to the specified phase are invoked.
+ void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
+ ARGS... args, FeatureVector *result) const {
+ result->reserve(this->feature_types());
+
+ // Extract features.
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->Evaluate(workspaces, object, args..., result);
+ }
+ }
+
+ private:
+ // Creates and initializes all feature functions in the feature extractor.
+ //
+ // Returns true on success, false otherwise.
+ SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override {
+ // Create all top-level feature functions.
+ for (int i = 0; i < descriptor().feature_size(); ++i) {
+ const FeatureFunctionDescriptor &fd = descriptor().feature(i);
+ Function *function = Function::Instantiate(this, &fd, "");
+ if (function == nullptr) return false;
+ functions_.push_back(function);
+ }
+ return true;
+ }
+
+ // Collect all feature types used in the feature extractor.
+ void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
+ for (int i = 0; i < functions_.size(); ++i) {
+ functions_[i]->GetFeatureTypes(types);
+ }
+ }
+
+ // Top-level feature functions (and variables) in the feature extractor.
+ // Owned.
+ std::vector<Function *> functions_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_
diff --git a/native/lang_id/common/fel/feature-types.h b/native/lang_id/common/fel/feature-types.h
new file mode 100644
index 0000000..ae422af
--- /dev/null
+++ b/native/lang_id/common/fel/feature-types.h
@@ -0,0 +1,190 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common feature types for parser components.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <utility>
+
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/str-cat.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// TODO(djweiss) Clean this up as well.
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// Each feature value in a feature vector has a feature type. The feature type
+// is used for converting feature type and value pairs to predicate values. The
+// feature type can also return names for feature values and calculate the size
+// of the feature value domain. The FeatureType class is abstract and must be
+// specialized for the concrete feature types.
+class FeatureType {
+ public:
+ // Initializes a feature type.
+ explicit FeatureType(const std::string &name)
+ : name_(name),
+ base_(0),
+ is_continuous_(name.find("continuous") != std::string::npos) {}
+
+ virtual ~FeatureType() {}
+
+ // Converts a feature value to a name.
+ virtual std::string GetFeatureValueName(FeatureValue value) const = 0;
+
+ // Returns the size of the feature values domain.
+ virtual int64 GetDomainSize() const = 0;
+
+ // Returns the feature type name.
+ const std::string &name() const { return name_; }
+
+ Predicate base() const { return base_; }
+ void set_base(Predicate base) { base_ = base; }
+
+ // Returns true iff this feature is continuous; see FloatFeatureValue.
+ bool is_continuous() const { return is_continuous_; }
+
+ private:
+ // Feature type name.
+ std::string name_;
+
+ // "Base" feature value: i.e. a "slot" in a global ordering of features.
+ Predicate base_;
+
+ // See doc for is_continuous().
+ bool is_continuous_;
+};
+
+// Feature type that is defined using an explicit map from FeatureValue to
+// string values. This can reduce some of the boilerplate when defining
+// features that generate enum values. Example usage:
+//
+// class BeverageSizeFeature : public FeatureFunction<Beverage>
+// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
+// void Init(TaskContext *context) override {
+// set_feature_type(new EnumFeatureType("beverage_size",
+// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
+// }
+// [...]
+// };
+class EnumFeatureType : public FeatureType {
+ public:
+ EnumFeatureType(const std::string &name,
+ const std::map<FeatureValue, std::string> &value_names)
+ : FeatureType(name), value_names_(value_names) {
+ for (const auto &pair : value_names) {
+ SAFTM_CHECK_GE(pair.first, 0)
+ << "Invalid feature value: " << pair.first << ", " << pair.second;
+ domain_size_ = std::max(domain_size_, pair.first + 1);
+ }
+ }
+
+ // Returns the feature name for a given feature value.
+ std::string GetFeatureValueName(FeatureValue value) const override {
+ auto it = value_names_.find(value);
+ if (it == value_names_.end()) {
+ SAFTM_LOG(ERROR) << "Invalid feature value " << value << " for "
+ << name();
+ return "<INVALID>";
+ }
+ return it->second;
+ }
+
+ // Returns the number of possible values for this feature type. This is one
+ // greater than the largest value in the value_names map.
+ FeatureValue GetDomainSize() const override { return domain_size_; }
+
+ protected:
+ // Maximum possible value this feature could take.
+ FeatureValue domain_size_ = 0;
+
+ // Names of feature values.
+ std::map<FeatureValue, std::string> value_names_;
+};
+
+// Feature type for binary features.
+class BinaryFeatureType : public FeatureType {
+ public:
+ BinaryFeatureType(const std::string &name, const std::string &off,
+ const std::string &on)
+ : FeatureType(name), off_(off), on_(on) {}
+
+ // Returns the feature name for a given feature value.
+ std::string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 0) return off_;
+ if (value == 1) return on_;
+ return "";
+ }
+
+ // Binary features always have two feature values.
+ FeatureValue GetDomainSize() const override { return 2; }
+
+ private:
+ // Feature value names for on and off.
+ std::string off_;
+ std::string on_;
+};
+
+// Feature type for numeric features.
+class NumericFeatureType : public FeatureType {
+ public:
+ // Initializes numeric feature.
+ NumericFeatureType(const std::string &name, FeatureValue size)
+ : FeatureType(name), size_(size) {}
+
+ // Returns numeric feature value.
+ std::string GetFeatureValueName(FeatureValue value) const override {
+ if (value < 0) return "";
+ return LiteStrCat(value);
+ }
+
+ // Returns the number of feature values.
+ FeatureValue GetDomainSize() const override { return size_; }
+
+ private:
+ // The underlying size of the numeric feature.
+ FeatureValue size_;
+};
+
+// Feature type for byte features, including an "outside" value.
+class ByteFeatureType : public NumericFeatureType {
+ public:
+ explicit ByteFeatureType(const std::string &name)
+ : NumericFeatureType(name, 257) {}
+
+ std::string GetFeatureValueName(FeatureValue value) const override {
+ if (value == 256) {
+ return "<NULL>";
+ }
+ std::string result;
+ result += static_cast<char>(value);
+ return result;
+ }
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_TYPES_H_
diff --git a/native/lang_id/common/fel/fel-parser.cc b/native/lang_id/common/fel/fel-parser.cc
new file mode 100644
index 0000000..2682941
--- /dev/null
+++ b/native/lang_id/common/fel/fel-parser.cc
@@ -0,0 +1,290 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/fel-parser.h"
+
+#include <ctype.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline bool IsValidCharAtStartOfIdentifier(char c) {
+ return isalpha(c) || (c == '_') || (c == '/');
+}
+
+// Returns true iff character c can appear inside an identifier.
+inline bool IsValidCharInsideIdentifier(char c) {
+ return isalnum(c) || (c == '_') || (c == '-') || (c == '/');
+}
+
+// Returns true iff character c can appear at the beginning of a number.
+inline bool IsValidCharAtStartOfNumber(char c) {
+ return isdigit(c) || (c == '+') || (c == '-');
+}
+
+// Returns true iff character c can appear inside a number.
+inline bool IsValidCharInsideNumber(char c) {
+ return isdigit(c) || (c == '.');
+}
+} // namespace
+
+bool FELParser::Initialize(const std::string &source) {
+ // Initialize parser state.
+ source_ = source;
+ current_ = source_.begin();
+ item_start_ = line_start_ = current_;
+ line_number_ = item_line_number_ = 1;
+
+ // Read first input item.
+ return NextItem();
+}
+
+void FELParser::ReportError(const std::string &error_message) {
+ const int position = item_start_ - line_start_ + 1;
+ const std::string line(line_start_, current_);
+
+ SAFTM_LOG(ERROR) << "Error in feature model, line " << item_line_number_
+ << ", position " << position << ": " << error_message
+ << "\n " << line << " <--HERE";
+}
+
+void FELParser::Next() {
+ // Move to the next input character. If we are at a line break update line
+ // number and line start position.
+ if (CurrentChar() == '\n') {
+ ++line_number_;
+ ++current_;
+ line_start_ = current_;
+ } else {
+ ++current_;
+ }
+}
+
+bool FELParser::NextItem() {
+ // Skip white space and comments.
+ while (!eos()) {
+ if (CurrentChar() == '#') {
+ // Skip comment.
+ while (!eos() && CurrentChar() != '\n') Next();
+ } else if (isspace(CurrentChar())) {
+ // Skip whitespace.
+ while (!eos() && isspace(CurrentChar())) Next();
+ } else {
+ break;
+ }
+ }
+
+ // Record start position for next item.
+ item_start_ = current_;
+ item_line_number_ = line_number_;
+
+ // Check for end of input.
+ if (eos()) {
+ item_type_ = END;
+ return true;
+ }
+
+ // Parse number.
+ if (IsValidCharAtStartOfNumber(CurrentChar())) {
+ std::string::iterator start = current_;
+ Next();
+ while (!eos() && IsValidCharInsideNumber(CurrentChar())) Next();
+ item_text_.assign(start, current_);
+ item_type_ = NUMBER;
+ return true;
+ }
+
+ // Parse string.
+ if (CurrentChar() == '"') {
+ Next();
+ std::string::iterator start = current_;
+ while (CurrentChar() != '"') {
+ if (eos()) {
+ ReportError("Unterminated string");
+ return false;
+ }
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = STRING;
+ Next();
+ return true;
+ }
+
+ // Parse identifier name.
+ if (IsValidCharAtStartOfIdentifier(CurrentChar())) {
+ std::string::iterator start = current_;
+ while (!eos() && IsValidCharInsideIdentifier(CurrentChar())) {
+ Next();
+ }
+ item_text_.assign(start, current_);
+ item_type_ = NAME;
+ return true;
+ }
+
+ // Single character item.
+ item_type_ = CurrentChar();
+ Next();
+ return true;
+}
+
+bool FELParser::Parse(const std::string &source,
+ FeatureExtractorDescriptor *result) {
+ // Initialize parser.
+ if (!Initialize(source)) {
+ return false;
+ }
+
+ while (item_type_ != END) {
+ // Current item should be a feature name.
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ std::string name = item_text_;
+ if (!NextItem()) {
+ return false;
+ }
+
+ if (item_type_ == '=') {
+ ReportError("Invalid syntax: feature expected");
+ return false;
+ } else {
+ // Parse feature.
+ FeatureFunctionDescriptor *descriptor = result->add_feature();
+ descriptor->set_type(name);
+ if (!ParseFeature(descriptor)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool FELParser::ParseFeature(FeatureFunctionDescriptor *result) {
+ // Parse argument and parameters.
+ if (item_type_ == '(') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ while (item_type_ == ',') {
+ if (!NextItem()) return false;
+ if (!ParseParameter(result)) return false;
+ }
+
+ if (item_type_ != ')') {
+ ReportError(") expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ }
+
+ // Parse feature name.
+ if (item_type_ == ':') {
+ if (!NextItem()) return false;
+ if (item_type_ != NAME && item_type_ != STRING) {
+ ReportError("Feature name expected");
+ return false;
+ }
+ std::string name = item_text_;
+ if (!NextItem()) return false;
+
+ // Set feature name.
+ result->set_name(name);
+ }
+
+ // Parse sub-features.
+ if (item_type_ == '.') {
+ // Parse dotted sub-feature.
+ if (!NextItem()) return false;
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ std::string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ } else if (item_type_ == '{') {
+ // Parse sub-feature block.
+ if (!NextItem()) return false;
+ while (item_type_ != '}') {
+ if (item_type_ != NAME) {
+ ReportError("Feature type name expected");
+ return false;
+ }
+ std::string type = item_text_;
+ if (!NextItem()) return false;
+
+ // Parse sub-feature.
+ FeatureFunctionDescriptor *subfeature = result->add_feature();
+ subfeature->set_type(type);
+ if (!ParseFeature(subfeature)) return false;
+ }
+ if (!NextItem()) return false;
+ }
+ return true;
+}
+
+bool FELParser::ParseParameter(FeatureFunctionDescriptor *result) {
+ if (item_type_ == NUMBER) {
+ int argument;
+ if (!LiteAtoi(item_text_, &argument)) {
+ ReportError("Unable to parse number");
+ return false;
+ }
+ if (!NextItem()) return false;
+
+ // Set default argument for feature.
+ result->set_argument(argument);
+ } else if (item_type_ == NAME) {
+ std::string name = item_text_;
+ if (!NextItem()) return false;
+ if (item_type_ != '=') {
+ ReportError("= expected");
+ return false;
+ }
+ if (!NextItem()) return false;
+ if (item_type_ >= END) {
+ ReportError("Parameter value expected");
+ return false;
+ }
+ std::string value = item_text_;
+ if (!NextItem()) return false;
+
+ // Add parameter to feature.
+ Parameter *parameter;
+ parameter = result->add_parameter();
+ parameter->set_name(name);
+ parameter->set_value(value);
+ } else {
+ ReportError("Syntax error in parameter list");
+ return false;
+ }
+ return true;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/fel-parser.h b/native/lang_id/common/fel/fel-parser.h
new file mode 100644
index 0000000..d2c454c
--- /dev/null
+++ b/native/lang_id/common/fel/fel-parser.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature extraction language (FEL) parser.
+//
+// BNF grammar for FEL:
+//
+// <feature model> ::= { <feature extractor> }
+//
+// <feature extractor> ::= <extractor spec> |
+// <extractor spec> '.' <feature extractor> |
+// <extractor spec> '{' { <feature extractor> } '}'
+//
+// <extractor spec> ::= <extractor type>
+// [ '(' <parameter list> ')' ]
+// [ ':' <extractor name> ]
+//
+// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
+//
+// <parameter> ::= <parameter name> '=' <parameter value>
+//
+// <extractor type> ::= NAME
+// <extractor name> ::= NAME | STRING
+// <argument> ::= NUMBER
+// <parameter name> ::= NAME
+// <parameter value> ::= NUMBER | STRING | NAME
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
+
+#include <string>
+
+#include "lang_id/common/fel/feature-descriptors.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FELParser {
+ public:
+ // Parses fml specification into feature extractor descriptor.
+ // Returns true on success, false on error (e.g., syntax errors).
+ bool Parse(const std::string &source, FeatureExtractorDescriptor *result);
+
+ private:
+ // Initializes the parser with the source text.
+ // Returns true on success, false on syntax error.
+ bool Initialize(const std::string &source);
+
+ // Outputs an error message, with context info.
+ void ReportError(const std::string &error_message);
+
+ // Moves to the next input character.
+ void Next();
+
+ // Moves to the next input item. Sets item_text_ and item_type_ accordingly.
+ // Returns true on success, false on syntax error.
+ bool NextItem();
+
+ // Parses a feature descriptor.
+ // Returns true on success, false on syntax error.
+ bool ParseFeature(FeatureFunctionDescriptor *result);
+
+ // Parses a parameter specification.
+ // Returns true on success, false on syntax error.
+ bool ParseParameter(FeatureFunctionDescriptor *result);
+
+ // Returns true if end of source input has been reached.
+ bool eos() const { return current_ >= source_.end(); }
+
+ // Returns current character. Other methods should access the current
+ // character through this method (instead of using *current_ directly): this
+ // method performs extra safety checks.
+ //
+ // In case of an unsafe access, returns '\0'.
+ char CurrentChar() const {
+ if ((current_ >= source_.begin()) && (current_ < source_.end())) {
+ return *current_;
+ } else {
+ SAFTM_LOG(ERROR) << "Unsafe char read";
+ return '\0';
+ }
+ }
+
+ // Item types.
+ enum ItemTypes {
+ END = 0,
+ NAME = -1,
+ NUMBER = -2,
+ STRING = -3,
+ };
+
+ // Source text.
+ std::string source_;
+
+ // Current input position.
+ std::string::iterator current_;
+
+ // Line number for current input position.
+ int line_number_;
+
+ // Start position for current item.
+ std::string::iterator item_start_;
+
+ // Start position for current line.
+ std::string::iterator line_start_;
+
+ // Line number for current item.
+ int item_line_number_;
+
+ // Item type for current item. If this is positive it is interpreted as a
+ // character. If it is negative it is interpreted as an item type.
+ int item_type_;
+
+ // Text for current item.
+ std::string item_text_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEL_PARSER_H_
diff --git a/native/lang_id/common/fel/task-context.cc b/native/lang_id/common/fel/task-context.cc
new file mode 100644
index 0000000..5e1d7f6
--- /dev/null
+++ b/native/lang_id/common/fel/task-context.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/task-context.h"
+
+#include <string>
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+std::string TaskContext::GetInputPath(const std::string &name) const {
+ auto it = inputs_.find(name);
+ if (it != inputs_.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+void TaskContext::SetInputPath(const std::string &name,
+ const std::string &path) {
+ inputs_[name] = path;
+}
+
+std::string TaskContext::Get(const std::string &name,
+ const char *defval) const {
+ auto it = parameters_.find(name);
+ if (it != parameters_.end()) {
+ return it->second;
+ }
+ return defval;
+}
+
+int TaskContext::Get(const std::string &name, int defval) const {
+ const std::string s = Get(name, "");
+ int value = defval;
+ if (LiteAtoi(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+float TaskContext::Get(const std::string &name, float defval) const {
+ const std::string s = Get(name, "");
+ float value = defval;
+ if (LiteAtof(s, &value)) {
+ return value;
+ }
+ return defval;
+}
+
+bool TaskContext::Get(const std::string &name, bool defval) const {
+ std::string value = Get(name, "");
+ return value.empty() ? defval : value == "true";
+}
+
+void TaskContext::SetParameter(const std::string &name,
+ const std::string &value) {
+ parameters_[name] = value;
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/task-context.h b/native/lang_id/common/fel/task-context.h
new file mode 100644
index 0000000..b6bcd92
--- /dev/null
+++ b/native/lang_id/common/fel/task-context.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
+
+#include <map>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Class that provides access to model parameter and inputs.
+//
+// Note: This class is related to the servers-side nlp_saft::TaskContext, but it
+// has been simplified to reduce code dependencies.
+class TaskContext {
+ public:
+ // Returns path for the input named |name|. Returns empty string ("") if
+ // there is no input with that name. Note: this can be a standard file path,
+ // or a path in a more special file system.
+ std::string GetInputPath(const std::string &name) const;
+
+ // Sets path for input |name|. Previous path, if any, is overwritten.
+ void SetInputPath(const std::string &name, const std::string &path);
+
+ // Returns parameter value. If the parameter is not specified in this
+ // context, the default value is returned.
+ std::string Get(const std::string &name, const char *defval) const;
+ int Get(const std::string &name, int defval) const;
+ float Get(const std::string &name, float defval) const;
+ bool Get(const std::string &name, bool defval) const;
+
+ // Sets value of parameter |name| to |value|.
+ void SetParameter(const std::string &name, const std::string &value);
+
+ private:
+ // Maps input name -> path.
+ std::map<std::string, std::string> inputs_;
+
+ // Maps parameter name -> value.
+ std::map<std::string, std::string> parameters_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_TASK_CONTEXT_H_
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
new file mode 100644
index 0000000..af41e29
--- /dev/null
+++ b/native/lang_id/common/fel/workspace.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/fel/workspace.h"
+
+#include <atomic>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// static
+int GetFreshTypeId() {
+ // Static local below is initialized the first time this method is run.
+ static std::atomic<int> counter(0);
+ return counter++;
+}
+
+std::string WorkspaceRegistry::DebugString() const {
+ std::string str;
+ for (auto &it : workspace_names_) {
+ const std::string &type_name = workspace_types_.at(it.first);
+ for (size_t index = 0; index < it.second.size(); ++index) {
+ const std::string &workspace_name = it.second[index];
+ str.append("\n ");
+ str.append(type_name);
+ str.append(" :: ");
+ str.append(workspace_name);
+ }
+ }
+ return str;
+}
+
+VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
+
+VectorIntWorkspace::VectorIntWorkspace(int size, int value)
+ : elements_(size, value) {}
+
+VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
+ : elements_(elements) {}
+
+std::string VectorIntWorkspace::TypeName() { return "Vector"; }
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
new file mode 100644
index 0000000..f13d802
--- /dev/null
+++ b/native/lang_id/common/fel/workspace.h
@@ -0,0 +1,205 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Notes on thread-safety: All of the classes here are thread-compatible. More
+// specifically, the registry machinery is thread-safe, as long as each thread
+// performs feature extraction on a different Sentence object.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
+
+#include <stddef.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// A base class for shared workspaces. Derived classes implement a static member
+// function TypeName() which returns a human readable string name for the class.
+class Workspace {
+ public:
+ // Polymorphic destructor.
+ virtual ~Workspace() {}
+
+ protected:
+ // Create an empty workspace.
+ Workspace() {}
+
+ private:
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(Workspace);
+};
+
+// Returns a new, strictly increasing int every time it is invoked.
+int GetFreshTypeId();
+
+// Struct to simulate typeid, but without RTTI.
+template <typename T>
+struct TypeId {
+ static int type_id;
+};
+
+template <typename T>
+int TypeId<T>::type_id = GetFreshTypeId();
+
+// A registry that keeps track of workspaces.
+class WorkspaceRegistry {
+ public:
+ // Create an empty registry.
+ WorkspaceRegistry() {}
+
+ // Returns the index of a named workspace, adding it to the registry first
+ // if necessary.
+ template <class W>
+ int Request(const std::string &name) {
+ const int id = TypeId<W>::type_id;
+ max_workspace_id_ = std::max(id, max_workspace_id_);
+ workspace_types_[id] = W::TypeName();
+ std::vector<std::string> &names = workspace_names_[id];
+ for (int i = 0; i < names.size(); ++i) {
+ if (names[i] == name) return i;
+ }
+ names.push_back(name);
+ return names.size() - 1;
+ }
+
+ // Returns the maximum workspace id that has been registered.
+ int MaxId() const {
+ return max_workspace_id_;
+ }
+
+ const std::unordered_map<int, std::vector<std::string> > &WorkspaceNames()
+ const {
+ return workspace_names_;
+ }
+
+ // Returns a string describing the registered workspaces.
+ std::string DebugString() const;
+
+ private:
+ // Workspace type names, indexed as workspace_types_[typeid].
+ std::unordered_map<int, std::string> workspace_types_;
+
+ // Workspace names, indexed as workspace_names_[typeid][workspace].
+ std::unordered_map<int, std::vector<std::string> > workspace_names_;
+
+ // The maximum workspace id that has been registered.
+ int max_workspace_id_ = 0;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
+};
+
+// A typed collected of workspaces. The workspaces are indexed according to an
+// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
+// also immutable.
+class WorkspaceSet {
+ public:
+ ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
+
+ // Returns true if a workspace has been set.
+ template <class W>
+ bool Has(int index) const {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ if (id >= workspaces_.size()) return false;
+ return workspaces_[id][index] != nullptr;
+ }
+
+ // Returns an indexed workspace; the workspace must have been set.
+ template <class W>
+ const W &Get(int index) const {
+ SAFTM_DCHECK(Has<W>(index));
+ const int id = TypeId<W>::type_id;
+ const Workspace *w = workspaces_[id][index];
+ return reinterpret_cast<const W &>(*w);
+ }
+
+ // Sets an indexed workspace; this takes ownership of the workspace, which
+ // must have been new-allocated. It is an error to set a workspace twice.
+ template <class W>
+ void Set(int index, W *workspace) {
+ const int id = TypeId<W>::type_id;
+ SAFTM_DCHECK_GE(id, 0);
+ SAFTM_DCHECK_LT(id, workspaces_.size());
+ SAFTM_DCHECK_GE(index, 0);
+ SAFTM_DCHECK_LT(index, workspaces_[id].size());
+ SAFTM_DCHECK(workspaces_[id][index] == nullptr);
+ SAFTM_DCHECK(workspace != nullptr);
+ workspaces_[id][index] = workspace;
+ }
+
+ void Reset(const WorkspaceRegistry ®istry) {
+ // Deallocate current workspaces.
+ for (auto &it : workspaces_) {
+ for (size_t index = 0; index < it.size(); ++index) {
+ delete it[index];
+ }
+ }
+ workspaces_.clear();
+ workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
+ for (auto &it : registry.WorkspaceNames()) {
+ workspaces_[it.first].resize(it.second.size());
+ }
+ }
+
+ private:
+ // The set of workspaces, indexed as workspaces_[typeid][index].
+ std::vector<std::vector<Workspace *> > workspaces_;
+};
+
+// A workspace that wraps around a vector of int.
+class VectorIntWorkspace : public Workspace {
+ public:
+ // Creates a vector of the given size.
+ explicit VectorIntWorkspace(int size);
+
+ // Creates a vector initialized with the given array.
+ explicit VectorIntWorkspace(const std::vector<int> &elements);
+
+ // Creates a vector of the given size, with each element initialized to the
+ // given value.
+ VectorIntWorkspace(int size, int value);
+
+ // Returns the name of this type of workspace.
+ static std::string TypeName();
+
+ // Returns the i'th element.
+ int element(int i) const { return elements_[i]; }
+
+ // Sets the i'th element.
+ void set_element(int i, int value) { elements_[i] = value; }
+
+ // Returns the size of the underlying vector.
+ int size() const { return elements_.size(); }
+
+ private:
+ // The enclosed vector.
+ std::vector<int> elements_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_WORKSPACE_H_
diff --git a/native/lang_id/common/file/file-utils.cc b/native/lang_id/common/file/file-utils.cc
new file mode 100644
index 0000000..1ee229f
--- /dev/null
+++ b/native/lang_id/common/file/file-utils.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/file-utils.h"
+
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+bool GetFileContent(const std::string &filename, std::string *content) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ SAFTM_LOG(ERROR) << "Error opening " << filename;
+ return false;
+ }
+ StringPiece sp = handle.to_stringpiece();
+ content->assign(sp.data(), sp.size());
+ return true;
+}
+
+bool FileExists(const std::string &filename) {
+ struct stat s = {0};
+ if (!stat(filename.c_str(), &s)) {
+ return s.st_mode & S_IFREG;
+ } else {
+ return false;
+ }
+}
+
+bool DirectoryExists(const std::string &dirpath) {
+ struct stat s = {0};
+ if (!stat(dirpath.c_str(), &s)) {
+ return s.st_mode & S_IFDIR;
+ } else {
+ return false;
+ }
+}
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/file/file-utils.h b/native/lang_id/common/file/file-utils.h
new file mode 100644
index 0000000..e8b0fef
--- /dev/null
+++ b/native/lang_id/common/file/file-utils.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace file_utils {
+
+// Reads the entire content of a file into a string. Returns true on success,
+// false on error.
+bool GetFileContent(const std::string &filename, std::string *content);
+
+// Parses a proto from its serialized representation in memory. That
+// representation starts at address |data| and should contain exactly
+// |num_bytes| bytes. Returns true on success, false otherwise.
+template <class Proto>
+bool ParseProtoFromMemory(const char *data, size_t num_bytes, Proto *proto) {
+ if (data == nullptr) {
+ // Avoid passing a nullptr to ParseFromArray below.
+ return false;
+ }
+ return proto->ParseFromArray(data, num_bytes);
+}
+
+// Convenience StringPiece-based version of ParseProtoFromMemory.
+template <class Proto>
+inline bool ParseProtoFromMemory(StringPiece sp, Proto *proto) {
+ return ParseProtoFromMemory(sp.data(), sp.size(), proto);
+}
+
+// Parses a proto from a file. Returns true on success, false otherwise.
+//
+// Note: the entire content of the file should be the binary (not
+// human-readable) serialization of a protocol buffer.
+//
+// Note: when we compile for Android, the proto parsing methods need to know the
+// type of the message they are parsing. We use template polymorphism for that.
+template <class Proto>
+bool ReadProtoFromFile(const std::string &filename, Proto *proto) {
+ ScopedMmap scoped_mmap(filename);
+ const MmapHandle &handle = scoped_mmap.handle();
+ if (!handle.ok()) {
+ return false;
+ }
+ return ParseProtoFromMemory(handle.to_stringpiece(), proto);
+}
+
+// Returns true if filename is the name of an existing file, and false
+// otherwise.
+bool FileExists(const std::string &filename);
+
+// Returns true if dirpath is the path to an existing directory, and false
+// otherwise.
+bool DirectoryExists(const std::string &dirpath);
+
+} // namespace file_utils
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_FILE_UTILS_H_
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
new file mode 100644
index 0000000..3dcdd3b
--- /dev/null
+++ b/native/lang_id/common/file/mmap.cc
@@ -0,0 +1,246 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/file/mmap.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdint.h>
+#include <string.h>
+#ifdef _WIN32
+#include <winbase.h>
+#include <windows.h>
+#else
+#include <sys/mman.h>
+#include <unistd.h>
+#endif
+#include <sys/stat.h>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace {
+inline MmapHandle GetErrorMmapHandle() { return MmapHandle(nullptr, 0); }
+} // anonymous namespace
+
+#ifdef _WIN32
+
+namespace {
+inline std::string GetLastSystemError() {
+ LPTSTR message_buffer;
+ DWORD error_code = GetLastError();
+ FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
+ FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+ (LPTSTR)&message_buffer, 0, NULL);
+ std::string result(message_buffer);
+ LocalFree(message_buffer);
+ return result;
+}
+
+// Class for automatically closing a Win32 HANDLE on exit from a scope.
+class Win32HandleCloser {
+ public:
+ explicit Win32HandleCloser(HANDLE handle) : handle_(handle) {}
+ ~Win32HandleCloser() {
+ bool result = CloseHandle(handle_);
+ if (!result) {
+ const DWORD last_error = GetLastError();
+ SAFTM_LOG(ERROR) << "Error closing handle: " << last_error << ": "
+ << GetLastSystemError();
+ }
+ }
+
+ private:
+ const HANDLE handle_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(Win32HandleCloser);
+};
+} // namespace
+
+MmapHandle MmapFile(const std::string &filename) {
+ HANDLE handle =
+ CreateFile(filename.c_str(), // File to open.
+ GENERIC_READ, // Open for reading.
+ FILE_SHARE_READ, // Share for reading.
+ NULL, // Default security.
+ OPEN_EXISTING, // Existing file only.
+ FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, // Normal file.
+ NULL); // No attr. template.
+ if (handle == INVALID_HANDLE_VALUE) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ // Make sure we close handle no matter how we exit this function.
+ Win32HandleCloser handle_closer(handle);
+
+ return MmapFile(handle);
+}
+
+MmapHandle MmapFile(HANDLE file_handle) {
+ // Get the file size.
+ DWORD file_size_high = 0;
+ DWORD file_size_low = GetFileSize(file_handle, &file_size_high);
+ if (file_size_low == INVALID_FILE_SIZE && GetLastError() != NO_ERROR) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ size_t file_size_in_bytes = (static_cast<size_t>(file_size_high) << 32) +
+ static_cast<size_t>(file_size_low);
+
+ // Create a file mapping object that refers to the file.
+ HANDLE file_mapping_object =
+ CreateFileMappingA(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr);
+ if (file_mapping_object == NULL) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ Win32HandleCloser handle_closer(file_mapping_object);
+
+ // Map the file mapping object into memory.
+ void *mmap_addr =
+ MapViewOfFile(file_mapping_object, FILE_MAP_READ, 0, 0, // File offset.
+ 0 // Number of bytes to map; 0 means map the whole file.
+ );
+ if (mmap_addr == nullptr) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ return MmapHandle(mmap_addr, file_size_in_bytes);
+}
+
+bool Unmap(MmapHandle mmap_handle) {
+ if (!mmap_handle.ok()) {
+ // Unmapping something that hasn't been mapped is trivially successful.
+ return true;
+ }
+ bool succeeded = UnmapViewOfFile(mmap_handle.start());
+ if (!succeeded) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error during Unmap / UnmapViewOfFile: " << last_error;
+ return false;
+ }
+ return true;
+}
+
+#else
+
+namespace {
+inline std::string GetLastSystemError() { return std::string(strerror(errno)); }
+
+class FileCloser {
+ public:
+ explicit FileCloser(int fd) : fd_(fd) {}
+ ~FileCloser() {
+ int result = close(fd_);
+ if (result != 0) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
+ }
+ }
+ private:
+ const int fd_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(FileCloser);
+};
+} // namespace
+
+MmapHandle MmapFile(const std::string &filename) {
+ int fd = open(filename.c_str(), O_RDONLY);
+
+ if (fd < 0) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error opening " << filename << ": " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ // Make sure we close fd no matter how we exit this function. As the man page
+ // for mmap clearly states: "closing the file descriptor does not unmap the
+ // region." Hence, we can close fd as soon as we return from here.
+ FileCloser file_closer(fd);
+
+ return MmapFile(fd);
+}
+
+MmapHandle MmapFile(int fd) {
+ // Get file stats to obtain file size.
+ struct stat sb;
+ if (fstat(fd, &sb) != 0) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Unable to stat fd: " << last_error;
+ return GetErrorMmapHandle();
+ }
+ size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
+
+ // Perform actual mmap.
+ void *mmap_addr = mmap(
+
+ // Let system pick address for mmapp-ed data.
+ nullptr,
+
+ // Mmap all bytes from the file.
+ file_size_in_bytes,
+
+ // One can read / write the mapped data (but see MAP_PRIVATE below).
+ // Normally, we expect only to read it, but in the future, we may want to
+ // write it, to fix e.g., endianness differences.
+ PROT_READ | PROT_WRITE,
+
+ // Updates to mmaped data are *not* propagated to actual file.
+ // AFAIK(salcianu) that's anyway not possible on Android.
+ MAP_PRIVATE,
+
+ // Descriptor of file to mmap.
+ fd,
+
+ // Map bytes right from the beginning of the file. This, and
+ // file_size_in_bytes (2nd argument) means we map all bytes from the file.
+ 0);
+ if (mmap_addr == MAP_FAILED) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
+ return GetErrorMmapHandle();
+ }
+
+ return MmapHandle(mmap_addr, file_size_in_bytes);
+}
+
+bool Unmap(MmapHandle mmap_handle) {
+ if (!mmap_handle.ok()) {
+ // Unmapping something that hasn't been mapped is trivially successful.
+ return true;
+ }
+ if (munmap(mmap_handle.start(), mmap_handle.num_bytes()) != 0) {
+ const std::string last_error = GetLastSystemError();
+ SAFTM_LOG(ERROR) << "Error during Unmap / munmap: " << last_error;
+ return false;
+ }
+ return true;
+}
+
+#endif
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/file/mmap.h b/native/lang_id/common/file/mmap.h
new file mode 100644
index 0000000..f785465
--- /dev/null
+++ b/native/lang_id/common/file/mmap.h
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#endif
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Handle for a memory area where a file has been mmapped.
+//
+// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
+// it using Unmap(). Just like a pointer, it is passed around by value (see
+// signature of MmapFile and Unmap; fortunately, it's a small class, so there
+// shouldn't be any significant performance penalty) and its usage is not
+// necessarily scoped (that's why the destructor is not performing the unmap).
+//
+// Note: on program termination, each still unmapped file is automatically
+// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
+// are ok keeping that file in memory the whole time).
+class MmapHandle {
+ public:
+ MmapHandle(void *start, size_t num_bytes)
+ : start_(start), num_bytes_(num_bytes) {}
+
+ // Returns start address for the memory area where a file has been mmapped.
+ void *start() const { return start_; }
+
+ // Returns number of bytes of the memory area from start().
+ size_t num_bytes() const { return num_bytes_; }
+
+ // Shortcut to simplify checking success of MmapFile(). See usage example
+ // from the doc of that function.
+ bool ok() const { return start() != nullptr; }
+
+ // Returns a StringPiece pointing to the same underlying bytes.
+ StringPiece to_stringpiece() const {
+ return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
+ }
+
+ private:
+ // See doc for start(). Not owned.
+ void *const start_;
+
+ // See doc for num_bytes().
+ const size_t num_bytes_;
+};
+
+// Maps the full content of a file in memory (using mmap).
+//
+// When done using the file content, one can unmap using Unmap(). Otherwise,
+// all mapped files are unmapped when the program terminates.
+//
+// Sample usage:
+//
+// MmapHandle mmap_handle = MmapFile(filename);
+// CHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
+//
+// ... use data from addresses
+// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
+//
+// Unmap(mmap_handle); // Unmap logs errors internally.
+//
+// Note: one can read *and* write the num_bytes bytes from start, but those
+// writes are not propagated to the underlying file, nor to other processes that
+// may have mmapped that file (all changes are local to current process).
+MmapHandle MmapFile(const std::string &filename);
+
+#ifdef _WIN32
+using FileDescriptorOrHandle = HANDLE;
+#else
+using FileDescriptorOrHandle = int;
+#endif
+
+// Like MmapFile(const std::string &filename), but uses a file descriptor.
+MmapHandle MmapFile(FileDescriptorOrHandle fd);
+
+// Unmaps a file mapped using MmapFile. Returns true on success, false
+// otherwise.
+bool Unmap(MmapHandle mmap_handle);
+
+// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
+// destruction.
+class ScopedMmap {
+ public:
+ explicit ScopedMmap(const std::string &filename)
+ : handle_(MmapFile(filename)) {}
+
+ explicit ScopedMmap(FileDescriptorOrHandle fd) : handle_(MmapFile(fd)) {}
+
+ ~ScopedMmap() {
+ if (handle_.ok()) {
+ Unmap(handle_);
+ }
+ }
+
+ const MmapHandle &handle() { return handle_; }
+
+ private:
+ MmapHandle handle_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FILE_MMAP_H_
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
similarity index 100%
rename from lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
rename to native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.cc
diff --git a/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h b/native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
similarity index 100%
rename from lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
rename to native/lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h
diff --git a/lang_id/common/flatbuffers/embedding-network.fbs b/native/lang_id/common/flatbuffers/embedding-network.fbs
similarity index 100%
rename from lang_id/common/flatbuffers/embedding-network.fbs
rename to native/lang_id/common/flatbuffers/embedding-network.fbs
diff --git a/native/lang_id/common/flatbuffers/model-utils.cc b/native/lang_id/common/flatbuffers/model-utils.cc
new file mode 100644
index 0000000..66f7f38
--- /dev/null
+++ b/native/lang_id/common/flatbuffers/model-utils.cc
@@ -0,0 +1,210 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/flatbuffers/model-utils.h"
+
+#include <string.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/checksum.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+namespace {
+
+// Returns true if we have clear evidence that |model| fails its checksum.
+//
+// E.g., if |model| has the crc32 field, and the value of that field does not
+// match the checksum, then this function returns true. If there is no crc32
+// field, then we don't know what the original (at build time) checksum was, so
+// we don't know anything clear and this function returns false.
+bool ClearlyFailsChecksum(const Model &model) {
+ if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
+ SAFTM_LOG(WARNING)
+ << "No CRC32, most likely an old model; skip CRC32 check";
+ return false;
+ }
+ const mobile::uint32 expected_crc32 = model.crc32();
+ const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
+ if (actual_crc32 != expected_crc32) {
+ SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
+ << " vs " << expected_crc32;
+ return true;
+ }
+ SAFTM_DLOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
+ return false;
+}
+} // namespace
+
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
+ if ((data == nullptr) || (num_bytes == 0)) {
+ SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
+ return nullptr;
+ }
+ const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
+ flatbuffers::Verifier verifier(start, num_bytes);
+ if (!VerifyModelBuffer(verifier)) {
+ SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
+ return nullptr;
+ }
+ const Model *model = GetModel(start);
+ if (model == nullptr) {
+ return nullptr;
+ }
+ if (ClearlyFailsChecksum(*model)) {
+ return nullptr;
+ }
+ return model;
+}
+
+const ModelInput *GetInputByName(const Model *model, const std::string &name) {
+ if (model == nullptr) {
+ SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
+ return nullptr;
+ }
+ const auto *inputs = model->inputs();
+ if (inputs == nullptr) {
+ // We should always have a list of inputs; maybe an empty one, if no inputs,
+ // but the list should be there.
+ SAFTM_LOG(ERROR) << "null inputs";
+ return nullptr;
+ }
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ const flatbuffers::String *input_name = input->name();
+ if (input_name && input_name->str() == name) {
+ return input;
+ }
+ }
+ }
+ return nullptr;
+}
+
+mobile::StringPiece GetInputBytes(const ModelInput *input) {
+ if ((input == nullptr) || (input->data() == nullptr)) {
+ SAFTM_LOG(ERROR) << "ModelInput has no content";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ const flatbuffers::Vector<uint8_t> *input_data = input->data();
+ if (input_data == nullptr) {
+ SAFTM_LOG(ERROR) << "null input data";
+ return mobile::StringPiece(nullptr, 0);
+ }
+ return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
+ input_data->size());
+}
+
+bool FillParameters(const Model &model, mobile::TaskContext *context) {
+ if (context == nullptr) {
+ SAFTM_LOG(ERROR) << "null context";
+ return false;
+ }
+ const auto *parameters = model.parameters();
+ if (parameters == nullptr) {
+ // We should always have a list of parameters; maybe an empty one, if no
+ // parameters, but the list should be there.
+ SAFTM_LOG(ERROR) << "null list of parameters";
+ return false;
+ }
+ for (const ModelParameter *p : *parameters) {
+ if (p == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter";
+ return false;
+ }
+ if (p->name() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ const std::string name = p->name()->str();
+ if (name.empty()) {
+ SAFTM_LOG(ERROR) << "empty parameter name";
+ return false;
+ }
+ if (p->value() == nullptr) {
+ SAFTM_LOG(ERROR) << "null parameter name";
+ return false;
+ }
+ context->SetParameter(name, p->value()->str());
+ }
+ return true;
+}
+
+namespace {
+// Updates |*crc| with the information from |s|. Auxiliary for
+// ComputeCrc2Checksum.
+//
+// The bytes from |info| are also used to update the CRC32 checksum. |info|
+// should be a brief tag that indicates what |s| represents. The idea is to add
+// some structure to the information that goes into the CRC32 computation.
+template <typename T>
+void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
+ mobile::StringPiece info) {
+ crc->Update("|");
+ crc->Update(info.data(), info.size());
+ crc->Update(":");
+ if (s == nullptr) {
+ crc->Update("empty");
+ } else {
+ crc->Update(reinterpret_cast<const char *>(s->data()),
+ s->size() * sizeof(T));
+ }
+}
+} // namespace
+
+mobile::uint32 ComputeCrc2Checksum(const Model *model) {
+ // Implementation note: originally, I (salcianu@) thought we can just compute
+ // a CRC32 checksum of the model bytes. Unfortunately, the expected checksum
+ // is there too (and because we don't control the flatbuffer format, we can't
+ // "arrange" for it to be placed at the head / tail of those bytes). Instead,
+ // we traverse |model| and feed into the CRC32 computation those parts we are
+ // interested in (which excludes the crc32 field).
+ //
+ // Note: storing the checksum outside the Model would be too disruptive for
+ // the way we currently ship our models.
+ mobile::Crc32 crc;
+ if (model == nullptr) {
+ return crc.Get();
+ }
+ crc.Update("|Parameters:");
+ const auto *parameters = model->parameters();
+ if (parameters != nullptr) {
+ for (const ModelParameter *p : *parameters) {
+ if (p != nullptr) {
+ UpdateCrc(&crc, p->name(), "name");
+ UpdateCrc(&crc, p->value(), "value");
+ }
+ }
+ }
+ crc.Update("|Inputs:");
+ const auto *inputs = model->inputs();
+ if (inputs != nullptr) {
+ for (const ModelInput *input : *inputs) {
+ if (input != nullptr) {
+ UpdateCrc(&crc, input->name(), "name");
+ UpdateCrc(&crc, input->type(), "type");
+ UpdateCrc(&crc, input->sub_type(), "sub-type");
+ UpdateCrc(&crc, input->data(), "data");
+ }
+ }
+ }
+ return crc.Get();
+}
+
+} // namespace saft_fbs
+} // namespace nlp_saft
diff --git a/native/lang_id/common/flatbuffers/model-utils.h b/native/lang_id/common/flatbuffers/model-utils.h
new file mode 100644
index 0000000..197e1e3
--- /dev/null
+++ b/native/lang_id/common/flatbuffers/model-utils.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace saft_fbs {
+
+// Verifies that the |num_bytes| bytes that start at |data| represent a valid
+// Model flatbuffer. If so, returns that Model. Otherwise, returns nullptr.
+//
+// Note: if the Model has the crc32 field, this method checks that the Model
+// checksum matches that field; if they don't match, the Model is considered
+// invalid, and this function returns nullptr. The checksum test is in addition
+// to the standard flatbuffer validity checking.
+const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes);
+
+// Convenience StringPiece version of GetVerifiedModelFromBytes.
+inline const Model *GetVerifiedModelFromBytes(mobile::StringPiece bytes) {
+ return GetVerifiedModelFromBytes(bytes.data(), bytes.size());
+}
+
+// Returns the |model| input with specified |name|. Returns nullptr if no such
+// input exists. If |model| contains multiple inputs with that |name|, returns
+// the first one (model builders should avoid building such models).
+const ModelInput *GetInputByName(const Model *model, const std::string &name);
+
+// Returns a StringPiece pointing to the bytes for the content of |input|. In
+// case of errors, returns StringPiece(nullptr, 0).
+mobile::StringPiece GetInputBytes(const ModelInput *input);
+
+// Fills parameters from |context|, based on the parameters from |model|.
+// Returns false if any error is encountered, true otherwise. In the case of an
+// error, some parameters may have been added to |context| (e.g., if we find a
+// problem with the 3rd parameter, the first 2 have been added).
+bool FillParameters(const Model &model, mobile::TaskContext *context);
+
+// Returns the CRC32 checksum of |model|. This checksum is computed over the
+// entire information from the model (including the bytes of the inputs),
+// *except* the crc32 field. Hence, when a model is build, one can store the
+// result of this function into that field; on the user side, one can check that
+// the result of this function matches the crc32 field, to guard against model
+// corruption. GetVerifiedModelFromBytes performs this check.
+mobile::uint32 ComputeCrc2Checksum(const Model *model);
+
+} // namespace saft_fbs
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_MODEL_UTILS_H_
diff --git a/lang_id/common/flatbuffers/model.fbs b/native/lang_id/common/flatbuffers/model.fbs
similarity index 100%
rename from lang_id/common/flatbuffers/model.fbs
rename to native/lang_id/common/flatbuffers/model.fbs
diff --git a/lang_id/common/lite_base/attributes.h b/native/lang_id/common/lite_base/attributes.h
similarity index 100%
rename from lang_id/common/lite_base/attributes.h
rename to native/lang_id/common/lite_base/attributes.h
diff --git a/lang_id/common/lite_base/casts.h b/native/lang_id/common/lite_base/casts.h
similarity index 100%
rename from lang_id/common/lite_base/casts.h
rename to native/lang_id/common/lite_base/casts.h
diff --git a/lang_id/common/lite_base/compact-logging-levels.h b/native/lang_id/common/lite_base/compact-logging-levels.h
similarity index 100%
rename from lang_id/common/lite_base/compact-logging-levels.h
rename to native/lang_id/common/lite_base/compact-logging-levels.h
diff --git a/native/lang_id/common/lite_base/compact-logging-raw.cc b/native/lang_id/common/lite_base/compact-logging-raw.cc
new file mode 100644
index 0000000..27c6446
--- /dev/null
+++ b/native/lang_id/common/lite_base/compact-logging-raw.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_base/compact-logging-raw.h"
+
+#include <stdio.h>
+
+#include <string>
+
+// NOTE: this file contains two implementations: one for Android, one for all
+// other cases. We always build exactly one implementation.
+#if defined(__ANDROID__)
+
+// Compiled as part of Android.
+#include <android/log.h>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to level for __android_log_write.
+int GetAndroidLogLevel(LogSeverity severity) {
+ switch (severity) {
+ case FATAL:
+ return ANDROID_LOG_FATAL;
+ case ERROR:
+ return ANDROID_LOG_ERROR;
+ case WARNING:
+ return ANDROID_LOG_WARN;
+ case INFO:
+ return ANDROID_LOG_INFO;
+ default:
+ return ANDROID_LOG_DEBUG;
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message) {
+ const int android_log_level = GetAndroidLogLevel(severity);
+#if !defined(SAFTM_DEBUG_LOGGING)
+ if (android_log_level != ANDROID_LOG_ERROR &&
+ android_log_level != ANDROID_LOG_FATAL) {
+ return;
+ }
+#endif
+ __android_log_write(android_log_level, tag.c_str(), message.c_str());
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#else // if defined(__ANDROID__)
+
+// Not on Android: implement LowLevelLogging to print to stderr (see below).
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+namespace {
+// Converts LogSeverity to human-readable text.
+const char *LogSeverityToString(LogSeverity severity) {
+ switch (severity) {
+ case INFO:
+ return "INFO";
+ case WARNING:
+ return "WARNING";
+ case ERROR:
+ return "ERROR";
+ case FATAL:
+ return "FATAL";
+ default:
+ return "UNKNOWN";
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message) {
+ fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
+ message.c_str());
+ fflush(stderr);
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // if defined(__ANDROID__)
diff --git a/native/lang_id/common/lite_base/compact-logging-raw.h b/native/lang_id/common/lite_base/compact-logging-raw.h
new file mode 100644
index 0000000..d77a990
--- /dev/null
+++ b/native/lang_id/common/lite_base/compact-logging-raw.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// Low-level logging primitive. Logs a message, with the indicated log
+// severity. From android/log.h: "the tag normally corresponds to the component
+// that emits the log message, and should be reasonably small".
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message);
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_RAW_H_
diff --git a/lang_id/common/lite_base/compact-logging.cc b/native/lang_id/common/lite_base/compact-logging.cc
similarity index 100%
rename from lang_id/common/lite_base/compact-logging.cc
rename to native/lang_id/common/lite_base/compact-logging.cc
diff --git a/native/lang_id/common/lite_base/compact-logging.h b/native/lang_id/common/lite_base/compact-logging.h
new file mode 100644
index 0000000..29450b1
--- /dev/null
+++ b/native/lang_id/common/lite_base/compact-logging.h
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
+
+#include <cassert>
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/compact-logging-levels.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace internal_logging {
+
+// A tiny code footprint string stream for assembling log messages.
+struct LoggingStringStream {
+ LoggingStringStream() {}
+ LoggingStringStream &stream() { return *this; }
+
+ // Needed for invocation in SAFTM_CHECK macro.
+ explicit operator bool() const { return true; }
+
+ std::string message;
+};
+
+template <typename T>
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const T &entry) {
+ stream.message.append(std::to_string(entry));
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const char *message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ const std::string &message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream &operator<<(LoggingStringStream &stream,
+ StringPiece sp) {
+ stream.message.append(sp.data(), sp.size());
+ return stream;
+}
+
+// The class that does all the work behind our SAFTM_LOG(severity) macros. Each
+// SAFTM_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
+// LogMessage temporary object containing a stringstream. Each operator<< adds
+// info to that stringstream and the LogMessage destructor performs the actual
+// logging. The reason this works is that in C++, "all temporary objects are
+// destroyed as the last step in evaluating the full-expression that (lexically)
+// contains the point where they were created." For more info, see
+// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
+// invoked after the last << from that logging statement.
+class LogMessage {
+ public:
+ LogMessage(LogSeverity severity, const char *file_name,
+ int line_number) SAFTM_ATTRIBUTE_NOINLINE;
+
+ ~LogMessage() SAFTM_ATTRIBUTE_NOINLINE;
+
+ // Returns the stream associated with the logger object.
+ LoggingStringStream &stream() { return stream_; }
+
+ private:
+ const LogSeverity severity_;
+
+ // Stream that "prints" all info into a string (not to a file). We construct
+ // here the entire logging message and next print it in one operation.
+ LoggingStringStream stream_;
+};
+
+// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
+// anything.
+class NullStream {
+ public:
+ NullStream() {}
+ NullStream &stream() { return *this; }
+};
+template <typename T>
+inline NullStream &operator<<(NullStream &str, const T &) {
+ return str;
+}
+
+} // namespace internal_logging
+} // namespace mobile
+} // namespace nlp_saft
+
+#define SAFTM_LOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::severity, __FILE__, __LINE__) \
+ .stream()
+
+// If condition x is true, does nothing. Otherwise, crashes the program (liek
+// LOG(FATAL)) with an informative message. Can be continued with extra
+// messages, via <<, like any logging macro, e.g.,
+//
+// SAFTM_CHECK(my_cond) << "I think we hit a problem";
+#define SAFTM_CHECK(x) \
+ (x) || SAFTM_LOG(FATAL) << __FILE__ << ":" << __LINE__ \
+ << ": check failed: \"" << #x
+
+#define SAFTM_CHECK_EQ(x, y) SAFTM_CHECK((x) == (y))
+#define SAFTM_CHECK_LT(x, y) SAFTM_CHECK((x) < (y))
+#define SAFTM_CHECK_GT(x, y) SAFTM_CHECK((x) > (y))
+#define SAFTM_CHECK_LE(x, y) SAFTM_CHECK((x) <= (y))
+#define SAFTM_CHECK_GE(x, y) SAFTM_CHECK((x) >= (y))
+#define SAFTM_CHECK_NE(x, y) SAFTM_CHECK((x) != (y))
+
+#define SAFTM_NULLSTREAM \
+ ::libtextclassifier3::mobile::internal_logging::NullStream().stream()
+
+// Debug checks: a SAFTM_DCHECK<suffix> macro should behave like
+// SAFTM_CHECK<suffix> in debug mode an don't check / don't print anything in
+// non-debug mode.
+#ifdef NDEBUG
+
+#define SAFTM_DCHECK(x) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GT(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_LE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_GE(x, y) SAFTM_NULLSTREAM
+#define SAFTM_DCHECK_NE(x, y) SAFTM_NULLSTREAM
+
+// In non-debug mode, SAFT_DLOG statements do not generate any logging.
+#define SAFTM_DLOG(severity) SAFTM_NULLSTREAM
+
+#else // NDEBUG
+
+// In debug mode, each SAFTM_DCHECK<suffix> is equivalent to
+// SAFTM_CHECK<suffix>, i.e., a real check that crashes when the condition is
+// not true.
+#define SAFTM_DCHECK(x) SAFTM_CHECK(x)
+#define SAFTM_DCHECK_EQ(x, y) SAFTM_CHECK_EQ(x, y)
+#define SAFTM_DCHECK_LT(x, y) SAFTM_CHECK_LT(x, y)
+#define SAFTM_DCHECK_GT(x, y) SAFTM_CHECK_GT(x, y)
+#define SAFTM_DCHECK_LE(x, y) SAFTM_CHECK_LE(x, y)
+#define SAFTM_DCHECK_GE(x, y) SAFTM_CHECK_GE(x, y)
+#define SAFTM_DCHECK_NE(x, y) SAFTM_CHECK_NE(x, y)
+
+// In debug mode, SAFT_DLOG statements are like SAFT_LOG.
+#define SAFTM_DLOG SAFTM_LOG
+
+#endif // NDEBUG
+
+#ifdef LIBTEXTCLASSIFIER_VLOG
+#define SAFTM_VLOG(severity) \
+ ::libtextclassifier3::mobile::internal_logging::LogMessage( \
+ ::libtextclassifier3::mobile::internal_logging::INFO, __FILE__, __LINE__) \
+ .stream()
+#else
+#define SAFTM_VLOG(severity) SAFTM_NULLSTREAM
+#endif
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_BASE_COMPACT_LOGGING_H_
diff --git a/lang_id/common/lite_base/endian.h b/native/lang_id/common/lite_base/endian.h
similarity index 100%
rename from lang_id/common/lite_base/endian.h
rename to native/lang_id/common/lite_base/endian.h
diff --git a/lang_id/common/lite_base/float16.h b/native/lang_id/common/lite_base/float16.h
similarity index 100%
rename from lang_id/common/lite_base/float16.h
rename to native/lang_id/common/lite_base/float16.h
diff --git a/lang_id/common/lite_base/integral-types.h b/native/lang_id/common/lite_base/integral-types.h
similarity index 100%
rename from lang_id/common/lite_base/integral-types.h
rename to native/lang_id/common/lite_base/integral-types.h
diff --git a/lang_id/common/lite_base/logging.h b/native/lang_id/common/lite_base/logging.h
similarity index 100%
rename from lang_id/common/lite_base/logging.h
rename to native/lang_id/common/lite_base/logging.h
diff --git a/lang_id/common/lite_base/macros.h b/native/lang_id/common/lite_base/macros.h
similarity index 100%
rename from lang_id/common/lite_base/macros.h
rename to native/lang_id/common/lite_base/macros.h
diff --git a/native/lang_id/common/lite_strings/numbers.cc b/native/lang_id/common/lite_strings/numbers.cc
new file mode 100644
index 0000000..f933f04
--- /dev/null
+++ b/native/lang_id/common/lite_strings/numbers.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/common/lite_strings/numbers.h"
+
+#include <ctype.h>
+#include <stdlib.h>
+
+#include <climits>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns true if the characters that start at address ptr (inclusive) and stop
+// at the first '\0' consist of only whitespaces, as determined by isspace().
+// Note: this function returns false if ptr is nullptr.
+static bool OnlyWhitespaces(const char *ptr) {
+ if (ptr == nullptr) {
+ return false;
+ }
+ for (; *ptr != '\0'; ++ptr) {
+ if (!isspace(*ptr)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LiteAtoi(const char *c_str, int *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // Short version of man strtol:
+ //
+ // strtol parses some optional whitespaces, an optional +/- sign, and next a
+ // succession of digits. If it finds some digits, it sets temp to point to
+ // the first character after that succession of digits and returns the parsed
+ // integer.
+ //
+ // If there were no digits at all, strtol() sets temp to be c_str (the start
+ // address) and returns 0.
+ char *temp = nullptr;
+ const long int parsed_value = strtol(c_str, &temp, 0); // NOLINT
+
+ // Check for overflow. Note: to simplify the code, we assume that LONG_MIN /
+ // LONG_MAX means that strtol encountered an overflow (normally, in that case,
+ // one should also inspect errno). Hence, we maybe give up the possibility to
+ // parse one extreme value on each side (min/max). That should be ok.
+ if ((parsed_value == LONG_MIN) || (parsed_value == LONG_MAX) ||
+ (parsed_value < INT_MIN) || (parsed_value > INT_MAX)) {
+ return false;
+ }
+ *value = static_cast<int>(parsed_value);
+
+ // First part of the expression below means that the input string contained at
+ // least one digit. The other part checks that what remains after the number
+ // (if anything) consists only of whitespaces.
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+bool LiteAtof(const char *c_str, float *value) {
+ if (c_str == nullptr) {
+ return false;
+ }
+
+ // strtof is similar to strtol, see more detailed comments inside LiteAtoi.
+ char *temp = nullptr;
+ *value = strtof(c_str, &temp);
+ return (temp != c_str) && OnlyWhitespaces(temp);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/common/lite_strings/numbers.h b/native/lang_id/common/lite_strings/numbers.h
new file mode 100644
index 0000000..f832a96
--- /dev/null
+++ b/native/lang_id/common/lite_strings/numbers.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
+
+#include <string>
+
+#include "lang_id/common/lite_strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Parses an int from a C-style string; similar to absl::SimpleAtoi.
+//
+// c_str should point to a zero-terminated array of chars that contains the
+// number representation as (a) "<radix-10-number>" (e.g., "721"), (b)
+// "0x<radix-16-number>" (e.g., "0xa1"), or (c) "0<radix-8-number>" (e.g.,
+// "017201"). Whitespaces (as determined by isspace()) are allowed before and
+// after the number representation (but obviously not in the middle).
+//
+// Stores parsed number into *value. Returns true on success, false on error.
+// Note: presence of extra non-whitespace characters after the number counts as
+// an error: e.g., parsing "123a" will return false due to the extra "a" (which
+// is not a valid radix-10 digit). This function also returns false for strings
+// that do not contain any digit (e.g., ""), as well as for overflows /
+// underflows.
+bool LiteAtoi(const char *c_str, int *value);
+
+inline bool LiteAtoi(const std::string &s, int *value) {
+ return LiteAtoi(s.c_str(), value);
+}
+
+inline bool LiteAtoi(StringPiece sp, int *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const std::string temp(sp.data(), sp.size());
+ return LiteAtoi(temp.c_str(), value);
+}
+
+// Like LiteAtoi, but for float; similar to absl::SimpleAtof.
+//
+// NOTE: currently, does not properly handle overflow / underflow.
+// TODO(salcianu): fix that.
+bool LiteAtof(const char *c_str, float *value);
+
+inline bool LiteAtof(const std::string &s, float *value) {
+ return LiteAtof(s.c_str(), value);
+}
+
+inline bool LiteAtof(StringPiece sp, float *value) {
+ // Unfortunately, we can't directly call LiteAtoi(sp.data()): LiteAtoi(const
+ // char *) needs a zero-terminated string.
+ const std::string temp(sp.data(), sp.size());
+ return LiteAtof(temp.c_str(), value);
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_NUMBERS_H_
diff --git a/native/lang_id/common/lite_strings/str-cat.h b/native/lang_id/common/lite_strings/str-cat.h
new file mode 100644
index 0000000..25cec4d
--- /dev/null
+++ b/native/lang_id/common/lite_strings/str-cat.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
+
+// Less efficient but more compact versions of several absl string utils.
+//
+// "More compact" means "pulls in fewer code dependencies". That's useful if
+// one tries to minimize the code size.
+//
+// Note: the name and the signature of the functions from this header were
+// chosen to minimize the effort of converting code that uses absl::LiteStrCat &
+// co to our more compact functions.
+
+#include <string>
+
+#ifdef COMPILER_MSVC
+#include <sstream>
+#endif // COMPILER_MSVC
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Less efficient but more compact version of absl::LiteStrCat().
+//
+// Given a value v (see supported types below) LiteStrCat(v) returns a new
+// string that contains the representation of v. For examples, see
+// str-cat_test.cc.
+template <typename T>
+inline std::string LiteStrCat(T v) {
+#ifdef COMPILER_MSVC
+ std::stringstream stream;
+ stream << v;
+ return stream.str();
+#else
+ return std::to_string(v);
+#endif
+}
+
+template <>
+inline std::string LiteStrCat(const char *v) {
+ return std::string(v);
+}
+
+// TODO(salcianu): use a reference type (const std::string &). For some reason,
+// I couldn't get that to work on a first try.
+template <>
+inline std::string LiteStrCat(std::string v) {
+ return v;
+}
+
+template <>
+inline std::string LiteStrCat(char v) {
+ return std::string(1, v);
+}
+
+// Less efficient but more compact version of absl::LiteStrAppend().
+template <typename T>
+inline void LiteStrAppend(std::string *dest, T v) {
+ dest->append(LiteStrCat(v)); // NOLINT
+}
+
+template <typename T1, typename T2>
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2) {
+ dest->append(LiteStrCat(v1)); // NOLINT
+ dest->append(LiteStrCat(v2)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3>
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2, T3 v3) {
+ LiteStrAppend(dest, v1, v2);
+ dest->append(LiteStrCat(v3)); // NOLINT
+}
+
+template <typename T1, typename T2, typename T3, typename T4>
+inline void LiteStrAppend(std::string *dest, T1 v1, T2 v2, T3 v3, T4 v4) {
+ LiteStrAppend(dest, v1, v2, v3);
+ dest->append(LiteStrCat(v4)); // NOLINT
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STR_CAT_H_
diff --git a/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc
similarity index 100%
rename from lang_id/common/lite_strings/str-split.cc
rename to native/lang_id/common/lite_strings/str-split.cc
diff --git a/lang_id/common/lite_strings/str-split.h b/native/lang_id/common/lite_strings/str-split.h
similarity index 100%
rename from lang_id/common/lite_strings/str-split.h
rename to native/lang_id/common/lite_strings/str-split.h
diff --git a/native/lang_id/common/lite_strings/stringpiece.h b/native/lang_id/common/lite_strings/stringpiece.h
new file mode 100644
index 0000000..6565053
--- /dev/null
+++ b/native/lang_id/common/lite_strings/stringpiece.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
+
+#include <stddef.h>
+#include <string.h>
+
+#include <ostream>
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Read-only "view" of a piece of data. Does not own the underlying data.
+class StringPiece {
+ public:
+ StringPiece() : StringPiece(nullptr, 0) {}
+
+ StringPiece(const char *str) // NOLINT
+ : start_(str), size_(strlen(str)) {}
+
+ StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
+
+ // Intentionally no "explicit" keyword: in function calls, we want strings to
+ // be converted to StringPiece implicitly.
+ StringPiece(const std::string &s) // NOLINT
+ : StringPiece(s.data(), s.size()) {}
+
+ StringPiece(const std::string &s, int offset, int len)
+ : StringPiece(s.data() + offset, len) {}
+
+ char operator[](size_t i) const { return start_[i]; }
+
+ // Returns start address of underlying data.
+ const char *data() const { return start_; }
+
+ // Returns number of bytes of underlying data.
+ size_t size() const { return size_; }
+
+ // Returns true if this StringPiece does not refer to any characters.
+ bool empty() const { return size() == 0; }
+
+ template <typename A>
+ explicit operator std::basic_string<char, std::char_traits<char>, A>() const {
+ if (!data()) return {};
+ return std::basic_string<char, std::char_traits<char>, A>(data(), size());
+ }
+
+ private:
+ const char *start_; // Not owned.
+ size_t size_;
+};
+
+inline std::ostream &operator<<(std::ostream &out, StringPiece sp) {
+ return out.write(sp.data(), sp.size());
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_LITE_STRINGS_STRINGPIECE_H_
diff --git a/native/lang_id/common/math/algorithm.h b/native/lang_id/common/math/algorithm.h
new file mode 100644
index 0000000..5c8596b
--- /dev/null
+++ b/native/lang_id/common/math/algorithm.h
@@ -0,0 +1,148 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Generic utils similar to those from the C++ header <algorithm>.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
+
+#include <algorithm>
+#include <queue>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+// Returns index of max element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMax(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::max_element(elements.begin(), elements.end()));
+}
+
+// Returns index of min element from the vector |elements|. Returns 0 if
+// |elements| is empty. T should be a type that can be compared by operator<.
+template<typename T>
+inline int GetArgMin(const std::vector<T> &elements) {
+ return std::distance(
+ elements.begin(),
+ std::min_element(elements.begin(), elements.end()));
+}
+
+// Returns indices of greatest k elements from |v|.
+//
+// The order between elements is indicated by |smaller|, which should be an
+// object like std::less<T>, std::greater<T>, etc. If smaller(a, b) is true,
+// that means that "a is smaller than b". Intuitively, |smaller| is a
+// generalization of operator<. Formally, it is a strict weak ordering, see
+// https://en.cppreference.com/w/cpp/named_req/Compare
+//
+// Calling this function with std::less<T>() returns the indices of the larger k
+// elements; calling it with std::greater<T>() returns the indices of the
+// smallest k elements. This is similar to e.g., std::priority_queue: using the
+// default std::less gives you a max-heap, while using std::greater results in a
+// min-heap.
+//
+// Returned indices are sorted in decreasing order of the corresponding elements
+// (e.g., first element of the returned array is the index of the largest
+// element). In case of ties (e.g., equal elements) we select the one with the
+// smallest index. E.g., getting the indices of the top-2 elements from [3, 2,
+// 1, 3, 0, 3] returns [0, 3] (the indices of the first and the second 3).
+//
+// Corner cases: If k <= 0, this function returns an empty vector. If |v| has
+// only n < k elements, this function returns all n indices [0, 1, 2, ..., n -
+// 1], sorted according to the comp order of the indicated elements.
+//
+// Assuming each comparison is O(1), this function uses O(k) auxiliary space,
+// and runs in O(n * log k) time. Note: it is possible to use std::nth_element
+// and obtain an O(n + k * log k) time algorithm, but that uses O(n) auxiliary
+// space. In our case, k << n, e.g., we may want to select the top-3 most
+// likely classes from a set of 100 classes, so the time complexity difference
+// should not matter in practice.
+template <typename T, typename Smaller>
+std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
+ Smaller smaller) {
+ if (k <= 0) {
+ return std::vector<int>();
+ }
+
+ if (k > v.size()) {
+ k = v.size();
+ }
+
+ // An order between indices. Intuitively, rev_vcomp(i1, i2) iff v[i2] is
+ // smaller than v[i1]. No typo: this inversion is necessary for Invariant B
+ // below. "vcomp" stands for "value comparator" (we compare the values
+ // indicates by the two indices) and "rev_" stands for the reverse order.
+ const auto rev_vcomp = [&v, &smaller](int i1, int i2) -> bool {
+ if (smaller(v[i2], v[i1])) return true;
+ if (smaller(v[i1], v[i2])) return false;
+
+ // Break ties in favor of earlier elements.
+ return i1 < i2;
+ };
+
+ // Indices of the top-k elements seen so far.
+ std::vector<int> heap(k);
+
+ // First, we fill |heap| with the first k indices.
+ for (int i = 0; i < k; ++i) {
+ heap[i] = i;
+ }
+ std::make_heap(heap.begin(), heap.end(), rev_vcomp);
+
+ // Next, we explore the rest of the vector v. Loop invariants:
+ //
+ // Invariant A: |heap| contains the indices of the top-k elements from v[0:i].
+ //
+ // Invariant B: heap[0] is the index of the smallest element from all elements
+ // indicated by the indices from |heap|.
+ //
+ // Invariant C: |heap| is a max heap, according to order rev_vcomp.
+ for (int i = k; i < v.size(); ++i) {
+ // We have to update |heap| iff v[i] is larger than the smallest of the
+ // top-k seen so far. This test is easy to do, due to Invariant B above.
+ if (smaller(v[heap[0]], v[i])) {
+ // Next lines replace heap[0] with i and re-"heapify" heap[0:k-1].
+ heap.push_back(i);
+ std::pop_heap(heap.begin(), heap.end(), rev_vcomp);
+ heap.pop_back();
+ }
+ }
+
+ // Arrange indices from |heap| in decreasing order of corresponding elements.
+ //
+ // More info: in iteration #0, we extract the largest heap element (according
+ // to rev_vcomp, i.e., the index of the smallest of the top-k elements) and
+ // place it at the end of heap, i.e., in heap[k-1]. In iteration #1, we
+ // extract the second largest and place it in heap[k-2], etc.
+ for (int i = 0; i < k; ++i) {
+ std::pop_heap(heap.begin(), heap.end() - i, rev_vcomp);
+ }
+ return heap;
+}
+
+template <typename T>
+std::vector<int> GetTopKIndices(int k, const std::vector<T> &elements) {
+ return GetTopKIndices(k, elements, std::less<T>());
+}
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
diff --git a/lang_id/common/math/checksum.cc b/native/lang_id/common/math/checksum.cc
similarity index 100%
rename from lang_id/common/math/checksum.cc
rename to native/lang_id/common/math/checksum.cc
diff --git a/lang_id/common/math/checksum.h b/native/lang_id/common/math/checksum.h
similarity index 100%
rename from lang_id/common/math/checksum.h
rename to native/lang_id/common/math/checksum.h
diff --git a/lang_id/common/math/fastexp.cc b/native/lang_id/common/math/fastexp.cc
similarity index 100%
rename from lang_id/common/math/fastexp.cc
rename to native/lang_id/common/math/fastexp.cc
diff --git a/native/lang_id/common/math/fastexp.h b/native/lang_id/common/math/fastexp.h
new file mode 100644
index 0000000..761e9ac
--- /dev/null
+++ b/native/lang_id/common/math/fastexp.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Fast approximation for exp.
+//
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
+
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+#include "lang_id/common/lite_base/casts.h"
+#include "lang_id/common/lite_base/integral-types.h"
+#include "lang_id/common/lite_base/logging.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+class FastMathClass {
+ private:
+ static constexpr int kBits = 7;
+ static constexpr int kMask1 = (1 << kBits) - 1;
+ static constexpr int kMask2 = 0xFF << kBits;
+ static constexpr float kLogBase2OfE = 1.44269504088896340736f;
+
+ struct Table {
+ int32 exp1[1 << kBits];
+ };
+
+ public:
+ float VeryFastExp2(float f) const {
+ SAFTM_DCHECK_LE(fabs(f), 126);
+ const float g = f + (127 + (1 << (23 - kBits)));
+ const int32 x = bit_cast<int32>(g);
+ int32 ret = ((x & kMask2) << (23 - kBits))
+ | cache_.exp1[x & kMask1];
+ return bit_cast<float>(ret);
+ }
+
+ float VeryFastExp(float f) const {
+ return VeryFastExp2(f * kLogBase2OfE);
+ }
+
+ private:
+ static const Table cache_;
+};
+
+extern FastMathClass FastMathInstance;
+
+inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
diff --git a/lang_id/common/math/hash.cc b/native/lang_id/common/math/hash.cc
similarity index 100%
rename from lang_id/common/math/hash.cc
rename to native/lang_id/common/math/hash.cc
diff --git a/native/lang_id/common/math/hash.h b/native/lang_id/common/math/hash.h
new file mode 100644
index 0000000..a1c24d5
--- /dev/null
+++ b/native/lang_id/common/math/hash.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
+
+#include <string>
+
+#include "lang_id/common/lite_base/integral-types.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns a 32 bit hash of the |n| bytes that start at |data|, using |seed| for
+// internal initialization. By changing the seed, one effectively gets
+// different hash functions.
+//
+// NOTE: this function is guaranteed not to change in the future.
+//
+// IMPORTANT: for speed reasons, this method does not check its parameters
+// |data| and |n|. The caller should ensure that n >= 0 and that one can read
+// from the memory area [data, data + n).
+uint32 Hash32(const char *data, size_t n, uint32 seed);
+
+static inline uint32 Hash32WithDefaultSeed(const char *data, size_t n) {
+ return Hash32(data, n, 0xBEEF);
+}
+
+static inline uint32 Hash32WithDefaultSeed(const std::string &input) {
+ return Hash32WithDefaultSeed(input.data(), input.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_HASH_H_
diff --git a/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
similarity index 100%
rename from lang_id/common/math/softmax.cc
rename to native/lang_id/common/math/softmax.cc
diff --git a/lang_id/common/math/softmax.h b/native/lang_id/common/math/softmax.h
similarity index 100%
rename from lang_id/common/math/softmax.h
rename to native/lang_id/common/math/softmax.h
diff --git a/native/lang_id/common/registry.h b/native/lang_id/common/registry.h
new file mode 100644
index 0000000..632f917
--- /dev/null
+++ b/native/lang_id/common/registry.h
@@ -0,0 +1,321 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Mechanism to instantiate classes by name.
+//
+// This mechanism is useful if the concrete classes to be instantiated are not
+// statically known (e.g., if their names are read from a dynamically-provided
+// config).
+//
+// In that case, the first step is to define the API implemented by the
+// instantiated classes. E.g.,
+//
+// // In a header file function.h:
+//
+// // Abstract function that takes a double and returns a double.
+// class Function : public RegisterableClass<Function> {
+// public:
+// virtual ~Function() {}
+// virtual double Evaluate(double x) = 0;
+// };
+//
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(Function);
+//
+// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
+// is defined by this file (registry.h). Under the hood, this inheritanace
+// defines a "registry" that maps names (zero-terminated arrays of chars) to
+// factory methods that create Functions. You should give a human-readable name
+// to this registry. To do that, use the following macro in a .cc file (it has
+// to be a .cc file, as it defines some static data):
+//
+// // Inside function.cc
+// // Should be inside namespace libtextclassifier3::mobile.
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("function", Function);
+//
+// Now, let's define a few concrete Functions: e.g.,
+//
+// class Cos : public Function {
+// public:
+// double Evaluate(double x) override { return cos(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("cos", Cos);
+// };
+//
+// class Exp : public Function {
+// public:
+// double Evaluate(double x) override { return exp(x); }
+// SAFTM_DEFINE_REGISTRATION_METHOD("sin", Sin);
+// };
+//
+// Each concrete Function implementation should have (in the public section) the
+// macro
+//
+// SAFTM_DEFINE_REGISTRATION_METHOD("name", implementation_class);
+//
+// This defines a RegisterClass static method that, when invoked, associates
+// "name" with a factory method that creates instances of implementation_class.
+//
+// Before instantiating Functions by name, we need to tell our system which
+// Functions we may be interested in. This is done by calling the
+// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
+// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
+// first call will perform something, the others will return immediately.
+//
+// Cos::RegisterClass();
+// Exp::RegisterClass();
+//
+// Now, let's instantiate a Function based on its name. This get a lot more
+// interesting if the Function name is not statically known (i.e.,
+// read from an input proto:
+//
+// std::unique_ptr<Function> f(Function::Create("cos"));
+// double result = f->Evaluate(arg);
+//
+// NOTE: the same binary can use this mechanism for different APIs. E.g., one
+// can also have (in the binary with Function, Sin, Cos, etc):
+//
+// class IntFunction : public RegisterableClass<IntFunction> {
+// public:
+// virtual ~IntFunction() {}
+// virtual int Evaluate(int k) = 0;
+// };
+//
+// SAFTM_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
+//
+// SAFTM_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
+//
+// class Inc : public IntFunction {
+// public:
+// int Evaluate(int k) override { return k + 1; }
+// SAFTM_DEFINE_REGISTRATION_METHOD("inc", Inc);
+// };
+//
+// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
+// own registries: each maps string names to implementation of the corresponding
+// API.
+//
+// NOTE: the mechanism described above requires you to explicitly call
+// RegisterClass() for all relevant classes before instantiating them. You can
+// do this in the main() function or in any other function that is guaranteed to
+// run before the code that instantiates those classes. Alternatively, you can
+// use the macro SAFTM_STATIC_REGISTRATION to perform this registration in a
+// decentralized fashion. Just use that macro in a .cc file, outside any
+// function / class, e.g.,
+//
+// SAFTM_STATIC_REGISTRATION(Cos);
+//
+// and make sure you link in all symbols from that .cc file; e.g., in bazel, use
+// alwayslink = 1 for the corresponding cc_library. Still, please be aware that
+// using alwayslink = 1 limits the ability of the linker to perform dead code
+// elimination.
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_base/macros.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+namespace internal {
+// Registry that associates keys (zero-terminated array of chars) with values.
+// Values are pointers to type T (the template parameter). This is used to
+// store the association between component names and factory methods that
+// produce those components; the error messages are focused on that case.
+//
+// Internally, this registry uses a linked list of (key, value) pairs. We do
+// not use an STL map, list, etc because we aim for small code size.
+template <class T>
+class ComponentRegistry {
+ public:
+ explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
+
+ // Adds a the (key, value) pair to this registry (if the key does not already
+ // exists in this registry) and returns true. If the registry already has a
+ // mapping for key, returns false and does not modify the registry. NOTE: the
+ // error (false) case happens even if the existing value for key is equal with
+ // the new one.
+ //
+ // This method does not take ownership of key, nor of value.
+ bool Add(const char *key, T *value) {
+ const Cell *old_cell = FindCell(key);
+ if (old_cell != nullptr) {
+ SAFTM_LOG(ERROR) << "Duplicate component: " << key;
+ return false;
+ }
+ Cell *new_cell = new Cell(key, value, head_);
+ head_ = new_cell;
+ return true;
+ }
+
+ // Returns the value attached to a key in this registry. Returns nullptr on
+ // error (e.g., unknown key).
+ T *Lookup(const char *key) const {
+ const Cell *cell = FindCell(key);
+ if (cell == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown " << name() << " component: " << key;
+ }
+ return (cell == nullptr) ? nullptr : cell->value();
+ }
+
+ T *Lookup(const std::string &key) const { return Lookup(key.c_str()); }
+
+ // Returns name of this ComponentRegistry.
+ const char *name() const { return name_; }
+
+ // Fills *names with names of all components registered in this
+ // ComponentRegistry. Previous content of *names is cleared out.
+ void GetComponentNames(std::vector<std::string> *names) {
+ names->clear();
+ for (const Cell *c = head_; c!= nullptr; c = c->next()) {
+ names->emplace_back(c->key());
+ }
+ }
+
+ private:
+ // Cell for the singly-linked list underlying this ComponentRegistry. Each
+ // cell contains a key, the value for that key, as well as a pointer to the
+ // next Cell from the list.
+ class Cell {
+ public:
+ // Constructs a new Cell.
+ Cell(const char *key, T *value, Cell *next)
+ : key_(key), value_(value), next_(next) {}
+
+ const char *key() const { return key_; }
+ T *value() const { return value_; }
+ Cell *next() const { return next_; }
+
+ private:
+ const char *const key_;
+ T *const value_;
+ Cell *const next_;
+ };
+
+ // Finds Cell for indicated key in the singly-linked list pointed to by head_.
+ // Returns pointer to that first Cell with that key, or nullptr if no such
+ // Cell (i.e., unknown key).
+ //
+ // Caller does NOT own the returned pointer.
+ const Cell *FindCell(const char *key) const {
+ const Cell *c = head_;
+ while (c != nullptr && strcmp(key, c->key()) != 0) {
+ c = c->next();
+ }
+ return c;
+ }
+
+ // Human-readable description for this ComponentRegistry. For debug purposes.
+ const char *const name_;
+
+ // Pointer to the first Cell from the underlying list of (key, value) pairs.
+ Cell *head_;
+};
+} // namespace internal
+
+// Base class for registerable classes.
+template <class T>
+class RegisterableClass {
+ public:
+ // Factory function type.
+ typedef T *(Factory)();
+
+ // Registry type.
+ typedef internal::ComponentRegistry<Factory> Registry;
+
+ // Creates a new instance of T. Returns pointer to new instance or nullptr in
+ // case of errors (e.g., unknown component).
+ //
+ // Passes ownership of the returned pointer to the caller.
+ static T *Create(const std::string &name) { // NOLINT
+ auto *factory = registry()->Lookup(name);
+ if (factory == nullptr) {
+ SAFTM_LOG(ERROR) << "Unknown RegisterableClass " << name;
+ return nullptr;
+ }
+ return factory();
+ }
+
+ // Returns registry for class.
+ static Registry *registry() {
+ static Registry *registry_for_type_t = new Registry(kRegistryName);
+ return registry_for_type_t;
+ }
+
+ protected:
+ // Factory method for subclass ComponentClass. Used internally by the static
+ // method RegisterClass() defined by SAFTM_DEFINE_REGISTRATION_METHOD.
+ template <class ComponentClass>
+ static T *_internal_component_factory() {
+ return new ComponentClass();
+ }
+
+ private:
+ // Human-readable name for the registry for this class.
+ static const char kRegistryName[];
+};
+
+// Defines the static method component_class::RegisterClass() that should be
+// called before trying to instantiate component_class by name. Should be used
+// inside the public section of the declaration of component_class. See
+// comments at the top-level of this file.
+#define SAFTM_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
+ static void RegisterClass() { \
+ static bool once = registry()->Add( \
+ component_name, &_internal_component_factory<component_class>); \
+ if (!once) { \
+ SAFTM_LOG(ERROR) << "Problem registering " << component_name; \
+ } \
+ SAFTM_DCHECK(once); \
+ }
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DECLARE_CLASS_REGISTRY_NAME(base_class) \
+ template <> \
+ const char ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[]
+
+// Defines the human-readable name of the registry associated with base_class.
+#define SAFTM_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
+ template <> \
+ const char \
+ ::libtextclassifier3::mobile::RegisterableClass<base_class>::kRegistryName[] \
+ = registry_name
+
+// Register component_name, by calling component_class::RegisterClass() on
+// program start-up, before main. NOTE: this macro should be used in
+// conjunction with something like alwayslink = 1 from bazel. That is
+// discouraged, as it prevents the linker from doing dead code elimination, so
+// please use this macro only in special cases. Instead, if you care about code
+// size, then you should aim to explicitly call RegisterClass from your code
+// (e.g., from the main method, or from the constructor of the class that may
+// need those registered components).
+#define SAFTM_STATIC_REGISTRATION(component_class) \
+ static bool SAFTM_UNIQUE_ID(_kRegistrationDummy) = [] { \
+ component_class::RegisterClass(); \
+ return true; \
+ }()
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_REGISTRY_H_
diff --git a/lang_id/common/stl-util.h b/native/lang_id/common/stl-util.h
similarity index 100%
rename from lang_id/common/stl-util.h
rename to native/lang_id/common/stl-util.h
diff --git a/lang_id/common/utf8.cc b/native/lang_id/common/utf8.cc
similarity index 100%
rename from lang_id/common/utf8.cc
rename to native/lang_id/common/utf8.cc
diff --git a/native/lang_id/common/utf8.h b/native/lang_id/common/utf8.h
new file mode 100644
index 0000000..6103bdd
--- /dev/null
+++ b/native/lang_id/common/utf8.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TC3_STD_STRING_IMPORT
+#define TC3_STD_STRING_IMPORT
+#include <string>
+
+namespace libtextclassifier3 {
+using string = std::string;
+template <class CharT, class Traits = std::char_traits<CharT>,
+ class Allocator = std::allocator<CharT> >
+using basic_string = std::basic_string<CharT, Traits, Allocator>;
+} // namespace libtextclassifier3
+#endif
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
+
+#include <stddef.h>
+
+#include <string>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace utils {
+
+// Returns the length (number of bytes) of the UTF8 code point starting at src,
+// by reading only the byte from address src.
+//
+// The result is a number from the set {1, 2, 3, 4}.
+static inline int OneCharLen(const char *src) {
+ // On most platforms, char is unsigned by default, but iOS is an exception.
+ // The cast below makes sure we always interpret *src as an unsigned char.
+ return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
+ [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
+}
+
+// Returns a pointer "end" inside [data, data + size) such that the prefix from
+// [data, end) is the largest one that does not contain '\0' and offers the
+// following guarantee: if one starts with
+//
+// curr = text.data()
+//
+// and keeps executing
+//
+// curr += OneCharLen(curr)
+//
+// one would eventually reach curr == end (the pointer returned by this
+// function) without accessing data outside the string. This guards against
+// scenarios like a broken UTF8 string which has only e.g., the first 2 bytes
+// from a 3-byte UTF8 sequence.
+//
+// Preconditions: data != nullptr.
+const char *GetSafeEndOfUtf8String(const char *data, size_t size);
+
+static inline const char *GetSafeEndOfUtf8String(const std::string &text) {
+ return GetSafeEndOfUtf8String(text.data(), text.size());
+}
+
+} // namespace utils
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_UTF8_H_
diff --git a/native/lang_id/custom-tokenizer.cc b/native/lang_id/custom-tokenizer.cc
new file mode 100644
index 0000000..46a64b2
--- /dev/null
+++ b/native/lang_id/custom-tokenizer.cc
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/custom-tokenizer.h"
+
+#include <ctype.h>
+
+#include <string>
+
+#include "lang_id/common/lite_base/attributes.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "utf.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+inline bool IsTokenSeparator(int num_bytes, const char *curr) {
+ if (num_bytes != 1) {
+ return false;
+ }
+ return !isalpha(*curr);
+}
+
+// Appends to *word the UTF8 encoding for the lowercase version of the UTF8
+// character that starts at |curr| and has |num_bytes| bytes.
+//
+// NOTE: if the current UTF8 character does not have a lowercase version, then
+// we append the original UTF8 character.
+inline SAFTM_ATTRIBUTE_ALWAYS_INLINE void AppendLowerCase(const char *curr,
+ int num_bytes,
+ std::string *word) {
+ if (num_bytes == 1) {
+ // Optimize the ASCII case.
+ word->push_back(tolower(*curr));
+ return;
+ }
+
+ // Harder, general case.
+ //
+ // NOTE: for lowercasing, we use the utils from utf.h:
+ // charntorune + tolowerrune + runetochar. Unfortunately, that library does
+ // not contain any fast util for determining the number of bytes for the UTF8
+ // character that starts at a given address *without* converting to a full
+ // codepoint (like our utils::OneCharLen, which is used intensively by the
+ // rest of our code, including by the performance-critical char ngram
+ // feature). Hence, the rest of our code continues to use utils::OneCharLen,
+ // and here, when we append the bytes to *word, we make sure that's consistent
+ // with utils::OneCharLen.
+
+ // charntorune() below reads the UTF8 character that starts at curr (using at
+ // most num_bytes bytes) and stores the corresponding codepoint into rune.
+ Rune rune;
+ charntorune(&rune, curr, num_bytes);
+ if (rune != Runeerror) {
+ Rune lower = tolowerrune(rune);
+ char lower_buf[UTFmax];
+ runetochar(lower_buf, &lower);
+
+ // When appending the UTF8 bytes to word, we do not use the number of bytes
+ // returned by runetochar(); instead, we use utils::OneCharLen(), the same
+ // method used by the char ngram feature. We expect them to be equal, but
+ // just in case.
+ int lower_num_bytes = utils::OneCharLen(lower_buf);
+
+ // Using lower_num_bytes below is safe, because, by definition of UTFmax,
+ SAFTM_DCHECK_GE(UTFmax, 4);
+
+ // And, by implementation of utils::OneCharLen():
+ SAFTM_DCHECK_GT(lower_num_bytes, 0);
+ SAFTM_DCHECK_LE(lower_num_bytes, 4);
+ word->append(lower_buf, lower_num_bytes);
+ } else {
+ // There are sequences of bytes that charntorune() can't convert into a
+ // valid Rune (a special case is [0xEF, 0xBF, 0xBD], the UTF8 encoding for
+ // the U+FFFD special Unicode character, which is also the value of
+ // Runeerror). We keep those bytes unchanged.
+ word->append(curr, num_bytes);
+ }
+}
+} // namespace
+
+void TokenizerForLangId::Setup(TaskContext *context) {
+ lowercase_input_ = context->Get("lang_id_lowercase_input", false);
+}
+
+void TokenizerForLangId::Tokenize(StringPiece text,
+ LightSentence *sentence) const {
+ const char *const start = text.data();
+ const char *curr = start;
+ const char *end = utils::GetSafeEndOfUtf8String(start, text.size());
+
+ // Corner case: the safe part of the text is empty ("").
+ if (curr >= end) {
+ return;
+ }
+
+ // Number of bytes for UTF8 character starting at *curr. Note: the loop below
+ // is guaranteed to terminate because in each iteration, we move curr by at
+ // least num_bytes, and num_bytes is guaranteed to be > 0.
+ int num_bytes = utils::OneCharLen(curr);
+ while (curr < end) {
+ // Jump over consecutive token separators.
+ while (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ if (curr >= end) {
+ return;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ }
+
+ // If control reaches this point, we are at beginning of a non-empty token.
+ sentence->emplace_back();
+ std::string *word = &(sentence->back());
+
+ // Add special token-start character.
+ word->push_back('^');
+
+ // Add UTF8 characters to word, until we hit the end of the safe text or a
+ // token separator.
+ while (true) {
+ if (lowercase_input_) {
+ AppendLowerCase(curr, num_bytes, word);
+ } else {
+ word->append(curr, num_bytes);
+ }
+ curr += num_bytes;
+ if (curr >= end) {
+ break;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ if (IsTokenSeparator(num_bytes, curr)) {
+ curr += num_bytes;
+ if (curr >= end) {
+ break;
+ }
+ num_bytes = utils::OneCharLen(curr);
+ break;
+ }
+ }
+ word->push_back('$');
+ }
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/custom-tokenizer.h b/native/lang_id/custom-tokenizer.h
similarity index 100%
rename from lang_id/custom-tokenizer.h
rename to native/lang_id/custom-tokenizer.h
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
new file mode 100644
index 0000000..b2163eb
--- /dev/null
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/lang-id-from-fb.h"
+
+#include <string>
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(
+ const std::string &filename) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(filename));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(data, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/fb_model/lang-id-from-fb.h b/native/lang_id/fb_model/lang-id-from-fb.h
new file mode 100644
index 0000000..061247b
--- /dev/null
+++ b/native/lang_id/fb_model/lang-id-from-fb.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// |filename|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFile(
+ const std::string &filename);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
+// the |num_bytes| bytes that start at address |data|.
+//
+// IMPORTANT: the model bytes must be alive during the lifetime of the returned
+// LangId. To avoid overhead (e.g., heap allocation), this method does not make
+// a private copy of the model bytes. Avoiding overhead is the main reason we
+// use flatbuffers.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
+ size_t num_bytes);
+
+// Convenience string-based version of GetLangIdFromFlatbufferBytes.
+//
+// IMPORTANT: |bytes| must be alive during the lifetime of the returned LangId.
+inline std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(
+ const std::string &bytes) {
+ return GetLangIdFromFlatbufferBytes(bytes.data(), bytes.size());
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_LANG_ID_FROM_FB_H_
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
new file mode 100644
index 0000000..c81b116
--- /dev/null
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/fb_model/model-provider-from-fb.h"
+
+#include <string>
+
+#include "lang_id/common/file/file-utils.h"
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
+#include "lang_id/common/flatbuffers/model-utils.h"
+#include "lang_id/common/lite_strings/str-split.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ const std::string &filename)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(filename)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ FileDescriptorOrHandle fd)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
+void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
+ // Note: valid_ was initialized to false. In the code below, we set valid_ to
+ // true only if all initialization steps completed successfully. Otherwise,
+ // we return early, leaving valid_ to its default value false.
+ model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
+ if (model_ == nullptr) {
+ SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
+ return;
+ }
+
+ // Initialize context_ parameters.
+ if (!saft_fbs::FillParameters(*model_, &context_)) {
+ // FillParameters already performs error logging.
+ return;
+ }
+
+ // Init languages_.
+ const std::string known_languages_str =
+ context_.Get("supported_languages", "");
+ for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
+ languages_.emplace_back(sp);
+ }
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
+ return;
+ }
+
+ // Init nn_params_.
+ if (!InitNetworkParams()) {
+ // InitNetworkParams already performs error logging.
+ return;
+ }
+
+ // Everything looks fine.
+ valid_ = true;
+}
+
+bool ModelProviderFromFlatbuffer::InitNetworkParams() {
+ const std::string kInputName = "language-identifier-network";
+ StringPiece bytes =
+ saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
+ if ((bytes.data() == nullptr) || bytes.empty()) {
+ SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
+ return false;
+ }
+ std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
+ new EmbeddingNetworkParamsFromFlatbuffer(bytes));
+ if (!nn_params_from_fb->is_valid()) {
+ SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
+ return false;
+ }
+ nn_params_ = std::move(nn_params_from_fb);
+ return true;
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/fb_model/model-provider-from-fb.h b/native/lang_id/fb_model/model-provider-from-fb.h
new file mode 100644
index 0000000..c3def49
--- /dev/null
+++ b/native/lang_id/fb_model/model-provider-from-fb.h
@@ -0,0 +1,116 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/file/mmap.h"
+#include "lang_id/common/flatbuffers/model_generated.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// ModelProvider for LangId, based on a SAFT model in flatbuffer format.
+class ModelProviderFromFlatbuffer : public ModelProvider {
+ public:
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // |filename|.
+ explicit ModelProviderFromFlatbuffer(const std::string &filename);
+
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd);
+
+ // Constructs a model provider from a flatbuffer-format SAFT model the bytes
+ // of which are already in RAM (size bytes starting from address data).
+ // Useful if you "transport" these bytes otherwise than via a normal file
+ // (e.g., if you embed them somehow in your binary).
+ //
+ // IMPORTANT: |data| should be alive during the lifetime of the
+ // newly-constructed ModelProviderFromFlatbuffer. This is trivial to ensure
+ // for data that's statically embedded in your binary, but more complex in
+ // other cases. To avoid overhead (e.g., heap allocation), this method does
+ // not make a private copy of the data. In general, the ownership of the
+ // newly-constructed ModelProviderFromFlatbuffer is immediately passed to a
+ // LangId object (which doesn't pass it further); hence, one needs to make
+ // sure |data| is alive during the lifetime of that LangId object.
+ ModelProviderFromFlatbuffer(const char *data, std::size_t size) {
+ StringPiece model_bytes(data, size);
+ Initialize(model_bytes);
+ }
+
+ ~ModelProviderFromFlatbuffer() override = default;
+
+ const TaskContext *GetTaskContext() const override {
+ return &context_;
+ }
+
+ const EmbeddingNetworkParams *GetNnParams() const override {
+ return nn_params_.get();
+ }
+
+ std::vector<std::string> GetLanguages() const override { return languages_; }
+
+ private:
+ // Initializes the fields of this class based on the flatbuffer from
+ // |model_bytes|. These bytes are supposed to be the representation of a
+ // Model flatbuffer and should be alive during the lifetime of this object.
+ void Initialize(StringPiece model_bytes);
+
+ // Initializes nn_params_ based on model_.
+ bool InitNetworkParams();
+
+ // If filename-based constructor is used, scoped_mmap_ keeps the file mmapped
+ // during the lifetime of this object, such that references inside the Model
+ // flatbuffer from those bytes remain valid.
+ const std::unique_ptr<ScopedMmap> scoped_mmap_;
+
+ // Pointer to the flatbuffer from
+ //
+ // (a) [if filename constructor was used:] the bytes mmapped by scoped_mmap_
+ // (for safety considerations, see comment for that field), or
+ //
+ // (b) [of (data, size) constructor was used:] the bytes from [data,
+ // data+size). Please read carefully the doc for that constructor.
+ const saft_fbs::Model *model_;
+
+ // Context returned by this model provider. We set its parameters based on
+ // model_, at construction time.
+ TaskContext context_;
+
+ // List of supported languages, see GetLanguages(). We expect this list to be
+ // specified by the ModelParameter named "supported_languages" from model_.
+ std::vector<std::string> languages_;
+
+ // EmbeddingNetworkParams, see GetNnParams(). Set based on the ModelInput
+ // named "language-identifier-network" from model_.
+ std::unique_ptr<EmbeddingNetworkParams> nn_params_;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_FB_MODEL_MODEL_PROVIDER_FROM_FB_H_
diff --git a/native/lang_id/features/char-ngram-feature.cc b/native/lang_id/features/char-ngram-feature.cc
new file mode 100644
index 0000000..31faf2f
--- /dev/null
+++ b/native/lang_id/features/char-ngram-feature.cc
@@ -0,0 +1,157 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/char-ngram-feature.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/math/hash.h"
+#include "lang_id/common/utf8.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
+ // Parameters in the feature function descriptor.
+ bool include_terminators = GetBoolParameter("include_terminators", false);
+ if (!include_terminators) {
+ SAFTM_LOG(ERROR) << "No support for include_terminators=true";
+ return false;
+ }
+
+ bool include_spaces = GetBoolParameter("include_spaces", false);
+ if (include_spaces) {
+ SAFTM_LOG(ERROR) << "No support for include_spaces=true";
+ return false;
+ }
+
+ bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
+ if (use_equal_ngram_weight) {
+ SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
+ return false;
+ }
+
+ ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
+ ngram_size_ = GetIntParameter("size", 3);
+
+ counts_.assign(ngram_id_dimension_, 0);
+ return true;
+}
+
+bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
+ return true;
+}
+
+int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
+ const LightSentence &sentence) const {
+ SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
+ SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
+
+ int total_count = 0;
+
+ for (const std::string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+
+ // Set ngram_start at the start of the current token (word).
+ const char *ngram_start = word.data();
+
+ // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
+ // UTF8 character contains between 1 and 4 bytes.
+ const char *ngram_end = ngram_start;
+ int num_utf8_chars = 0;
+ do {
+ ngram_end += utils::OneCharLen(ngram_end);
+ num_utf8_chars++;
+ } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
+
+ if (num_utf8_chars < ngram_size_) {
+ // Current token is so small, it does not contain a single ngram of
+ // ngram_size UTF8 characters. Not much we can do in this case ...
+ continue;
+ }
+
+ // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
+ // UTF8 characters from current token.
+ while (true) {
+ // Compute ngram id: hash(ngram) % ngram_id_dimension
+ int ngram_id = (
+ utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
+ % ngram_id_dimension_);
+
+ // Use a reference to the actual count, such that we can both test whether
+ // the count was 0 and increment it without perfoming two lookups.
+ int &ref_to_count_for_ngram = counts_[ngram_id];
+ if (ref_to_count_for_ngram == 0) {
+ non_zero_count_indices_.push_back(ngram_id);
+ }
+ ref_to_count_for_ngram++;
+ total_count++;
+ if (ngram_end >= word_end) {
+ break;
+ }
+
+ // Advance both ngram_start and ngram_end by one UTF8 character. This
+ // way, the number of UTF8 characters between them remains constant
+ // (ngram_size).
+ ngram_start += utils::OneCharLen(ngram_start);
+ ngram_end += utils::OneCharLen(ngram_end);
+ }
+ } // end of loop over tokens.
+
+ return total_count;
+}
+
+void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
+ const LightSentence &sentence,
+ FeatureVector *result) const {
+ // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
+ // porting to Android and to avoid pulling in absl (which increases our code
+ // size).
+ std::lock_guard<std::mutex> mlock(state_mutex_);
+
+ // Find the char ngram counts.
+ int total_count = ComputeNgramCounts(sentence);
+
+ // Populate the feature vector.
+ const float norm = static_cast<float>(total_count);
+
+ // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
+ // elements) separately.
+ for (int ngram_id : non_zero_count_indices_) {
+ const float weight = counts_[ngram_id] / norm;
+ FloatFeatureValue value(ngram_id, weight);
+ result->add(feature_type(), value.discrete_value);
+
+ // Clear up counts_, for the next invocation of Evaluate().
+ counts_[ngram_id] = 0;
+ }
+
+ // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
+ non_zero_count_indices_.clear();
+}
+
+SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/char-ngram-feature.h b/native/lang_id/features/char-ngram-feature.h
similarity index 100%
rename from lang_id/features/char-ngram-feature.h
rename to native/lang_id/features/char-ngram-feature.h
diff --git a/lang_id/features/light-sentence-features.cc b/native/lang_id/features/light-sentence-features.cc
similarity index 100%
rename from lang_id/features/light-sentence-features.cc
rename to native/lang_id/features/light-sentence-features.cc
diff --git a/lang_id/features/light-sentence-features.h b/native/lang_id/features/light-sentence-features.h
similarity index 100%
rename from lang_id/features/light-sentence-features.h
rename to native/lang_id/features/light-sentence-features.h
diff --git a/native/lang_id/features/relevant-script-feature.cc b/native/lang_id/features/relevant-script-feature.cc
new file mode 100644
index 0000000..e88b328
--- /dev/null
+++ b/native/lang_id/features/relevant-script-feature.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/features/relevant-script-feature.h"
+
+#include <string>
+
+#include "lang_id/common/fel/feature-types.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/utf8.h"
+#include "lang_id/script/script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+bool RelevantScriptFeature::Setup(TaskContext *context) {
+ std::string script_detector_name = GetParameter(
+ "script_detector_name", /* default_value = */ "tiny-script-detector");
+
+ // We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
+ script_detector_.reset(ScriptDetector::Create(script_detector_name));
+ if (script_detector_ == nullptr) {
+ // This means ScriptDetector::Create() could not find the requested
+ // script_detector_name. In that case, Create() already logged an error
+ // message.
+ return false;
+ }
+
+ // We use default value 172 because this is the number of scripts supported by
+ // the first model we trained with this feature. See http://b/70617713.
+ // Newer models may support more scripts.
+ num_supported_scripts_ = GetIntParameter("num_supported_scripts", 172);
+ return true;
+}
+
+bool RelevantScriptFeature::Init(TaskContext *context) {
+ set_feature_type(new NumericFeatureType(name(), num_supported_scripts_));
+ return true;
+}
+
+void RelevantScriptFeature::Evaluate(
+ const WorkspaceSet &workspaces, const LightSentence &sentence,
+ FeatureVector *result) const {
+ // counts[s] is the number of characters with script s.
+ std::vector<int> counts(num_supported_scripts_);
+ int total_count = 0;
+ for (const std::string &word : sentence) {
+ const char *const word_end = word.data() + word.size();
+ const char *curr = word.data();
+
+ // Skip over token start '^'.
+ SAFTM_DCHECK_EQ(*curr, '^');
+ curr += utils::OneCharLen(curr);
+ while (true) {
+ const int num_bytes = utils::OneCharLen(curr);
+
+ int script = script_detector_->GetScript(curr, num_bytes);
+
+ // We do this update and the if (...) break below *before* incrementing
+ // counts[script] in order to skip the token end '$'.
+ curr += num_bytes;
+ if (curr >= word_end) {
+ SAFTM_DCHECK_EQ(*(curr - num_bytes), '$');
+ break;
+ }
+ SAFTM_DCHECK_GE(script, 0);
+
+ if (script < num_supported_scripts_) {
+ counts[script]++;
+ total_count++;
+ } else {
+ // Unsupported script: this usually indicates a script that is
+ // recognized by newer versions of the code, after the model was
+ // trained. E.g., new code running with old model.
+ }
+ }
+ }
+
+ for (int script_id = 0; script_id < num_supported_scripts_; ++script_id) {
+ int count = counts[script_id];
+ if (count > 0) {
+ const float weight = static_cast<float>(count) / total_count;
+ FloatFeatureValue value(script_id, weight);
+ result->add(feature_type(), value.discrete_value);
+ }
+ }
+}
+
+SAFTM_STATIC_REGISTRATION(RelevantScriptFeature);
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/features/relevant-script-feature.h b/native/lang_id/features/relevant-script-feature.h
similarity index 100%
rename from lang_id/features/relevant-script-feature.h
rename to native/lang_id/features/relevant-script-feature.h
diff --git a/native/lang_id/lang-id-wrapper.cc b/native/lang_id/lang-id-wrapper.cc
new file mode 100644
index 0000000..4246cce
--- /dev/null
+++ b/native/lang_id/lang-id-wrapper.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id-wrapper.h"
+
+#include <fcntl.h>
+
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+namespace langid {
+
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromPath(
+ const std::string& langid_model_path) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(langid_model_path);
+ return langid_model;
+}
+
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
+ const int langid_fd) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor(
+ langid_fd);
+ return langid_model;
+}
+
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text) {
+ return GetPredictions(model, text.data(), text.size());
+}
+
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const char* text,
+ int text_size) {
+ std::vector<std::pair<std::string, float>> prediction_results;
+ if (model == nullptr) {
+ return prediction_results;
+ }
+
+ const float noise_threshold =
+ model->GetFloatProperty("text_classifier_langid_noise_threshold", -1.0f);
+
+ // Speed up the things by specifying the max results we want. For example, if
+ // the noise threshold is 0.1, we don't need more than 10 results.
+ const int max_results =
+ noise_threshold < 0.01
+ ? -1 // -1 means FindLanguages returns all predictions
+ : static_cast<int>(1 / noise_threshold) + 1;
+
+ libtextclassifier3::mobile::lang_id::LangIdResult langid_result;
+ model->FindLanguages(text, text_size, &langid_result, max_results);
+ for (int i = 0; i < langid_result.predictions.size(); i++) {
+ const auto& prediction = langid_result.predictions[i];
+ if (prediction.second >= noise_threshold && prediction.first != "und") {
+ prediction_results.push_back({prediction.first, prediction.second});
+ }
+ }
+ return prediction_results;
+}
+
+std::string GetLanguageTags(const libtextclassifier3::mobile::lang_id::LangId* model,
+ const std::string& text) {
+ const std::vector<std::pair<std::string, float>>& predictions =
+ GetPredictions(model, text);
+ const float threshold =
+ model->GetFloatProperty("text_classifier_langid_threshold", -1.0f);
+ std::string detected_language_tags = "";
+ bool first_accepted_language = true;
+ for (int i = 0; i < predictions.size(); i++) {
+ const auto& prediction = predictions[i];
+ if (threshold >= 0.f && prediction.second < threshold) {
+ continue;
+ }
+ if (first_accepted_language) {
+ first_accepted_language = false;
+ } else {
+ detected_language_tags += ",";
+ }
+ detected_language_tags += prediction.first;
+ }
+ return detected_language_tags;
+}
+
+} // namespace langid
+
+} // namespace libtextclassifier3
diff --git a/native/lang_id/lang-id-wrapper.h b/native/lang_id/lang-id-wrapper.h
new file mode 100644
index 0000000..47e6f44
--- /dev/null
+++ b/native/lang_id/lang-id-wrapper.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
+#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+namespace langid {
+
+// Loads the LangId model from a given path.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromPath(
+ const std::string& path);
+
+// Loads the LangId model from a file descriptor.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
+ const int fd);
+
+// Returns the LangId predictions (locale, confidence) from the given LangId
+// model. The maximum number of predictions returned will be computed internally
+// relatively to the noise threshold.
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text);
+
+// Same as above but takes a char pointer and byte length.
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const char* text,
+ int text_size);
+
+// Returns the language tags string from the given LangId model. The language
+// tags will be filtered internally by the LangId threshold.
+std::string GetLanguageTags(const libtextclassifier3::mobile::lang_id::LangId* model,
+ const std::string& text);
+
+} // namespace langid
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
new file mode 100644
index 0000000..92359a9
--- /dev/null
+++ b/native/lang_id/lang-id.cc
@@ -0,0 +1,329 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id.h"
+
+#include <stdio.h>
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-interface.h"
+#include "lang_id/common/embedding-network-params.h"
+#include "lang_id/common/embedding-network.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/lite_base/logging.h"
+#include "lang_id/common/lite_strings/numbers.h"
+#include "lang_id/common/lite_strings/str-split.h"
+#include "lang_id/common/lite_strings/stringpiece.h"
+#include "lang_id/common/math/algorithm.h"
+#include "lang_id/common/math/softmax.h"
+#include "lang_id/custom-tokenizer.h"
+#include "lang_id/features/light-sentence-features.h"
+// The two features/ headers below are needed only for RegisterClass().
+#include "lang_id/features/char-ngram-feature.h"
+#include "lang_id/features/relevant-script-feature.h"
+#include "lang_id/light-sentence.h"
+// The two script/ headers below are needed only for RegisterClass().
+#include "lang_id/script/approx-script.h"
+#include "lang_id/script/tiny-script-detector.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+namespace {
+// Default value for the confidence threshold. If the confidence of the top
+// prediction is below this threshold, then FindLanguage() returns
+// LangId::kUnknownLanguageCode. Note: this is just a default value; if the
+// TaskSpec from the model specifies a "reliability_thresh" parameter, then we
+// use that value instead. Note: for legacy reasons, our code and comments use
+// the terms "confidence", "probability" and "reliability" equivalently.
+static const float kDefaultConfidenceThreshold = 0.50f;
+} // namespace
+
+// Class that performs all work behind LangId.
+class LangIdImpl {
+ public:
+ explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
+ : model_provider_(std::move(model_provider)),
+ lang_id_brain_interface_("language_identifier") {
+ // Note: in the code below, we set valid_ to true only if all initialization
+ // steps completed successfully. Otherwise, we return early, leaving valid_
+ // to its default value false.
+ if (!model_provider_ || !model_provider_->is_valid()) {
+ SAFTM_LOG(ERROR) << "Invalid model provider";
+ return;
+ }
+
+ auto *nn_params = model_provider_->GetNnParams();
+ if (!nn_params) {
+ SAFTM_LOG(ERROR) << "No NN params";
+ return;
+ }
+ network_.reset(new EmbeddingNetwork(nn_params));
+
+ languages_ = model_provider_->GetLanguages();
+ if (languages_.empty()) {
+ SAFTM_LOG(ERROR) << "No known languages";
+ return;
+ }
+
+ TaskContext context = *model_provider_->GetTaskContext();
+ if (!Setup(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
+ return;
+ }
+ if (!Init(&context)) {
+ SAFTM_LOG(ERROR) << "Unable to Init() LangId";
+ return;
+ }
+ valid_ = true;
+ }
+
+ std::string FindLanguage(StringPiece text) const {
+ LangIdResult lang_id_result;
+ FindLanguages(text, &lang_id_result, /* max_results = */ 1);
+ if (lang_id_result.predictions.empty()) {
+ return LangId::kUnknownLanguageCode;
+ }
+
+ const std::string &language = lang_id_result.predictions[0].first;
+ const float probability = lang_id_result.predictions[0].second;
+ SAFTM_DLOG(INFO) << "Predicted " << language
+ << " with prob: " << probability << " for \"" << text
+ << "\"";
+
+ // Find confidence threshold for language.
+ float threshold = default_threshold_;
+ auto it = per_lang_thresholds_.find(language);
+ if (it != per_lang_thresholds_.end()) {
+ threshold = it->second;
+ }
+ if (probability < threshold) {
+ SAFTM_DLOG(INFO) << " below threshold => "
+ << LangId::kUnknownLanguageCode;
+ return LangId::kUnknownLanguageCode;
+ }
+ return language;
+ }
+
+ void FindLanguages(StringPiece text, LangIdResult *result,
+ int max_results) const {
+ if (result == nullptr) return;
+
+ if (max_results <= 0) {
+ max_results = languages_.size();
+ }
+ result->predictions.clear();
+ if (!is_valid() || (max_results == 0)) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
+ // Tokenize the input text (this also does some pre-processing, like
+ // removing ASCII digits, punctuation, etc).
+ LightSentence sentence;
+ tokenizer_.Tokenize(text, &sentence);
+
+ // Test input size here, after pre-processing removed irrelevant chars.
+ if (IsTooShort(sentence)) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
+ // Extract features from the tokenized text.
+ std::vector<FeatureVector> features =
+ lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
+
+ // Run feed-forward neural network to compute scores (softmax logits).
+ std::vector<float> scores;
+ network_->ComputeFinalScores(features, &scores);
+
+ if (max_results == 1) {
+ // Optimization for the case when the user wants only the top result.
+ // Computing argmax is faster than the general top-k code.
+ int prediction_id = GetArgMax(scores);
+ const std::string language = GetLanguageForSoftmaxLabel(prediction_id);
+ float probability = ComputeSoftmaxProbability(scores, prediction_id);
+ result->predictions.emplace_back(language, probability);
+ } else {
+ // Compute and sort softmax in descending order by probability and convert
+ // IDs to language code strings. When probabilities are equal, we sort by
+ // language code string in ascending order.
+ const std::vector<float> softmax = ComputeSoftmax(scores);
+ const std::vector<int> indices = GetTopKIndices(max_results, softmax);
+ for (const int index : indices) {
+ result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index),
+ softmax[index]);
+ }
+ }
+ }
+
+ bool is_valid() const { return valid_; }
+
+ int GetModelVersion() const { return model_version_; }
+
+ // Returns a property stored in the model file.
+ template <typename T, typename R>
+ R GetProperty(const std::string &property, T default_value) const {
+ return model_provider_->GetTaskContext()->Get(property, default_value);
+ }
+
+ // Perform any necessary static initialization.
+ // This function is thread-safe.
+ // It's also safe to call this function multiple times.
+ //
+ // We explicitly call RegisterClass() rather than relying on alwayslink=1 in
+ // the BUILD file, because the build process for some users of this code
+ // doesn't support any equivalent to alwayslink=1 (in particular the
+ // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might
+ // be possible to add such support, avoiding the need for an equivalent to
+ // alwayslink=1 is preferable because it avoids unnecessarily bloating code
+ // size in apps that link against this code but don't use it.
+ static void RegisterClasses() {
+ static bool initialized = []() -> bool {
+ libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass();
+ libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
+ libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass();
+ libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass();
+ return true;
+ }();
+ (void)initialized; // Variable used only for initializer's side effects.
+ }
+
+ private:
+ bool Setup(TaskContext *context) {
+ tokenizer_.Setup(context);
+ if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
+
+ min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
+ default_threshold_ =
+ context->Get("reliability_thresh", kDefaultConfidenceThreshold);
+
+ // Parse task parameter "per_lang_reliability_thresholds", fill
+ // per_lang_thresholds_.
+ const std::string thresholds_str =
+ context->Get("per_lang_reliability_thresholds", "");
+ std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
+ for (const auto &token : tokens) {
+ if (token.empty()) continue;
+ std::vector<StringPiece> parts = LiteStrSplit(token, '=');
+ float threshold = 0.0f;
+ if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
+ per_lang_thresholds_[std::string(parts[0])] = threshold;
+ } else {
+ SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
+ }
+ }
+ model_version_ = context->Get("model_version", model_version_);
+ return true;
+ }
+
+ bool Init(TaskContext *context) {
+ return lang_id_brain_interface_.InitForProcessing(context);
+ }
+
+ // Returns language code for a softmax label. See comments for languages_
+ // field. If label is out of range, returns LangId::kUnknownLanguageCode.
+ std::string GetLanguageForSoftmaxLabel(int label) const {
+ if ((label >= 0) && (label < languages_.size())) {
+ return languages_[label];
+ } else {
+ SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
+ << languages_.size() << ")";
+ return LangId::kUnknownLanguageCode;
+ }
+ }
+
+ bool IsTooShort(const LightSentence &sentence) const {
+ int text_size = 0;
+ for (const std::string &token : sentence) {
+ // Each token has the form ^...$: we subtract 2 because we want to count
+ // only the real text, not the chars added by us.
+ text_size += token.size() - 2;
+ }
+ return text_size < min_text_size_in_bytes_;
+ }
+
+ std::unique_ptr<ModelProvider> model_provider_;
+
+ TokenizerForLangId tokenizer_;
+
+ EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
+ lang_id_brain_interface_;
+
+ // Neural network to use for scoring.
+ std::unique_ptr<EmbeddingNetwork> network_;
+
+ // True if this object is ready to perform language predictions.
+ bool valid_ = false;
+
+ // The model returns LangId::kUnknownLanguageCode for input text that has
+ // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
+ // digits, and punctuation).
+ int min_text_size_in_bytes_ = 0;
+
+ // Only predictions with a probability (confidence) above this threshold are
+ // reported. Otherwise, we report LangId::kUnknownLanguageCode.
+ float default_threshold_ = kDefaultConfidenceThreshold;
+
+ std::unordered_map<std::string, float> per_lang_thresholds_;
+
+ // Recognized languages: softmax label i means languages_[i] (something like
+ // "en", "fr", "ru", etc).
+ std::vector<std::string> languages_;
+
+ // Version of the model used by this LangIdImpl object. Zero means that the
+ // model version could not be determined.
+ int model_version_ = 0;
+};
+
+const char LangId::kUnknownLanguageCode[] = "und";
+
+LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
+ : pimpl_(new LangIdImpl(std::move(model_provider))) {
+ LangIdImpl::RegisterClasses();
+}
+
+LangId::~LangId() = default;
+
+std::string LangId::FindLanguage(const char *data, size_t num_bytes) const {
+ StringPiece text(data, num_bytes);
+ return pimpl_->FindLanguage(text);
+}
+
+void LangId::FindLanguages(const char *data, size_t num_bytes,
+ LangIdResult *result, int max_results) const {
+ SAFTM_DCHECK(result) << "LangIdResult must not be null.";
+ StringPiece text(data, num_bytes);
+ pimpl_->FindLanguages(text, result, max_results);
+}
+
+bool LangId::is_valid() const { return pimpl_->is_valid(); }
+
+int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
+
+float LangId::GetFloatProperty(const std::string &property,
+ float default_value) const {
+ return pimpl_->GetProperty<float, float>(property, default_value);
+}
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/native/lang_id/lang-id.h b/native/lang_id/lang-id.h
new file mode 100644
index 0000000..18c6e77
--- /dev/null
+++ b/native/lang_id/lang-id.h
@@ -0,0 +1,144 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
+
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/common/lite_base/macros.h"
+#include "lang_id/model-provider.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Forward-declaration of the class that performs all underlying work.
+class LangIdImpl;
+
+struct LangIdResult {
+ // An n-best list of possible language codes for a given input sorted in
+ // descending order according to each code's respective probability.
+ //
+ // This list is guaranteed to be non-empty after calling
+ // LangId::FindLanguages. The most likely language code is always the first
+ // item in this array.
+ //
+ // If the model cannot make a prediction, this array contains a single result:
+ // a language code LangId::kUnknownLanguageCode with probability 1.
+ std::vector<std::pair<std::string, float>> predictions;
+};
+
+// Class for detecting the language of a document.
+//
+// Note: this class does not handle the details of loading the actual model.
+// Those details have been "outsourced" to the ModelProvider class.
+//
+// This class is thread safe.
+class LangId {
+ public:
+ // Standard BCP-47 language code for Unknown/Undetermined language.
+ static const char kUnknownLanguageCode[];
+
+ // Constructs a LangId object, based on |model_provider|.
+ //
+ // Note: we don't crash if we detect a problem at construction time (e.g., the
+ // model provider can't read an underlying file). Instead, we mark the
+ // newly-constructed object as invalid; clients can invoke FindLanguage() on
+ // an invalid object: nothing crashes, but accuracy will be bad.
+ explicit LangId(std::unique_ptr<ModelProvider> model_provider);
+
+ virtual ~LangId();
+
+ // Computes the n-best list of language codes and probabilities corresponding
+ // to the most likely languages the given input text is written in. That list
+ // includes the most likely |max_results| languages and is sorted in
+ // descending order by language probability.
+ //
+ // The input text consists of the |num_bytes| bytes that starts at |data|.
+ //
+ // If max_results <= 0, we report probabilities for all languages known by
+ // this LangId object (as always, in decreasing order of their probabilities).
+ //
+ // Note: If this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, this method sets the LangIdResult to
+ // contain a single entry with kUnknownLanguageCode with probability 1.
+ //
+ void FindLanguages(const char *data, size_t num_bytes, LangIdResult *result,
+ int max_results = 0) const;
+
+ // Convenience version of FindLanguages(const char *, size_t, LangIdResult *).
+ void FindLanguages(const std::string &text, LangIdResult *result,
+ int max_results = 0) const {
+ FindLanguages(text.data(), text.size(), result, max_results);
+ }
+
+ // Returns language code for the most likely language for a piece of text.
+ //
+ // The input text consists of the |num_bytes| bytes that start at |data|.
+ //
+ // Note: this method reports the most likely (1-best) language only if its
+ // probability is high enough; otherwise, it returns
+ // LangId::kUnknownLanguageCode. The specific probability threshold is tuned
+ // to the needs of an early client. If you need a different threshold, you
+ // can use FindLanguages (plural) to get the full LangIdResult, and apply your
+ // own threshold.
+ //
+ // Note: if this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, then this method returns
+ // LangId::kUnknownLanguageCode.
+ //
+ std::string FindLanguage(const char *data, size_t num_bytes) const;
+
+ // Convenience version of FindLanguage(const char *, size_t).
+ std::string FindLanguage(const std::string &text) const {
+ return FindLanguage(text.data(), text.size());
+ }
+
+ // Returns true if this object has been correctly initialized and is ready to
+ // perform predictions. For more info, see doc for LangId
+ // constructor above.
+ bool is_valid() const;
+
+ // Returns the version of the model used by this LangId object. On success,
+ // the returned version number is a strictly positive integer. Returns 0 if
+ // the model version can not be determined (e.g., for old models that do not
+ // specify a version number).
+ int GetModelVersion() const;
+
+ // Returns a typed property stored in the model file.
+ float GetFloatProperty(const std::string &property,
+ float default_value) const;
+
+ private:
+ // Pimpl ("pointer to implementation") pattern, to hide all internals from our
+ // clients.
+ std::unique_ptr<LangIdImpl> pimpl_;
+
+ SAFTM_DISALLOW_COPY_AND_ASSIGN(LangId);
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_H_
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
new file mode 100644
index 0000000..30753dc
--- /dev/null
+++ b/native/lang_id/lang-id_jni.cc
@@ -0,0 +1,166 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id_jni.h"
+
+#include <jni.h>
+
+#include <type_traits>
+#include <vector>
+
+#include "lang_id/lang-id-wrapper.h"
+#include "utils/base/logging.h"
+#include "utils/java/jni-helper.h"
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+using libtextclassifier3::JniHelper;
+using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
+using libtextclassifier3::ToStlString;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
+using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
+using libtextclassifier3::mobile::lang_id::LangId;
+using libtextclassifier3::mobile::lang_id::LangIdResult;
+
+namespace {
+
+StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
+ JNIEnv* env,
+ const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
+
+ TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(env, result_class.get(), "<init>",
+ "(Ljava/lang/String;F)V"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, lang_id_predictions.size(),
+ result_class.get(), nullptr));
+ for (int i = 0; i < lang_id_predictions.size(); i++) {
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jstring> predicted_language,
+ JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(
+ env, result_class.get(), result_class_constructor,
+ predicted_language.get(),
+ static_cast<jfloat>(lang_id_predictions[i].second)));
+ env->SetObjectArrayElement(results.get(), i, result.get());
+ }
+ return results;
+}
+
+float GetNoiseThreshold(const LangId& model) {
+ return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
+}
+} // namespace
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject thiz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject thiz, jstring path) {
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
+ if (!lang_id->is_valid()) {
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ return reinterpret_cast<jlong>(lang_id.release());
+}
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ if (!model) {
+ return nullptr;
+ }
+
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str, ToStlString(env, text));
+
+ const std::vector<std::pair<std::string, float>>& prediction_results =
+ libtextclassifier3::langid::GetPredictions(model, text_str);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ LangIdResultToJObjectArray(env, prediction_results));
+ return results.release();
+}
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ TC3_LOG(ERROR) << "Trying to close null LangId.";
+ return;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ delete model;
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr) {
+ if (!ptr) {
+ return -1;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetModelVersion();
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd) {
+ std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return -1.0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
+}
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return -1.0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return GetNoiseThreshold(*model);
+}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return 0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetFloatProperty("min_text_size_in_bytes", 0);
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
new file mode 100644
index 0000000..219349c
--- /dev/null
+++ b/native/lang_id/lang-id_jni.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// JNI wrapper for LangId.
+
+#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
+
+#include <jni.h>
+#include <string>
+#include "utils/java/jni-base.h"
+
+#ifndef TC3_LANG_ID_CLASS_NAME
+#define TC3_LANG_ID_CLASS_NAME LangIdModel
+#endif
+
+#define TC3_LANG_ID_CLASS_NAME_STR TC3_ADD_QUOTES(TC3_LANG_ID_CLASS_NAME)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
+(JNIEnv* env, jobject clazz, jstring path);
+
+TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
+
+TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
+TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
diff --git a/native/lang_id/light-sentence.h b/native/lang_id/light-sentence.h
new file mode 100644
index 0000000..2aee2ea
--- /dev/null
+++ b/native/lang_id/light-sentence.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
+
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Very simplified alternative to heavy sentence.proto, for the purpose of
+// LangId. It turns out that in this case, all we need is a vector of strings,
+// which uses a lot less code size than a Sentence proto.
+using LightSentence = std::vector<std::string>;
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LIGHT_SENTENCE_H_
diff --git a/native/lang_id/model-provider.h b/native/lang_id/model-provider.h
new file mode 100644
index 0000000..bf250ed
--- /dev/null
+++ b/native/lang_id/model-provider.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-network-params.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace lang_id {
+
+// Interface for accessing parameters for the LangId model.
+//
+// Note: some clients prefer to include the model parameters in the binary,
+// others prefer loading them from a separate file. This file provides a common
+// interface for these alternative mechanisms.
+class ModelProvider {
+ public:
+ virtual ~ModelProvider() = default;
+
+ // Returns true if this ModelProvider has been succesfully constructed (e.g.,
+ // can return false if an underlying model file could not be read). Clients
+ // should not use invalid ModelProviders.
+ bool is_valid() { return valid_; }
+
+ // Returns the TaskContext with parameters for the LangId model. E.g., one
+ // important parameter specifies the features to use.
+ virtual const TaskContext *GetTaskContext() const = 0;
+
+ // Returns parameters for the underlying Neurosis feed-forward neural network.
+ virtual const EmbeddingNetworkParams *GetNnParams() const = 0;
+
+ // Returns list of languages recognized by the model. Each element of the
+ // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc).
+ // Language at index i from the returned vector corresponds to softmax label
+ // i.
+ virtual std::vector<std::string> GetLanguages() const = 0;
+
+ protected:
+ bool valid_ = false;
+};
+
+} // namespace lang_id
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_
diff --git a/native/lang_id/script/approx-script-data.cc b/native/lang_id/script/approx-script-data.cc
new file mode 100755
index 0000000..233653f
--- /dev/null
+++ b/native/lang_id/script/approx-script-data.cc
@@ -0,0 +1,1173 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Internal data for approx-script.cc; see approx-script-data.h
+//
+// DO NOT EDIT BY HAND
+//
+// Generated by
+// lang_id/script/update-script-data.sh
+
+#include "lang_id/script/approx-script-data.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+namespace approx_script_internal {
+
+const int kNumRanges = 376;
+
+const uint32 kRangeFirst[] = {
+ 65, // Range #0: [65, 90, Latin]
+ 97, // Range #1: [97, 122, Latin]
+ 170, // Range #2: [170, 170, Latin]
+ 186, // Range #3: [186, 186, Latin]
+ 192, // Range #4: [192, 214, Latin]
+ 216, // Range #5: [216, 246, Latin]
+ 248, // Range #6: [248, 696, Latin]
+ 736, // Range #7: [736, 740, Latin]
+ 746, // Range #8: [746, 747, Bopomofo]
+ 880, // Range #9: [880, 883, Greek]
+ 885, // Range #10: [885, 893, Greek]
+ 895, // Range #11: [895, 900, Greek]
+ 902, // Range #12: [902, 902, Greek]
+ 904, // Range #13: [904, 993, Greek]
+ 994, // Range #14: [994, 1007, Coptic]
+ 1008, // Range #15: [1008, 1023, Greek]
+ 1024, // Range #16: [1024, 1156, Cyrillic]
+ 1159, // Range #17: [1159, 1327, Cyrillic]
+ 1329, // Range #18: [1329, 1423, Armenian]
+ 1425, // Range #19: [1425, 1479, Hebrew]
+ 1488, // Range #20: [1488, 1524, Hebrew]
+ 1536, // Range #21: [1536, 1540, Arabic]
+ 1542, // Range #22: [1542, 1547, Arabic]
+ 1549, // Range #23: [1549, 1562, Arabic]
+ 1564, // Range #24: [1564, 1566, Arabic]
+ 1568, // Range #25: [1568, 1599, Arabic]
+ 1601, // Range #26: [1601, 1610, Arabic]
+ 1622, // Range #27: [1622, 1647, Arabic]
+ 1649, // Range #28: [1649, 1756, Arabic]
+ 1758, // Range #29: [1758, 1791, Arabic]
+ 1792, // Range #30: [1792, 1871, Syriac]
+ 1872, // Range #31: [1872, 1919, Arabic]
+ 1920, // Range #32: [1920, 1969, Thaana]
+ 1984, // Range #33: [1984, 2047, Nko]
+ 2048, // Range #34: [2048, 2110, Samaritan]
+ 2112, // Range #35: [2112, 2142, Mandaic]
+ 2144, // Range #36: [2144, 2154, Syriac]
+ 2208, // Range #37: [2208, 2247, Arabic]
+ 2259, // Range #38: [2259, 2273, Arabic]
+ 2275, // Range #39: [2275, 2303, Arabic]
+ 2304, // Range #40: [2304, 2384, Devanagari]
+ 2389, // Range #41: [2389, 2403, Devanagari]
+ 2406, // Range #42: [2406, 2431, Devanagari]
+ 2432, // Range #43: [2432, 2510, Bengali]
+ 2519, // Range #44: [2519, 2558, Bengali]
+ 2561, // Range #45: [2561, 2641, Gurmukhi]
+ 2649, // Range #46: [2649, 2654, Gurmukhi]
+ 2662, // Range #47: [2662, 2678, Gurmukhi]
+ 2689, // Range #48: [2689, 2768, Gujarati]
+ 2784, // Range #49: [2784, 2801, Gujarati]
+ 2809, // Range #50: [2809, 2815, Gujarati]
+ 2817, // Range #51: [2817, 2893, Oriya]
+ 2901, // Range #52: [2901, 2935, Oriya]
+ 2946, // Range #53: [2946, 3024, Tamil]
+ 3031, // Range #54: [3031, 3031, Tamil]
+ 3046, // Range #55: [3046, 3066, Tamil]
+ 3072, // Range #56: [3072, 3149, Telugu]
+ 3157, // Range #57: [3157, 3162, Telugu]
+ 3168, // Range #58: [3168, 3183, Telugu]
+ 3191, // Range #59: [3191, 3199, Telugu]
+ 3200, // Range #60: [3200, 3277, Kannada]
+ 3285, // Range #61: [3285, 3286, Kannada]
+ 3294, // Range #62: [3294, 3314, Kannada]
+ 3328, // Range #63: [3328, 3455, Malayalam]
+ 3457, // Range #64: [3457, 3551, Sinhala]
+ 3558, // Range #65: [3558, 3572, Sinhala]
+ 3585, // Range #66: [3585, 3642, Thai]
+ 3648, // Range #67: [3648, 3675, Thai]
+ 3713, // Range #68: [3713, 3807, Lao]
+ 3840, // Range #69: [3840, 4052, Tibetan]
+ 4057, // Range #70: [4057, 4058, Tibetan]
+ 4096, // Range #71: [4096, 4255, Myanmar]
+ 4256, // Range #72: [4256, 4295, Georgian]
+ 4301, // Range #73: [4301, 4346, Georgian]
+ 4348, // Range #74: [4348, 4351, Georgian]
+ 4352, // Range #75: [4352, 4607, Hangul]
+ 4608, // Range #76: [4608, 5017, Ethiopic]
+ 5024, // Range #77: [5024, 5117, Cherokee]
+ 5120, // Range #78: [5120, 5759, Canadian_Aboriginal]
+ 5760, // Range #79: [5760, 5788, Ogham]
+ 5792, // Range #80: [5792, 5866, Runic]
+ 5870, // Range #81: [5870, 5880, Runic]
+ 5888, // Range #82: [5888, 5908, Tagalog]
+ 5920, // Range #83: [5920, 5940, Hanunoo]
+ 5952, // Range #84: [5952, 5971, Buhid]
+ 5984, // Range #85: [5984, 6003, Tagbanwa]
+ 6016, // Range #86: [6016, 6121, Khmer]
+ 6128, // Range #87: [6128, 6137, Khmer]
+ 6144, // Range #88: [6144, 6145, Mongolian]
+ 6148, // Range #89: [6148, 6148, Mongolian]
+ 6150, // Range #90: [6150, 6169, Mongolian]
+ 6176, // Range #91: [6176, 6264, Mongolian]
+ 6272, // Range #92: [6272, 6314, Mongolian]
+ 6320, // Range #93: [6320, 6389, Canadian_Aboriginal]
+ 6400, // Range #94: [6400, 6479, Limbu]
+ 6480, // Range #95: [6480, 6516, Tai_Le]
+ 6528, // Range #96: [6528, 6601, New_Tai_Lue]
+ 6608, // Range #97: [6608, 6623, New_Tai_Lue]
+ 6624, // Range #98: [6624, 6655, Khmer]
+ 6656, // Range #99: [6656, 6687, Buginese]
+ 6688, // Range #100: [6688, 6793, Tai_Tham]
+ 6800, // Range #101: [6800, 6809, Tai_Tham]
+ 6816, // Range #102: [6816, 6829, Tai_Tham]
+ 6912, // Range #103: [6912, 7036, Balinese]
+ 7040, // Range #104: [7040, 7103, Sundanese]
+ 7104, // Range #105: [7104, 7155, Batak]
+ 7164, // Range #106: [7164, 7167, Batak]
+ 7168, // Range #107: [7168, 7247, Lepcha]
+ 7248, // Range #108: [7248, 7295, Ol_Chiki]
+ 7296, // Range #109: [7296, 7304, Cyrillic]
+ 7312, // Range #110: [7312, 7359, Georgian]
+ 7360, // Range #111: [7360, 7367, Sundanese]
+ 7424, // Range #112: [7424, 7461, Latin]
+ 7462, // Range #113: [7462, 7466, Greek]
+ 7467, // Range #114: [7467, 7467, Cyrillic]
+ 7468, // Range #115: [7468, 7516, Latin]
+ 7517, // Range #116: [7517, 7521, Greek]
+ 7522, // Range #117: [7522, 7525, Latin]
+ 7526, // Range #118: [7526, 7530, Greek]
+ 7531, // Range #119: [7531, 7543, Latin]
+ 7544, // Range #120: [7544, 7544, Cyrillic]
+ 7545, // Range #121: [7545, 7614, Latin]
+ 7615, // Range #122: [7615, 7615, Greek]
+ 7680, // Range #123: [7680, 7935, Latin]
+ 7936, // Range #124: [7936, 8190, Greek]
+ 8305, // Range #125: [8305, 8305, Latin]
+ 8319, // Range #126: [8319, 8319, Latin]
+ 8336, // Range #127: [8336, 8348, Latin]
+ 8486, // Range #128: [8486, 8486, Greek]
+ 8490, // Range #129: [8490, 8491, Latin]
+ 8498, // Range #130: [8498, 8498, Latin]
+ 8526, // Range #131: [8526, 8526, Latin]
+ 8544, // Range #132: [8544, 8584, Latin]
+ 10240, // Range #133: [10240, 10495, Braille]
+ 11264, // Range #134: [11264, 11358, Glagolitic]
+ 11360, // Range #135: [11360, 11391, Latin]
+ 11392, // Range #136: [11392, 11507, Coptic]
+ 11513, // Range #137: [11513, 11519, Coptic]
+ 11520, // Range #138: [11520, 11559, Georgian]
+ 11565, // Range #139: [11565, 11565, Georgian]
+ 11568, // Range #140: [11568, 11623, Tifinagh]
+ 11631, // Range #141: [11631, 11632, Tifinagh]
+ 11647, // Range #142: [11647, 11647, Tifinagh]
+ 11648, // Range #143: [11648, 11670, Ethiopic]
+ 11680, // Range #144: [11680, 11742, Ethiopic]
+ 11744, // Range #145: [11744, 11775, Cyrillic]
+ 11904, // Range #146: [11904, 12019, Han]
+ 12032, // Range #147: [12032, 12245, Han]
+ 12293, // Range #148: [12293, 12293, Han]
+ 12295, // Range #149: [12295, 12295, Han]
+ 12321, // Range #150: [12321, 12329, Han]
+ 12334, // Range #151: [12334, 12335, Hangul]
+ 12344, // Range #152: [12344, 12347, Han]
+ 12353, // Range #153: [12353, 12438, Hiragana]
+ 12445, // Range #154: [12445, 12447, Hiragana]
+ 12449, // Range #155: [12449, 12538, Katakana]
+ 12541, // Range #156: [12541, 12543, Katakana]
+ 12549, // Range #157: [12549, 12591, Bopomofo]
+ 12593, // Range #158: [12593, 12686, Hangul]
+ 12704, // Range #159: [12704, 12735, Bopomofo]
+ 12784, // Range #160: [12784, 12799, Katakana]
+ 12800, // Range #161: [12800, 12830, Hangul]
+ 12896, // Range #162: [12896, 12926, Hangul]
+ 13008, // Range #163: [13008, 13054, Katakana]
+ 13056, // Range #164: [13056, 13143, Katakana]
+ 13312, // Range #165: [13312, 19903, Han]
+ 19968, // Range #166: [19968, 40956, Han]
+ 40960, // Range #167: [40960, 42182, Yi]
+ 42192, // Range #168: [42192, 42239, Lisu]
+ 42240, // Range #169: [42240, 42539, Vai]
+ 42560, // Range #170: [42560, 42655, Cyrillic]
+ 42656, // Range #171: [42656, 42743, Bamum]
+ 42786, // Range #172: [42786, 42887, Latin]
+ 42891, // Range #173: [42891, 42954, Latin]
+ 42997, // Range #174: [42997, 43007, Latin]
+ 43008, // Range #175: [43008, 43052, Syloti_Nagri]
+ 43072, // Range #176: [43072, 43127, Phags_Pa]
+ 43136, // Range #177: [43136, 43205, Saurashtra]
+ 43214, // Range #178: [43214, 43225, Saurashtra]
+ 43232, // Range #179: [43232, 43263, Devanagari]
+ 43264, // Range #180: [43264, 43309, Kayah_Li]
+ 43311, // Range #181: [43311, 43311, Kayah_Li]
+ 43312, // Range #182: [43312, 43347, Rejang]
+ 43359, // Range #183: [43359, 43359, Rejang]
+ 43360, // Range #184: [43360, 43388, Hangul]
+ 43392, // Range #185: [43392, 43469, Javanese]
+ 43472, // Range #186: [43472, 43487, Javanese]
+ 43488, // Range #187: [43488, 43518, Myanmar]
+ 43520, // Range #188: [43520, 43574, Cham]
+ 43584, // Range #189: [43584, 43615, Cham]
+ 43616, // Range #190: [43616, 43647, Myanmar]
+ 43648, // Range #191: [43648, 43714, Tai_Viet]
+ 43739, // Range #192: [43739, 43743, Tai_Viet]
+ 43744, // Range #193: [43744, 43766, Meetei_Mayek]
+ 43777, // Range #194: [43777, 43798, Ethiopic]
+ 43808, // Range #195: [43808, 43822, Ethiopic]
+ 43824, // Range #196: [43824, 43866, Latin]
+ 43868, // Range #197: [43868, 43876, Latin]
+ 43877, // Range #198: [43877, 43877, Greek]
+ 43878, // Range #199: [43878, 43881, Latin]
+ 43888, // Range #200: [43888, 43967, Cherokee]
+ 43968, // Range #201: [43968, 44025, Meetei_Mayek]
+ 44032, // Range #202: [44032, 55203, Hangul]
+ 55216, // Range #203: [55216, 55291, Hangul]
+ 63744, // Range #204: [63744, 64217, Han]
+ 64256, // Range #205: [64256, 64262, Latin]
+ 64275, // Range #206: [64275, 64279, Armenian]
+ 64285, // Range #207: [64285, 64335, Hebrew]
+ 64336, // Range #208: [64336, 64449, Arabic]
+ 64467, // Range #209: [64467, 64829, Arabic]
+ 64848, // Range #210: [64848, 64967, Arabic]
+ 65008, // Range #211: [65008, 65021, Arabic]
+ 65070, // Range #212: [65070, 65071, Cyrillic]
+ 65136, // Range #213: [65136, 65276, Arabic]
+ 65313, // Range #214: [65313, 65338, Latin]
+ 65345, // Range #215: [65345, 65370, Latin]
+ 65382, // Range #216: [65382, 65391, Katakana]
+ 65393, // Range #217: [65393, 65437, Katakana]
+ 65440, // Range #218: [65440, 65500, Hangul]
+ 65536, // Range #219: [65536, 65629, Linear_B]
+ 65664, // Range #220: [65664, 65786, Linear_B]
+ 65856, // Range #221: [65856, 65934, Greek]
+ 65952, // Range #222: [65952, 65952, Greek]
+ 66176, // Range #223: [66176, 66204, Lycian]
+ 66208, // Range #224: [66208, 66256, Carian]
+ 66304, // Range #225: [66304, 66339, Old_Italic]
+ 66349, // Range #226: [66349, 66351, Old_Italic]
+ 66352, // Range #227: [66352, 66378, Gothic]
+ 66384, // Range #228: [66384, 66426, Old_Permic]
+ 66432, // Range #229: [66432, 66463, Ugaritic]
+ 66464, // Range #230: [66464, 66517, Old_Persian]
+ 66560, // Range #231: [66560, 66639, Deseret]
+ 66640, // Range #232: [66640, 66687, Shavian]
+ 66688, // Range #233: [66688, 66729, Osmanya]
+ 66736, // Range #234: [66736, 66811, Osage]
+ 66816, // Range #235: [66816, 66855, Elbasan]
+ 66864, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 66927, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 67072, // Range #238: [67072, 67382, Linear_A]
+ 67392, // Range #239: [67392, 67413, Linear_A]
+ 67424, // Range #240: [67424, 67431, Linear_A]
+ 67584, // Range #241: [67584, 67647, Cypriot]
+ 67648, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 67680, // Range #243: [67680, 67711, Palmyrene]
+ 67712, // Range #244: [67712, 67742, Nabataean]
+ 67751, // Range #245: [67751, 67759, Nabataean]
+ 67808, // Range #246: [67808, 67829, Hatran]
+ 67835, // Range #247: [67835, 67839, Hatran]
+ 67840, // Range #248: [67840, 67871, Phoenician]
+ 67872, // Range #249: [67872, 67897, Lydian]
+ 67903, // Range #250: [67903, 67903, Lydian]
+ 67968, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 68000, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 68096, // Range #253: [68096, 68102, Kharoshthi]
+ 68108, // Range #254: [68108, 68168, Kharoshthi]
+ 68176, // Range #255: [68176, 68184, Kharoshthi]
+ 68192, // Range #256: [68192, 68223, Old_South_Arabian]
+ 68224, // Range #257: [68224, 68255, Old_North_Arabian]
+ 68288, // Range #258: [68288, 68342, Manichaean]
+ 68352, // Range #259: [68352, 68415, Avestan]
+ 68416, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 68448, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 68472, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 68480, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 68505, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 68521, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 68608, // Range #266: [68608, 68680, Old_Turkic]
+ 68736, // Range #267: [68736, 68786, Old_Hungarian]
+ 68800, // Range #268: [68800, 68850, Old_Hungarian]
+ 68858, // Range #269: [68858, 68863, Old_Hungarian]
+ 68864, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 68912, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 69216, // Range #272: [69216, 69246, Arabic]
+ 69248, // Range #273: [69248, 69297, Yezidi]
+ 69376, // Range #274: [69376, 69415, Old_Sogdian]
+ 69424, // Range #275: [69424, 69465, Sogdian]
+ 69552, // Range #276: [69552, 69579, Chorasmian]
+ 69600, // Range #277: [69600, 69622, Elymaic]
+ 69632, // Range #278: [69632, 69743, Brahmi]
+ 69759, // Range #279: [69759, 69759, Brahmi]
+ 69760, // Range #280: [69760, 69825, Kaithi]
+ 69837, // Range #281: [69837, 69837, Kaithi]
+ 69840, // Range #282: [69840, 69864, Sora_Sompeng]
+ 69872, // Range #283: [69872, 69881, Sora_Sompeng]
+ 69888, // Range #284: [69888, 69959, Chakma]
+ 69968, // Range #285: [69968, 70006, Mahajani]
+ 70016, // Range #286: [70016, 70111, Sharada]
+ 70113, // Range #287: [70113, 70132, Sinhala]
+ 70144, // Range #288: [70144, 70206, Khojki]
+ 70272, // Range #289: [70272, 70313, Multani]
+ 70320, // Range #290: [70320, 70378, Khudawadi]
+ 70384, // Range #291: [70384, 70393, Khudawadi]
+ 70400, // Range #292: [70400, 70457, Grantha]
+ 70460, // Range #293: [70460, 70480, Grantha]
+ 70487, // Range #294: [70487, 70487, Grantha]
+ 70493, // Range #295: [70493, 70516, Grantha]
+ 70656, // Range #296: [70656, 70753, Newa]
+ 70784, // Range #297: [70784, 70855, Tirhuta]
+ 70864, // Range #298: [70864, 70873, Tirhuta]
+ 71040, // Range #299: [71040, 71133, Siddham]
+ 71168, // Range #300: [71168, 71236, Modi]
+ 71248, // Range #301: [71248, 71257, Modi]
+ 71264, // Range #302: [71264, 71276, Mongolian]
+ 71296, // Range #303: [71296, 71352, Takri]
+ 71360, // Range #304: [71360, 71369, Takri]
+ 71424, // Range #305: [71424, 71487, Ahom]
+ 71680, // Range #306: [71680, 71739, Dogra]
+ 71840, // Range #307: [71840, 71922, Warang_Citi]
+ 71935, // Range #308: [71935, 71935, Warang_Citi]
+ 71936, // Range #309: [71936, 72006, Dives_Akuru]
+ 72016, // Range #310: [72016, 72025, Dives_Akuru]
+ 72096, // Range #311: [72096, 72164, Nandinagari]
+ 72192, // Range #312: [72192, 72263, Zanabazar_Square]
+ 72272, // Range #313: [72272, 72354, Soyombo]
+ 72384, // Range #314: [72384, 72440, Pau_Cin_Hau]
+ 72704, // Range #315: [72704, 72773, Bhaiksuki]
+ 72784, // Range #316: [72784, 72812, Bhaiksuki]
+ 72816, // Range #317: [72816, 72886, Marchen]
+ 72960, // Range #318: [72960, 73031, Masaram_Gondi]
+ 73040, // Range #319: [73040, 73049, Masaram_Gondi]
+ 73056, // Range #320: [73056, 73112, Gunjala_Gondi]
+ 73120, // Range #321: [73120, 73129, Gunjala_Gondi]
+ 73440, // Range #322: [73440, 73464, Makasar]
+ 73648, // Range #323: [73648, 73648, Lisu]
+ 73664, // Range #324: [73664, 73713, Tamil]
+ 73727, // Range #325: [73727, 73727, Tamil]
+ 73728, // Range #326: [73728, 74649, Cuneiform]
+ 74752, // Range #327: [74752, 74868, Cuneiform]
+ 74880, // Range #328: [74880, 75075, Cuneiform]
+ 77824, // Range #329: [77824, 78904, Egyptian_Hieroglyphs]
+ 82944, // Range #330: [82944, 83526, Anatolian_Hieroglyphs]
+ 92160, // Range #331: [92160, 92728, Bamum]
+ 92736, // Range #332: [92736, 92783, Mro]
+ 92880, // Range #333: [92880, 92917, Bassa_Vah]
+ 92928, // Range #334: [92928, 92997, Pahawh_Hmong]
+ 93008, // Range #335: [93008, 93047, Pahawh_Hmong]
+ 93053, // Range #336: [93053, 93071, Pahawh_Hmong]
+ 93760, // Range #337: [93760, 93850, Medefaidrin]
+ 93952, // Range #338: [93952, 94087, Miao]
+ 94095, // Range #339: [94095, 94111, Miao]
+ 94176, // Range #340: [94176, 94176, Tangut]
+ 94177, // Range #341: [94177, 94177, Nushu]
+ 94180, // Range #342: [94180, 94180, Khitan_Small_Script]
+ 94192, // Range #343: [94192, 94193, Han]
+ 94208, // Range #344: [94208, 100343, Tangut]
+ 100352, // Range #345: [100352, 101119, Tangut]
+ 101120, // Range #346: [101120, 101589, Khitan_Small_Script]
+ 101632, // Range #347: [101632, 101640, Tangut]
+ 110592, // Range #348: [110592, 110592, Katakana]
+ 110593, // Range #349: [110593, 110878, Hiragana]
+ 110928, // Range #350: [110928, 110930, Hiragana]
+ 110948, // Range #351: [110948, 110951, Katakana]
+ 110960, // Range #352: [110960, 111355, Nushu]
+ 113664, // Range #353: [113664, 113770, Duployan]
+ 113776, // Range #354: [113776, 113800, Duployan]
+ 113808, // Range #355: [113808, 113823, Duployan]
+ 119296, // Range #356: [119296, 119365, Greek]
+ 120832, // Range #357: [120832, 121483, SignWriting]
+ 121499, // Range #358: [121499, 121519, SignWriting]
+ 122880, // Range #359: [122880, 122922, Glagolitic]
+ 123136, // Range #360: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 123584, // Range #361: [123584, 123641, Wancho]
+ 123647, // Range #362: [123647, 123647, Wancho]
+ 124928, // Range #363: [124928, 125142, Mende_Kikakui]
+ 125184, // Range #364: [125184, 125279, Adlam]
+ 126464, // Range #365: [126464, 126523, Arabic]
+ 126530, // Range #366: [126530, 126619, Arabic]
+ 126625, // Range #367: [126625, 126651, Arabic]
+ 126704, // Range #368: [126704, 126705, Arabic]
+ 127488, // Range #369: [127488, 127488, Hiragana]
+ 131072, // Range #370: [131072, 173789, Han]
+ 173824, // Range #371: [173824, 177972, Han]
+ 177984, // Range #372: [177984, 183969, Han]
+ 183984, // Range #373: [183984, 191456, Han]
+ 194560, // Range #374: [194560, 195101, Han]
+ 196608, // Range #375: [196608, 201546, Han]
+};
+
+const uint16 kRangeSizeMinusOne[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 0, // Range #2: [170, 170, Latin]
+ 0, // Range #3: [186, 186, Latin]
+ 22, // Range #4: [192, 214, Latin]
+ 30, // Range #5: [216, 246, Latin]
+ 448, // Range #6: [248, 696, Latin]
+ 4, // Range #7: [736, 740, Latin]
+ 1, // Range #8: [746, 747, Bopomofo]
+ 3, // Range #9: [880, 883, Greek]
+ 8, // Range #10: [885, 893, Greek]
+ 5, // Range #11: [895, 900, Greek]
+ 0, // Range #12: [902, 902, Greek]
+ 89, // Range #13: [904, 993, Greek]
+ 13, // Range #14: [994, 1007, Coptic]
+ 15, // Range #15: [1008, 1023, Greek]
+ 132, // Range #16: [1024, 1156, Cyrillic]
+ 168, // Range #17: [1159, 1327, Cyrillic]
+ 94, // Range #18: [1329, 1423, Armenian]
+ 54, // Range #19: [1425, 1479, Hebrew]
+ 36, // Range #20: [1488, 1524, Hebrew]
+ 4, // Range #21: [1536, 1540, Arabic]
+ 5, // Range #22: [1542, 1547, Arabic]
+ 13, // Range #23: [1549, 1562, Arabic]
+ 2, // Range #24: [1564, 1566, Arabic]
+ 31, // Range #25: [1568, 1599, Arabic]
+ 9, // Range #26: [1601, 1610, Arabic]
+ 25, // Range #27: [1622, 1647, Arabic]
+ 107, // Range #28: [1649, 1756, Arabic]
+ 33, // Range #29: [1758, 1791, Arabic]
+ 79, // Range #30: [1792, 1871, Syriac]
+ 47, // Range #31: [1872, 1919, Arabic]
+ 49, // Range #32: [1920, 1969, Thaana]
+ 63, // Range #33: [1984, 2047, Nko]
+ 62, // Range #34: [2048, 2110, Samaritan]
+ 30, // Range #35: [2112, 2142, Mandaic]
+ 10, // Range #36: [2144, 2154, Syriac]
+ 39, // Range #37: [2208, 2247, Arabic]
+ 14, // Range #38: [2259, 2273, Arabic]
+ 28, // Range #39: [2275, 2303, Arabic]
+ 80, // Range #40: [2304, 2384, Devanagari]
+ 14, // Range #41: [2389, 2403, Devanagari]
+ 25, // Range #42: [2406, 2431, Devanagari]
+ 78, // Range #43: [2432, 2510, Bengali]
+ 39, // Range #44: [2519, 2558, Bengali]
+ 80, // Range #45: [2561, 2641, Gurmukhi]
+ 5, // Range #46: [2649, 2654, Gurmukhi]
+ 16, // Range #47: [2662, 2678, Gurmukhi]
+ 79, // Range #48: [2689, 2768, Gujarati]
+ 17, // Range #49: [2784, 2801, Gujarati]
+ 6, // Range #50: [2809, 2815, Gujarati]
+ 76, // Range #51: [2817, 2893, Oriya]
+ 34, // Range #52: [2901, 2935, Oriya]
+ 78, // Range #53: [2946, 3024, Tamil]
+ 0, // Range #54: [3031, 3031, Tamil]
+ 20, // Range #55: [3046, 3066, Tamil]
+ 77, // Range #56: [3072, 3149, Telugu]
+ 5, // Range #57: [3157, 3162, Telugu]
+ 15, // Range #58: [3168, 3183, Telugu]
+ 8, // Range #59: [3191, 3199, Telugu]
+ 77, // Range #60: [3200, 3277, Kannada]
+ 1, // Range #61: [3285, 3286, Kannada]
+ 20, // Range #62: [3294, 3314, Kannada]
+ 127, // Range #63: [3328, 3455, Malayalam]
+ 94, // Range #64: [3457, 3551, Sinhala]
+ 14, // Range #65: [3558, 3572, Sinhala]
+ 57, // Range #66: [3585, 3642, Thai]
+ 27, // Range #67: [3648, 3675, Thai]
+ 94, // Range #68: [3713, 3807, Lao]
+ 212, // Range #69: [3840, 4052, Tibetan]
+ 1, // Range #70: [4057, 4058, Tibetan]
+ 159, // Range #71: [4096, 4255, Myanmar]
+ 39, // Range #72: [4256, 4295, Georgian]
+ 45, // Range #73: [4301, 4346, Georgian]
+ 3, // Range #74: [4348, 4351, Georgian]
+ 255, // Range #75: [4352, 4607, Hangul]
+ 409, // Range #76: [4608, 5017, Ethiopic]
+ 93, // Range #77: [5024, 5117, Cherokee]
+ 639, // Range #78: [5120, 5759, Canadian_Aboriginal]
+ 28, // Range #79: [5760, 5788, Ogham]
+ 74, // Range #80: [5792, 5866, Runic]
+ 10, // Range #81: [5870, 5880, Runic]
+ 20, // Range #82: [5888, 5908, Tagalog]
+ 20, // Range #83: [5920, 5940, Hanunoo]
+ 19, // Range #84: [5952, 5971, Buhid]
+ 19, // Range #85: [5984, 6003, Tagbanwa]
+ 105, // Range #86: [6016, 6121, Khmer]
+ 9, // Range #87: [6128, 6137, Khmer]
+ 1, // Range #88: [6144, 6145, Mongolian]
+ 0, // Range #89: [6148, 6148, Mongolian]
+ 19, // Range #90: [6150, 6169, Mongolian]
+ 88, // Range #91: [6176, 6264, Mongolian]
+ 42, // Range #92: [6272, 6314, Mongolian]
+ 69, // Range #93: [6320, 6389, Canadian_Aboriginal]
+ 79, // Range #94: [6400, 6479, Limbu]
+ 36, // Range #95: [6480, 6516, Tai_Le]
+ 73, // Range #96: [6528, 6601, New_Tai_Lue]
+ 15, // Range #97: [6608, 6623, New_Tai_Lue]
+ 31, // Range #98: [6624, 6655, Khmer]
+ 31, // Range #99: [6656, 6687, Buginese]
+ 105, // Range #100: [6688, 6793, Tai_Tham]
+ 9, // Range #101: [6800, 6809, Tai_Tham]
+ 13, // Range #102: [6816, 6829, Tai_Tham]
+ 124, // Range #103: [6912, 7036, Balinese]
+ 63, // Range #104: [7040, 7103, Sundanese]
+ 51, // Range #105: [7104, 7155, Batak]
+ 3, // Range #106: [7164, 7167, Batak]
+ 79, // Range #107: [7168, 7247, Lepcha]
+ 47, // Range #108: [7248, 7295, Ol_Chiki]
+ 8, // Range #109: [7296, 7304, Cyrillic]
+ 47, // Range #110: [7312, 7359, Georgian]
+ 7, // Range #111: [7360, 7367, Sundanese]
+ 37, // Range #112: [7424, 7461, Latin]
+ 4, // Range #113: [7462, 7466, Greek]
+ 0, // Range #114: [7467, 7467, Cyrillic]
+ 48, // Range #115: [7468, 7516, Latin]
+ 4, // Range #116: [7517, 7521, Greek]
+ 3, // Range #117: [7522, 7525, Latin]
+ 4, // Range #118: [7526, 7530, Greek]
+ 12, // Range #119: [7531, 7543, Latin]
+ 0, // Range #120: [7544, 7544, Cyrillic]
+ 69, // Range #121: [7545, 7614, Latin]
+ 0, // Range #122: [7615, 7615, Greek]
+ 255, // Range #123: [7680, 7935, Latin]
+ 254, // Range #124: [7936, 8190, Greek]
+ 0, // Range #125: [8305, 8305, Latin]
+ 0, // Range #126: [8319, 8319, Latin]
+ 12, // Range #127: [8336, 8348, Latin]
+ 0, // Range #128: [8486, 8486, Greek]
+ 1, // Range #129: [8490, 8491, Latin]
+ 0, // Range #130: [8498, 8498, Latin]
+ 0, // Range #131: [8526, 8526, Latin]
+ 40, // Range #132: [8544, 8584, Latin]
+ 255, // Range #133: [10240, 10495, Braille]
+ 94, // Range #134: [11264, 11358, Glagolitic]
+ 31, // Range #135: [11360, 11391, Latin]
+ 115, // Range #136: [11392, 11507, Coptic]
+ 6, // Range #137: [11513, 11519, Coptic]
+ 39, // Range #138: [11520, 11559, Georgian]
+ 0, // Range #139: [11565, 11565, Georgian]
+ 55, // Range #140: [11568, 11623, Tifinagh]
+ 1, // Range #141: [11631, 11632, Tifinagh]
+ 0, // Range #142: [11647, 11647, Tifinagh]
+ 22, // Range #143: [11648, 11670, Ethiopic]
+ 62, // Range #144: [11680, 11742, Ethiopic]
+ 31, // Range #145: [11744, 11775, Cyrillic]
+ 115, // Range #146: [11904, 12019, Han]
+ 213, // Range #147: [12032, 12245, Han]
+ 0, // Range #148: [12293, 12293, Han]
+ 0, // Range #149: [12295, 12295, Han]
+ 8, // Range #150: [12321, 12329, Han]
+ 1, // Range #151: [12334, 12335, Hangul]
+ 3, // Range #152: [12344, 12347, Han]
+ 85, // Range #153: [12353, 12438, Hiragana]
+ 2, // Range #154: [12445, 12447, Hiragana]
+ 89, // Range #155: [12449, 12538, Katakana]
+ 2, // Range #156: [12541, 12543, Katakana]
+ 42, // Range #157: [12549, 12591, Bopomofo]
+ 93, // Range #158: [12593, 12686, Hangul]
+ 31, // Range #159: [12704, 12735, Bopomofo]
+ 15, // Range #160: [12784, 12799, Katakana]
+ 30, // Range #161: [12800, 12830, Hangul]
+ 30, // Range #162: [12896, 12926, Hangul]
+ 46, // Range #163: [13008, 13054, Katakana]
+ 87, // Range #164: [13056, 13143, Katakana]
+ 6591, // Range #165: [13312, 19903, Han]
+ 20988, // Range #166: [19968, 40956, Han]
+ 1222, // Range #167: [40960, 42182, Yi]
+ 47, // Range #168: [42192, 42239, Lisu]
+ 299, // Range #169: [42240, 42539, Vai]
+ 95, // Range #170: [42560, 42655, Cyrillic]
+ 87, // Range #171: [42656, 42743, Bamum]
+ 101, // Range #172: [42786, 42887, Latin]
+ 63, // Range #173: [42891, 42954, Latin]
+ 10, // Range #174: [42997, 43007, Latin]
+ 44, // Range #175: [43008, 43052, Syloti_Nagri]
+ 55, // Range #176: [43072, 43127, Phags_Pa]
+ 69, // Range #177: [43136, 43205, Saurashtra]
+ 11, // Range #178: [43214, 43225, Saurashtra]
+ 31, // Range #179: [43232, 43263, Devanagari]
+ 45, // Range #180: [43264, 43309, Kayah_Li]
+ 0, // Range #181: [43311, 43311, Kayah_Li]
+ 35, // Range #182: [43312, 43347, Rejang]
+ 0, // Range #183: [43359, 43359, Rejang]
+ 28, // Range #184: [43360, 43388, Hangul]
+ 77, // Range #185: [43392, 43469, Javanese]
+ 15, // Range #186: [43472, 43487, Javanese]
+ 30, // Range #187: [43488, 43518, Myanmar]
+ 54, // Range #188: [43520, 43574, Cham]
+ 31, // Range #189: [43584, 43615, Cham]
+ 31, // Range #190: [43616, 43647, Myanmar]
+ 66, // Range #191: [43648, 43714, Tai_Viet]
+ 4, // Range #192: [43739, 43743, Tai_Viet]
+ 22, // Range #193: [43744, 43766, Meetei_Mayek]
+ 21, // Range #194: [43777, 43798, Ethiopic]
+ 14, // Range #195: [43808, 43822, Ethiopic]
+ 42, // Range #196: [43824, 43866, Latin]
+ 8, // Range #197: [43868, 43876, Latin]
+ 0, // Range #198: [43877, 43877, Greek]
+ 3, // Range #199: [43878, 43881, Latin]
+ 79, // Range #200: [43888, 43967, Cherokee]
+ 57, // Range #201: [43968, 44025, Meetei_Mayek]
+ 11171, // Range #202: [44032, 55203, Hangul]
+ 75, // Range #203: [55216, 55291, Hangul]
+ 473, // Range #204: [63744, 64217, Han]
+ 6, // Range #205: [64256, 64262, Latin]
+ 4, // Range #206: [64275, 64279, Armenian]
+ 50, // Range #207: [64285, 64335, Hebrew]
+ 113, // Range #208: [64336, 64449, Arabic]
+ 362, // Range #209: [64467, 64829, Arabic]
+ 119, // Range #210: [64848, 64967, Arabic]
+ 13, // Range #211: [65008, 65021, Arabic]
+ 1, // Range #212: [65070, 65071, Cyrillic]
+ 140, // Range #213: [65136, 65276, Arabic]
+ 25, // Range #214: [65313, 65338, Latin]
+ 25, // Range #215: [65345, 65370, Latin]
+ 9, // Range #216: [65382, 65391, Katakana]
+ 44, // Range #217: [65393, 65437, Katakana]
+ 60, // Range #218: [65440, 65500, Hangul]
+ 93, // Range #219: [65536, 65629, Linear_B]
+ 122, // Range #220: [65664, 65786, Linear_B]
+ 78, // Range #221: [65856, 65934, Greek]
+ 0, // Range #222: [65952, 65952, Greek]
+ 28, // Range #223: [66176, 66204, Lycian]
+ 48, // Range #224: [66208, 66256, Carian]
+ 35, // Range #225: [66304, 66339, Old_Italic]
+ 2, // Range #226: [66349, 66351, Old_Italic]
+ 26, // Range #227: [66352, 66378, Gothic]
+ 42, // Range #228: [66384, 66426, Old_Permic]
+ 31, // Range #229: [66432, 66463, Ugaritic]
+ 53, // Range #230: [66464, 66517, Old_Persian]
+ 79, // Range #231: [66560, 66639, Deseret]
+ 47, // Range #232: [66640, 66687, Shavian]
+ 41, // Range #233: [66688, 66729, Osmanya]
+ 75, // Range #234: [66736, 66811, Osage]
+ 39, // Range #235: [66816, 66855, Elbasan]
+ 51, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 0, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 310, // Range #238: [67072, 67382, Linear_A]
+ 21, // Range #239: [67392, 67413, Linear_A]
+ 7, // Range #240: [67424, 67431, Linear_A]
+ 63, // Range #241: [67584, 67647, Cypriot]
+ 31, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 31, // Range #243: [67680, 67711, Palmyrene]
+ 30, // Range #244: [67712, 67742, Nabataean]
+ 8, // Range #245: [67751, 67759, Nabataean]
+ 21, // Range #246: [67808, 67829, Hatran]
+ 4, // Range #247: [67835, 67839, Hatran]
+ 31, // Range #248: [67840, 67871, Phoenician]
+ 25, // Range #249: [67872, 67897, Lydian]
+ 0, // Range #250: [67903, 67903, Lydian]
+ 31, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 95, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 6, // Range #253: [68096, 68102, Kharoshthi]
+ 60, // Range #254: [68108, 68168, Kharoshthi]
+ 8, // Range #255: [68176, 68184, Kharoshthi]
+ 31, // Range #256: [68192, 68223, Old_South_Arabian]
+ 31, // Range #257: [68224, 68255, Old_North_Arabian]
+ 54, // Range #258: [68288, 68342, Manichaean]
+ 63, // Range #259: [68352, 68415, Avestan]
+ 31, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 18, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 7, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 17, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 3, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 6, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 72, // Range #266: [68608, 68680, Old_Turkic]
+ 50, // Range #267: [68736, 68786, Old_Hungarian]
+ 50, // Range #268: [68800, 68850, Old_Hungarian]
+ 5, // Range #269: [68858, 68863, Old_Hungarian]
+ 39, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 9, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 30, // Range #272: [69216, 69246, Arabic]
+ 49, // Range #273: [69248, 69297, Yezidi]
+ 39, // Range #274: [69376, 69415, Old_Sogdian]
+ 41, // Range #275: [69424, 69465, Sogdian]
+ 27, // Range #276: [69552, 69579, Chorasmian]
+ 22, // Range #277: [69600, 69622, Elymaic]
+ 111, // Range #278: [69632, 69743, Brahmi]
+ 0, // Range #279: [69759, 69759, Brahmi]
+ 65, // Range #280: [69760, 69825, Kaithi]
+ 0, // Range #281: [69837, 69837, Kaithi]
+ 24, // Range #282: [69840, 69864, Sora_Sompeng]
+ 9, // Range #283: [69872, 69881, Sora_Sompeng]
+ 71, // Range #284: [69888, 69959, Chakma]
+ 38, // Range #285: [69968, 70006, Mahajani]
+ 95, // Range #286: [70016, 70111, Sharada]
+ 19, // Range #287: [70113, 70132, Sinhala]
+ 62, // Range #288: [70144, 70206, Khojki]
+ 41, // Range #289: [70272, 70313, Multani]
+ 58, // Range #290: [70320, 70378, Khudawadi]
+ 9, // Range #291: [70384, 70393, Khudawadi]
+ 57, // Range #292: [70400, 70457, Grantha]
+ 20, // Range #293: [70460, 70480, Grantha]
+ 0, // Range #294: [70487, 70487, Grantha]
+ 23, // Range #295: [70493, 70516, Grantha]
+ 97, // Range #296: [70656, 70753, Newa]
+ 71, // Range #297: [70784, 70855, Tirhuta]
+ 9, // Range #298: [70864, 70873, Tirhuta]
+ 93, // Range #299: [71040, 71133, Siddham]
+ 68, // Range #300: [71168, 71236, Modi]
+ 9, // Range #301: [71248, 71257, Modi]
+ 12, // Range #302: [71264, 71276, Mongolian]
+ 56, // Range #303: [71296, 71352, Takri]
+ 9, // Range #304: [71360, 71369, Takri]
+ 63, // Range #305: [71424, 71487, Ahom]
+ 59, // Range #306: [71680, 71739, Dogra]
+ 82, // Range #307: [71840, 71922, Warang_Citi]
+ 0, // Range #308: [71935, 71935, Warang_Citi]
+ 70, // Range #309: [71936, 72006, Dives_Akuru]
+ 9, // Range #310: [72016, 72025, Dives_Akuru]
+ 68, // Range #311: [72096, 72164, Nandinagari]
+ 71, // Range #312: [72192, 72263, Zanabazar_Square]
+ 82, // Range #313: [72272, 72354, Soyombo]
+ 56, // Range #314: [72384, 72440, Pau_Cin_Hau]
+ 69, // Range #315: [72704, 72773, Bhaiksuki]
+ 28, // Range #316: [72784, 72812, Bhaiksuki]
+ 70, // Range #317: [72816, 72886, Marchen]
+ 71, // Range #318: [72960, 73031, Masaram_Gondi]
+ 9, // Range #319: [73040, 73049, Masaram_Gondi]
+ 56, // Range #320: [73056, 73112, Gunjala_Gondi]
+ 9, // Range #321: [73120, 73129, Gunjala_Gondi]
+ 24, // Range #322: [73440, 73464, Makasar]
+ 0, // Range #323: [73648, 73648, Lisu]
+ 49, // Range #324: [73664, 73713, Tamil]
+ 0, // Range #325: [73727, 73727, Tamil]
+ 921, // Range #326: [73728, 74649, Cuneiform]
+ 116, // Range #327: [74752, 74868, Cuneiform]
+ 195, // Range #328: [74880, 75075, Cuneiform]
+ 1080, // Range #329: [77824, 78904, Egyptian_Hieroglyphs]
+ 582, // Range #330: [82944, 83526, Anatolian_Hieroglyphs]
+ 568, // Range #331: [92160, 92728, Bamum]
+ 47, // Range #332: [92736, 92783, Mro]
+ 37, // Range #333: [92880, 92917, Bassa_Vah]
+ 69, // Range #334: [92928, 92997, Pahawh_Hmong]
+ 39, // Range #335: [93008, 93047, Pahawh_Hmong]
+ 18, // Range #336: [93053, 93071, Pahawh_Hmong]
+ 90, // Range #337: [93760, 93850, Medefaidrin]
+ 135, // Range #338: [93952, 94087, Miao]
+ 16, // Range #339: [94095, 94111, Miao]
+ 0, // Range #340: [94176, 94176, Tangut]
+ 0, // Range #341: [94177, 94177, Nushu]
+ 0, // Range #342: [94180, 94180, Khitan_Small_Script]
+ 1, // Range #343: [94192, 94193, Han]
+ 6135, // Range #344: [94208, 100343, Tangut]
+ 767, // Range #345: [100352, 101119, Tangut]
+ 469, // Range #346: [101120, 101589, Khitan_Small_Script]
+ 8, // Range #347: [101632, 101640, Tangut]
+ 0, // Range #348: [110592, 110592, Katakana]
+ 285, // Range #349: [110593, 110878, Hiragana]
+ 2, // Range #350: [110928, 110930, Hiragana]
+ 3, // Range #351: [110948, 110951, Katakana]
+ 395, // Range #352: [110960, 111355, Nushu]
+ 106, // Range #353: [113664, 113770, Duployan]
+ 24, // Range #354: [113776, 113800, Duployan]
+ 15, // Range #355: [113808, 113823, Duployan]
+ 69, // Range #356: [119296, 119365, Greek]
+ 651, // Range #357: [120832, 121483, SignWriting]
+ 20, // Range #358: [121499, 121519, SignWriting]
+ 42, // Range #359: [122880, 122922, Glagolitic]
+ 79, // Range #360: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 57, // Range #361: [123584, 123641, Wancho]
+ 0, // Range #362: [123647, 123647, Wancho]
+ 214, // Range #363: [124928, 125142, Mende_Kikakui]
+ 95, // Range #364: [125184, 125279, Adlam]
+ 59, // Range #365: [126464, 126523, Arabic]
+ 89, // Range #366: [126530, 126619, Arabic]
+ 26, // Range #367: [126625, 126651, Arabic]
+ 1, // Range #368: [126704, 126705, Arabic]
+ 0, // Range #369: [127488, 127488, Hiragana]
+ 42717, // Range #370: [131072, 173789, Han]
+ 4148, // Range #371: [173824, 177972, Han]
+ 5985, // Range #372: [177984, 183969, Han]
+ 7472, // Range #373: [183984, 191456, Han]
+ 541, // Range #374: [194560, 195101, Han]
+ 4938, // Range #375: [196608, 201546, Han]
+};
+
+const uint8 kRangeScript[] = {
+ 25, // Range #0: [65, 90, Latin]
+ 25, // Range #1: [97, 122, Latin]
+ 25, // Range #2: [170, 170, Latin]
+ 25, // Range #3: [186, 186, Latin]
+ 25, // Range #4: [192, 214, Latin]
+ 25, // Range #5: [216, 246, Latin]
+ 25, // Range #6: [248, 696, Latin]
+ 25, // Range #7: [736, 740, Latin]
+ 5, // Range #8: [746, 747, Bopomofo]
+ 14, // Range #9: [880, 883, Greek]
+ 14, // Range #10: [885, 893, Greek]
+ 14, // Range #11: [895, 900, Greek]
+ 14, // Range #12: [902, 902, Greek]
+ 14, // Range #13: [904, 993, Greek]
+ 7, // Range #14: [994, 1007, Coptic]
+ 14, // Range #15: [1008, 1023, Greek]
+ 8, // Range #16: [1024, 1156, Cyrillic]
+ 8, // Range #17: [1159, 1327, Cyrillic]
+ 3, // Range #18: [1329, 1423, Armenian]
+ 19, // Range #19: [1425, 1479, Hebrew]
+ 19, // Range #20: [1488, 1524, Hebrew]
+ 2, // Range #21: [1536, 1540, Arabic]
+ 2, // Range #22: [1542, 1547, Arabic]
+ 2, // Range #23: [1549, 1562, Arabic]
+ 2, // Range #24: [1564, 1566, Arabic]
+ 2, // Range #25: [1568, 1599, Arabic]
+ 2, // Range #26: [1601, 1610, Arabic]
+ 2, // Range #27: [1622, 1647, Arabic]
+ 2, // Range #28: [1649, 1756, Arabic]
+ 2, // Range #29: [1758, 1791, Arabic]
+ 34, // Range #30: [1792, 1871, Syriac]
+ 2, // Range #31: [1872, 1919, Arabic]
+ 37, // Range #32: [1920, 1969, Thaana]
+ 87, // Range #33: [1984, 2047, Nko]
+ 126, // Range #34: [2048, 2110, Samaritan]
+ 84, // Range #35: [2112, 2142, Mandaic]
+ 34, // Range #36: [2144, 2154, Syriac]
+ 2, // Range #37: [2208, 2247, Arabic]
+ 2, // Range #38: [2259, 2273, Arabic]
+ 2, // Range #39: [2275, 2303, Arabic]
+ 10, // Range #40: [2304, 2384, Devanagari]
+ 10, // Range #41: [2389, 2403, Devanagari]
+ 10, // Range #42: [2406, 2431, Devanagari]
+ 4, // Range #43: [2432, 2510, Bengali]
+ 4, // Range #44: [2519, 2558, Bengali]
+ 16, // Range #45: [2561, 2641, Gurmukhi]
+ 16, // Range #46: [2649, 2654, Gurmukhi]
+ 16, // Range #47: [2662, 2678, Gurmukhi]
+ 15, // Range #48: [2689, 2768, Gujarati]
+ 15, // Range #49: [2784, 2801, Gujarati]
+ 15, // Range #50: [2809, 2815, Gujarati]
+ 31, // Range #51: [2817, 2893, Oriya]
+ 31, // Range #52: [2901, 2935, Oriya]
+ 35, // Range #53: [2946, 3024, Tamil]
+ 35, // Range #54: [3031, 3031, Tamil]
+ 35, // Range #55: [3046, 3066, Tamil]
+ 36, // Range #56: [3072, 3149, Telugu]
+ 36, // Range #57: [3157, 3162, Telugu]
+ 36, // Range #58: [3168, 3183, Telugu]
+ 36, // Range #59: [3191, 3199, Telugu]
+ 21, // Range #60: [3200, 3277, Kannada]
+ 21, // Range #61: [3285, 3286, Kannada]
+ 21, // Range #62: [3294, 3314, Kannada]
+ 26, // Range #63: [3328, 3455, Malayalam]
+ 33, // Range #64: [3457, 3551, Sinhala]
+ 33, // Range #65: [3558, 3572, Sinhala]
+ 38, // Range #66: [3585, 3642, Thai]
+ 38, // Range #67: [3648, 3675, Thai]
+ 24, // Range #68: [3713, 3807, Lao]
+ 39, // Range #69: [3840, 4052, Tibetan]
+ 39, // Range #70: [4057, 4058, Tibetan]
+ 28, // Range #71: [4096, 4255, Myanmar]
+ 12, // Range #72: [4256, 4295, Georgian]
+ 12, // Range #73: [4301, 4346, Georgian]
+ 12, // Range #74: [4348, 4351, Georgian]
+ 18, // Range #75: [4352, 4607, Hangul]
+ 11, // Range #76: [4608, 5017, Ethiopic]
+ 6, // Range #77: [5024, 5117, Cherokee]
+ 40, // Range #78: [5120, 5759, Canadian_Aboriginal]
+ 29, // Range #79: [5760, 5788, Ogham]
+ 32, // Range #80: [5792, 5866, Runic]
+ 32, // Range #81: [5870, 5880, Runic]
+ 42, // Range #82: [5888, 5908, Tagalog]
+ 43, // Range #83: [5920, 5940, Hanunoo]
+ 44, // Range #84: [5952, 5971, Buhid]
+ 45, // Range #85: [5984, 6003, Tagbanwa]
+ 23, // Range #86: [6016, 6121, Khmer]
+ 23, // Range #87: [6128, 6137, Khmer]
+ 27, // Range #88: [6144, 6145, Mongolian]
+ 27, // Range #89: [6148, 6148, Mongolian]
+ 27, // Range #90: [6150, 6169, Mongolian]
+ 27, // Range #91: [6176, 6264, Mongolian]
+ 27, // Range #92: [6272, 6314, Mongolian]
+ 40, // Range #93: [6320, 6389, Canadian_Aboriginal]
+ 48, // Range #94: [6400, 6479, Limbu]
+ 52, // Range #95: [6480, 6516, Tai_Le]
+ 59, // Range #96: [6528, 6601, New_Tai_Lue]
+ 59, // Range #97: [6608, 6623, New_Tai_Lue]
+ 23, // Range #98: [6624, 6655, Khmer]
+ 55, // Range #99: [6656, 6687, Buginese]
+ 106, // Range #100: [6688, 6793, Tai_Tham]
+ 106, // Range #101: [6800, 6809, Tai_Tham]
+ 106, // Range #102: [6816, 6829, Tai_Tham]
+ 62, // Range #103: [6912, 7036, Balinese]
+ 113, // Range #104: [7040, 7103, Sundanese]
+ 63, // Range #105: [7104, 7155, Batak]
+ 63, // Range #106: [7164, 7167, Batak]
+ 82, // Range #107: [7168, 7247, Lepcha]
+ 109, // Range #108: [7248, 7295, Ol_Chiki]
+ 8, // Range #109: [7296, 7304, Cyrillic]
+ 12, // Range #110: [7312, 7359, Georgian]
+ 113, // Range #111: [7360, 7367, Sundanese]
+ 25, // Range #112: [7424, 7461, Latin]
+ 14, // Range #113: [7462, 7466, Greek]
+ 8, // Range #114: [7467, 7467, Cyrillic]
+ 25, // Range #115: [7468, 7516, Latin]
+ 14, // Range #116: [7517, 7521, Greek]
+ 25, // Range #117: [7522, 7525, Latin]
+ 14, // Range #118: [7526, 7530, Greek]
+ 25, // Range #119: [7531, 7543, Latin]
+ 8, // Range #120: [7544, 7544, Cyrillic]
+ 25, // Range #121: [7545, 7614, Latin]
+ 14, // Range #122: [7615, 7615, Greek]
+ 25, // Range #123: [7680, 7935, Latin]
+ 14, // Range #124: [7936, 8190, Greek]
+ 25, // Range #125: [8305, 8305, Latin]
+ 25, // Range #126: [8319, 8319, Latin]
+ 25, // Range #127: [8336, 8348, Latin]
+ 14, // Range #128: [8486, 8486, Greek]
+ 25, // Range #129: [8490, 8491, Latin]
+ 25, // Range #130: [8498, 8498, Latin]
+ 25, // Range #131: [8526, 8526, Latin]
+ 25, // Range #132: [8544, 8584, Latin]
+ 46, // Range #133: [10240, 10495, Braille]
+ 56, // Range #134: [11264, 11358, Glagolitic]
+ 25, // Range #135: [11360, 11391, Latin]
+ 7, // Range #136: [11392, 11507, Coptic]
+ 7, // Range #137: [11513, 11519, Coptic]
+ 12, // Range #138: [11520, 11559, Georgian]
+ 12, // Range #139: [11565, 11565, Georgian]
+ 60, // Range #140: [11568, 11623, Tifinagh]
+ 60, // Range #141: [11631, 11632, Tifinagh]
+ 60, // Range #142: [11647, 11647, Tifinagh]
+ 11, // Range #143: [11648, 11670, Ethiopic]
+ 11, // Range #144: [11680, 11742, Ethiopic]
+ 8, // Range #145: [11744, 11775, Cyrillic]
+ 17, // Range #146: [11904, 12019, Han]
+ 17, // Range #147: [12032, 12245, Han]
+ 17, // Range #148: [12293, 12293, Han]
+ 17, // Range #149: [12295, 12295, Han]
+ 17, // Range #150: [12321, 12329, Han]
+ 18, // Range #151: [12334, 12335, Hangul]
+ 17, // Range #152: [12344, 12347, Han]
+ 20, // Range #153: [12353, 12438, Hiragana]
+ 20, // Range #154: [12445, 12447, Hiragana]
+ 22, // Range #155: [12449, 12538, Katakana]
+ 22, // Range #156: [12541, 12543, Katakana]
+ 5, // Range #157: [12549, 12591, Bopomofo]
+ 18, // Range #158: [12593, 12686, Hangul]
+ 5, // Range #159: [12704, 12735, Bopomofo]
+ 22, // Range #160: [12784, 12799, Katakana]
+ 18, // Range #161: [12800, 12830, Hangul]
+ 18, // Range #162: [12896, 12926, Hangul]
+ 22, // Range #163: [13008, 13054, Katakana]
+ 22, // Range #164: [13056, 13143, Katakana]
+ 17, // Range #165: [13312, 19903, Han]
+ 17, // Range #166: [19968, 40956, Han]
+ 41, // Range #167: [40960, 42182, Yi]
+ 131, // Range #168: [42192, 42239, Lisu]
+ 99, // Range #169: [42240, 42539, Vai]
+ 8, // Range #170: [42560, 42655, Cyrillic]
+ 130, // Range #171: [42656, 42743, Bamum]
+ 25, // Range #172: [42786, 42887, Latin]
+ 25, // Range #173: [42891, 42954, Latin]
+ 25, // Range #174: [42997, 43007, Latin]
+ 58, // Range #175: [43008, 43052, Syloti_Nagri]
+ 90, // Range #176: [43072, 43127, Phags_Pa]
+ 111, // Range #177: [43136, 43205, Saurashtra]
+ 111, // Range #178: [43214, 43225, Saurashtra]
+ 10, // Range #179: [43232, 43263, Devanagari]
+ 79, // Range #180: [43264, 43309, Kayah_Li]
+ 79, // Range #181: [43311, 43311, Kayah_Li]
+ 110, // Range #182: [43312, 43347, Rejang]
+ 110, // Range #183: [43359, 43359, Rejang]
+ 18, // Range #184: [43360, 43388, Hangul]
+ 78, // Range #185: [43392, 43469, Javanese]
+ 78, // Range #186: [43472, 43487, Javanese]
+ 28, // Range #187: [43488, 43518, Myanmar]
+ 66, // Range #188: [43520, 43574, Cham]
+ 66, // Range #189: [43584, 43615, Cham]
+ 28, // Range #190: [43616, 43647, Myanmar]
+ 127, // Range #191: [43648, 43714, Tai_Viet]
+ 127, // Range #192: [43739, 43743, Tai_Viet]
+ 115, // Range #193: [43744, 43766, Meetei_Mayek]
+ 11, // Range #194: [43777, 43798, Ethiopic]
+ 11, // Range #195: [43808, 43822, Ethiopic]
+ 25, // Range #196: [43824, 43866, Latin]
+ 25, // Range #197: [43868, 43876, Latin]
+ 14, // Range #198: [43877, 43877, Greek]
+ 25, // Range #199: [43878, 43881, Latin]
+ 6, // Range #200: [43888, 43967, Cherokee]
+ 115, // Range #201: [43968, 44025, Meetei_Mayek]
+ 18, // Range #202: [44032, 55203, Hangul]
+ 18, // Range #203: [55216, 55291, Hangul]
+ 17, // Range #204: [63744, 64217, Han]
+ 25, // Range #205: [64256, 64262, Latin]
+ 3, // Range #206: [64275, 64279, Armenian]
+ 19, // Range #207: [64285, 64335, Hebrew]
+ 2, // Range #208: [64336, 64449, Arabic]
+ 2, // Range #209: [64467, 64829, Arabic]
+ 2, // Range #210: [64848, 64967, Arabic]
+ 2, // Range #211: [65008, 65021, Arabic]
+ 8, // Range #212: [65070, 65071, Cyrillic]
+ 2, // Range #213: [65136, 65276, Arabic]
+ 25, // Range #214: [65313, 65338, Latin]
+ 25, // Range #215: [65345, 65370, Latin]
+ 22, // Range #216: [65382, 65391, Katakana]
+ 22, // Range #217: [65393, 65437, Katakana]
+ 18, // Range #218: [65440, 65500, Hangul]
+ 49, // Range #219: [65536, 65629, Linear_B]
+ 49, // Range #220: [65664, 65786, Linear_B]
+ 14, // Range #221: [65856, 65934, Greek]
+ 14, // Range #222: [65952, 65952, Greek]
+ 107, // Range #223: [66176, 66204, Lycian]
+ 104, // Range #224: [66208, 66256, Carian]
+ 30, // Range #225: [66304, 66339, Old_Italic]
+ 30, // Range #226: [66349, 66351, Old_Italic]
+ 13, // Range #227: [66352, 66378, Gothic]
+ 89, // Range #228: [66384, 66426, Old_Permic]
+ 53, // Range #229: [66432, 66463, Ugaritic]
+ 61, // Range #230: [66464, 66517, Old_Persian]
+ 9, // Range #231: [66560, 66639, Deseret]
+ 51, // Range #232: [66640, 66687, Shavian]
+ 50, // Range #233: [66688, 66729, Osmanya]
+ 171, // Range #234: [66736, 66811, Osage]
+ 136, // Range #235: [66816, 66855, Elbasan]
+ 159, // Range #236: [66864, 66915, Caucasian_Albanian]
+ 159, // Range #237: [66927, 66927, Caucasian_Albanian]
+ 83, // Range #238: [67072, 67382, Linear_A]
+ 83, // Range #239: [67392, 67413, Linear_A]
+ 83, // Range #240: [67424, 67431, Linear_A]
+ 47, // Range #241: [67584, 67647, Cypriot]
+ 116, // Range #242: [67648, 67679, Imperial_Aramaic]
+ 144, // Range #243: [67680, 67711, Palmyrene]
+ 143, // Range #244: [67712, 67742, Nabataean]
+ 143, // Range #245: [67751, 67759, Nabataean]
+ 162, // Range #246: [67808, 67829, Hatran]
+ 162, // Range #247: [67835, 67839, Hatran]
+ 91, // Range #248: [67840, 67871, Phoenician]
+ 108, // Range #249: [67872, 67897, Lydian]
+ 108, // Range #250: [67903, 67903, Lydian]
+ 86, // Range #251: [67968, 67999, Meroitic_Hieroglyphs]
+ 141, // Range #252: [68000, 68095, Meroitic_Cursive]
+ 57, // Range #253: [68096, 68102, Kharoshthi]
+ 57, // Range #254: [68108, 68168, Kharoshthi]
+ 57, // Range #255: [68176, 68184, Kharoshthi]
+ 133, // Range #256: [68192, 68223, Old_South_Arabian]
+ 142, // Range #257: [68224, 68255, Old_North_Arabian]
+ 121, // Range #258: [68288, 68342, Manichaean]
+ 117, // Range #259: [68352, 68415, Avestan]
+ 125, // Range #260: [68416, 68447, Inscriptional_Parthian]
+ 122, // Range #261: [68448, 68466, Inscriptional_Pahlavi]
+ 122, // Range #262: [68472, 68479, Inscriptional_Pahlavi]
+ 123, // Range #263: [68480, 68497, Psalter_Pahlavi]
+ 123, // Range #264: [68505, 68508, Psalter_Pahlavi]
+ 123, // Range #265: [68521, 68527, Psalter_Pahlavi]
+ 88, // Range #266: [68608, 68680, Old_Turkic]
+ 76, // Range #267: [68736, 68786, Old_Hungarian]
+ 76, // Range #268: [68800, 68850, Old_Hungarian]
+ 76, // Range #269: [68858, 68863, Old_Hungarian]
+ 182, // Range #270: [68864, 68903, Hanifi_Rohingya]
+ 182, // Range #271: [68912, 68921, Hanifi_Rohingya]
+ 2, // Range #272: [69216, 69246, Arabic]
+ 192, // Range #273: [69248, 69297, Yezidi]
+ 184, // Range #274: [69376, 69415, Old_Sogdian]
+ 183, // Range #275: [69424, 69465, Sogdian]
+ 189, // Range #276: [69552, 69579, Chorasmian]
+ 185, // Range #277: [69600, 69622, Elymaic]
+ 65, // Range #278: [69632, 69743, Brahmi]
+ 65, // Range #279: [69759, 69759, Brahmi]
+ 120, // Range #280: [69760, 69825, Kaithi]
+ 120, // Range #281: [69837, 69837, Kaithi]
+ 152, // Range #282: [69840, 69864, Sora_Sompeng]
+ 152, // Range #283: [69872, 69881, Sora_Sompeng]
+ 118, // Range #284: [69888, 69959, Chakma]
+ 160, // Range #285: [69968, 70006, Mahajani]
+ 151, // Range #286: [70016, 70111, Sharada]
+ 33, // Range #287: [70113, 70132, Sinhala]
+ 157, // Range #288: [70144, 70206, Khojki]
+ 164, // Range #289: [70272, 70313, Multani]
+ 145, // Range #290: [70320, 70378, Khudawadi]
+ 145, // Range #291: [70384, 70393, Khudawadi]
+ 137, // Range #292: [70400, 70457, Grantha]
+ 137, // Range #293: [70460, 70480, Grantha]
+ 137, // Range #294: [70487, 70487, Grantha]
+ 137, // Range #295: [70493, 70516, Grantha]
+ 170, // Range #296: [70656, 70753, Newa]
+ 158, // Range #297: [70784, 70855, Tirhuta]
+ 158, // Range #298: [70864, 70873, Tirhuta]
+ 166, // Range #299: [71040, 71133, Siddham]
+ 163, // Range #300: [71168, 71236, Modi]
+ 163, // Range #301: [71248, 71257, Modi]
+ 27, // Range #302: [71264, 71276, Mongolian]
+ 153, // Range #303: [71296, 71352, Takri]
+ 153, // Range #304: [71360, 71369, Takri]
+ 161, // Range #305: [71424, 71487, Ahom]
+ 178, // Range #306: [71680, 71739, Dogra]
+ 146, // Range #307: [71840, 71922, Warang_Citi]
+ 146, // Range #308: [71935, 71935, Warang_Citi]
+ 190, // Range #309: [71936, 72006, Dives_Akuru]
+ 190, // Range #310: [72016, 72025, Dives_Akuru]
+ 187, // Range #311: [72096, 72164, Nandinagari]
+ 177, // Range #312: [72192, 72263, Zanabazar_Square]
+ 176, // Range #313: [72272, 72354, Soyombo]
+ 165, // Range #314: [72384, 72440, Pau_Cin_Hau]
+ 168, // Range #315: [72704, 72773, Bhaiksuki]
+ 168, // Range #316: [72784, 72812, Bhaiksuki]
+ 169, // Range #317: [72816, 72886, Marchen]
+ 175, // Range #318: [72960, 73031, Masaram_Gondi]
+ 175, // Range #319: [73040, 73049, Masaram_Gondi]
+ 179, // Range #320: [73056, 73112, Gunjala_Gondi]
+ 179, // Range #321: [73120, 73129, Gunjala_Gondi]
+ 180, // Range #322: [73440, 73464, Makasar]
+ 131, // Range #323: [73648, 73648, Lisu]
+ 35, // Range #324: [73664, 73713, Tamil]
+ 35, // Range #325: [73727, 73727, Tamil]
+ 101, // Range #326: [73728, 74649, Cuneiform]
+ 101, // Range #327: [74752, 74868, Cuneiform]
+ 101, // Range #328: [74880, 75075, Cuneiform]
+ 71, // Range #329: [77824, 78904, Egyptian_Hieroglyphs]
+ 156, // Range #330: [82944, 83526, Anatolian_Hieroglyphs]
+ 130, // Range #331: [92160, 92728, Bamum]
+ 149, // Range #332: [92736, 92783, Mro]
+ 134, // Range #333: [92880, 92917, Bassa_Vah]
+ 75, // Range #334: [92928, 92997, Pahawh_Hmong]
+ 75, // Range #335: [93008, 93047, Pahawh_Hmong]
+ 75, // Range #336: [93053, 93071, Pahawh_Hmong]
+ 181, // Range #337: [93760, 93850, Medefaidrin]
+ 92, // Range #338: [93952, 94087, Miao]
+ 92, // Range #339: [94095, 94111, Miao]
+ 154, // Range #340: [94176, 94176, Tangut]
+ 150, // Range #341: [94177, 94177, Nushu]
+ 191, // Range #342: [94180, 94180, Khitan_Small_Script]
+ 17, // Range #343: [94192, 94193, Han]
+ 154, // Range #344: [94208, 100343, Tangut]
+ 154, // Range #345: [100352, 101119, Tangut]
+ 191, // Range #346: [101120, 101589, Khitan_Small_Script]
+ 154, // Range #347: [101632, 101640, Tangut]
+ 22, // Range #348: [110592, 110592, Katakana]
+ 20, // Range #349: [110593, 110878, Hiragana]
+ 20, // Range #350: [110928, 110930, Hiragana]
+ 22, // Range #351: [110948, 110951, Katakana]
+ 150, // Range #352: [110960, 111355, Nushu]
+ 135, // Range #353: [113664, 113770, Duployan]
+ 135, // Range #354: [113776, 113800, Duployan]
+ 135, // Range #355: [113808, 113823, Duployan]
+ 14, // Range #356: [119296, 119365, Greek]
+ 112, // Range #357: [120832, 121483, SignWriting]
+ 112, // Range #358: [121499, 121519, SignWriting]
+ 56, // Range #359: [122880, 122922, Glagolitic]
+ 186, // Range #360: [123136, 123215, Nyiakeng_Puachue_Hmong]
+ 188, // Range #361: [123584, 123641, Wancho]
+ 188, // Range #362: [123647, 123647, Wancho]
+ 140, // Range #363: [124928, 125142, Mende_Kikakui]
+ 167, // Range #364: [125184, 125279, Adlam]
+ 2, // Range #365: [126464, 126523, Arabic]
+ 2, // Range #366: [126530, 126619, Arabic]
+ 2, // Range #367: [126625, 126651, Arabic]
+ 2, // Range #368: [126704, 126705, Arabic]
+ 20, // Range #369: [127488, 127488, Hiragana]
+ 17, // Range #370: [131072, 173789, Han]
+ 17, // Range #371: [173824, 177972, Han]
+ 17, // Range #372: [177984, 183969, Han]
+ 17, // Range #373: [183984, 191456, Han]
+ 17, // Range #374: [194560, 195101, Han]
+ 17, // Range #375: [196608, 201546, Han]
+};
+
+const uint8 kMaxScript = 192;
+
+} // namespace approx_script_internal
+} // namespace mobile
+} // namespace nlp_saft
diff --git a/lang_id/script/approx-script-data.h b/native/lang_id/script/approx-script-data.h
similarity index 100%
rename from lang_id/script/approx-script-data.h
rename to native/lang_id/script/approx-script-data.h
diff --git a/lang_id/script/approx-script.cc b/native/lang_id/script/approx-script.cc
similarity index 100%
rename from lang_id/script/approx-script.cc
rename to native/lang_id/script/approx-script.cc
diff --git a/lang_id/script/approx-script.h b/native/lang_id/script/approx-script.h
similarity index 100%
rename from lang_id/script/approx-script.h
rename to native/lang_id/script/approx-script.h
diff --git a/lang_id/script/script-detector.cc b/native/lang_id/script/script-detector.cc
similarity index 100%
rename from lang_id/script/script-detector.cc
rename to native/lang_id/script/script-detector.cc
diff --git a/lang_id/script/script-detector.h b/native/lang_id/script/script-detector.h
similarity index 100%
rename from lang_id/script/script-detector.h
rename to native/lang_id/script/script-detector.h
diff --git a/lang_id/script/tiny-script-detector.cc b/native/lang_id/script/tiny-script-detector.cc
similarity index 100%
rename from lang_id/script/tiny-script-detector.cc
rename to native/lang_id/script/tiny-script-detector.cc
diff --git a/lang_id/script/tiny-script-detector.h b/native/lang_id/script/tiny-script-detector.h
similarity index 100%
rename from lang_id/script/tiny-script-detector.h
rename to native/lang_id/script/tiny-script-detector.h
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
new file mode 100755
index 0000000..d4b0ced
--- /dev/null
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/models/actions_suggestions.universal.model b/native/models/actions_suggestions.universal.model
new file mode 100755
index 0000000..2ee546c
--- /dev/null
+++ b/native/models/actions_suggestions.universal.model
Binary files differ
diff --git a/native/models/lang_id.model b/native/models/lang_id.model
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/native/models/lang_id.model
Binary files differ
diff --git a/native/models/textclassifier.ar.model b/native/models/textclassifier.ar.model
new file mode 100755
index 0000000..dbd685b
--- /dev/null
+++ b/native/models/textclassifier.ar.model
Binary files differ
diff --git a/native/models/textclassifier.en.model b/native/models/textclassifier.en.model
new file mode 100755
index 0000000..c930fe6
--- /dev/null
+++ b/native/models/textclassifier.en.model
Binary files differ
diff --git a/native/models/textclassifier.es.model b/native/models/textclassifier.es.model
new file mode 100755
index 0000000..26e3908
--- /dev/null
+++ b/native/models/textclassifier.es.model
Binary files differ
diff --git a/native/models/textclassifier.fr.model b/native/models/textclassifier.fr.model
new file mode 100755
index 0000000..9746ec9
--- /dev/null
+++ b/native/models/textclassifier.fr.model
Binary files differ
diff --git a/native/models/textclassifier.it.model b/native/models/textclassifier.it.model
new file mode 100755
index 0000000..1ce898c
--- /dev/null
+++ b/native/models/textclassifier.it.model
Binary files differ
diff --git a/native/models/textclassifier.ja.model b/native/models/textclassifier.ja.model
new file mode 100755
index 0000000..bc61400
--- /dev/null
+++ b/native/models/textclassifier.ja.model
Binary files differ
diff --git a/native/models/textclassifier.ko.model b/native/models/textclassifier.ko.model
new file mode 100755
index 0000000..59a9cde
--- /dev/null
+++ b/native/models/textclassifier.ko.model
Binary files differ
diff --git a/native/models/textclassifier.nl.model b/native/models/textclassifier.nl.model
new file mode 100755
index 0000000..aa95ca4
--- /dev/null
+++ b/native/models/textclassifier.nl.model
Binary files differ
diff --git a/native/models/textclassifier.pl.model b/native/models/textclassifier.pl.model
new file mode 100755
index 0000000..10e36e1
--- /dev/null
+++ b/native/models/textclassifier.pl.model
Binary files differ
diff --git a/native/models/textclassifier.pt.model b/native/models/textclassifier.pt.model
new file mode 100755
index 0000000..c76e430
--- /dev/null
+++ b/native/models/textclassifier.pt.model
Binary files differ
diff --git a/native/models/textclassifier.ru.model b/native/models/textclassifier.ru.model
new file mode 100755
index 0000000..b9a3ffd
--- /dev/null
+++ b/native/models/textclassifier.ru.model
Binary files differ
diff --git a/native/models/textclassifier.th.model b/native/models/textclassifier.th.model
new file mode 100755
index 0000000..a67237a
--- /dev/null
+++ b/native/models/textclassifier.th.model
Binary files differ
diff --git a/native/models/textclassifier.tr.model b/native/models/textclassifier.tr.model
new file mode 100755
index 0000000..e3cfd68
--- /dev/null
+++ b/native/models/textclassifier.tr.model
Binary files differ
diff --git a/native/models/textclassifier.universal.model b/native/models/textclassifier.universal.model
new file mode 100755
index 0000000..7f7476c
--- /dev/null
+++ b/native/models/textclassifier.universal.model
Binary files differ
diff --git a/native/models/textclassifier.zh.model b/native/models/textclassifier.zh.model
new file mode 100755
index 0000000..fe11975
--- /dev/null
+++ b/native/models/textclassifier.zh.model
Binary files differ
diff --git a/models/update.sh b/native/models/update.sh
similarity index 100%
rename from models/update.sh
rename to native/models/update.sh
diff --git a/util/hash/hash.cc b/native/util/hash/hash.cc
similarity index 100%
rename from util/hash/hash.cc
rename to native/util/hash/hash.cc
diff --git a/util/hash/hash.h b/native/util/hash/hash.h
similarity index 100%
rename from util/hash/hash.h
rename to native/util/hash/hash.h
diff --git a/native/utils/base/arena.cc b/native/utils/base/arena.cc
new file mode 100644
index 0000000..fcaed8e
--- /dev/null
+++ b/native/utils/base/arena.cc
@@ -0,0 +1,513 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This approach to arenas overcomes many of the limitations described
+// in the "Specialized allocators" section of
+// http://www.pdos.lcs.mit.edu/~dm/c++-new.html
+//
+// A somewhat similar approach to Gladiator, but for heap-detection, was
+// suggested by Ron van der Wal and Scott Meyers at
+// http://www.aristeia.com/BookErrata/M27Comments_frames.html
+
+#include "utils/base/arena.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+
+namespace libtextclassifier3 {
+
+static void *aligned_malloc(size_t size, int minimum_alignment) {
+ void *ptr = nullptr;
+ // posix_memalign requires that the requested alignment be at least
+ // sizeof(void*). In this case, fall back on malloc which should return memory
+ // aligned to at least the size of a pointer.
+ const int required_alignment = sizeof(void*);
+ if (minimum_alignment < required_alignment)
+ return malloc(size);
+ if (posix_memalign(&ptr, static_cast<size_t>(minimum_alignment), size) != 0)
+ return nullptr;
+ else
+ return ptr;
+}
+
+// The value here doesn't matter until page_aligned_ is supported.
+static const int kPageSize = 8192; // should be getpagesize()
+
+// We used to only keep track of how much space has been allocated in
+// debug mode. Now we track this for optimized builds, as well. If you
+// want to play with the old scheme to see if this helps performance,
+// change this TC3_ARENASET() macro to a NOP. However, NOTE: some
+// applications of arenas depend on this space information (exported
+// via bytes_allocated()).
+#define TC3_ARENASET(x) (x)
+
+namespace {
+
+#ifdef __cpp_aligned_new
+
+char* AllocateBytes(size_t size) {
+ return static_cast<char*>(::operator new(size));
+}
+
+// REQUIRES: alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__
+//
+// For alignments <=__STDCPP_DEFAULT_NEW_ALIGNMENT__, AllocateBytes() will
+// provide the correct alignment.
+char* AllocateAlignedBytes(size_t size, size_t alignment) {
+ TC3_CHECK_GT(alignment, __STDCPP_DEFAULT_NEW_ALIGNMENT__);
+ return static_cast<char*>(::operator new(size, std::align_val_t(alignment)));
+}
+
+void DeallocateBytes(void* ptr, size_t size, size_t alignment) {
+ if (alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
+#ifdef __cpp_sized_deallocation
+ ::operator delete(ptr, size, std::align_val_t(alignment));
+#else // !__cpp_sized_deallocation
+ ::operator delete(ptr, std::align_val_t(alignment));
+#endif // !__cpp_sized_deallocation
+ } else {
+#ifdef __cpp_sized_deallocation
+ ::operator delete(ptr, size);
+#else // !__cpp_sized_deallocation
+ ::operator delete(ptr);
+#endif // !__cpp_sized_deallocation
+ }
+}
+
+#else // !__cpp_aligned_new
+
+char* AllocateBytes(size_t size) {
+ return static_cast<char*>(malloc(size));
+}
+
+char* AllocateAlignedBytes(size_t size, size_t alignment) {
+ return static_cast<char*>(aligned_malloc(size, alignment));
+}
+
+void DeallocateBytes(void* ptr, size_t size, size_t alignment) {
+ free(ptr);
+}
+
+#endif // !__cpp_aligned_new
+
+} // namespace
+
+const int BaseArena::kDefaultAlignment;
+
+// ----------------------------------------------------------------------
+// BaseArena::BaseArena()
+// BaseArena::~BaseArena()
+// Destroying the arena automatically calls Reset()
+// ----------------------------------------------------------------------
+
+BaseArena::BaseArena(char* first, const size_t orig_block_size,
+ bool align_to_page)
+ : remaining_(0),
+ block_size_(orig_block_size),
+ freestart_(nullptr), // set for real in Reset()
+ last_alloc_(nullptr),
+ overflow_blocks_(nullptr),
+ first_block_externally_owned_(first != nullptr),
+ page_aligned_(align_to_page),
+ blocks_alloced_(1) {
+ // Trivial check that aligned objects can actually be allocated.
+ TC3_CHECK_GT(block_size_, kDefaultAlignment)
+ << "orig_block_size = " << orig_block_size;
+ if (page_aligned_) {
+ // kPageSize must be power of 2, so make sure of this.
+ TC3_CHECK(kPageSize > 0 && 0 == (kPageSize & (kPageSize - 1)))
+ << "kPageSize[ " << kPageSize << "] is not "
+ << "correctly initialized: not a power of 2.";
+ }
+
+ if (first) {
+ TC3_CHECK(!page_aligned_ ||
+ (reinterpret_cast<uintptr_t>(first) & (kPageSize - 1)) == 0);
+ first_blocks_[0].mem = first;
+ first_blocks_[0].size = orig_block_size;
+ } else {
+ if (page_aligned_) {
+ // Make sure the blocksize is page multiple, as we need to end on a page
+ // boundary.
+ TC3_CHECK_EQ(block_size_ & (kPageSize - 1), 0) << "block_size is not a"
+ << "multiple of kPageSize";
+ first_blocks_[0].mem = AllocateAlignedBytes(block_size_, kPageSize);
+ first_blocks_[0].alignment = kPageSize;
+ TC3_CHECK(nullptr != first_blocks_[0].mem);
+ } else {
+ first_blocks_[0].mem = AllocateBytes(block_size_);
+ first_blocks_[0].alignment = 0;
+ }
+ first_blocks_[0].size = block_size_;
+ }
+
+ Reset();
+}
+
+BaseArena::~BaseArena() {
+ FreeBlocks();
+ assert(overflow_blocks_ == nullptr); // FreeBlocks() should do that
+#ifdef ADDRESS_SANITIZER
+ if (first_block_externally_owned_) {
+ ASAN_UNPOISON_MEMORY_REGION(first_blocks_[0].mem, first_blocks_[0].size);
+ }
+#endif
+ // The first X blocks stay allocated always by default. Delete them now.
+ for (int i = first_block_externally_owned_ ? 1 : 0;
+ i < blocks_alloced_; ++i) {
+ DeallocateBytes(first_blocks_[i].mem, first_blocks_[i].size,
+ first_blocks_[i].alignment);
+ }
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::block_count()
+// Only reason this is in .cc file is because it involves STL.
+// ----------------------------------------------------------------------
+
+int BaseArena::block_count() const {
+ return (blocks_alloced_ +
+ (overflow_blocks_ ? static_cast<int>(overflow_blocks_->size()) : 0));
+}
+
+// Returns true iff it advances freestart_ to the first position
+// satisfying alignment without exhausting the current block.
+bool BaseArena::SatisfyAlignment(size_t alignment) {
+ const size_t overage =
+ reinterpret_cast<size_t>(freestart_) & (alignment - 1);
+ if (overage > 0) {
+ const size_t waste = alignment - overage;
+ if (waste >= remaining_) {
+ return false;
+ }
+ freestart_ += waste;
+ remaining_ -= waste;
+ }
+ TC3_DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1));
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::Reset()
+// Clears all the memory an arena is using.
+// ----------------------------------------------------------------------
+
+void BaseArena::Reset() {
+ FreeBlocks();
+ freestart_ = first_blocks_[0].mem;
+ remaining_ = first_blocks_[0].size;
+ last_alloc_ = nullptr;
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(freestart_, remaining_);
+#endif
+
+ TC3_ARENASET(status_.bytes_allocated_ = block_size_);
+
+ // There is no guarantee the first block is properly aligned, so
+ // enforce that now.
+ TC3_CHECK(SatisfyAlignment(kDefaultAlignment));
+
+ freestart_when_empty_ = freestart_;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::MakeNewBlock()
+// Our sbrk() equivalent. We always make blocks of the same size
+// (though GetMemory() can also make a new block for really big
+// data.
+// ----------------------------------------------------------------------
+
+void BaseArena::MakeNewBlock(const uint32 alignment) {
+ AllocatedBlock *block = AllocNewBlock(block_size_, alignment);
+ freestart_ = block->mem;
+ remaining_ = block->size;
+ TC3_CHECK(SatisfyAlignment(alignment));
+}
+
+// The following simple numeric routines also exist in util/math/mathutil.h
+// but we don't want to depend on that library.
+
+// Euclid's algorithm for Greatest Common Denominator.
+static uint32 GCD(uint32 x, uint32 y) {
+ while (y != 0) {
+ uint32 r = x % y;
+ x = y;
+ y = r;
+ }
+ return x;
+}
+
+static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
+ if (a > b) {
+ return (a / GCD(a, b)) * b;
+ } else if (a < b) {
+ return (b / GCD(b, a)) * a;
+ } else {
+ return a;
+ }
+}
+
+// -------------------------------------------------------------
+// BaseArena::AllocNewBlock()
+// Adds and returns an AllocatedBlock.
+// The returned AllocatedBlock* is valid until the next call
+// to AllocNewBlock or Reset. (i.e. anything that might
+// affect overflow_blocks_).
+// -------------------------------------------------------------
+
+BaseArena::AllocatedBlock* BaseArena::AllocNewBlock(const size_t block_size,
+ const uint32 alignment) {
+ AllocatedBlock *block;
+ // Find the next block.
+ if (blocks_alloced_ < TC3_ARRAYSIZE(first_blocks_)) {
+ // Use one of the pre-allocated blocks
+ block = &first_blocks_[blocks_alloced_++];
+ } else { // oops, out of space, move to the vector
+ if (overflow_blocks_ == nullptr)
+ overflow_blocks_ = new std::vector<AllocatedBlock>;
+ // Adds another block to the vector.
+ overflow_blocks_->resize(overflow_blocks_->size()+1);
+ // block points to the last block of the vector.
+ block = &overflow_blocks_->back();
+ }
+
+ // NOTE(tucker): this utility is made slightly more complex by
+ // not disallowing the case where alignment > block_size.
+ // Can we, without breaking existing code?
+
+ // If page_aligned_, then alignment must be a multiple of page size.
+ // Otherwise, must be a multiple of kDefaultAlignment, unless
+ // requested alignment is 1, in which case we don't care at all.
+ const uint32 adjusted_alignment =
+ page_aligned_ ? LeastCommonMultiple(kPageSize, alignment)
+ : (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1);
+ TC3_CHECK_LE(adjusted_alignment, 1 << 20)
+ << "Alignment on boundaries greater than 1MB not supported.";
+
+ // If block_size > alignment we force block_size to be a multiple
+ // of alignment; if block_size < alignment we make no adjustment, unless
+ // page_aligned_ is true, in which case it must be a multiple of
+ // kPageSize because SetProtect() will assume that.
+ size_t adjusted_block_size = block_size;
+#ifdef __STDCPP_DEFAULT_NEW_ALIGNMENT__
+ if (adjusted_alignment > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
+#else
+ if (adjusted_alignment > 1) {
+#endif
+ if (adjusted_block_size > adjusted_alignment) {
+ const uint32 excess = adjusted_block_size % adjusted_alignment;
+ adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
+ }
+ if (page_aligned_) {
+ size_t num_pages = ((adjusted_block_size - 1)/kPageSize) + 1;
+ adjusted_block_size = num_pages * kPageSize;
+ }
+ block->mem = AllocateAlignedBytes(adjusted_block_size, adjusted_alignment);
+ } else {
+ block->mem = AllocateBytes(adjusted_block_size);
+ }
+ block->size = adjusted_block_size;
+ block->alignment = adjusted_alignment;
+ TC3_CHECK(nullptr != block->mem)
+ << "block_size=" << block_size
+ << " adjusted_block_size=" << adjusted_block_size
+ << " alignment=" << alignment
+ << " adjusted_alignment=" << adjusted_alignment;
+
+ TC3_ARENASET(status_.bytes_allocated_ += adjusted_block_size);
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(block->mem, block->size);
+#endif
+ return block;
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::IndexToBlock()
+// Index encoding is as follows:
+// For blocks in the first_blocks_ array, we use index of the block in
+// the array.
+// For blocks in the overflow_blocks_ vector, we use the index of the
+// block in iverflow_blocks_, plus the size of the first_blocks_ array.
+// ----------------------------------------------------------------------
+
+const BaseArena::AllocatedBlock *BaseArena::IndexToBlock(int index) const {
+ if (index < TC3_ARRAYSIZE(first_blocks_)) {
+ return &first_blocks_[index];
+ }
+ TC3_CHECK(overflow_blocks_ != nullptr);
+ int index_in_overflow_blocks = index - TC3_ARRAYSIZE(first_blocks_);
+ TC3_CHECK_GE(index_in_overflow_blocks, 0);
+ TC3_CHECK_LT(static_cast<size_t>(index_in_overflow_blocks),
+ overflow_blocks_->size());
+ return &(*overflow_blocks_)[index_in_overflow_blocks];
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::GetMemoryFallback()
+// We take memory out of our pool, aligned on the byte boundary
+// requested. If we don't have space in our current pool, we
+// allocate a new block (wasting the remaining space in the
+// current block) and give you that. If your memory needs are
+// too big for a single block, we make a special your-memory-only
+// allocation -- this is equivalent to not using the arena at all.
+// ----------------------------------------------------------------------
+
+void* BaseArena::GetMemoryFallback(const size_t size, const int alignment) {
+ if (0 == size) {
+ return nullptr; // stl/stl_alloc.h says this is okay
+ }
+
+ // alignment must be a positive power of 2.
+ TC3_CHECK(alignment > 0 && 0 == (alignment & (alignment - 1)));
+
+ // If the object is more than a quarter of the block size, allocate
+ // it separately to avoid wasting too much space in leftover bytes.
+ if (block_size_ == 0 || size > block_size_/4) {
+ // Use a block separate from all other allocations; in particular
+ // we don't update last_alloc_ so you can't reclaim space on this block.
+ AllocatedBlock* b = AllocNewBlock(size, alignment);
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(b->mem, b->size);
+#endif
+ return b->mem;
+ }
+
+ // Enforce alignment on freestart_ then check for adequate space,
+ // which may require starting a new block.
+ if (!SatisfyAlignment(alignment) || size > remaining_) {
+ MakeNewBlock(alignment);
+ }
+ TC3_CHECK_LE(size, remaining_);
+
+ remaining_ -= size;
+ last_alloc_ = freestart_;
+ freestart_ += size;
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, size);
+#endif
+ return reinterpret_cast<void*>(last_alloc_);
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::ReturnMemoryFallback()
+// BaseArena::FreeBlocks()
+// Unlike GetMemory(), which does actual work, ReturnMemory() is a
+// no-op: we don't "free" memory until Reset() is called. We do
+// update some stats, though. Note we do no checking that the
+// pointer you pass in was actually allocated by us, or that it
+// was allocated for the size you say, so be careful here!
+// FreeBlocks() does the work for Reset(), actually freeing all
+// memory allocated in one fell swoop.
+// ----------------------------------------------------------------------
+
+void BaseArena::FreeBlocks() {
+ for ( int i = 1; i < blocks_alloced_; ++i ) { // keep first block alloced
+ DeallocateBytes(first_blocks_[i].mem, first_blocks_[i].size,
+ first_blocks_[i].alignment);
+ first_blocks_[i].mem = nullptr;
+ first_blocks_[i].size = 0;
+ }
+ blocks_alloced_ = 1;
+ if (overflow_blocks_ != nullptr) {
+ std::vector<AllocatedBlock>::iterator it;
+ for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) {
+ DeallocateBytes(it->mem, it->size, it->alignment);
+ }
+ delete overflow_blocks_; // These should be used very rarely
+ overflow_blocks_ = nullptr;
+ }
+}
+
+// ----------------------------------------------------------------------
+// BaseArena::AdjustLastAlloc()
+// If you realize you didn't want your last alloc to be for
+// the size you asked, after all, you can fix it by calling
+// this. We'll grow or shrink the last-alloc region if we
+// can (we can always shrink, but we might not be able to
+// grow if you want to grow too big.
+// RETURNS true if we successfully modified the last-alloc
+// region, false if the pointer you passed in wasn't actually
+// the last alloc or if you tried to grow bigger than we could.
+// ----------------------------------------------------------------------
+
+bool BaseArena::AdjustLastAlloc(void *last_alloc, const size_t newsize) {
+ // It's only legal to call this on the last thing you alloced.
+ if (last_alloc == nullptr || last_alloc != last_alloc_) return false;
+ // last_alloc_ should never point into a "big" block, w/ size >= block_size_
+ assert(freestart_ >= last_alloc_ && freestart_ <= last_alloc_ + block_size_);
+ assert(remaining_ >= 0); // should be: it's a size_t!
+ if (newsize > (freestart_ - last_alloc_) + remaining_)
+ return false; // not enough room, even after we get back last_alloc_ space
+ const char* old_freestart = freestart_; // where last alloc used to end
+ freestart_ = last_alloc_ + newsize; // where last alloc ends now
+ remaining_ -= (freestart_ - old_freestart); // how much new space we've taken
+
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, newsize);
+ ASAN_POISON_MEMORY_REGION(freestart_, remaining_);
+#endif
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// UnsafeArena::Realloc()
+// SafeArena::Realloc()
+// If you decide you want to grow -- or shrink -- a memory region,
+// we'll do it for you here. Typically this will involve copying
+// the existing memory to somewhere else on the arena that has
+// more space reserved. But if you're reallocing the last-allocated
+// block, we may be able to accommodate you just by updating a
+// pointer. In any case, we return a pointer to the new memory
+// location, which may be the same as the pointer you passed in.
+// Here's an example of how you might use Realloc():
+//
+// compr_buf = arena->Alloc(uncompr_size); // get too-much space
+// int compr_size;
+// zlib.Compress(uncompr_buf, uncompr_size, compr_buf, &compr_size);
+// compr_buf = arena->Realloc(compr_buf, uncompr_size, compr_size);
+// ----------------------------------------------------------------------
+
+char* UnsafeArena::Realloc(char* original, size_t oldsize, size_t newsize) {
+ assert(oldsize >= 0 && newsize >= 0);
+ // if original happens to be the last allocation we can avoid fragmentation.
+ if (AdjustLastAlloc(original, newsize)) {
+ return original;
+ }
+
+ char* resized = original;
+ if (newsize > oldsize) {
+ resized = Alloc(newsize);
+ memcpy(resized, original, oldsize);
+ } else {
+ // no need to do anything; we're ain't reclaiming any memory!
+ }
+
+#ifdef ADDRESS_SANITIZER
+ // Alloc already returns unpoisoned memory, but handling both cases here
+ // allows us to poison the old memory without worrying about whether or not it
+ // overlaps with the new memory. Thus, we must poison the old memory first.
+ ASAN_POISON_MEMORY_REGION(original, oldsize);
+ ASAN_UNPOISON_MEMORY_REGION(resized, newsize);
+#endif
+ return resized;
+}
+
+// Avoid weak vtables by defining a dummy key method.
+void UnsafeArena::UnusedKeyMethod() {}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/base/arena.h b/native/utils/base/arena.h
new file mode 100644
index 0000000..28b6f6c
--- /dev/null
+++ b/native/utils/base/arena.h
@@ -0,0 +1,287 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Sometimes it is necessary to allocate a large number of small
+// objects. Doing this the usual way (malloc, new) is slow,
+// especially for multithreaded programs. A BaseArena provides a
+// mark/release method of memory management: it asks for a large chunk
+// from the operating system and doles it out bit by bit as required.
+// Then you free all the memory at once by calling BaseArena::Reset().
+//
+//
+// --Example Uses Of UnsafeArena
+// This is the simplest way. Just create an arena, and whenever you
+// need a block of memory to put something in, call BaseArena::Alloc(). eg
+// s = arena.Alloc(100);
+// snprintf(s, 100, "%s:%d", host, port);
+// arena.Shrink(strlen(s)+1); // optional; see below for use
+//
+// You'll probably use the convenience routines more often:
+// s = arena.Strdup(host); // a copy of host lives in the arena
+// s = arena.Strndup(host, 100); // we guarantee to NUL-terminate!
+// s = arena.Memdup(protobuf, sizeof(protobuf);
+//
+// If you go the Alloc() route, you'll probably allocate too-much-space.
+// You can reclaim the extra space by calling Shrink() before the next
+// Alloc() (or Strdup(), or whatever), with the #bytes you actually used.
+// If you use this method, memory management is easy: just call Alloc()
+// and friends a lot, and call Reset() when you're done with the data.
+//
+// FOR STRINGS: --Uses UnsafeArena
+// This is a special case of STL (below), but is simpler. Use an
+// astring, which acts like a string but allocates from the passed-in
+// arena:
+// astring s(arena); // or "sastring" to use a SafeArena
+// s.assign(host);
+// astring s2(host, hostlen, arena);
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
+
+#include <assert.h>
+#include <string.h>
+
+#include <vector>
+#ifdef ADDRESS_SANITIZER
+#include <sanitizer/asan_interface.h>
+#endif
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+// This class is "thread-compatible": different threads can access the
+// arena at the same time without locking, as long as they use only
+// const methods.
+class BaseArena {
+ protected: // You can't make an arena directly; only a subclass of one
+ BaseArena(char* first_block, const size_t block_size, bool align_to_page);
+
+ public:
+ virtual ~BaseArena();
+
+ virtual void Reset();
+
+ // they're "slow" only 'cause they're virtual (subclasses define "fast" ones)
+ virtual char* SlowAlloc(size_t size) = 0;
+ virtual void SlowFree(void* memory, size_t size) = 0;
+ virtual char* SlowRealloc(char* memory, size_t old_size, size_t new_size) = 0;
+
+ class Status {
+ private:
+ friend class BaseArena;
+ size_t bytes_allocated_;
+
+ public:
+ Status() : bytes_allocated_(0) {}
+ size_t bytes_allocated() const { return bytes_allocated_; }
+ };
+
+ // Accessors and stats counters
+ // This accessor isn't so useful here, but is included so we can be
+ // type-compatible with ArenaAllocator (in arena_allocator.h). That is,
+ // we define arena() because ArenaAllocator does, and that way you
+ // can template on either of these and know it's safe to call arena().
+ virtual BaseArena* arena() { return this; }
+ size_t block_size() const { return block_size_; }
+ int block_count() const;
+ bool is_empty() const {
+ // must check block count in case we allocated a block larger than blksize
+ return freestart_ == freestart_when_empty_ && 1 == block_count();
+ }
+
+ // The alignment that ArenaAllocator uses except for 1-byte objects.
+ static constexpr int kDefaultAlignment = 8;
+
+ protected:
+ bool SatisfyAlignment(const size_t alignment);
+ void MakeNewBlock(const uint32 alignment);
+ void* GetMemoryFallback(const size_t size, const int align);
+ void* GetMemory(const size_t size, const int align) {
+ assert(remaining_ <= block_size_); // an invariant
+ if (size > 0 && size <= remaining_ && align == 1) { // common case
+ last_alloc_ = freestart_;
+ freestart_ += size;
+ remaining_ -= size;
+#ifdef ADDRESS_SANITIZER
+ ASAN_UNPOISON_MEMORY_REGION(last_alloc_, size);
+#endif
+ return reinterpret_cast<void*>(last_alloc_);
+ }
+ return GetMemoryFallback(size, align);
+ }
+
+ // This doesn't actually free any memory except for the last piece allocated
+ void ReturnMemory(void* memory, const size_t size) {
+ if (memory == last_alloc_ &&
+ size == static_cast<size_t>(freestart_ - last_alloc_)) {
+ remaining_ += size;
+ freestart_ = last_alloc_;
+ }
+#ifdef ADDRESS_SANITIZER
+ ASAN_POISON_MEMORY_REGION(memory, size);
+#endif
+ }
+
+ // This is used by Realloc() -- usually we Realloc just by copying to a
+ // bigger space, but for the last alloc we can realloc by growing the region.
+ bool AdjustLastAlloc(void* last_alloc, const size_t newsize);
+
+ Status status_;
+ size_t remaining_;
+
+ private:
+ struct AllocatedBlock {
+ char* mem;
+ size_t size;
+ size_t alignment;
+ };
+
+ // Allocate new new block of at least block_size, with the specified
+ // alignment.
+ // The returned AllocatedBlock* is valid until the next call to AllocNewBlock
+ // or Reset (i.e. anything that might affect overflow_blocks_).
+ AllocatedBlock* AllocNewBlock(const size_t block_size,
+ const uint32 alignment);
+
+ const AllocatedBlock* IndexToBlock(int index) const;
+
+ const size_t block_size_;
+ char* freestart_; // beginning of the free space in most recent block
+ char* freestart_when_empty_; // beginning of the free space when we're empty
+ char* last_alloc_; // used to make sure ReturnBytes() is safe
+ // if the first_blocks_ aren't enough, expand into overflow_blocks_.
+ std::vector<AllocatedBlock>* overflow_blocks_;
+ // STL vector isn't as efficient as it could be, so we use an array at first
+ const bool first_block_externally_owned_; // true if they pass in 1st block
+ const bool page_aligned_; // when true, all blocks need to be page aligned
+ int8_t blocks_alloced_; // how many of the first_blocks_ have been allocated
+ AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary
+
+ void FreeBlocks(); // Frees all except first block
+
+ BaseArena(const BaseArena&) = delete;
+ BaseArena& operator=(const BaseArena&) = delete;
+};
+
+class UnsafeArena : public BaseArena {
+ public:
+ // Allocates a thread-compatible arena with the specified block size.
+ explicit UnsafeArena(const size_t block_size)
+ : BaseArena(nullptr, block_size, false) {}
+ UnsafeArena(const size_t block_size, bool align)
+ : BaseArena(nullptr, block_size, align) {}
+
+ // Allocates a thread-compatible arena with the specified block
+ // size. "first_block" must have size "block_size". Memory is
+ // allocated from "first_block" until it is exhausted; after that
+ // memory is allocated by allocating new blocks from the heap.
+ UnsafeArena(char* first_block, const size_t block_size)
+ : BaseArena(first_block, block_size, false) {}
+ UnsafeArena(char* first_block, const size_t block_size, bool align)
+ : BaseArena(first_block, block_size, align) {}
+
+ char* Alloc(const size_t size) {
+ return reinterpret_cast<char*>(GetMemory(size, 1));
+ }
+ void* AllocAligned(const size_t size, const int align) {
+ return GetMemory(size, align);
+ }
+
+ // Allocates and initializes an object on the arena.
+ template <typename T, typename... Args>
+ T* AllocAndInit(Args... args) {
+ return new (reinterpret_cast<T*>(AllocAligned(sizeof(T), alignof(T))))
+ T(std::forward<Args>(args)...);
+ }
+
+ char* Calloc(const size_t size) {
+ void* return_value = Alloc(size);
+ memset(return_value, 0, size);
+ return reinterpret_cast<char*>(return_value);
+ }
+
+ void* CallocAligned(const size_t size, const int align) {
+ void* return_value = AllocAligned(size, align);
+ memset(return_value, 0, size);
+ return return_value;
+ }
+
+ // Free does nothing except for the last piece allocated.
+ void Free(void* memory, size_t size) { ReturnMemory(memory, size); }
+ char* SlowAlloc(size_t size) override { // "slow" 'cause it's virtual
+ return Alloc(size);
+ }
+ void SlowFree(void* memory,
+ size_t size) override { // "slow" 'cause it's virt
+ Free(memory, size);
+ }
+ char* SlowRealloc(char* memory, size_t old_size, size_t new_size) override {
+ return Realloc(memory, old_size, new_size);
+ }
+
+ char* Memdup(const char* s, size_t bytes) {
+ char* newstr = Alloc(bytes);
+ memcpy(newstr, s, bytes);
+ return newstr;
+ }
+ char* MemdupPlusNUL(const char* s, size_t bytes) { // like "string(s, len)"
+ char* newstr = Alloc(bytes + 1);
+ memcpy(newstr, s, bytes);
+ newstr[bytes] = '\0';
+ return newstr;
+ }
+ char* Strdup(const char* s) { return Memdup(s, strlen(s) + 1); }
+ // Unlike libc's strncpy, I always NUL-terminate. libc's semantics are dumb.
+ // This will allocate at most n+1 bytes (+1 is for the nul terminator).
+ char* Strndup(const char* s, size_t n) {
+ // Use memchr so we don't walk past n.
+ // We can't use the one in //strings since this is the base library,
+ // so we have to reinterpret_cast from the libc void*.
+ const char* eos = reinterpret_cast<const char*>(memchr(s, '\0', n));
+ // if no null terminator found, use full n
+ const size_t bytes = (eos == nullptr) ? n : eos - s;
+ return MemdupPlusNUL(s, bytes);
+ }
+
+ // You can realloc a previously-allocated string either bigger or smaller.
+ // We can be more efficient if you realloc a string right after you allocate
+ // it (eg allocate way-too-much space, fill it, realloc to just-big-enough)
+ char* Realloc(char* original, size_t oldsize, size_t newsize);
+ // If you know the new size is smaller (or equal), you don't need to know
+ // oldsize. We don't check that newsize is smaller, so you'd better be sure!
+ char* Shrink(char* s, size_t newsize) {
+ AdjustLastAlloc(s, newsize); // reclaim space if we can
+ return s; // never need to move if we go smaller
+ }
+
+ // We make a copy so you can keep track of status at a given point in time
+ Status status() const { return status_; }
+
+ // Number of bytes remaining before the arena has to allocate another block.
+ size_t bytes_until_next_allocation() const { return remaining_; }
+
+ private:
+ UnsafeArena(const UnsafeArena&) = delete;
+ UnsafeArena& operator=(const UnsafeArena&) = delete;
+
+ virtual void UnusedKeyMethod(); // Dummy key method to avoid weak vtable.
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_ARENA_H_
diff --git a/native/utils/base/arena_leakage_unittest.cc b/native/utils/base/arena_leakage_unittest.cc
new file mode 100644
index 0000000..642dacd
--- /dev/null
+++ b/native/utils/base/arena_leakage_unittest.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/arena.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+TEST(Arena, Leakage) {
+ UnsafeArena arena(32);
+ // Grab just 10 bytes.
+ EXPECT_EQ(arena.bytes_until_next_allocation(), 32);
+ const char* block = arena.Alloc(10);
+ EXPECT_NE(block, nullptr);
+ EXPECT_EQ(arena.bytes_until_next_allocation(), 22);
+ // Grab the rest.
+ const char* expected_next_block = block + 10;
+ const char* next_block = arena.Alloc(22);
+ // If the below test fails, a new block has been allocated for "next_block".
+ // This means that the last 22 bytes of the previous block have been lost.
+ EXPECT_EQ(next_block, expected_next_block);
+ EXPECT_EQ(arena.bytes_until_next_allocation(), 0);
+ // Try allocating a 0 bytes block. Arena should remain unchanged.
+ const char* null_block = arena.Alloc(0);
+ EXPECT_EQ(null_block, nullptr);
+ EXPECT_EQ(arena.bytes_until_next_allocation(), 0);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/base/casts.h b/native/utils/base/casts.h
similarity index 100%
rename from utils/base/casts.h
rename to native/utils/base/casts.h
diff --git a/utils/base/config.h b/native/utils/base/config.h
similarity index 100%
rename from utils/base/config.h
rename to native/utils/base/config.h
diff --git a/utils/base/endian.h b/native/utils/base/endian.h
similarity index 100%
rename from utils/base/endian.h
rename to native/utils/base/endian.h
diff --git a/native/utils/base/integral_types.h b/native/utils/base/integral_types.h
new file mode 100644
index 0000000..7cb417a
--- /dev/null
+++ b/native/utils/base/integral_types.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Basic integer type definitions.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
+
+#include "utils/base/config.h"
+
+namespace libtextclassifier3 {
+
+typedef unsigned int uint;
+typedef unsigned int uint32;
+typedef unsigned long long uint64;
+
+#ifndef SWIG
+typedef int int32;
+typedef unsigned char uint8; // NOLINT
+typedef signed char int8; // NOLINT
+typedef unsigned short uint16; // NOLINT
+typedef signed short int16; // NOLINT
+
+// A type to represent a Unicode code-point value. As of Unicode 4.0,
+// such values require up to 21 bits.
+// (For type-checking on pointers, make this explicitly signed,
+// and it should always be the signed version of whatever int32 is.)
+typedef signed int char32;
+#endif // SWIG
+
+#ifdef COMPILER_MSVC
+typedef __int64 int64;
+#else
+typedef long long int64; // NOLINT
+#endif // COMPILER_MSVC
+
+// Some compile-time assertions that our new types have the intended size.
+// static_assert exists only since C++11, so we need an ifdef.
+#ifdef LANG_CXX11
+static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
+static_assert(sizeof(uint32) == 4, "wrong size");
+static_assert(sizeof(int32) == 4, "wrong size");
+static_assert(sizeof(uint8) == 1, "wrong size");
+static_assert(sizeof(int8) == 1, "wrong size");
+static_assert(sizeof(uint16) == 2, "wrong size");
+static_assert(sizeof(int16) == 2, "wrong size");
+static_assert(sizeof(char32) == 4, "wrong size");
+static_assert(sizeof(int64) == 8, "wrong size");
+#endif // LANG_CXX11
+
+// There are still some requirements that we build these headers in
+// C-compatibility mode. Unfortunately, -Wall doesn't like c-style
+// casts, and C doesn't know how to read braced-initialization for
+// integers.
+#if defined(__cplusplus)
+const uint16 kuint16max{0xFFFF};
+const int16 kint16max{0x7FFF};
+const int16 kint16min{~0x7FFF};
+const uint32 kuint32max{0xFFFFFFFF};
+const int32 kint32max{0x7FFFFFFF};
+#else // not __cplusplus, this branch exists only for C-compat
+static const uint16 kuint16max = ((uint16)0xFFFF);
+static const int16 kint16min = ((int16)~0x7FFF);
+static const int16 kint16max = ((int16)0x7FFF);
+static const uint32 kuint32max = ((uint32)0xFFFFFFFF);
+static const int32 kint32max = ((int32)0x7FFFFFFF);
+#endif // __cplusplus
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
diff --git a/native/utils/base/logging.cc b/native/utils/base/logging.cc
new file mode 100644
index 0000000..ddd1170
--- /dev/null
+++ b/native/utils/base/logging.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/logging.h"
+
+#include <stdlib.h>
+
+#include <exception>
+
+#include "utils/base/logging_raw.h"
+
+namespace libtextclassifier3 {
+namespace logging {
+
+namespace {
+// Returns pointer to beginning of last /-separated token from file_name.
+// file_name should be a pointer to a zero-terminated array of chars.
+// E.g., "foo/bar.cc" -> "bar.cc", "foo/" -> "", "foo" -> "foo".
+const char *JumpToBasename(const char *file_name) {
+ if (file_name == nullptr) {
+ return nullptr;
+ }
+
+ // Points to the beginning of the last encountered token.
+ const char *last_token_start = file_name;
+ while (*file_name != '\0') {
+ if (*file_name == '/') {
+ // Found token separator. A new (potentially empty) token starts after
+ // this position. Notice that if file_name is a valid zero-terminated
+ // string, file_name + 1 is a valid pointer (there is at least one char
+ // after address file_name, the zero terminator).
+ last_token_start = file_name + 1;
+ }
+ file_name++;
+ }
+ return last_token_start;
+}
+} // namespace
+
+LogMessage::LogMessage(LogSeverity severity, const char *file_name,
+ int line_number)
+ : severity_(severity) {
+ stream_ << JumpToBasename(file_name) << ":" << line_number << ": ";
+}
+
+LogMessage::~LogMessage() {
+ LowLevelLogging(severity_, /* tag = */ "txtClsf", stream_.message);
+ if (severity_ == FATAL) {
+ std::terminate(); // Will print a stacktrace (stdout or logcat).
+ }
+}
+
+} // namespace logging
+} // namespace libtextclassifier3
diff --git a/native/utils/base/logging.h b/native/utils/base/logging.h
new file mode 100644
index 0000000..eae71b9
--- /dev/null
+++ b/native/utils/base/logging.h
@@ -0,0 +1,180 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
+
+#include <cassert>
+#include <string>
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging_levels.h"
+#include "utils/base/port.h"
+
+
+namespace libtextclassifier3 {
+namespace logging {
+
+// A tiny code footprint string stream for assembling log messages.
+struct LoggingStringStream {
+ LoggingStringStream() {}
+ LoggingStringStream& stream() { return *this; }
+ // Needed for invocation in TC3_CHECK macro.
+ explicit operator bool() const { return true; }
+
+ std::string message;
+};
+
+template <typename T>
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ const T& entry) {
+ stream.message.append(std::to_string(entry));
+ return stream;
+}
+
+template <typename T>
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ T* const entry) {
+ stream.message.append(std::to_string(reinterpret_cast<const uint64>(entry)));
+ return stream;
+}
+
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ const char* message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ const std::string& message) {
+ stream.message.append(message);
+ return stream;
+}
+
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ const std::string_view message) {
+ stream.message.append(message);
+ return stream;
+}
+
+template <typename T1, typename T2>
+inline LoggingStringStream& operator<<(LoggingStringStream& stream,
+ const std::pair<T1, T2>& entry) {
+ stream << "(" << entry.first << ", " << entry.second << ")";
+ return stream;
+}
+
+// The class that does all the work behind our TC3_LOG(severity) macros. Each
+// TC3_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
+// LogMessage temporary object containing a stringstream. Each operator<< adds
+// info to that stringstream and the LogMessage destructor performs the actual
+// logging. The reason this works is that in C++, "all temporary objects are
+// destroyed as the last step in evaluating the full-expression that (lexically)
+// contains the point where they were created." For more info, see
+// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
+// invoked after the last << from that logging statement.
+class LogMessage {
+ public:
+ LogMessage(LogSeverity severity, const char* file_name,
+ int line_number) TC3_ATTRIBUTE_NOINLINE;
+
+ ~LogMessage() TC3_ATTRIBUTE_NOINLINE;
+
+ // Returns the stream associated with the logger object.
+ LoggingStringStream& stream() { return stream_; }
+
+ private:
+ const LogSeverity severity_;
+
+ // Stream that "prints" all info into a string (not to a file). We construct
+ // here the entire logging message and next print it in one operation.
+ LoggingStringStream stream_;
+};
+
+// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
+// anything.
+class NullStream {
+ public:
+ NullStream() {}
+ NullStream& stream() { return *this; }
+};
+template <typename T>
+inline NullStream& operator<<(NullStream& str, const T&) {
+ return str;
+}
+
+} // namespace logging
+} // namespace libtextclassifier3
+
+#define TC3_LOG(severity) \
+ ::libtextclassifier3::logging::LogMessage( \
+ ::libtextclassifier3::logging::severity, __FILE__, __LINE__) \
+ .stream()
+
+// If condition x is true, does nothing. Otherwise, crashes the program (like
+// LOG(FATAL)) with an informative message. Can be continued with extra
+// messages, via <<, like any logging macro, e.g.,
+//
+// TC3_CHECK(my_cond) << "I think we hit a problem";
+#define TC3_CHECK(x) \
+ (x) || TC3_LOG(FATAL) << __FILE__ << ":" << __LINE__ << ": check failed: \"" \
+ << #x << "\" "
+
+#define TC3_CHECK_EQ(x, y) TC3_CHECK((x) == (y))
+#define TC3_CHECK_LT(x, y) TC3_CHECK((x) < (y))
+#define TC3_CHECK_GT(x, y) TC3_CHECK((x) > (y))
+#define TC3_CHECK_LE(x, y) TC3_CHECK((x) <= (y))
+#define TC3_CHECK_GE(x, y) TC3_CHECK((x) >= (y))
+#define TC3_CHECK_NE(x, y) TC3_CHECK((x) != (y))
+
+#define TC3_NULLSTREAM ::libtextclassifier3::logging::NullStream().stream()
+
+// Debug checks: a TC3_DCHECK<suffix> macro should behave like TC3_CHECK<suffix>
+// in debug mode an don't check / don't print anything in non-debug mode.
+#if defined(NDEBUG) && !defined(TC3_DEBUG_LOGGING) && !defined(TC3_DEBUG_CHECKS)
+
+#define TC3_DCHECK(x) TC3_NULLSTREAM
+#define TC3_DCHECK_EQ(x, y) TC3_NULLSTREAM
+#define TC3_DCHECK_LT(x, y) TC3_NULLSTREAM
+#define TC3_DCHECK_GT(x, y) TC3_NULLSTREAM
+#define TC3_DCHECK_LE(x, y) TC3_NULLSTREAM
+#define TC3_DCHECK_GE(x, y) TC3_NULLSTREAM
+#define TC3_DCHECK_NE(x, y) TC3_NULLSTREAM
+
+#else // NDEBUG
+
+// In debug mode, each TC3_DCHECK<suffix> is equivalent to TC3_CHECK<suffix>,
+// i.e., a real check that crashes when the condition is not true.
+#define TC3_DCHECK(x) TC3_CHECK(x)
+#define TC3_DCHECK_EQ(x, y) TC3_CHECK_EQ(x, y)
+#define TC3_DCHECK_LT(x, y) TC3_CHECK_LT(x, y)
+#define TC3_DCHECK_GT(x, y) TC3_CHECK_GT(x, y)
+#define TC3_DCHECK_LE(x, y) TC3_CHECK_LE(x, y)
+#define TC3_DCHECK_GE(x, y) TC3_CHECK_GE(x, y)
+#define TC3_DCHECK_NE(x, y) TC3_CHECK_NE(x, y)
+
+#endif // NDEBUG
+
+#ifdef TC3_ENABLE_VLOG
+#define TC3_VLOG(severity) \
+ ::libtextclassifier3::logging::LogMessage( \
+ ::libtextclassifier3::logging::INFO, __FILE__, __LINE__) \
+ .stream()
+#else
+#define TC3_VLOG(severity) TC3_NULLSTREAM
+#endif
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
diff --git a/utils/base/logging_levels.h b/native/utils/base/logging_levels.h
similarity index 100%
rename from utils/base/logging_levels.h
rename to native/utils/base/logging_levels.h
diff --git a/native/utils/base/logging_raw.cc b/native/utils/base/logging_raw.cc
new file mode 100644
index 0000000..e3a73e2
--- /dev/null
+++ b/native/utils/base/logging_raw.cc
@@ -0,0 +1,105 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/logging_raw.h"
+
+#include <stdio.h>
+
+#include <string>
+
+#define TC3_RETURN_IF_NOT_ERROR_OR_FATAL \
+ if (severity != ERROR && severity != FATAL) { \
+ return; \
+ }
+
+// NOTE: this file contains two implementations: one for Android, one for all
+// other cases. We always build exactly one implementation.
+#if defined(__ANDROID__)
+
+// Compiled as part of Android.
+#include <android/log.h>
+
+namespace libtextclassifier3 {
+namespace logging {
+
+namespace {
+// Converts LogSeverity to level for __android_log_write.
+int GetAndroidLogLevel(LogSeverity severity) {
+ switch (severity) {
+ case FATAL:
+ return ANDROID_LOG_FATAL;
+ case ERROR:
+ return ANDROID_LOG_ERROR;
+ case WARNING:
+ return ANDROID_LOG_WARN;
+ case INFO:
+ return ANDROID_LOG_INFO;
+ default:
+ return ANDROID_LOG_DEBUG;
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const std::string& tag,
+ const std::string& message) {
+#if !defined(TC3_DEBUG_LOGGING)
+ TC3_RETURN_IF_NOT_ERROR_OR_FATAL
+#endif
+ const int android_log_level = GetAndroidLogLevel(severity);
+ __android_log_write(android_log_level, tag.c_str(), message.c_str());
+}
+
+} // namespace logging
+} // namespace libtextclassifier3
+
+#else // if defined(__ANDROID__)
+
+// Not on Android: implement LowLevelLogging to print to stderr (see below).
+namespace libtextclassifier3 {
+namespace logging {
+
+namespace {
+// Converts LogSeverity to human-readable text.
+const char *LogSeverityToString(LogSeverity severity) {
+ switch (severity) {
+ case INFO:
+ return "INFO";
+ case WARNING:
+ return "WARNING";
+ case ERROR:
+ return "ERROR";
+ case FATAL:
+ return "FATAL";
+ default:
+ return "UNKNOWN";
+ }
+}
+} // namespace
+
+void LowLevelLogging(LogSeverity severity, const std::string &tag,
+ const std::string &message) {
+#if !defined(TC3_DEBUG_LOGGING)
+ TC3_RETURN_IF_NOT_ERROR_OR_FATAL
+#endif
+ fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
+ message.c_str());
+ fflush(stderr);
+}
+
+} // namespace logging
+} // namespace libtextclassifier3
+
+#endif // if defined(__ANDROID__)
diff --git a/utils/base/logging_raw.h b/native/utils/base/logging_raw.h
similarity index 100%
rename from utils/base/logging_raw.h
rename to native/utils/base/logging_raw.h
diff --git a/native/utils/base/macros.h b/native/utils/base/macros.h
new file mode 100644
index 0000000..260f3a9
--- /dev/null
+++ b/native/utils/base/macros.h
@@ -0,0 +1,138 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
+
+#include "utils/base/config.h"
+
+namespace libtextclassifier3 {
+
+#define TC3_ARRAYSIZE(a) \
+ ((sizeof(a) / sizeof(*(a))) / (size_t)(!(sizeof(a) % sizeof(*(a)))))
+
+#if LANG_CXX11
+#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName &) = delete; \
+ TypeName &operator=(const TypeName &) = delete
+#else // C++98 case follows
+
+// Note that these C++98 implementations cannot completely disallow copying,
+// as members and friends can still accidentally make elided copies without
+// triggering a linker error.
+#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName &); \
+ TypeName &operator=(const TypeName &)
+#endif // LANG_CXX11
+
+// The TC3_FALLTHROUGH_INTENDED macro can be used to annotate implicit
+// fall-through between switch labels:
+//
+// switch (x) {
+// case 40:
+// case 41:
+// if (truth_is_out_there) {
+// ++x;
+// TC3_FALLTHROUGH_INTENDED; // Use instead of/along with annotations in
+// // comments.
+// } else {
+// return x;
+// }
+// case 42:
+// ...
+//
+// As shown in the example above, the TC3_FALLTHROUGH_INTENDED macro should be
+// followed by a semicolon. It is designed to mimic control-flow statements
+// like 'break;', so it can be placed in most places where 'break;' can, but
+// only if there are no statements on the execution path between it and the
+// next switch label.
+//
+// When compiled with clang in C++11 mode, the TC3_FALLTHROUGH_INTENDED macro
+// is expanded to [[clang::fallthrough]] attribute, which is analysed when
+// performing switch labels fall-through diagnostic ('-Wimplicit-fallthrough').
+// See clang documentation on language extensions for details:
+// http://clang.llvm.org/docs/AttributeReference.html#fallthrough-clang-fallthrough
+//
+// When used with unsupported compilers, the TC3_FALLTHROUGH_INTENDED macro has
+// no effect on diagnostics.
+//
+// In either case this macro has no effect on runtime behavior and performance
+// of code.
+#if defined(__clang__) && defined(__has_warning)
+#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
+#define TC3_FALLTHROUGH_INTENDED [[clang::fallthrough]]
+#endif
+#elif defined(__GNUC__) && __GNUC__ >= 7
+#define TC3_FALLTHROUGH_INTENDED [[gnu::fallthrough]]
+#endif
+
+#ifndef TC3_FALLTHROUGH_INTENDED
+#define TC3_FALLTHROUGH_INTENDED \
+ do { \
+ } while (0)
+#endif
+
+#ifdef __has_builtin
+#define TC3_HAS_BUILTIN(x) __has_builtin(x)
+#else
+#define TC3_HAS_BUILTIN(x) 0
+#endif
+
+// Compilers can be told that a certain branch is not likely to be taken
+// (for instance, a CHECK failure), and use that information in static
+// analysis. Giving it this information can help it optimize for the
+// common case in the absence of better information (ie.
+// -fprofile-arcs).
+//
+// We need to disable this for GPU builds, though, since nvcc8 and older
+// don't recognize `__builtin_expect` as a builtin, and fail compilation.
+#if (!defined(__NVCC__)) && (TC3_HAS_BUILTIN(__builtin_expect) || \
+ (defined(__GNUC__) && __GNUC__ >= 3))
+#define TC3_PREDICT_FALSE(x) (__builtin_expect(x, 0))
+#define TC3_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
+#else
+#define TC3_PREDICT_FALSE(x) (x)
+#define TC3_PREDICT_TRUE(x) (x)
+#endif
+
+// TC3_HAVE_ATTRIBUTE
+//
+// A function-like feature checking macro that is a wrapper around
+// `__has_attribute`, which is defined by GCC 5+ and Clang and evaluates to a
+// nonzero constant integer if the attribute is supported or 0 if not.
+//
+// It evaluates to zero if `__has_attribute` is not defined by the compiler.
+//
+// GCC: https://gcc.gnu.org/gcc-5/changes.html
+// Clang: https://clang.llvm.org/docs/LanguageExtensions.html
+#ifdef __has_attribute
+#define TC3_HAVE_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define TC3_HAVE_ATTRIBUTE(x) 0
+#endif
+
+// TC3_ATTRIBUTE_PACKED
+//
+// Prevents the compiler from padding a structure to natural alignment
+#if TC3_HAVE_ATTRIBUTE(packed) || (defined(__GNUC__) && !defined(__clang__))
+#define TC3_ATTRIBUTE_PACKED __attribute__((__packed__))
+#else
+#define TC3_ATTRIBUTE_PACKED
+#endif
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
diff --git a/utils/base/port.h b/native/utils/base/port.h
similarity index 100%
rename from utils/base/port.h
rename to native/utils/base/port.h
diff --git a/native/utils/base/status.cc b/native/utils/base/status.cc
new file mode 100644
index 0000000..ee9204d
--- /dev/null
+++ b/native/utils/base/status.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/status.h"
+
+namespace libtextclassifier3 {
+
+const Status& Status::OK = *new Status(StatusCode::OK, "");
+const Status& Status::UNKNOWN = *new Status(StatusCode::UNKNOWN, "");
+
+Status::Status() : code_(StatusCode::OK) {}
+Status::Status(StatusCode error, const std::string& message)
+ : code_(error), message_(message) {}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Status& status) {
+ stream << status.error_code();
+ return stream;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/base/status.h b/native/utils/base/status.h
new file mode 100644
index 0000000..865c2df
--- /dev/null
+++ b/native/utils/base/status.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
+
+#include <string>
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+enum class StatusCode {
+ // Not an error; returned on success
+ OK = 0,
+
+ // All of the following StatusCodes represent errors.
+ CANCELLED = 1,
+ UNKNOWN = 2,
+ INVALID_ARGUMENT = 3,
+ DEADLINE_EXCEEDED = 4,
+ NOT_FOUND = 5,
+ ALREADY_EXISTS = 6,
+ PERMISSION_DENIED = 7,
+ RESOURCE_EXHAUSTED = 8,
+ FAILED_PRECONDITION = 9,
+ ABORTED = 10,
+ OUT_OF_RANGE = 11,
+ UNIMPLEMENTED = 12,
+ INTERNAL = 13,
+ UNAVAILABLE = 14,
+ DATA_LOSS = 15,
+ UNAUTHENTICATED = 16
+};
+
+// A Status is a combination of an error code and a string message (for non-OK
+// error codes).
+class Status {
+ public:
+ // Creates an OK status
+ Status();
+
+ // Make a Status from the specified error and message.
+ Status(StatusCode error, const std::string& error_message);
+
+ // Some pre-defined Status objects
+ static const Status& OK;
+ static const Status& UNKNOWN;
+
+ // Accessors
+ bool ok() const { return code_ == StatusCode::OK; }
+ int error_code() const { return static_cast<int>(code_); }
+
+ StatusCode CanonicalCode() const { return code_; }
+
+ const std::string& error_message() const { return message_; }
+
+ // Noop function provided to allow callers to suppress compiler warnings about
+ // ignored return values.
+ void IgnoreError() const {}
+
+ bool operator==(const Status& x) const;
+ bool operator!=(const Status& x) const;
+
+ std::string ToString() const;
+
+ private:
+ StatusCode code_;
+ std::string message_;
+};
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Status& status);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
diff --git a/native/utils/base/status_macros.h b/native/utils/base/status_macros.h
new file mode 100644
index 0000000..40159fe
--- /dev/null
+++ b/native/utils/base/status_macros.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_MACROS_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_MACROS_H_
+
+#include <utility>
+
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+
+namespace libtextclassifier3 {
+
+// An adapter to enable TC3_RETURN_IF_ERROR to be used with either Status or
+// StatusOr.
+class StatusAdapter {
+ public:
+ explicit StatusAdapter(const Status& s) : s_(s) {}
+ explicit StatusAdapter(Status&& s) : s_(std::move(s)) {}
+ template <typename T>
+ explicit StatusAdapter(const StatusOr<T>& s) : s_(s.status()) {}
+ template <typename T>
+ explicit StatusAdapter(StatusOr<T>&& s) : s_(std::move(s).status()) {}
+
+ bool ok() const { return s_.ok(); }
+ explicit operator bool() const { return ok(); }
+
+ const Status& status() const& { return s_; }
+ Status status() && { return std::move(s_); }
+
+ private:
+ Status s_;
+};
+
+} // namespace libtextclassifier3
+
+// Evaluates an expression that produces a `libtextclassifier3::Status`. If the
+// status is not ok, returns it from the current function.
+//
+// For example:
+// libtextclassifier3::Status MultiStepFunction() {
+// TC3_RETURN_IF_ERROR(Function(args...));
+// TC3_RETURN_IF_ERROR(foo.Method(args...));
+// return libtextclassifier3::Status();
+// }
+#define TC3_RETURN_IF_ERROR(expr) \
+ TC3_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ if (::libtextclassifier3::StatusAdapter adapter{expr}) { \
+ } else /* NOLINT */ \
+ return std::move(adapter).status()
+
+// The GNU compiler emits a warning for code like:
+//
+// if (foo)
+// if (bar) { } else baz;
+//
+// because it thinks you might want the else to bind to the first if. This
+// leads to problems with code like:
+//
+// if (do_expr) TC3_RETURN_IF_ERROR(expr);
+//
+// The "switch (0) case 0:" idiom is used to suppress this.
+#define TC3_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ switch (0) \
+ case 0: \
+ default: // NOLINT
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_MACROS_H_
diff --git a/native/utils/base/status_test.cc b/native/utils/base/status_test.cc
new file mode 100644
index 0000000..82d5aad
--- /dev/null
+++ b/native/utils/base/status_test.cc
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/status.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/status_macros.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(StatusTest, PrintsAbortedStatus) {
+ logging::LoggingStringStream stream;
+ stream << Status::UNKNOWN;
+ EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
+ EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(Status::UNKNOWN.error_message(), "");
+ EXPECT_EQ(stream.message, "2");
+}
+
+TEST(StatusTest, PrintsOKStatus) {
+ logging::LoggingStringStream stream;
+ stream << Status::OK;
+ EXPECT_EQ(Status::OK.error_code(), 0);
+ EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(Status::OK.error_message(), "");
+ EXPECT_EQ(stream.message, "0");
+}
+
+TEST(StatusTest, UnknownStatusHasRightAttributes) {
+ EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
+ EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(Status::UNKNOWN.error_message(), "");
+}
+
+TEST(StatusTest, OkStatusHasRightAttributes) {
+ EXPECT_EQ(Status::OK.error_code(), 0);
+ EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(Status::OK.error_message(), "");
+}
+
+TEST(StatusTest, CustomStatusHasRightAttributes) {
+ Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
+ EXPECT_EQ(status.error_code(), 3);
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
+ EXPECT_EQ(status.error_message(), "You can't put this here!");
+}
+
+TEST(StatusTest, AssignmentPreservesMembers) {
+ Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
+
+ Status status2 = status;
+
+ EXPECT_EQ(status2.error_code(), 3);
+ EXPECT_EQ(status2.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
+ EXPECT_EQ(status2.error_message(), "You can't put this here!");
+}
+
+TEST(StatusTest, ReturnIfErrorOkStatus) {
+ bool returned_due_to_error = true;
+ auto lambda = [&returned_due_to_error](const Status& s) {
+ TC3_RETURN_IF_ERROR(s);
+ returned_due_to_error = false;
+ return Status::OK;
+ };
+
+ // OK should allow execution to continue and the returned status should also
+ // be OK.
+ Status status = lambda(Status());
+ EXPECT_EQ(status.error_code(), 0);
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(status.error_message(), "");
+ EXPECT_FALSE(returned_due_to_error);
+}
+
+TEST(StatusTest, ReturnIfErrorInvalidArgumentStatus) {
+ bool returned_due_to_error = true;
+ auto lambda = [&returned_due_to_error](const Status& s) {
+ TC3_RETURN_IF_ERROR(s);
+ returned_due_to_error = false;
+ return Status::OK;
+ };
+
+ // INVALID_ARGUMENT should cause an early return.
+ Status invalid_arg_status(StatusCode::INVALID_ARGUMENT, "You can't do that!");
+ Status status = lambda(invalid_arg_status);
+ EXPECT_EQ(status.error_code(), 3);
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
+ EXPECT_EQ(status.error_message(), "You can't do that!");
+ EXPECT_TRUE(returned_due_to_error);
+}
+
+TEST(StatusTest, ReturnIfErrorUnknownStatus) {
+ bool returned_due_to_error = true;
+ auto lambda = [&returned_due_to_error](const Status& s) {
+ TC3_RETURN_IF_ERROR(s);
+ returned_due_to_error = false;
+ return Status::OK;
+ };
+
+ // UNKNOWN should cause an early return.
+ Status unknown_status(StatusCode::UNKNOWN,
+ "We also know there are known unknowns.");
+ libtextclassifier3::Status status = lambda(unknown_status);
+ EXPECT_EQ(status.error_code(), 2);
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(status.error_message(), "We also know there are known unknowns.");
+ EXPECT_TRUE(returned_due_to_error);
+}
+
+TEST(StatusTest, ReturnIfErrorOnlyInvokesExpressionOnce) {
+ int num_invocations = 0;
+ auto ok_internal_expr = [&num_invocations]() {
+ ++num_invocations;
+ return Status::OK;
+ };
+ auto ok_lambda = [&ok_internal_expr]() {
+ TC3_RETURN_IF_ERROR(ok_internal_expr());
+ return Status::OK;
+ };
+
+ libtextclassifier3::Status status = ok_lambda();
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(num_invocations, 1);
+
+ num_invocations = 0;
+ auto error_internal_expr = [&num_invocations]() {
+ ++num_invocations;
+ return Status::UNKNOWN;
+ };
+ auto error_lambda = [&error_internal_expr]() {
+ TC3_RETURN_IF_ERROR(error_internal_expr());
+ return Status::OK;
+ };
+
+ status = error_lambda();
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(num_invocations, 1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/base/statusor.h b/native/utils/base/statusor.h
new file mode 100644
index 0000000..dde9ecd
--- /dev/null
+++ b/native/utils/base/statusor.h
@@ -0,0 +1,344 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
+
+#include <type_traits>
+#include <utility>
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/base/status.h"
+
+namespace libtextclassifier3 {
+
+// A StatusOr holds a Status (in the case of an error), or a value T.
+template <typename T>
+class StatusOr {
+ public:
+ // Has status UNKNOWN.
+ inline StatusOr();
+
+ // Builds from a non-OK status. Crashes if an OK status is specified.
+ inline StatusOr(const Status& status); // NOLINT
+
+ // Builds from the specified value.
+ inline StatusOr(const T& value); // NOLINT
+ inline StatusOr(T&& value); // NOLINT
+
+ // Copy constructor.
+ inline StatusOr(const StatusOr& other);
+ // Move constructor.
+ inline StatusOr(StatusOr&& other);
+
+ // Conversion copy constructor, T must be copy constructible from U.
+ template <typename U,
+ std::enable_if_t<
+ std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>>::value,
+ int> = 0>
+ inline StatusOr(const StatusOr<U>& other); // NOLINT
+
+ // Conversion move constructor, T must by move constructible from U.
+ template <
+ typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int> = 0>
+ inline StatusOr(StatusOr<U>&& other); // NOLINT
+
+ // Value conversion copy constructor, T must by copy constructible from U.
+ template <typename U,
+ std::enable_if_t<
+ std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>>::value,
+ int> = 0>
+ inline StatusOr(const U& value); // NOLINT
+
+ // Value conversion move constructor, T must by move constructible from U.
+ template <
+ typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int> = 0>
+ inline StatusOr(U&& value); // NOLINT
+
+ // Assignment operator.
+ inline StatusOr& operator=(const StatusOr& other);
+ inline StatusOr& operator=(StatusOr&& other);
+
+ // Conversion assignment operator, T must be assignable from U
+ template <typename U>
+ inline StatusOr& operator=(const StatusOr<U>& other);
+
+ inline ~StatusOr();
+
+ // Accessors.
+ inline const Status& status() const& { return status_; }
+ inline Status status() && { return std::move(status_); }
+
+ // Shorthand for status().ok().
+ inline bool ok() const { return status_.ok(); }
+
+ // Returns value or crashes if ok() is false.
+ inline const T& ValueOrDie() const& {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+ inline T& ValueOrDie() & {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+ inline const T&& ValueOrDie() const&& {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return std::move(value_);
+ }
+ inline T&& ValueOrDie() && {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return std::move(value_);
+ }
+
+ template <typename U>
+ friend class StatusOr;
+
+ private:
+ Status status_;
+ // The members of unions do not require initialization and are not destructed
+ // unless specifically called. This allows us to construct instances of
+ // StatusOr with only error statuses where T is not default constructible.
+ union {
+ // value_ is active iff status_.ok()==true
+ // WARNING: The destructor of value_ is called ONLY if status_ is OK.
+ T value_;
+ };
+};
+
+// Implementation.
+
+template <typename T>
+inline StatusOr<T>::StatusOr() : status_(StatusCode::UNKNOWN, "") {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const Status& status) : status_(status) {
+ if (status.ok()) {
+ TC3_LOG(FATAL) << "OkStatus() is not a valid argument to StatusOr";
+ exit(1);
+ }
+}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const T& value) : value_(value) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(T&& value) : value_(std::move(value)) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const StatusOr& other)
+ : status_(other.status_), value_(other.value_) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(StatusOr&& other)
+ : status_(other.status_), value_(std::move(other.value_)) {}
+
+template <typename T>
+template <
+ typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>>::value,
+ int>>
+inline StatusOr<T>::StatusOr(const StatusOr<U>& other)
+ : status_(other.status_), value_(other.value_) {}
+
+template <typename T>
+template <typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int>>
+inline StatusOr<T>::StatusOr(StatusOr<U>&& other)
+ : status_(other.status_), value_(std::move(other.value_)) {}
+
+template <typename T>
+template <
+ typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>>::value,
+ int>>
+inline StatusOr<T>::StatusOr(const U& value) : StatusOr(T(value)) {}
+
+template <typename T>
+template <typename U,
+ std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int>>
+inline StatusOr<T>::StatusOr(U&& value) : StatusOr(T(std::forward<U>(value))) {}
+
+template <typename T>
+inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = other.value_;
+ }
+ return *this;
+}
+
+template <typename T>
+inline StatusOr<T>& StatusOr<T>::operator=(StatusOr&& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = std::move(other.value_);
+ }
+ return *this;
+}
+
+template <typename T>
+inline StatusOr<T>::~StatusOr() {
+ if (ok()) {
+ value_.~T();
+ }
+}
+
+template <typename T>
+template <typename U>
+inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = other.value_;
+ }
+ return *this;
+}
+
+} // namespace libtextclassifier3
+
+#define TC3_ASSIGN_OR_RETURN(...) \
+ TC_STATUS_MACROS_IMPL_GET_VARIADIC_( \
+ (__VA_ARGS__, TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
+ (__VA_ARGS__)
+
+#define TC3_ASSIGN_OR_RETURN_NULL(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, nullptr)
+
+#define TC3_ASSIGN_OR_RETURN_FALSE(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN(lhs, rexpr, false)
+
+#define TC3_ASSIGN_OR_RETURN_0(lhs, rexpr) TC3_ASSIGN_OR_RETURN(lhs, rexpr, 0)
+
+// =================================================================
+// == Implementation details, do not rely on anything below here. ==
+// =================================================================
+
+// Some builds do not support C++14 fully yet, using C++11 constexpr technique.
+constexpr bool HasPossiblyConditionalOperator(const char* lhs, int index) {
+ return (index == -1 ? false
+ : (lhs[index] == '?'
+ ? true
+ : HasPossiblyConditionalOperator(lhs, index - 1)));
+}
+
+// MSVC incorrectly expands variadic macros, splice together a macro call to
+// work around the bug.
+#define TC_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
+#define TC_STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
+ TC_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
+
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
+ error_expression) \
+ TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ TC_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
+ error_expression)
+#define TC_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
+ error_expression) \
+ auto statusor = (rexpr); \
+ if (!statusor.ok()) { \
+ ::libtextclassifier3::Status _(std::move(statusor).status()); \
+ (void)_; /* error_expression is allowed to not use this variable */ \
+ return (error_expression); \
+ } \
+ { \
+ static_assert(#lhs[0] != '(' || #lhs[sizeof(#lhs) - 2] != ')' || \
+ !HasPossiblyConditionalOperator(#lhs, sizeof(#lhs) - 2), \
+ "Identified potential conditional operator, consider not " \
+ "using ASSIGN_OR_RETURN"); \
+ } \
+ TC_STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \
+ std::move(statusor).ValueOrDie()
+
+// Internal helpers for macro expansion.
+#define TC_STATUS_MACROS_IMPL_EAT(...)
+#define TC_STATUS_MACROS_IMPL_REM(...) __VA_ARGS__
+#define TC_STATUS_MACROS_IMPL_EMPTY()
+
+// Internal helpers for emptyness arguments check.
+#define TC_STATUS_MACROS_IMPL_IS_EMPTY_INNER(...) \
+ TC_STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(__VA_ARGS__, 0, 1)
+#define TC_STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty
+
+#define TC_STATUS_MACROS_IMPL_IS_EMPTY(...) \
+ TC_STATUS_MACROS_IMPL_IS_EMPTY_I(__VA_ARGS__)
+#define TC_STATUS_MACROS_IMPL_IS_EMPTY_I(...) \
+ TC_STATUS_MACROS_IMPL_IS_EMPTY_INNER(_, ##__VA_ARGS__)
+
+// Internal helpers for if statement.
+#define TC_STATUS_MACROS_IMPL_IF_1(_Then, _Else) _Then
+#define TC_STATUS_MACROS_IMPL_IF_0(_Then, _Else) _Else
+#define TC_STATUS_MACROS_IMPL_IF(_Cond, _Then, _Else) \
+ TC_STATUS_MACROS_IMPL_CONCAT_(TC_STATUS_MACROS_IMPL_IF_, _Cond)(_Then, _Else)
+
+// Expands to 1 if the input is parenthesized. Otherwise expands to 0.
+#define TC_STATUS_MACROS_IMPL_IS_PARENTHESIZED(...) \
+ TC_STATUS_MACROS_IMPL_IS_EMPTY(TC_STATUS_MACROS_IMPL_EAT __VA_ARGS__)
+
+// If the input is parenthesized, removes the parentheses. Otherwise expands to
+// the input unchanged.
+#define TC_STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(...) \
+ TC_STATUS_MACROS_IMPL_IF( \
+ TC_STATUS_MACROS_IMPL_IS_PARENTHESIZED(__VA_ARGS__), \
+ TC_STATUS_MACROS_IMPL_REM, TC_STATUS_MACROS_IMPL_EMPTY()) \
+ __VA_ARGS__
+
+// Internal helper for concatenating macro values.
+#define TC_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
+#define TC_STATUS_MACROS_IMPL_CONCAT_(x, y) \
+ TC_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
diff --git a/native/utils/base/statusor_test.cc b/native/utils/base/statusor_test.cc
new file mode 100644
index 0000000..23165b0
--- /dev/null
+++ b/native/utils/base/statusor_test.cc
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/statusor.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/status.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(StatusOrTest, DoesntDieWhenOK) {
+ StatusOr<std::string> status_or_string = std::string("Hello World");
+ EXPECT_TRUE(status_or_string.ok());
+ EXPECT_EQ(status_or_string.ValueOrDie(), "Hello World");
+}
+
+TEST(StatusOrTest, DiesWhenNotOK) {
+ StatusOr<std::string> status_or_string = {Status::UNKNOWN};
+ EXPECT_FALSE(status_or_string.ok());
+ // Android does not print the error message to stderr, so we are not checking
+ // the error message here.
+ EXPECT_DEATH(status_or_string.ValueOrDie(), "");
+}
+
+// Foo is NOT default constructible and can be implicitly converted to from int.
+class Foo {
+ public:
+ // Copy value conversion
+ Foo(int i) : i_(i) {} // NOLINT
+ int i() const { return i_; }
+
+ private:
+ int i_;
+};
+
+TEST(StatusOrTest, HandlesNonDefaultConstructibleValues) {
+ StatusOr<Foo> foo_or(Foo(7));
+ EXPECT_TRUE(foo_or.ok());
+ EXPECT_EQ(foo_or.ValueOrDie().i(), 7);
+
+ StatusOr<Foo> error_or(Status::UNKNOWN);
+ EXPECT_FALSE(error_or.ok());
+ EXPECT_EQ(error_or.status().CanonicalCode(), StatusCode::UNKNOWN);
+}
+
+class Bar {
+ public:
+ // Move value conversion
+ Bar(Foo&& f) : i_(2 * f.i()) {} // NOLINT
+
+ // Movable, but not copyable.
+ Bar(const Bar& other) = delete;
+ Bar& operator=(const Bar& rhs) = delete;
+ Bar(Bar&& other) = default;
+ Bar& operator=(Bar&& rhs) = default;
+
+ int i() const { return i_; }
+
+ private:
+ int i_;
+};
+
+TEST(StatusOrTest, HandlesValueConversion) {
+ // Copy value conversion constructor : StatusOr<Foo>(const int&)
+ StatusOr<Foo> foo_status(19);
+ EXPECT_TRUE(foo_status.ok());
+ EXPECT_EQ(foo_status.ValueOrDie().i(), 19);
+
+ // Move value conversion constructor : StatusOr<Bar>(Foo&&)
+ StatusOr<Bar> bar_status(std::move(foo_status));
+ EXPECT_TRUE(bar_status.ok());
+ EXPECT_EQ(bar_status.ValueOrDie().i(), 38);
+
+ StatusOr<int> int_status(19);
+ // Copy conversion constructor : StatusOr<Foo>(const StatusOr<int>&)
+ StatusOr<Foo> copied_status(int_status);
+ EXPECT_TRUE(copied_status.ok());
+ EXPECT_EQ(copied_status.ValueOrDie().i(), 19);
+
+ // Move conversion constructor : StatusOr<Bar>(StatusOr<Foo>&&)
+ StatusOr<Bar> moved_status(std::move(copied_status));
+ EXPECT_TRUE(moved_status.ok());
+ EXPECT_EQ(moved_status.ValueOrDie().i(), 38);
+
+ // Move conversion constructor with error : StatusOr<Bar>(StatusOr<Foo>&&)
+ StatusOr<Foo> error_status(Status::UNKNOWN);
+ StatusOr<Bar> moved_error_status(std::move(error_status));
+ EXPECT_FALSE(moved_error_status.ok());
+}
+
+struct OkFn {
+ StatusOr<int> operator()() { return 42; }
+};
+TEST(StatusOrTest, AssignOrReturnValOk) {
+ auto lambda = []() {
+ TC3_ASSIGN_OR_RETURN(int i, OkFn()(), -1);
+ return i;
+ };
+
+ // OkFn() should return a valid integer, so lambda should return that integer.
+ EXPECT_EQ(lambda(), 42);
+}
+
+struct FailFn {
+ StatusOr<int> operator()() { return Status::UNKNOWN; }
+};
+TEST(StatusOrTest, AssignOrReturnValError) {
+ auto lambda = []() {
+ TC3_ASSIGN_OR_RETURN(int i, FailFn()(), -1);
+ return i;
+ };
+
+ // FailFn() should return an error, so lambda should return -1.
+ EXPECT_EQ(lambda(), -1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar-common.h b/native/utils/calendar/calendar-common.h
new file mode 100644
index 0000000..e6fd076
--- /dev/null
+++ b/native/utils/calendar/calendar-common.h
@@ -0,0 +1,351 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
+#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
+
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+
+namespace libtextclassifier3 {
+namespace calendar {
+
+// Macro to reduce the amount of boilerplate needed for propagating errors.
+#define TC3_CALENDAR_CHECK(EXPR) \
+ if (!(EXPR)) { \
+ return false; \
+ }
+
+// An implementation of CalendarLib that is independent of the particular
+// calendar implementation used (implementation type is passed as template
+// argument).
+template <class TCalendar>
+class CalendarLibTempl {
+ public:
+ bool InterpretParseData(const DatetimeParsedData& parse_data,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ bool prefer_future_for_unspecified_date,
+ TCalendar* calendar,
+ DatetimeGranularity* granularity) const;
+
+ DatetimeGranularity GetGranularity(const DatetimeParsedData& data) const;
+
+ private:
+ // Adjusts the calendar's time instant according to a relative date reference
+ // in the parsed data.
+ bool ApplyRelationField(const DatetimeComponent& relative_date_time_component,
+ TCalendar* calendar) const;
+
+ // Round the time instant's precision down to the given granularity.
+ bool RoundToGranularity(DatetimeGranularity granularity,
+ TCalendar* calendar) const;
+
+ // Adjusts time in steps of relation_type, by distance steps.
+ // For example:
+ // - Adjusting by -2 MONTHS will return the beginning of the 1st
+ // two weeks ago.
+ // - Adjusting by +4 Wednesdays will return the beginning of the next
+ // Wednesday at least 4 weeks from now.
+ // If allow_today is true, the same day of the week may be kept
+ // if it already matches the relation type.
+ bool AdjustByRelation(DatetimeComponent date_time_component, int distance,
+ bool allow_today, TCalendar* calendar) const;
+};
+
+inline bool HasOnlyTimeComponents(const DatetimeParsedData& parse_data) {
+ std::vector<DatetimeComponent> components;
+ parse_data.GetDatetimeComponents(&components);
+
+ for (const DatetimeComponent& component : components) {
+ if (!(component.component_type == DatetimeComponent::ComponentType::HOUR ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::MINUTE ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::SECOND ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::MERIDIEM)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <class TCalendar>
+bool CalendarLibTempl<TCalendar>::InterpretParseData(
+ const DatetimeParsedData& parse_data, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const std::string& reference_locale,
+ bool prefer_future_for_unspecified_date, TCalendar* calendar,
+ DatetimeGranularity* granularity) const {
+ TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale,
+ reference_time_ms_utc))
+
+ bool should_round_to_granularity = true;
+ *granularity = GetGranularity(parse_data);
+
+ // Apply each of the parsed fields in order of increasing granularity.
+ static const int64 kMillisInMinute = 1000 * 60;
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::ZONE_OFFSET)) {
+ int zone_offset;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::ZONE_OFFSET,
+ &zone_offset);
+ TC3_CALENDAR_CHECK(calendar->SetZoneOffset(zone_offset * kMillisInMinute))
+ }
+ static const int64 kMillisInHour = 1000 * 60 * 60;
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::DST_OFFSET)) {
+ int dst_offset;
+ if (parse_data.GetFieldValue(DatetimeComponent::ComponentType::DST_OFFSET,
+ &dst_offset)) {
+ TC3_CALENDAR_CHECK(calendar->SetDstOffset(dst_offset * kMillisInHour))
+ }
+ }
+ std::vector<DatetimeComponent> relative_components;
+ parse_data.GetRelativeDatetimeComponents(&relative_components);
+ if (!relative_components.empty()) {
+ // Currently only one relative date time component is possible.
+ const DatetimeComponent& relative_component = relative_components.back();
+ TC3_CALENDAR_CHECK(ApplyRelationField(relative_component, calendar));
+ should_round_to_granularity = relative_component.ShouldRoundToGranularity();
+ } else {
+ // By default, the parsed time is interpreted to be on the reference day.
+ // But a parsed date should have time 0:00:00 unless specified.
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0))
+ TC3_CALENDAR_CHECK(calendar->SetMinute(0))
+ TC3_CALENDAR_CHECK(calendar->SetSecond(0))
+ TC3_CALENDAR_CHECK(calendar->SetMillisecond(0))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::YEAR)) {
+ int year;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::YEAR, &year);
+ TC3_CALENDAR_CHECK(calendar->SetYear(year))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::MONTH)) {
+ int month;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MONTH, &month);
+ // ICU has months starting at 0, Java and Datetime parser at 1, so we
+ // need to subtract 1.
+ TC3_CALENDAR_CHECK(calendar->SetMonth(month - 1))
+ }
+
+ if (parse_data.HasAbsoluteValue(
+ DatetimeComponent::ComponentType::DAY_OF_MONTH)) {
+ int day_of_month;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::DAY_OF_MONTH,
+ &day_of_month);
+ TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(day_of_month))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::HOUR)) {
+ int hour;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::HOUR, &hour);
+ if (parse_data.HasFieldType(DatetimeComponent::ComponentType::MERIDIEM)) {
+ int merdiem;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MERIDIEM,
+ &merdiem);
+ if (merdiem == 1 && hour < 12) {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour + 12))
+ } else if (merdiem == 0 && hour == 12) {
+ // Set hour of the day's value to zero (12am == 0:00 in 24 hour format).
+ // Please see issue b/139923083.
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
+ } else {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour))
+ }
+ } else {
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(hour))
+ }
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::MINUTE)) {
+ int minute;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::MINUTE, &minute);
+ TC3_CALENDAR_CHECK(calendar->SetMinute(minute))
+ }
+ if (parse_data.HasAbsoluteValue(DatetimeComponent::ComponentType::SECOND)) {
+ int second;
+ parse_data.GetFieldValue(DatetimeComponent::ComponentType::SECOND, &second);
+ TC3_CALENDAR_CHECK(calendar->SetSecond(second))
+ }
+ if (should_round_to_granularity) {
+ TC3_CALENDAR_CHECK(RoundToGranularity(*granularity, calendar))
+ }
+
+ int64 calendar_millis;
+ TC3_CALENDAR_CHECK(calendar->GetTimeInMillis(&calendar_millis))
+ if (prefer_future_for_unspecified_date &&
+ calendar_millis < reference_time_ms_utc &&
+ HasOnlyTimeComponents(parse_data)) {
+ calendar->AddDayOfMonth(1);
+ }
+
+ return true;
+}
+
+template <class TCalendar>
+bool CalendarLibTempl<TCalendar>::ApplyRelationField(
+ const DatetimeComponent& relative_date_time_component,
+ TCalendar* calendar) const {
+ switch (relative_date_time_component.relative_qualifier) {
+ case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
+ TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
+ return false;
+ case DatetimeComponent::RelativeQualifier::NEXT:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/1,
+ /*allow_today=*/false, calendar));
+ return true;
+ case DatetimeComponent::RelativeQualifier::THIS:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/1,
+ /*allow_today=*/true, calendar))
+ return true;
+ case DatetimeComponent::RelativeQualifier::LAST:
+ TC3_CALENDAR_CHECK(AdjustByRelation(relative_date_time_component,
+ /*distance=*/-1,
+ /*allow_today=*/false, calendar))
+ return true;
+ case DatetimeComponent::RelativeQualifier::NOW:
+ return true; // NOOP
+ case DatetimeComponent::RelativeQualifier::TOMORROW:
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(1));
+ return true;
+ case DatetimeComponent::RelativeQualifier::YESTERDAY:
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(-1));
+ return true;
+ case DatetimeComponent::RelativeQualifier::PAST:
+ TC3_CALENDAR_CHECK(
+ AdjustByRelation(relative_date_time_component,
+ relative_date_time_component.relative_count,
+ /*allow_today=*/false, calendar))
+ return true;
+ case DatetimeComponent::RelativeQualifier::FUTURE:
+ TC3_CALENDAR_CHECK(
+ AdjustByRelation(relative_date_time_component,
+ relative_date_time_component.relative_count,
+ /*allow_today=*/false, calendar))
+ return true;
+ }
+ return false;
+}
+
+template <class TCalendar>
+bool CalendarLibTempl<TCalendar>::RoundToGranularity(
+ DatetimeGranularity granularity, TCalendar* calendar) const {
+ // Force recomputation before doing the rounding.
+ int unused;
+ TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&unused));
+
+ switch (granularity) {
+ case GRANULARITY_YEAR:
+ TC3_CALENDAR_CHECK(calendar->SetMonth(0));
+ TC3_FALLTHROUGH_INTENDED;
+ case GRANULARITY_MONTH:
+ TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1));
+ TC3_FALLTHROUGH_INTENDED;
+ case GRANULARITY_DAY:
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
+ TC3_FALLTHROUGH_INTENDED;
+ case GRANULARITY_HOUR:
+ TC3_CALENDAR_CHECK(calendar->SetMinute(0));
+ TC3_FALLTHROUGH_INTENDED;
+ case GRANULARITY_MINUTE:
+ TC3_CALENDAR_CHECK(calendar->SetSecond(0));
+ break;
+
+ case GRANULARITY_WEEK:
+ int first_day_of_week;
+ TC3_CALENDAR_CHECK(calendar->GetFirstDayOfWeek(&first_day_of_week));
+ TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(first_day_of_week));
+ TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
+ TC3_CALENDAR_CHECK(calendar->SetMinute(0));
+ TC3_CALENDAR_CHECK(calendar->SetSecond(0));
+ break;
+
+ case GRANULARITY_UNKNOWN:
+ case GRANULARITY_SECOND:
+ break;
+ }
+ return true;
+}
+
+template <class TCalendar>
+bool CalendarLibTempl<TCalendar>::AdjustByRelation(
+ DatetimeComponent date_time_component, int distance, bool allow_today,
+ TCalendar* calendar) const {
+ const int distance_sign = distance < 0 ? -1 : 1;
+ switch (date_time_component.component_type) {
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ if (!allow_today) {
+ // If we're not including the same day as the reference, skip it.
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
+ }
+ // Keep walking back until we hit the desired day of the week.
+ while (distance != 0) {
+ int day_of_week;
+ TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week))
+ if (day_of_week == (date_time_component.value)) {
+ distance += -distance_sign;
+ if (distance == 0) break;
+ }
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
+ }
+ return true;
+ case DatetimeComponent::ComponentType::SECOND:
+ TC3_CALENDAR_CHECK(calendar->AddSecond(distance));
+ return true;
+ case DatetimeComponent::ComponentType::MINUTE:
+ TC3_CALENDAR_CHECK(calendar->AddMinute(distance));
+ return true;
+ case DatetimeComponent::ComponentType::HOUR:
+ TC3_CALENDAR_CHECK(calendar->AddHourOfDay(distance));
+ return true;
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance));
+ return true;
+ case DatetimeComponent::ComponentType::WEEK:
+ TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance))
+ TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1))
+ return true;
+ case DatetimeComponent::ComponentType::MONTH:
+ TC3_CALENDAR_CHECK(calendar->AddMonth(distance))
+ TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1))
+ return true;
+ case DatetimeComponent::ComponentType::YEAR:
+ TC3_CALENDAR_CHECK(calendar->AddYear(distance))
+ TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1))
+ return true;
+ default:
+ TC3_LOG(ERROR) << "Unknown relation type: "
+ << static_cast<int>(date_time_component.component_type);
+ return false;
+ }
+ return false;
+}
+
+template <class TCalendar>
+DatetimeGranularity CalendarLibTempl<TCalendar>::GetGranularity(
+ const DatetimeParsedData& data) const {
+ return data.GetFinestGranularity();
+}
+
+}; // namespace calendar
+
+#undef TC3_CALENDAR_CHECK
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
diff --git a/native/utils/calendar/calendar-javaicu.cc b/native/utils/calendar/calendar-javaicu.cc
new file mode 100644
index 0000000..048df04
--- /dev/null
+++ b/native/utils/calendar/calendar-javaicu.cc
@@ -0,0 +1,218 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/calendar/calendar-javaicu.h"
+
+#include "annotator/types.h"
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Generic version of icu::Calendar::add with error checking.
+bool CalendarAdd(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
+ jint field, jint value) {
+ return JniHelper::CallVoidMethod(jenv, calendar, jni_cache->calendar_add,
+ field, value)
+ .ok();
+}
+
+// Generic version of icu::Calendar::get with error checking.
+bool CalendarGet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
+ jint field, jint* value) {
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallIntMethod(jenv, calendar, jni_cache->calendar_get, field));
+ return true;
+}
+
+// Generic version of icu::Calendar::set with error checking.
+bool CalendarSet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
+ jint field, jint value) {
+ return JniHelper::CallVoidMethod(jenv, calendar, jni_cache->calendar_set,
+ field, value)
+ .ok();
+}
+
+// Extracts the first tag from a BCP47 tag (e.g. "en" for "en-US").
+std::string GetFirstBcp47Tag(const std::string& tag) {
+ for (size_t i = 0; i < tag.size(); ++i) {
+ if (tag[i] == '_' || tag[i] == '-') {
+ return std::string(tag, 0, i);
+ }
+ }
+ return tag;
+}
+
+} // anonymous namespace
+
+Calendar::Calendar(JniCache* jni_cache)
+ : jni_cache_(jni_cache),
+ jenv_(jni_cache_ ? jni_cache->GetEnv() : nullptr),
+ calendar_(nullptr, jenv_) {}
+
+bool Calendar::Initialize(const std::string& time_zone,
+ const std::string& locale, int64 time_ms_utc) {
+ if (!jni_cache_ || !jenv_) {
+ TC3_LOG(ERROR) << "Initialize without env";
+ return false;
+ }
+
+ // We'll assume the day indices match later on, so verify it here.
+ if (jni_cache_->calendar_sunday != kSunday ||
+ jni_cache_->calendar_monday != kMonday ||
+ jni_cache_->calendar_tuesday != kTuesday ||
+ jni_cache_->calendar_wednesday != kWednesday ||
+ jni_cache_->calendar_thursday != kThursday ||
+ jni_cache_->calendar_friday != kFriday ||
+ jni_cache_->calendar_saturday != kSaturday) {
+ TC3_LOG(ERROR) << "day of the week indices mismatch";
+ return false;
+ }
+
+ // Get the time zone.
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> java_time_zone_str,
+ JniHelper::NewStringUTF(jenv_, time_zone.c_str()));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ ScopedLocalRef<jobject> java_time_zone,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->timezone_class.get(),
+ jni_cache_->timezone_get_timezone,
+ java_time_zone_str.get()));
+ if (java_time_zone == nullptr) {
+ TC3_LOG(ERROR) << "failed to get timezone";
+ return false;
+ }
+
+ // Get the locale.
+ ScopedLocalRef<jobject> java_locale(nullptr, jenv_);
+ if (jni_cache_->locale_for_language_tag) {
+ // API level 21+, we can actually parse language tags.
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> java_locale_str,
+ JniHelper::NewStringUTF(jenv_, locale.c_str()));
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ java_locale,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->locale_class.get(),
+ jni_cache_->locale_for_language_tag,
+ java_locale_str.get()));
+ } else {
+ // API level <21. We can't parse tags, so we just use the language.
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ ScopedLocalRef<jstring> java_language_str,
+ JniHelper::NewStringUTF(jenv_, GetFirstBcp47Tag(locale).c_str()));
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ java_locale, JniHelper::NewObject(jenv_, jni_cache_->locale_class.get(),
+ jni_cache_->locale_init_string,
+ java_language_str.get()));
+ }
+ if (java_locale == nullptr) {
+ TC3_LOG(ERROR) << "failed to get locale";
+ return false;
+ }
+
+ // Get the calendar.
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ calendar_, JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->calendar_class.get(),
+ jni_cache_->calendar_get_instance, java_time_zone.get(),
+ java_locale.get()));
+ if (calendar_ == nullptr) {
+ TC3_LOG(ERROR) << "failed to get calendar";
+ return false;
+ }
+
+ // Set the time.
+ if (!JniHelper::CallVoidMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_set_time_in_millis,
+ time_ms_utc)
+ .ok()) {
+ TC3_LOG(ERROR) << "failed to set time";
+ return false;
+ }
+ return true;
+}
+
+bool Calendar::GetFirstDayOfWeek(int* value) const {
+ if (!jni_cache_ || !jenv_ || !calendar_) return false;
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallIntMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_get_first_day_of_week));
+ return true;
+}
+
+bool Calendar::GetTimeInMillis(int64* value) const {
+ if (!jni_cache_ || !jenv_ || !calendar_) return false;
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallLongMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_get_time_in_millis));
+
+ return true;
+}
+
+CalendarLib::CalendarLib() {
+ TC3_LOG(FATAL) << "Java ICU CalendarLib must be initialized with a JniCache.";
+}
+
+CalendarLib::CalendarLib(const std::shared_ptr<JniCache>& jni_cache)
+ : jni_cache_(jni_cache) {}
+
+// Below is the boilerplate code for implementing the specialisations of
+// get/set/add for the various field types.
+#define TC3_DEFINE_FIELD_ACCESSOR(NAME, FIELD, KIND, TYPE) \
+ bool Calendar::KIND##NAME(TYPE value) const { \
+ if (!jni_cache_ || !jenv_ || !calendar_) return false; \
+ return Calendar##KIND(jni_cache_, jenv_, calendar_.get(), \
+ jni_cache_->calendar_##FIELD, value); \
+ }
+#define TC3_DEFINE_ADD(NAME, CONST) \
+ TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Add, int)
+#define TC3_DEFINE_SET(NAME, CONST) \
+ TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Set, int)
+#define TC3_DEFINE_GET(NAME, CONST) \
+ TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Get, int*)
+
+TC3_DEFINE_ADD(Second, second)
+TC3_DEFINE_ADD(Minute, minute)
+TC3_DEFINE_ADD(HourOfDay, hour_of_day)
+TC3_DEFINE_ADD(DayOfMonth, day_of_month)
+TC3_DEFINE_ADD(Year, year)
+TC3_DEFINE_ADD(Month, month)
+TC3_DEFINE_GET(DayOfWeek, day_of_week)
+TC3_DEFINE_SET(ZoneOffset, zone_offset)
+TC3_DEFINE_SET(DstOffset, dst_offset)
+TC3_DEFINE_SET(Year, year)
+TC3_DEFINE_SET(Month, month)
+TC3_DEFINE_SET(DayOfYear, day_of_year)
+TC3_DEFINE_SET(DayOfMonth, day_of_month)
+TC3_DEFINE_SET(DayOfWeek, day_of_week)
+TC3_DEFINE_SET(HourOfDay, hour_of_day)
+TC3_DEFINE_SET(Minute, minute)
+TC3_DEFINE_SET(Second, second)
+TC3_DEFINE_SET(Millisecond, millisecond)
+
+#undef TC3_DEFINE_FIELD_ACCESSOR
+#undef TC3_DEFINE_ADD
+#undef TC3_DEFINE_SET
+#undef TC3_DEFINE_GET
+
+} // namespace libtextclassifier3
diff --git a/native/utils/calendar/calendar-javaicu.h b/native/utils/calendar/calendar-javaicu.h
new file mode 100644
index 0000000..d6e1716
--- /dev/null
+++ b/native/utils/calendar/calendar-javaicu.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
+#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
+
+#include <jni.h>
+
+#include <memory>
+#include <string>
+
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/calendar/calendar-common.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+
+namespace libtextclassifier3 {
+
+class Calendar {
+ public:
+ explicit Calendar(JniCache* jni_cache);
+ bool Initialize(const std::string& time_zone, const std::string& locale,
+ int64 time_ms_utc);
+ bool AddSecond(int value) const;
+ bool AddMinute(int value) const;
+ bool AddHourOfDay(int value) const;
+ bool AddDayOfMonth(int value) const;
+ bool AddYear(int value) const;
+ bool AddMonth(int value) const;
+ bool GetDayOfWeek(int* value) const;
+ bool GetFirstDayOfWeek(int* value) const;
+ bool GetTimeInMillis(int64* value) const;
+ bool SetZoneOffset(int value) const;
+ bool SetDstOffset(int value) const;
+ bool SetYear(int value) const;
+ bool SetMonth(int value) const;
+ bool SetDayOfYear(int value) const;
+ bool SetDayOfMonth(int value) const;
+ bool SetDayOfWeek(int value) const;
+ bool SetHourOfDay(int value) const;
+ bool SetMinute(int value) const;
+ bool SetSecond(int value) const;
+ bool SetMillisecond(int value) const;
+
+ private:
+ JniCache* jni_cache_;
+ JNIEnv* jenv_;
+ ScopedLocalRef<jobject> calendar_;
+};
+
+class CalendarLib {
+ public:
+ CalendarLib();
+ explicit CalendarLib(const std::shared_ptr<JniCache>& jni_cache);
+
+ // Returns false (dummy version).
+ bool InterpretParseData(const DatetimeParsedData& parse_data,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ bool prefer_future_for_unspecified_date,
+ int64* interpreted_time_ms_utc,
+ DatetimeGranularity* granularity) const {
+ Calendar calendar(jni_cache_.get());
+ if (!impl_.InterpretParseData(parse_data, reference_time_ms_utc,
+ reference_timezone, reference_locale,
+ prefer_future_for_unspecified_date, &calendar,
+ granularity)) {
+ return false;
+ }
+ return calendar.GetTimeInMillis(interpreted_time_ms_utc);
+ }
+
+ DatetimeGranularity GetGranularity(const DatetimeParsedData& data) const {
+ return impl_.GetGranularity(data);
+ }
+
+ private:
+ std::shared_ptr<JniCache> jni_cache_;
+ calendar::CalendarLibTempl<Calendar> impl_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
diff --git a/native/utils/calendar/calendar.h b/native/utils/calendar/calendar.h
new file mode 100644
index 0000000..a018a53
--- /dev/null
+++ b/native/utils/calendar/calendar.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
+#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
+
+#if defined TC3_CALENDAR_ICU
+#include "utils/calendar/calendar-icu.h"
+#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR()
+#elif defined TC3_CALENDAR_DUMMY
+#include "utils/calendar/calendar-dummy.h"
+#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR()
+#elif defined TC3_CALENDAR_APPLE
+#include "utils/calendar/calendar-apple.h"
+#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR()
+#elif defined TC3_CALENDAR_JAVAICU
+#include "utils/calendar/calendar-javaicu.h"
+#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR(nullptr)
+#else
+#error No TC3_CALENDAR implementation specified.
+#endif
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
diff --git a/utils/checksum.cc b/native/utils/checksum.cc
similarity index 100%
rename from utils/checksum.cc
rename to native/utils/checksum.cc
diff --git a/utils/checksum.h b/native/utils/checksum.h
similarity index 100%
rename from utils/checksum.h
rename to native/utils/checksum.h
diff --git a/utils/checksum_test.cc b/native/utils/checksum_test.cc
similarity index 100%
rename from utils/checksum_test.cc
rename to native/utils/checksum_test.cc
diff --git a/utils/codepoint-range.cc b/native/utils/codepoint-range.cc
similarity index 100%
rename from utils/codepoint-range.cc
rename to native/utils/codepoint-range.cc
diff --git a/utils/codepoint-range.fbs b/native/utils/codepoint-range.fbs
similarity index 100%
rename from utils/codepoint-range.fbs
rename to native/utils/codepoint-range.fbs
diff --git a/utils/codepoint-range.h b/native/utils/codepoint-range.h
similarity index 100%
rename from utils/codepoint-range.h
rename to native/utils/codepoint-range.h
diff --git a/native/utils/container/double-array-trie.cc b/native/utils/container/double-array-trie.cc
new file mode 100644
index 0000000..f06bf92
--- /dev/null
+++ b/native/utils/container/double-array-trie.cc
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/double-array-trie.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool DoubleArrayTrie::GatherPrefixMatches(
+ StringPiece input, const std::function<void(Match)>& update_fn) const {
+ uint32 pos = 0;
+ if (nodes_length_ == 0) {
+ TC3_LOG(WARNING) << "Trie is empty. Skipping.";
+ return true;
+ }
+ pos = offset(0);
+ for (int i = 0; i < input.size(); i++) {
+ if (input[i] == 0) {
+ break;
+ }
+ pos ^= static_cast<unsigned char>(input[i]);
+ // We exhausted the trie, no more matches possible.
+ if (pos < 0 || pos >= nodes_length_) {
+ break;
+ }
+ if (label(pos) != input[i]) {
+ break;
+ }
+ const bool node_has_leaf = has_leaf(pos);
+ pos ^= offset(pos);
+ if (pos < 0 || pos > nodes_length_) {
+ TC3_LOG(ERROR) << "Out-of-bounds trie search position.";
+ return false;
+ }
+ if (node_has_leaf) {
+ update_fn(Match(/*id=*/value(pos), /*match_length=*/i + 1));
+ }
+ }
+ return true;
+}
+
+bool DoubleArrayTrie::FindAllPrefixMatches(StringPiece input,
+ std::vector<Match>* matches) const {
+ return GatherPrefixMatches(
+ input, [matches](const Match match) { matches->push_back(match); });
+}
+
+bool DoubleArrayTrie::LongestPrefixMatch(StringPiece input,
+ Match* longest_match) const {
+ *longest_match = Match();
+ return GatherPrefixMatches(
+ input, [longest_match](const Match match) { *longest_match = match; });
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/container/double-array-trie.h b/native/utils/container/double-array-trie.h
new file mode 100644
index 0000000..39c8822
--- /dev/null
+++ b/native/utils/container/double-array-trie.h
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CONTAINER_DOUBLE_ARRAY_TRIE_H_
+#define LIBTEXTCLASSIFIER_UTILS_CONTAINER_DOUBLE_ARRAY_TRIE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/base/endian.h"
+#include "utils/base/integral_types.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A trie node specifies a node in the tree, either an intermediate node or
+// a leaf node.
+// A leaf node contains the id as an int of the string match. This id is encoded
+// in the lower 31 bits, thus the number of distinct ids is 2^31.
+// An intermediate node has an associated label and an offset to it's children.
+// The label is encoded in the least significant byte and must match the input
+// character during matching.
+// We account for endianness when using the node values, as they are serialized
+// (in little endian) as bytes in the flatbuffer model.
+typedef uint32 TrieNode;
+
+// A memory mappable trie, compatible with Darts::DoubleArray.
+class DoubleArrayTrie : public StringSet {
+ public:
+ // nodes and nodes_length specify the array of the nodes of the trie.
+ DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
+ : nodes_(nodes), nodes_length_(nodes_length) {}
+
+ // Find matches that are prefixes of a string.
+ bool FindAllPrefixMatches(StringPiece input,
+ std::vector<Match>* matches) const override;
+ // Find the longest prefix match of a string.
+ bool LongestPrefixMatch(StringPiece input,
+ Match* longest_match) const override;
+
+ private:
+ // Returns whether a node as a leaf as a child.
+ bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; }
+
+ // Available when a node is a leaf.
+ int value(uint32 i) const {
+ return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff);
+ }
+
+ // Label associated with a node.
+ // A leaf node will have the MSB set and thus return an invalid label.
+ uint32 label(uint32 i) const {
+ return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff;
+ }
+
+ // Returns offset to children.
+ uint32 offset(uint32 i) const {
+ const uint32 node = LittleEndian::ToHost32(nodes_[i]);
+ return (node >> 10) << ((node & 0x200) >> 6);
+ }
+
+ bool GatherPrefixMatches(StringPiece input,
+ const std::function<void(Match)>& update_fn) const;
+
+ const TrieNode* nodes_;
+ const int nodes_length_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CONTAINER_DOUBLE_ARRAY_TRIE_H_
diff --git a/native/utils/container/sorted-strings-table.cc b/native/utils/container/sorted-strings-table.cc
new file mode 100644
index 0000000..f39d976
--- /dev/null
+++ b/native/utils/container/sorted-strings-table.cc
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/sorted-strings-table.h"
+
+#include <algorithm>
+
+#include "utils/base/endian.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+void SortedStringsTable::GatherPrefixMatches(
+ StringPiece input, const std::function<void(Match)>& update_fn) const {
+ int left = 0;
+ int right = num_pieces_;
+ int span_size = right - left;
+ int match_length = 0;
+
+ // Loop invariant:
+ // at the ith iteration, all strings from `left` ... `right` match the input
+ // on the first `match_length` characters.
+ while (span_size > use_linear_scan_threshold_) {
+ if (match_length >= input.length()) {
+ return;
+ }
+
+ // We find the possible range of pieces in `left` ... `right` matching the
+ // `match_length` + 1 character with two binary searches:
+ // `lower_bound` to find the start of the range of matching pieces.
+ // `upper_bound` to find the non-inclusive end of the range.
+ left = (std::lower_bound(
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 piece_offset, uint32 c) -> bool {
+ return static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]) <
+ LittleEndian::ToHost32(c);
+ }) -
+ offsets_);
+ right = (std::upper_bound(
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 c, uint32 piece_offset) -> bool {
+ return LittleEndian::ToHost32(c) <
+ static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]);
+ }) -
+ offsets_);
+ span_size = right - left;
+ if (span_size <= 0) {
+ return;
+ }
+ ++match_length;
+
+ // Due to the loop invariant and the fact that the strings are sorted, there
+ // can only be one piece matching completely now, namely at left.
+ if (pieces_[LittleEndian::ToHost32(offsets_[left]) + match_length] == 0) {
+ update_fn(Match(/*id=*/left,
+ /*match_length=*/match_length));
+ left++;
+ }
+ }
+
+ // Use linear scan for small problem instances.
+ // By the loop invariant characters 0...`match_length` of all pieces in
+ // in `left`...`right` match the input on 0...`match_length`.
+ for (int i = left; i < right; i++) {
+ bool matches = true;
+ int piece_match_length = match_length;
+ for (int k = LittleEndian::ToHost32(offsets_[i]) + piece_match_length;
+ pieces_[k] != 0; k++) {
+ if (piece_match_length >= input.size() ||
+ input[piece_match_length] != pieces_[k]) {
+ matches = false;
+ break;
+ }
+ piece_match_length++;
+ }
+ if (matches) {
+ update_fn(Match(/*id=*/i,
+ /*match_length=*/piece_match_length));
+ }
+ }
+}
+
+bool SortedStringsTable::FindAllPrefixMatches(
+ StringPiece input, std::vector<Match>* matches) const {
+ GatherPrefixMatches(
+ input, [matches](const Match match) { matches->push_back(match); });
+ return true;
+}
+
+bool SortedStringsTable::LongestPrefixMatch(StringPiece input,
+ Match* longest_match) const {
+ *longest_match = Match();
+ GatherPrefixMatches(
+ input, [longest_match](const Match match) { *longest_match = match; });
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/container/sorted-strings-table.h b/native/utils/container/sorted-strings-table.h
new file mode 100644
index 0000000..7f8a4a2
--- /dev/null
+++ b/native/utils/container/sorted-strings-table.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CONTAINER_SORTED_STRINGS_TABLE_H_
+#define LIBTEXTCLASSIFIER_UTILS_CONTAINER_SORTED_STRINGS_TABLE_H_
+
+#include <functional>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A matcher to find string pieces matching prefixes of an input string.
+// The list of reference strings are kept in sorted order in a zero separated
+// string.
+// binary search is used to find all prefix matches.
+// num_pieces: Number of sentence pieces.
+// offsets: Offsets into `pieces` where a string starts.
+// pieces: String pieces, concatenated in sorted order and zero byte separated.
+// use_linear_scan_threshold: Minimum size of binary search range before
+// switching to a linear sweep for prefix match testing.
+class SortedStringsTable : public StringSet {
+ public:
+ SortedStringsTable(const int num_pieces, const uint32* offsets,
+ StringPiece pieces,
+ const int use_linear_scan_threshold = 10)
+ : num_pieces_(num_pieces),
+ offsets_(offsets),
+ pieces_(pieces),
+ use_linear_scan_threshold_(use_linear_scan_threshold) {}
+
+ // Find matches that are prefixes of a string.
+ bool FindAllPrefixMatches(StringPiece input,
+ std::vector<Match>* matches) const override;
+ // Find the longest prefix match of a string.
+ bool LongestPrefixMatch(StringPiece input,
+ Match* longest_match) const override;
+
+ private:
+ void GatherPrefixMatches(StringPiece input,
+ const std::function<void(Match)>& update_fn) const;
+
+ const int num_pieces_;
+ const uint32* offsets_;
+ const StringPiece pieces_;
+ const int use_linear_scan_threshold_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CONTAINER_SORTED_STRINGS_TABLE_H_
diff --git a/native/utils/container/sorted-strings-table_test.cc b/native/utils/container/sorted-strings-table_test.cc
new file mode 100644
index 0000000..a93b197
--- /dev/null
+++ b/native/utils/container/sorted-strings-table_test.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/container/sorted-strings-table.h"
+
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(SortedStringsTest, Lookup) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+
+ SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
+ /*use_linear_scan_threshold=*/1);
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("hello there", &matches));
+ EXPECT_EQ(matches.size(), 2);
+ EXPECT_EQ(matches[0].id, 0 /*hell*/);
+ EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
+ EXPECT_EQ(matches[1].id, 1 /*hello*/);
+ EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("abcd", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches("hi there", &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(table.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ std::vector<StringSet::Match> matches;
+ EXPECT_TRUE(
+ table.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
+ EXPECT_THAT(matches, testing::IsEmpty());
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(table.LongestPrefixMatch("hella there", &match));
+ EXPECT_EQ(match.id, 0 /*hell*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(table.LongestPrefixMatch("hello there", &match));
+ EXPECT_EQ(match.id, 1 /*hello*/);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(table.LongestPrefixMatch("abcd", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ StringSet::Match match;
+ EXPECT_TRUE(table.LongestPrefixMatch("", &match));
+ EXPECT_EQ(match.id, -1);
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(table.Find("hell", &value));
+ EXPECT_EQ(value, 0);
+ }
+
+ {
+ int value;
+ EXPECT_FALSE(table.Find("hella", &value));
+ }
+
+ {
+ int value;
+ EXPECT_TRUE(table.Find("hello", &value));
+ EXPECT_EQ(value, 1 /*hello*/);
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/container/string-set.h b/native/utils/container/string-set.h
new file mode 100644
index 0000000..619b6bd
--- /dev/null
+++ b/native/utils/container/string-set.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_CONTAINER_STRING_SET_H_
+#define LIBTEXTCLASSIFIER_UTILS_CONTAINER_STRING_SET_H_
+
+#include <vector>
+
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+class StringSet {
+ public:
+ struct Match {
+ Match() {}
+ Match(int id, int match_length) : id(id), match_length(match_length) {}
+ int id = -1;
+ int match_length = -1;
+ };
+
+ virtual ~StringSet() {}
+
+ // Find matches that are prefixes of a string.
+ virtual bool FindAllPrefixMatches(StringPiece input,
+ std::vector<Match>* matches) const = 0;
+
+ // Find the longest prefix match of a string.
+ virtual bool LongestPrefixMatch(StringPiece input,
+ Match* longest_match) const = 0;
+
+ // Finds an exact string match.
+ virtual bool Find(StringPiece input, int* value) const {
+ Match match;
+ if (LongestPrefixMatch(input, &match) &&
+ match.match_length == input.length()) {
+ *value = match.id;
+ return true;
+ }
+ return false;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_CONTAINER_STRING_SET_H_
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
new file mode 100644
index 0000000..cf4c97f
--- /dev/null
+++ b/native/utils/flatbuffers.cc
@@ -0,0 +1,709 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/flatbuffers.h"
+
+#include <vector>
+
+#include "utils/strings/numbers.h"
+#include "utils/variant.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Gets the field information for a field name, returns nullptr if the
+// field was not defined.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name) {
+ TC3_CHECK(type != nullptr && type->fields() != nullptr);
+ return type->fields()->LookupByKey(field_name.data());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset) {
+ if (type->fields() == nullptr) {
+ return nullptr;
+ }
+ for (const reflection::Field* field : *type->fields()) {
+ if (field->offset() == field_offset) {
+ return field;
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset) {
+ // Lookup by name might be faster as the fields are sorted by name in the
+ // schema data, so try that first.
+ if (!field_name.empty()) {
+ return GetFieldOrNull(type, field_name.data());
+ }
+ return GetFieldOrNull(type, field_offset);
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ if (field->field_name() == nullptr) {
+ return GetFieldOrNull(type, field->field_offset());
+ }
+ return GetFieldOrNull(
+ type,
+ StringPiece(field->field_name()->data(), field->field_name()->size()),
+ field->field_offset());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ return GetFieldOrNull(type, field->field_name, field->field_offset);
+}
+
+bool Parse(const std::string& str_value, float* value) {
+ double double_value;
+ if (!ParseDouble(str_value.data(), &double_value)) {
+ return false;
+ }
+ *value = static_cast<float>(double_value);
+ return true;
+}
+
+bool Parse(const std::string& str_value, double* value) {
+ return ParseDouble(str_value.data(), value);
+}
+
+bool Parse(const std::string& str_value, int64* value) {
+ return ParseInt64(str_value.data(), value);
+}
+
+bool Parse(const std::string& str_value, int32* value) {
+ return ParseInt32(str_value.data(), value);
+}
+
+bool Parse(const std::string& str_value, std::string* value) {
+ *value = str_value;
+ return true;
+}
+
+template <typename T>
+bool ParseAndSetField(const reflection::Field* field,
+ const std::string& str_value,
+ ReflectiveFlatbuffer* buffer) {
+ T value;
+ if (!Parse(str_value, &value)) {
+ TC3_LOG(ERROR) << "Could not parse '" << str_value << "'";
+ return false;
+ }
+ if (field->type()->base_type() == reflection::Vector) {
+ buffer->Repeated(field)->Add(value);
+ return true;
+ } else {
+ return buffer->Set<T>(field, value);
+ }
+}
+
+} // namespace
+
+template <>
+const char* FlatbufferFileIdentifier<Model>() {
+ return ModelIdentifier();
+}
+
+std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
+ const {
+ if (!schema_->root_table()) {
+ TC3_LOG(ERROR) << "No root table specified.";
+ return nullptr;
+ }
+ return std::unique_ptr<ReflectiveFlatbuffer>(
+ new ReflectiveFlatbuffer(schema_, schema_->root_table()));
+}
+
+std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
+ StringPiece table_name) const {
+ for (const reflection::Object* object : *schema_->objects()) {
+ if (table_name.Equals(object->name()->str())) {
+ return std::unique_ptr<ReflectiveFlatbuffer>(
+ new ReflectiveFlatbuffer(schema_, object));
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+ const StringPiece field_name) const {
+ return libtextclassifier3::GetFieldOrNull(type_, field_name);
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+ const FlatbufferField* field) const {
+ return libtextclassifier3::GetFieldOrNull(type_, field);
+}
+
+bool ReflectiveFlatbuffer::GetFieldWithParent(
+ const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
+ reflection::Field const** field) {
+ const auto* path = field_path->field();
+ if (path == nullptr || path->size() == 0) {
+ return false;
+ }
+
+ for (int i = 0; i < path->size(); i++) {
+ *parent = (i == 0 ? this : (*parent)->Mutable(*field));
+ if (*parent == nullptr) {
+ return false;
+ }
+ *field = (*parent)->GetFieldOrNull(path->Get(i));
+ if (*field == nullptr) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
+ const int field_offset) const {
+ return libtextclassifier3::GetFieldOrNull(type_, field_offset);
+}
+
+bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
+ const std::string& value) {
+ switch (field->type()->base_type() == reflection::Vector
+ ? field->type()->element()
+ : field->type()->base_type()) {
+ case reflection::String:
+ return ParseAndSetField<std::string>(field, value, this);
+ case reflection::Int:
+ return ParseAndSetField<int32>(field, value, this);
+ case reflection::Long:
+ return ParseAndSetField<int64>(field, value, this);
+ case reflection::Float:
+ return ParseAndSetField<float>(field, value, this);
+ case reflection::Double:
+ return ParseAndSetField<double>(field, value, this);
+ default:
+ TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
+ return false;
+ }
+}
+
+bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
+ const std::string& value) {
+ ReflectiveFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->ParseAndSet(field, value);
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(StringPiece field_name) {
+ const reflection::Field* field = GetFieldOrNull(field_name);
+ if (field == nullptr) {
+ return nullptr;
+ }
+
+ if (field->type()->base_type() != reflection::BaseType::Vector) {
+ return nullptr;
+ }
+
+ return Add(field);
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(
+ const reflection::Field* field) {
+ if (field == nullptr) {
+ return nullptr;
+ }
+ return Repeated(field)->Add();
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
+ const StringPiece field_name) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return Mutable(field);
+ }
+ TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
+ return nullptr;
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
+ const reflection::Field* field) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field is not of type Object.";
+ return nullptr;
+ }
+ const auto entry = children_.find(field);
+ if (entry != children_.end()) {
+ return entry->second.get();
+ }
+ const auto it = children_.insert(
+ /*hint=*/entry,
+ std::make_pair(
+ field,
+ std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(field->type()->index())))));
+ return it->second.get();
+}
+
+RepeatedField* ReflectiveFlatbuffer::Repeated(StringPiece field_name) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ return Repeated(field);
+ }
+ TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
+ return nullptr;
+}
+
+RepeatedField* ReflectiveFlatbuffer::Repeated(const reflection::Field* field) {
+ if (field->type()->base_type() != reflection::Vector) {
+ TC3_LOG(ERROR) << "Field is not of type Vector.";
+ return nullptr;
+ }
+
+ // If the repeated field was already set, return its instance.
+ const auto entry = repeated_fields_.find(field);
+ if (entry != repeated_fields_.end()) {
+ return entry->second.get();
+ }
+
+ // Otherwise, create a new instance and store it.
+ std::unique_ptr<RepeatedField> repeated_field(
+ new RepeatedField(schema_, field));
+ const auto it = repeated_fields_.insert(
+ /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
+ return it->second.get();
+}
+
+flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ // Build all children before we can start with this table.
+ std::vector<
+ std::pair</* field vtable offset */ int,
+ /* field data offset in buffer */ flatbuffers::uoffset_t>>
+ offsets;
+ offsets.reserve(children_.size() + repeated_fields_.size());
+ for (const auto& it : children_) {
+ offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
+ }
+
+ // Create strings.
+ for (const auto& it : fields_) {
+ if (it.second.Has<std::string>()) {
+ offsets.push_back(
+ {it.first->offset(),
+ builder->CreateString(it.second.ConstRefValue<std::string>()).o});
+ }
+ }
+
+ // Build the repeated fields.
+ for (const auto& it : repeated_fields_) {
+ offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
+ }
+
+ // Build the table now.
+ const flatbuffers::uoffset_t table_start = builder->StartTable();
+
+ // Add scalar fields.
+ for (const auto& it : fields_) {
+ switch (it.second.GetType()) {
+ case Variant::TYPE_BOOL_VALUE:
+ builder->AddElement<uint8_t>(
+ it.first->offset(), static_cast<uint8_t>(it.second.Value<bool>()),
+ static_cast<uint8_t>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_INT8_VALUE:
+ builder->AddElement<int8_t>(
+ it.first->offset(), static_cast<int8_t>(it.second.Value<int8>()),
+ static_cast<int8_t>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_UINT8_VALUE:
+ builder->AddElement<uint8_t>(
+ it.first->offset(), static_cast<uint8_t>(it.second.Value<uint8>()),
+ static_cast<uint8_t>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_INT_VALUE:
+ builder->AddElement<int32>(
+ it.first->offset(), it.second.Value<int>(),
+ static_cast<int32>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_UINT_VALUE:
+ builder->AddElement<uint32>(
+ it.first->offset(), it.second.Value<uint>(),
+ static_cast<uint32>(it.first->default_integer()));
+ continue;
+ case Variant::TYPE_INT64_VALUE:
+ builder->AddElement<int64>(it.first->offset(), it.second.Value<int64>(),
+ it.first->default_integer());
+ continue;
+ case Variant::TYPE_UINT64_VALUE:
+ builder->AddElement<uint64>(it.first->offset(),
+ it.second.Value<uint64>(),
+ it.first->default_integer());
+ continue;
+ case Variant::TYPE_FLOAT_VALUE:
+ builder->AddElement<float>(
+ it.first->offset(), it.second.Value<float>(),
+ static_cast<float>(it.first->default_real()));
+ continue;
+ case Variant::TYPE_DOUBLE_VALUE:
+ builder->AddElement<double>(it.first->offset(),
+ it.second.Value<double>(),
+ it.first->default_real());
+ continue;
+ default:
+ continue;
+ }
+ }
+
+ // Add strings, subtables and repeated fields.
+ for (const auto& it : offsets) {
+ builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
+ }
+
+ return builder->EndTable(table_start);
+}
+
+std::string ReflectiveFlatbuffer::Serialize() const {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+template <>
+bool ReflectiveFlatbuffer::AppendFromVector<std::string>(
+ const flatbuffers::Table* from, const reflection::Field* field) {
+ auto* from_vector = from->GetPointer<
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ field->offset());
+ if (from_vector == nullptr) {
+ return false;
+ }
+
+ RepeatedField* to_repeated = Repeated(field);
+ for (const flatbuffers::String* element : *from_vector) {
+ to_repeated->Add(element->str());
+ }
+ return true;
+}
+
+template <>
+bool ReflectiveFlatbuffer::AppendFromVector<ReflectiveFlatbuffer>(
+ const flatbuffers::Table* from, const reflection::Field* field) {
+ auto* from_vector = from->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<const flatbuffers::Table>>*>(field->offset());
+ if (from_vector == nullptr) {
+ return false;
+ }
+
+ RepeatedField* to_repeated = Repeated(field);
+ for (const flatbuffers::Table* const from_element : *from_vector) {
+ ReflectiveFlatbuffer* to_element = to_repeated->Add();
+ if (to_element == nullptr) {
+ return false;
+ }
+ to_element->MergeFrom(from_element);
+ }
+ return true;
+}
+
+bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
+ // No fields to set.
+ if (type_->fields() == nullptr) {
+ return true;
+ }
+
+ for (const reflection::Field* field : *type_->fields()) {
+ // Skip fields that are not explicitly set.
+ if (!from->CheckField(field->offset())) {
+ continue;
+ }
+ const reflection::BaseType type = field->type()->base_type();
+ switch (type) {
+ case reflection::Bool:
+ Set<bool>(field, from->GetField<uint8_t>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Byte:
+ Set<int8_t>(field, from->GetField<int8_t>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::UByte:
+ Set<uint8_t>(field, from->GetField<uint8_t>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Int:
+ Set<int32>(field, from->GetField<int32>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::UInt:
+ Set<uint32>(field, from->GetField<uint32>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Long:
+ Set<int64>(field, from->GetField<int64>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::ULong:
+ Set<uint64>(field, from->GetField<uint64>(field->offset(),
+ field->default_integer()));
+ break;
+ case reflection::Float:
+ Set<float>(field, from->GetField<float>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::Double:
+ Set<double>(field, from->GetField<double>(field->offset(),
+ field->default_real()));
+ break;
+ case reflection::String:
+ Set<std::string>(
+ field, from->GetPointer<const flatbuffers::String*>(field->offset())
+ ->str());
+ break;
+ case reflection::Obj:
+ if (ReflectiveFlatbuffer* nested_field = Mutable(field);
+ nested_field == nullptr ||
+ !nested_field->MergeFrom(
+ from->GetPointer<const flatbuffers::Table* const>(
+ field->offset()))) {
+ return false;
+ }
+ break;
+ case reflection::Vector:
+ switch (field->type()->element()) {
+ case reflection::Int:
+ AppendFromVector<int32>(from, field);
+ break;
+ case reflection::UInt:
+ AppendFromVector<uint>(from, field);
+ break;
+ case reflection::Long:
+ AppendFromVector<int64>(from, field);
+ break;
+ case reflection::ULong:
+ AppendFromVector<uint64>(from, field);
+ break;
+ case reflection::Byte:
+ AppendFromVector<int8_t>(from, field);
+ break;
+ case reflection::UByte:
+ AppendFromVector<uint8_t>(from, field);
+ break;
+ case reflection::String:
+ AppendFromVector<std::string>(from, field);
+ break;
+ case reflection::Obj:
+ AppendFromVector<ReflectiveFlatbuffer>(from, field);
+ break;
+ case reflection::Double:
+ AppendFromVector<double>(from, field);
+ break;
+ case reflection::Float:
+ AppendFromVector<float>(from, field);
+ break;
+ default:
+ TC3_LOG(ERROR) << "Repeated unsupported type: "
+ << field->type()->element()
+ << " for field: " << field->name()->str();
+ return false;
+ break;
+ }
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << type
+ << " for field: " << field->name()->str();
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
+ return MergeFrom(flatbuffers::GetAnyRoot(
+ reinterpret_cast<const unsigned char*>(from.data())));
+}
+
+void ReflectiveFlatbuffer::AsFlatMap(
+ const std::string& key_separator, const std::string& key_prefix,
+ std::map<std::string, Variant>* result) const {
+ // Add direct fields.
+ for (const auto& it : fields_) {
+ (*result)[key_prefix + it.first->name()->str()] = it.second;
+ }
+
+ // Add nested messages.
+ for (const auto& it : children_) {
+ it.second->AsFlatMap(key_separator,
+ key_prefix + it.first->name()->str() + key_separator,
+ result);
+ }
+}
+
+std::string ReflectiveFlatbuffer::ToTextProto() const {
+ std::string result;
+ std::string current_field_separator;
+ // Add direct fields.
+ for (const auto& field_value_pair : fields_) {
+ const std::string field_name = field_value_pair.first->name()->str();
+ const Variant& value = field_value_pair.second;
+ std::string quotes;
+ if (value.GetType() == Variant::TYPE_STRING_VALUE) {
+ quotes = "'";
+ }
+ result.append(current_field_separator + field_name + ": " + quotes +
+ value.ToString() + quotes);
+ current_field_separator = ", ";
+ }
+
+ // Add nested messages.
+ for (const auto& field_flatbuffer_pair : children_) {
+ const std::string field_name = field_flatbuffer_pair.first->name()->str();
+ result.append(current_field_separator + field_name + " {" +
+ field_flatbuffer_pair.second->ToTextProto() + "}");
+ current_field_separator = ", ";
+ }
+
+ return result;
+}
+
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path) {
+ if (schema == nullptr || !schema->root_table()) {
+ TC3_LOG(ERROR) << "Empty schema provided.";
+ return false;
+ }
+
+ reflection::Object const* type = schema->root_table();
+ for (int i = 0; i < path->field.size(); i++) {
+ const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
+ return false;
+ }
+ path->field[i]->field_name.clear();
+ path->field[i]->field_offset = field->offset();
+
+ // Descend.
+ if (i < path->field.size() - 1) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field: " << field->name()->str()
+ << " is not of type `Object`.";
+ return false;
+ }
+ type = schema->objects()->Get(field->type()->index());
+ }
+ }
+ return true;
+}
+
+//
+// Repeated field methods.
+//
+
+ReflectiveFlatbuffer* RepeatedField::Add() {
+ if (is_primitive_) {
+ TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field.";
+ return nullptr;
+ }
+
+ object_items_.emplace_back(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(field_->type()->index())));
+ return object_items_.back().get();
+}
+
+namespace {
+
+template <typename T>
+flatbuffers::uoffset_t TypedSerialize(const std::vector<Variant>& values,
+ flatbuffers::FlatBufferBuilder* builder) {
+ std::vector<T> typed_values;
+ typed_values.reserve(values.size());
+ for (const Variant& item : values) {
+ typed_values.push_back(item.Value<T>());
+ }
+ return builder->CreateVector(typed_values).o;
+}
+
+} // namespace
+
+flatbuffers::uoffset_t RepeatedField::Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ switch (field_->type()->element()) {
+ case reflection::String:
+ return SerializeString(builder);
+ break;
+ case reflection::Obj:
+ return SerializeObject(builder);
+ break;
+ case reflection::Bool:
+ return TypedSerialize<bool>(items_, builder);
+ break;
+ case reflection::Byte:
+ return TypedSerialize<int8_t>(items_, builder);
+ break;
+ case reflection::UByte:
+ return TypedSerialize<uint8_t>(items_, builder);
+ break;
+ case reflection::Int:
+ return TypedSerialize<int>(items_, builder);
+ break;
+ case reflection::UInt:
+ return TypedSerialize<uint>(items_, builder);
+ break;
+ case reflection::Long:
+ return TypedSerialize<int64>(items_, builder);
+ break;
+ case reflection::ULong:
+ return TypedSerialize<uint64>(items_, builder);
+ break;
+ case reflection::Float:
+ return TypedSerialize<float>(items_, builder);
+ break;
+ case reflection::Double:
+ return TypedSerialize<double>(items_, builder);
+ break;
+ default:
+ TC3_LOG(FATAL) << "Unsupported type: " << field_->type()->element();
+ break;
+ }
+ TC3_LOG(FATAL) << "Invalid state.";
+ return 0;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeString(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = builder->CreateString(items_[i].ConstRefValue<std::string>());
+ }
+ return builder->CreateVector(offsets).o;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeObject(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ std::vector<flatbuffers::Offset<void>> offsets(object_items_.size());
+ for (int i = 0; i < object_items_.size(); i++) {
+ offsets[i] = object_items_[i]->Serialize(builder);
+ }
+ return builder->CreateVector(offsets).o;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.fbs b/native/utils/flatbuffers.fbs
new file mode 100755
index 0000000..155e8f8
--- /dev/null
+++ b/native/utils/flatbuffers.fbs
@@ -0,0 +1,32 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Specifies a field in a flatbuffer message.
+namespace libtextclassifier3;
+table FlatbufferField {
+ // Name of the field.
+ field_name:string (shared);
+
+ // Offset of the field
+ field_offset:int;
+}
+
+// Specifies a (nested) field in a flatbuffer message.
+namespace libtextclassifier3;
+table FlatbufferFieldPath {
+ field:[FlatbufferField];
+}
+
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
new file mode 100644
index 0000000..aaf248e
--- /dev/null
+++ b/native/utils/flatbuffers.h
@@ -0,0 +1,449 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
+#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "annotator/model_generated.h"
+#include "utils/base/logging.h"
+#include "utils/flatbuffers_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
+
+namespace libtextclassifier3 {
+
+class ReflectiveFlatBuffer;
+class RepeatedField;
+
+// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
+// integrity.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ flatbuffers::GetRoot<FlatbufferMessage>(buffer);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
+ size);
+ if (message->Verify(verifier)) {
+ return message;
+ } else {
+ return nullptr;
+ }
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
+// integrity and returns its mutable version.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
+ const FlatbufferMessage* message =
+ LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
+ if (message == nullptr) {
+ return nullptr;
+ }
+ return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
+ message->UnPack());
+}
+
+// Same as above but takes string.
+template <typename FlatbufferMessage>
+std::unique_ptr<typename FlatbufferMessage::NativeTableType>
+LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
+ return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
+ buffer.size());
+}
+
+template <typename FlatbufferMessage>
+const char* FlatbufferFileIdentifier() {
+ return nullptr;
+}
+
+template <>
+const char* FlatbufferFileIdentifier<Model>();
+
+// Packs the mutable flatbuffer message to string.
+template <typename FlatbufferMessage>
+std::string PackFlatbuffer(
+ const typename FlatbufferMessage::NativeTableType* mutable_message) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
+ FlatbufferFileIdentifier<FlatbufferMessage>());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+class ReflectiveFlatbuffer;
+
+// Checks whether a variant value type agrees with a field type.
+template <typename T>
+bool IsMatchingType(const reflection::BaseType type) {
+ switch (type) {
+ case reflection::Bool:
+ return std::is_same<T, bool>::value;
+ case reflection::Byte:
+ return std::is_same<T, int8>::value;
+ case reflection::UByte:
+ return std::is_same<T, uint8>::value;
+ case reflection::Int:
+ return std::is_same<T, int32>::value;
+ case reflection::UInt:
+ return std::is_same<T, uint32>::value;
+ case reflection::Long:
+ return std::is_same<T, int64>::value;
+ case reflection::ULong:
+ return std::is_same<T, uint64>::value;
+ case reflection::Float:
+ return std::is_same<T, float>::value;
+ case reflection::Double:
+ return std::is_same<T, double>::value;
+ case reflection::String:
+ return std::is_same<T, std::string>::value ||
+ std::is_same<T, StringPiece>::value ||
+ std::is_same<T, const char*>::value;
+ case reflection::Obj:
+ return std::is_same<T, ReflectiveFlatbuffer>::value;
+ default:
+ return false;
+ }
+}
+
+// A flatbuffer that can be built using flatbuffer reflection data of the
+// schema.
+// Normally, field information is hard-coded in code generated from a flatbuffer
+// schema. Here we lookup the necessary information for building a flatbuffer
+// from the provided reflection meta data.
+// When serializing a flatbuffer, the library requires that the sub messages
+// are already serialized, therefore we explicitly keep the field values and
+// serialize the message in (reverse) topological dependency order.
+class ReflectiveFlatbuffer {
+ public:
+ ReflectiveFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type)
+ : schema_(schema), type_(type) {}
+
+ // Gets the field information for a field name, returns nullptr if the
+ // field was not defined.
+ const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
+ const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
+ const reflection::Field* GetFieldOrNull(const int field_offset) const;
+
+ // Gets a nested field and the message it is defined on.
+ bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
+ ReflectiveFlatbuffer** parent,
+ reflection::Field const** field);
+
+ // Sets a field to a specific value.
+ // Returns true if successful, and false if the field was not found or the
+ // expected type doesn't match.
+ template <typename T>
+ bool Set(StringPiece field_name, T value);
+
+ // Sets a field to a specific value.
+ // Returns true if successful, and false if the expected type doesn't match.
+ // Expects `field` to be non-null.
+ template <typename T>
+ bool Set(const reflection::Field* field, T value);
+
+ // Sets a field to a specific value. Field is specified by path.
+ template <typename T>
+ bool Set(const FlatbufferFieldPath* path, T value);
+
+ // Sets sub-message field (if not set yet), and returns a pointer to it.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // table.
+ ReflectiveFlatbuffer* Mutable(StringPiece field_name);
+ ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
+
+ // Parses the value (according to the type) and sets a primitive field to the
+ // parsed value.
+ bool ParseAndSet(const reflection::Field* field, const std::string& value);
+ bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+
+ // Adds a primitive value to the repeated field.
+ template <typename T>
+ bool Add(StringPiece field_name, T value);
+
+ // Add a sub-message to the repeated field.
+ ReflectiveFlatbuffer* Add(StringPiece field_name);
+
+ template <typename T>
+ bool Add(const reflection::Field* field, T value);
+
+ ReflectiveFlatbuffer* Add(const reflection::Field* field);
+
+ // Gets the reflective flatbuffer for a repeated field.
+ // Returns nullptr if the field was not found, or the field type was not a
+ // vector.
+ RepeatedField* Repeated(StringPiece field_name);
+ RepeatedField* Repeated(const reflection::Field* field);
+
+ // Serializes the flatbuffer.
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ std::string Serialize() const;
+
+ // Merges the fields from the given flatbuffer table into this flatbuffer.
+ // Scalar fields will be overwritten, if present in `from`.
+ // Embedded messages will be merged.
+ bool MergeFrom(const flatbuffers::Table* from);
+ bool MergeFromSerializedFlatbuffer(StringPiece from);
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator`.
+ std::map<std::string, Variant> AsFlatMap(
+ const std::string& key_separator = ".") const {
+ std::map<std::string, Variant> result;
+ AsFlatMap(key_separator, /*key_prefix=*/"", &result);
+ return result;
+ }
+
+ // Converts the flatbuffer's content to a human-readable textproto
+ // representation.
+ std::string ToTextProto() const;
+
+ bool HasExplicitlySetFields() const {
+ return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
+ }
+
+ private:
+ // Helper function for merging given repeated field from given flatbuffer
+ // table. Appends the elements.
+ template <typename T>
+ bool AppendFromVector(const flatbuffers::Table* from,
+ const reflection::Field* field);
+
+ const reflection::Schema* const schema_;
+ const reflection::Object* const type_;
+
+ // Cached primitive fields (scalars and strings).
+ std::unordered_map<const reflection::Field*, Variant> fields_;
+
+ // Cached sub-messages.
+ std::unordered_map<const reflection::Field*,
+ std::unique_ptr<ReflectiveFlatbuffer>>
+ children_;
+
+ // Cached repeated fields.
+ std::unordered_map<const reflection::Field*, std::unique_ptr<RepeatedField>>
+ repeated_fields_;
+
+ // Flattens the flatbuffer as a flat map.
+ // (Nested) fields names are joined by `key_separator` and prefixed by
+ // `key_prefix`.
+ void AsFlatMap(const std::string& key_separator,
+ const std::string& key_prefix,
+ std::map<std::string, Variant>* result) const;
+};
+
+// A helper class to build flatbuffers based on schema reflection data.
+// Can be used to a `ReflectiveFlatbuffer` for the root message of the
+// schema, or any defined table via name.
+class ReflectiveFlatbufferBuilder {
+ public:
+ explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
+ : schema_(schema) {}
+
+ // Starts a new root table message.
+ std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
+
+ // Starts a new table message. Returns nullptr if no table with given name is
+ // found in the schema.
+ std::unique_ptr<ReflectiveFlatbuffer> NewTable(
+ const StringPiece table_name) const;
+
+ private:
+ const reflection::Schema* const schema_;
+};
+
+// Encapsulates a repeated field.
+// Serves as a common base class for repeated fields.
+class RepeatedField {
+ public:
+ RepeatedField(const reflection::Schema* const schema,
+ const reflection::Field* field)
+ : schema_(schema),
+ field_(field),
+ is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
+
+ template <typename T>
+ bool Add(const T value);
+
+ ReflectiveFlatbuffer* Add();
+
+ template <typename T>
+ T Get(int index) const {
+ return items_.at(index).Value<T>();
+ }
+
+ template <>
+ ReflectiveFlatbuffer* Get(int index) const {
+ if (is_primitive_) {
+ TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
+ "repeated field.";
+ return nullptr;
+ }
+ return object_items_.at(index).get();
+ }
+
+ int Size() const {
+ if (is_primitive_) {
+ return items_.size();
+ } else {
+ return object_items_.size();
+ }
+ }
+
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
+ private:
+ flatbuffers::uoffset_t SerializeString(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ flatbuffers::uoffset_t SerializeObject(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
+ const reflection::Schema* const schema_;
+ const reflection::Field* field_;
+ bool is_primitive_;
+
+ std::vector<Variant> items_;
+ std::vector<std::unique_ptr<ReflectiveFlatbuffer>> object_items_;
+};
+
+template <typename T>
+bool ReflectiveFlatbuffer::Set(StringPiece field_name, T value) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ if (field->type()->base_type() == reflection::BaseType::Vector ||
+ field->type()->base_type() == reflection::BaseType::Obj) {
+ TC3_LOG(ERROR)
+ << "Trying to set a primitive value on a non-scalar field.";
+ return false;
+ }
+ return Set<T>(field, value);
+ }
+ TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
+ return false;
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Expected non-null field.";
+ return false;
+ }
+ Variant variant_value(value);
+ if (!IsMatchingType<T>(field->type()->base_type())) {
+ TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+ << "`, expected: " << field->type()->base_type()
+ << ", got: " << variant_value.GetType();
+ return false;
+ }
+ fields_[field] = variant_value;
+ return true;
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
+ ReflectiveFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
+ }
+ return parent->Set<T>(field, value);
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Add(StringPiece field_name, T value) {
+ const reflection::Field* field = GetFieldOrNull(field_name);
+ if (field == nullptr) {
+ return false;
+ }
+
+ if (field->type()->base_type() != reflection::BaseType::Vector) {
+ return false;
+ }
+
+ return Add<T>(field, value);
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Add(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ return false;
+ }
+ Repeated(field)->Add(value);
+ return true;
+}
+
+template <typename T>
+bool RepeatedField::Add(const T value) {
+ if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
+ TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
+ return false;
+ }
+ items_.push_back(Variant{value});
+ return true;
+}
+
+// Resolves field lookups by name to the concrete field offsets.
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path);
+
+template <typename T>
+bool ReflectiveFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
+ const reflection::Field* field) {
+ const flatbuffers::Vector<T>* from_vector =
+ from->GetPointer<const flatbuffers::Vector<T>*>(field->offset());
+ if (from_vector == nullptr) {
+ return false;
+ }
+
+ RepeatedField* to_repeated = Repeated(field);
+ for (const T element : *from_vector) {
+ to_repeated->Add(element);
+ }
+ return true;
+}
+
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, flatbuffers::String* message) {
+ if (message != nullptr) {
+ stream.message.append(message->c_str(), message->size());
+ }
+ return stream;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/native/utils/grammar/callback-delegate.h b/native/utils/grammar/callback-delegate.h
new file mode 100644
index 0000000..a5424dd
--- /dev/null
+++ b/native/utils/grammar/callback-delegate.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/grammar/match.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+
+namespace libtextclassifier3::grammar {
+
+class Matcher;
+
+// CallbackDelegate is an interface and default implementation used by the
+// grammar matcher to dispatch rule matches.
+class CallbackDelegate {
+ public:
+ virtual ~CallbackDelegate() = default;
+
+ // This is called by the matcher whenever it finds a match for a rule to
+ // which a callback is attached.
+ virtual void MatchFound(const Match* match, const CallbackId callback_id,
+ const int64 callback_param, Matcher* matcher) {}
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_
diff --git a/native/utils/grammar/lexer.cc b/native/utils/grammar/lexer.cc
new file mode 100644
index 0000000..3a2d0d3
--- /dev/null
+++ b/native/utils/grammar/lexer.cc
@@ -0,0 +1,321 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/lexer.h"
+
+#include <unordered_map>
+
+#include "annotator/types.h"
+#include "utils/zlib/zlib.h"
+#include "utils/zlib/zlib_regex.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+inline bool CheckMemoryUsage(const Matcher* matcher) {
+ // The maximum memory usage for matching.
+ constexpr int kMaxMemoryUsage = 1 << 20;
+ return matcher->ArenaSize() <= kMaxMemoryUsage;
+}
+
+Match* CheckedAddMatch(const Nonterm nonterm,
+ const CodepointSpan codepoint_span,
+ const int match_offset, const int16 type,
+ Matcher* matcher) {
+ if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) {
+ return nullptr;
+ }
+ return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span,
+ match_offset, type);
+}
+
+void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span,
+ const int match_offset, int16 type, Matcher* matcher) {
+ if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) {
+ matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
+ nonterm, codepoint_span, match_offset, type));
+ }
+}
+
+int MapCodepointToTokenPaddingIfPresent(
+ const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
+ const int start) {
+ const auto it = token_alignment.find(start);
+ if (it != token_alignment.end()) {
+ return it->second;
+ }
+ return start;
+}
+
+} // namespace
+
+Lexer::Lexer(const UniLib* unilib, const RulesSet* rules)
+ : unilib_(*unilib),
+ rules_(rules),
+ regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {}
+
+std::vector<Lexer::RegexAnnotator> Lexer::BuildRegexAnnotator(
+ const UniLib& unilib, const RulesSet* rules) const {
+ std::vector<Lexer::RegexAnnotator> result;
+ if (rules->regex_annotator() != nullptr) {
+ std::unique_ptr<ZlibDecompressor> decompressor =
+ ZlibDecompressor::Instance();
+ result.reserve(rules->regex_annotator()->size());
+ for (const RulesSet_::RegexAnnotator* regex_annotator :
+ *rules->regex_annotator()) {
+ result.push_back(
+ {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
+ regex_annotator->compressed_pattern(),
+ rules->lazy_regex_compilation(),
+ decompressor.get()),
+ regex_annotator->nonterminal()});
+ }
+ }
+ return result;
+}
+
+void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
+ Matcher* matcher) const {
+ switch (symbol.type) {
+ case Symbol::Type::TYPE_MATCH: {
+ // Just emit the match.
+ matcher->AddMatch(symbol.match);
+ return;
+ }
+ case Symbol::Type::TYPE_DIGITS: {
+ // Emit <digits> if used by the rules.
+ CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span,
+ symbol.match_offset, Match::kDigitsType, matcher);
+
+ // Emit <n_digits> if used by the rules.
+ if (nonterms->n_digits_nt() != nullptr) {
+ const int num_digits =
+ symbol.codepoint_span.second - symbol.codepoint_span.first;
+ if (num_digits <= nonterms->n_digits_nt()->size()) {
+ CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1),
+ symbol.codepoint_span, symbol.match_offset,
+ Match::kDigitsType, matcher);
+ }
+ }
+ break;
+ }
+ case Symbol::Type::TYPE_TERM: {
+ // Emit <uppercase_token> if used by the rules.
+ if (nonterms->uppercase_token_nt() != 0 &&
+ unilib_.IsUpperText(
+ UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
+ CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span,
+ symbol.match_offset, Match::kTokenType, matcher);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ // Emit the token as terminal.
+ if (CheckMemoryUsage(matcher)) {
+ matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
+ symbol.lexeme);
+ }
+
+ // Emit <token> if used by rules.
+ CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset,
+ Match::kTokenType, matcher);
+}
+
+Lexer::Symbol::Type Lexer::GetSymbolType(
+ const UnicodeText::const_iterator& it) const {
+ if (unilib_.IsPunctuation(*it)) {
+ return Symbol::Type::TYPE_PUNCTUATION;
+ } else if (unilib_.IsDigit(*it)) {
+ return Symbol::Type::TYPE_DIGITS;
+ } else {
+ return Symbol::Type::TYPE_TERM;
+ }
+}
+
+void Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
+ const CodepointSpan codepoint_span,
+ std::vector<Lexer::Symbol>* symbols) const {
+ // Possibly split token.
+ UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
+ /*do_copy=*/false);
+ int last_end = prev_token_end;
+ auto token_end = token_unicode.end();
+ auto it = token_unicode.begin();
+ Symbol::Type type = GetSymbolType(it);
+ CodepointIndex sub_token_start = codepoint_span.first;
+ while (it != token_end) {
+ auto next = std::next(it);
+ int num_codepoints = 1;
+ Symbol::Type next_type;
+ while (next != token_end) {
+ next_type = GetSymbolType(next);
+ if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
+ break;
+ }
+ ++next;
+ ++num_codepoints;
+ }
+ symbols->push_back(Symbol{
+ type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
+ /*match_offset=*/last_end,
+ /*lexeme=*/
+ StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())});
+ last_end = sub_token_start + num_codepoints;
+ it = next;
+ type = next_type;
+ sub_token_start = last_end;
+ }
+}
+
+void Lexer::Process(const UnicodeText& text, const std::vector<Token>& tokens,
+ const std::vector<AnnotatedSpan>* annotations,
+ Matcher* matcher) const {
+ return Process(text, tokens.begin(), tokens.end(), annotations, matcher);
+}
+
+void Lexer::Process(const UnicodeText& text,
+ const std::vector<Token>::const_iterator& begin,
+ const std::vector<Token>::const_iterator& end,
+ const std::vector<AnnotatedSpan>* annotations,
+ Matcher* matcher) const {
+ if (begin == end) {
+ return;
+ }
+
+ const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals();
+
+ // Initialize processing of new text.
+ CodepointIndex prev_token_end = 0;
+ std::vector<Symbol> symbols;
+ matcher->Reset();
+
+ // The matcher expects the terminals and non-terminals it received to be in
+ // non-decreasing end-position order. The sorting above makes sure the
+ // pre-defined matches adhere to that order.
+ // Ideally, we would just have to emit a predefined match whenever we see that
+ // the next token we feed would be ending later.
+ // But as we implicitly ignore whitespace, we have to merge preceding
+ // whitespace to the match start so that tokens and non-terminals fed appear
+ // as next to each other without whitespace.
+ // We keep track of real token starts and precending whitespace in
+ // `token_match_start`, so that we can extend a predefined match's start to
+ // include the preceding whitespace.
+ std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
+
+ // Add start symbols.
+ if (Match* match =
+ CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+ if (Match* match =
+ CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0},
+ /*match_offset=*/0, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+
+ for (auto token_it = begin; token_it != end; token_it++) {
+ const Token& token = *token_it;
+
+ // Record match starts for token boundaries, so that we can snap pre-defined
+ // matches to it.
+ if (prev_token_end != token.start) {
+ token_match_start[token.start] = prev_token_end;
+ }
+
+ ProcessToken(token.value,
+ /*prev_token_end=*/prev_token_end,
+ CodepointSpan{token.start, token.end}, &symbols);
+ prev_token_end = token.end;
+
+ // Add word break symbol if used by the grammar.
+ if (Match* match = CheckedAddMatch(
+ nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end},
+ /*match_offset=*/token.end, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+ }
+
+ // Add end symbol if used by the grammar.
+ if (Match* match = CheckedAddMatch(
+ nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
+ /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+
+ // Add matches based on annotations.
+ auto annotation_nonterminals = nonterminals->annotation_nt();
+ if (annotation_nonterminals != nullptr && annotations != nullptr) {
+ for (const AnnotatedSpan& annotated_span : *annotations) {
+ const ClassificationResult& classification =
+ annotated_span.classification.front();
+ if (auto entry = annotation_nonterminals->LookupByKey(
+ classification.collection.c_str())) {
+ AnnotationMatch* match = matcher->AllocateAndInitMatch<AnnotationMatch>(
+ entry->value(), annotated_span.span,
+ /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(token_match_start,
+ annotated_span.span.first),
+ Match::kAnnotationMatch);
+ match->annotation = &classification;
+ symbols.push_back(Symbol(match));
+ }
+ }
+ }
+
+ // Add regex annotator matches for the range covered by the tokens.
+ for (const RegexAnnotator& regex_annotator : regex_annotators_) {
+ std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
+ regex_annotator.pattern->Matcher(UnicodeText::Substring(
+ text, begin->start, prev_token_end, /*do_copy=*/false));
+ int status = UniLib::RegexMatcher::kNoError;
+ while (regex_matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError) {
+ const CodepointSpan span = {
+ regex_matcher->Start(0, &status) + begin->start,
+ regex_matcher->End(0, &status) + begin->start};
+ if (Match* match =
+ CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/
+ MapCodepointToTokenPaddingIfPresent(
+ token_match_start, span.first),
+ Match::kUnknownType, matcher)) {
+ symbols.push_back(Symbol(match));
+ }
+ }
+ }
+
+ std::sort(symbols.begin(), symbols.end(),
+ [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
+
+ // Emit symbols to matcher.
+ for (const Symbol& symbol : symbols) {
+ Emit(symbol, nonterminals, matcher);
+ }
+
+ // Finish the matching.
+ matcher->Finish();
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/lexer.h b/native/utils/grammar/lexer.h
new file mode 100644
index 0000000..ca31c25
--- /dev/null
+++ b/native/utils/grammar/lexer.h
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This is a lexer that runs off the tokenizer and outputs the tokens to a
+// grammar matcher. The tokens it forwards are the same as the ones produced
+// by the tokenizer, but possibly further split and normalized (downcased).
+// Examples:
+//
+// - single character tokens for punctuation (e.g., AddTerminal("?"))
+//
+// - a string of letters (e.g., "Foo" -- it calls AddTerminal() on "foo")
+//
+// - a string of digits (e.g., AddTerminal("37"))
+//
+// In addition to the terminal tokens above, it also outputs certain
+// special nonterminals:
+//
+// - a <token> nonterminal, which it outputs in addition to the
+// regular AddTerminal() call for every token
+//
+// - a <digits> nonterminal, which it outputs in addition to
+// the regular AddTerminal() call for each string of digits
+//
+// - <N_digits> nonterminals, where N is the length of the string of
+// digits. By default the maximum N that will be output is 20. This
+// may be changed at compile time by kMaxNDigitsLength. For instance,
+// "123" will produce a <3_digits> nonterminal, "1234567" will produce
+// a <7_digits> nonterminal.
+//
+// It does not output any whitespace. Instead, whitespace gets absorbed into
+// the token that follows them in the text.
+// For example, if the text contains:
+//
+// ...hello there world...
+// | | |
+// offset=16 39 52
+//
+// then the output will be:
+//
+// "hello" [?, 16)
+// "there" [16, 44) <-- note "16" NOT "39"
+// "world" [44, ?) <-- note "44" NOT "52"
+//
+// This makes it appear to the Matcher as if the tokens are adjacent -- so
+// whitespace is simply ignored.
+//
+// A minor optimization: We don't bother to output nonterminals if the grammar
+// rules don't reference them.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
+
+#include "annotator/types.h"
+#include "utils/grammar/matcher.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+class Lexer {
+ public:
+ explicit Lexer(const UniLib* unilib, const RulesSet* rules);
+
+ // Processes a tokenized text. Classifies the tokens and feeds them to the
+ // matcher.
+ // The provided annotations will be fed to the matcher alongside the tokens.
+ // NOTE: The `annotations` need to outlive any dependent processing.
+ void Process(const UnicodeText& text, const std::vector<Token>& tokens,
+ const std::vector<AnnotatedSpan>* annotations,
+ Matcher* matcher) const;
+ void Process(const UnicodeText& text,
+ const std::vector<Token>::const_iterator& begin,
+ const std::vector<Token>::const_iterator& end,
+ const std::vector<AnnotatedSpan>* annotations,
+ Matcher* matcher) const;
+
+ private:
+ // A lexical symbol with an identified meaning that represents raw tokens,
+ // token categories or predefined text matches.
+ // It is the unit fed to the grammar matcher.
+ struct Symbol {
+ // The type of the lexical symbol.
+ enum class Type {
+ // A raw token.
+ TYPE_TERM,
+
+ // A symbol representing a string of digits.
+ TYPE_DIGITS,
+
+ // Punctuation characters.
+ TYPE_PUNCTUATION,
+
+ // A predefined match.
+ TYPE_MATCH
+ };
+
+ explicit Symbol() = default;
+
+ // Constructs a symbol of a given type with an anchor in the text.
+ Symbol(const Type type, const CodepointSpan codepoint_span,
+ const int match_offset, StringPiece lexeme)
+ : type(type),
+ codepoint_span(codepoint_span),
+ match_offset(match_offset),
+ lexeme(lexeme) {}
+
+ // Constructs a symbol from a pre-defined match.
+ explicit Symbol(Match* match)
+ : type(Type::TYPE_MATCH),
+ codepoint_span(match->codepoint_span),
+ match_offset(match->match_offset),
+ match(match) {}
+
+ // The type of the symbole.
+ Type type;
+
+ // The span in the text as codepoint offsets.
+ CodepointSpan codepoint_span;
+
+ // The match start offset (including preceding whitespace) as codepoint
+ // offset.
+ int match_offset;
+
+ // The symbol text value.
+ StringPiece lexeme;
+
+ // The predefined match.
+ Match* match;
+ };
+
+ // Processes a single token: the token is split and classified into symbols.
+ void ProcessToken(const StringPiece value, const int prev_token_end,
+ const CodepointSpan codepoint_span,
+ std::vector<Symbol>* symbols) const;
+
+ // Emits a token to the matcher.
+ void Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
+ Matcher* matcher) const;
+
+ // Gets the type of a character.
+ Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const;
+
+ private:
+ struct RegexAnnotator {
+ std::unique_ptr<UniLib::RegexPattern> pattern;
+ Nonterm nonterm;
+ };
+
+ // Uncompress and build the defined regex annotators.
+ std::vector<RegexAnnotator> BuildRegexAnnotator(const UniLib& unilib,
+ const RulesSet* rules) const;
+
+ const UniLib& unilib_;
+ const RulesSet* rules_;
+ std::vector<RegexAnnotator> regex_annotators_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_
diff --git a/native/utils/grammar/match.cc b/native/utils/grammar/match.cc
new file mode 100644
index 0000000..ecf9874
--- /dev/null
+++ b/native/utils/grammar/match.cc
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/match.h"
+
+#include <algorithm>
+#include <stack>
+
+namespace libtextclassifier3::grammar {
+
+void Traverse(const Match* root,
+ const std::function<bool(const Match*)>& node_fn) {
+ std::stack<const Match*> open;
+ open.push(root);
+
+ while (!open.empty()) {
+ const Match* node = open.top();
+ open.pop();
+ if (!node_fn(node) || node->IsLeaf()) {
+ continue;
+ }
+ open.push(node->rhs2);
+ if (node->rhs1 != nullptr) {
+ open.push(node->rhs1);
+ }
+ }
+}
+
+const Match* SelectFirst(const Match* root,
+ const std::function<bool(const Match*)>& pred_fn) {
+ std::stack<const Match*> open;
+ open.push(root);
+
+ while (!open.empty()) {
+ const Match* node = open.top();
+ open.pop();
+ if (pred_fn(node)) {
+ return node;
+ }
+ if (node->IsLeaf()) {
+ continue;
+ }
+ open.push(node->rhs2);
+ if (node->rhs1 != nullptr) {
+ open.push(node->rhs1);
+ }
+ }
+
+ return nullptr;
+}
+
+std::vector<const Match*> SelectAll(
+ const Match* root, const std::function<bool(const Match*)>& pred_fn) {
+ std::vector<const Match*> result;
+ Traverse(root, [&result, pred_fn](const Match* node) {
+ if (pred_fn(node)) {
+ result.push_back(node);
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/match.h b/native/utils/grammar/match.h
new file mode 100644
index 0000000..97edac9
--- /dev/null
+++ b/native/utils/grammar/match.h
@@ -0,0 +1,172 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
+
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+
+// Represents a single match that was found for a particular nonterminal.
+// Instances should be created by calling Matcher::AllocateMatch().
+// This uses an arena to allocate matches (and subclasses thereof).
+struct Match {
+ static constexpr int16 kUnknownType = 0;
+ static constexpr int16 kTokenType = -1;
+ static constexpr int16 kDigitsType = -2;
+ static constexpr int16 kBreakType = -3;
+ static constexpr int16 kAssertionMatch = -4;
+ static constexpr int16 kMappingMatch = -5;
+ static constexpr int16 kExclusionMatch = -6;
+ static constexpr int16 kAnnotationMatch = -7;
+
+ void Init(const Nonterm arg_lhs, const CodepointSpan arg_codepoint_span,
+ const int arg_match_offset, const int arg_type = kUnknownType) {
+ lhs = arg_lhs;
+ codepoint_span = arg_codepoint_span;
+ match_offset = arg_match_offset;
+ type = arg_type;
+ rhs1 = nullptr;
+ rhs2 = nullptr;
+ }
+
+ void Init(const Match& other) { *this = other; }
+
+ // For binary rule matches: rhs1 != NULL and rhs2 != NULL
+ // unary rule matches: rhs1 == NULL and rhs2 != NULL
+ // terminal rule matches: rhs1 != NULL and rhs2 == NULL
+ // custom leaves: rhs1 == NULL and rhs2 == NULL
+ bool IsInteriorNode() const { return rhs2 != nullptr; }
+ bool IsLeaf() const { return !rhs2; }
+
+ bool IsBinaryRule() const { return rhs1 && rhs2; }
+ bool IsUnaryRule() const { return !rhs1 && rhs2; }
+ bool IsTerminalRule() const { return rhs1 && !rhs2; }
+ bool HasLeadingWhitespace() const {
+ return codepoint_span.first != match_offset;
+ }
+
+ const Match* unary_rule_rhs() const { return rhs2; }
+
+ // Used in singly-linked queue of matches for processing.
+ Match* next = nullptr;
+
+ // Nonterminal we found a match for.
+ Nonterm lhs = kUnassignedNonterm;
+
+ // Type of the match.
+ int16 type = kUnknownType;
+
+ // The span in codepoints.
+ CodepointSpan codepoint_span;
+
+ // The begin codepoint offset used during matching.
+ // This is usually including any prefix whitespace.
+ int match_offset;
+
+ union {
+ // The first sub match for binary rules.
+ const Match* rhs1 = nullptr;
+
+ // The terminal, for terminal rules.
+ const char* terminal;
+ };
+ // First or second sub-match for interior nodes.
+ const Match* rhs2 = nullptr;
+};
+
+// Match type to keep track of associated values.
+struct MappingMatch : public Match {
+ // The associated id or value.
+ int64 id;
+};
+
+// Match type to keep track of assertions.
+struct AssertionMatch : public Match {
+ // If true, the assertion is negative and will be valid if the input doesn't
+ // match.
+ bool negative;
+};
+
+// Match type to define exclusions.
+struct ExclusionMatch : public Match {
+ // The nonterminal that denotes matches to exclude from a successful match.
+ // So the match is only valid if there is no match of `exclusion_nonterm`
+ // spanning the same text range.
+ Nonterm exclusion_nonterm;
+};
+
+// Match to represent an annotator annotated span in the grammar.
+struct AnnotationMatch : public Match {
+ const ClassificationResult* annotation;
+};
+
+// Utility functions for parse tree traversal.
+
+// Does a preorder traversal, calling `node_fn` on each node.
+// `node_fn` is expected to return whether to continue expanding a node.
+void Traverse(const Match* root,
+ const std::function<bool(const Match*)>& node_fn);
+
+// Does a preorder traversal, calling `pred_fn` and returns the first node
+// on which `pred_fn` returns true.
+const Match* SelectFirst(const Match* root,
+ const std::function<bool(const Match*)>& pred_fn);
+
+// Does a preorder traversal, selecting all nodes where `pred_fn` returns true.
+std::vector<const Match*> SelectAll(
+ const Match* root, const std::function<bool(const Match*)>& pred_fn);
+
+// Selects all terminals from a parse tree.
+inline std::vector<const Match*> SelectTerminals(const Match* root) {
+ return SelectAll(root, &Match::IsTerminalRule);
+}
+
+// Selects all leaves from a parse tree.
+inline std::vector<const Match*> SelectLeaves(const Match* root) {
+ return SelectAll(root, &Match::IsLeaf);
+}
+
+// Retrieves the first child node of a given type.
+template <typename T>
+const T* SelectFirstOfType(const Match* root, const int16 type) {
+ return static_cast<const T*>(SelectFirst(
+ root, [type](const Match* node) { return node->type == type; }));
+}
+
+// Retrieves all nodes of a given type.
+template <typename T>
+const std::vector<const T*> SelectAllOfType(const Match* root,
+ const int16 type) {
+ std::vector<const T*> result;
+ Traverse(root, [&result, type](const Match* node) {
+ if (node->type == type) {
+ result.push_back(static_cast<const T*>(node));
+ }
+ return true;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_
diff --git a/native/utils/grammar/matcher.cc b/native/utils/grammar/matcher.cc
new file mode 100644
index 0000000..a8ebba5
--- /dev/null
+++ b/native/utils/grammar/matcher.cc
@@ -0,0 +1,512 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/matcher.h"
+
+#include <iostream>
+#include <limits>
+
+#include "utils/base/endian.h"
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/grammar/types.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Iterator that just enumerates the bytes in a utf8 text.
+struct ByteIterator {
+ explicit ByteIterator(StringPiece text)
+ : data(text.data()), end(text.data() + text.size()) {}
+
+ inline char Next() {
+ TC3_DCHECK(HasNext());
+ const char c = data[0];
+ data++;
+ return c;
+ }
+ inline bool HasNext() const { return data < end; }
+
+ const char* data;
+ const char* end;
+};
+
+// Iterator that lowercases a utf8 string on the fly and enumerates the bytes.
+struct LowercasingByteIterator {
+ LowercasingByteIterator(const UniLib* unilib, StringPiece text)
+ : unilib(*unilib),
+ data(text.data()),
+ end(text.data() + text.size()),
+ buffer_pos(0),
+ buffer_size(0) {}
+
+ inline char Next() {
+ // Queue next character.
+ if (buffer_pos >= buffer_size) {
+ buffer_pos = 0;
+ // Lower-case the next character.
+ buffer_size =
+ ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
+ data += buffer_size;
+ }
+ TC3_DCHECK_LT(buffer_pos, buffer_size);
+ return buffer[buffer_pos++];
+ }
+
+ inline bool HasNext() const {
+ // Either we are not at the end of the data or didn't consume all bytes of
+ // the current character.
+ return (data < end || buffer_pos < buffer_size);
+ }
+
+ const UniLib& unilib;
+ const char* data;
+ const char* end;
+
+ // Each unicode codepoint can have up to 4 utf8 encoding bytes.
+ char buffer[4];
+ int buffer_pos;
+ int buffer_size;
+};
+
+// Searches a terminal match within a sorted table of terminals.
+// Using `LowercasingByteIterator` allows to lower-case the query string on the
+// fly.
+template <typename T>
+const char* FindTerminal(T input_iterator, const char* strings,
+ const uint32* offsets, const int num_terminals,
+ int* terminal_index) {
+ int left = 0;
+ int right = num_terminals;
+ int span_size = right - left;
+ int match_length = 0;
+
+ // Loop invariant:
+ // At the ith iteration, all strings in the range `left` ... `right` match the
+ // input on the first `match_length` characters.
+ while (true) {
+ const unsigned char c =
+ static_cast<const unsigned char>(input_iterator.Next());
+
+ // We find the possible range of strings in `left` ... `right` matching the
+ // `match_length` + 1 character with two binary searches:
+ // 1) `lower_bound` to find the start of the range of matching strings.
+ // 2) `upper_bound` to find the non-inclusive end of the range.
+ left =
+ (std::lower_bound(
+ offsets + left, offsets + right, c,
+ [strings, match_length](uint32 string_offset, uint32 c) -> bool {
+ return static_cast<unsigned char>(
+ strings[string_offset + match_length]) <
+ LittleEndian::ToHost32(c);
+ }) -
+ offsets);
+ right =
+ (std::upper_bound(
+ offsets + left, offsets + right, c,
+ [strings, match_length](uint32 c, uint32 string_offset) -> bool {
+ return LittleEndian::ToHost32(c) <
+ static_cast<unsigned char>(
+ strings[string_offset + match_length]);
+ }) -
+ offsets);
+ span_size = right - left;
+ if (span_size <= 0) {
+ return nullptr;
+ }
+ ++match_length;
+
+ // By the loop variant and due to the fact that the strings are sorted,
+ // a matching string will be at `left` now.
+ if (!input_iterator.HasNext()) {
+ const int string_offset = LittleEndian::ToHost32(offsets[left]);
+ if (strings[string_offset + match_length] == 0) {
+ *terminal_index = left;
+ return &strings[string_offset];
+ }
+ return nullptr;
+ }
+ }
+
+ // No match found.
+ return nullptr;
+}
+
+// Finds terminal matches in the terminal rules hash tables.
+// In case a match is found, `terminal` will be set to point into the
+// terminals string pool.
+template <typename T>
+const RulesSet_::LhsSet* FindTerminalMatches(
+ T input_iterator, const RulesSet* rules_set,
+ const RulesSet_::Rules_::TerminalRulesMap* terminal_rules,
+ StringPiece* terminal) {
+ const int terminal_size = terminal->size();
+ if (terminal_size < terminal_rules->min_terminal_length() ||
+ terminal_size > terminal_rules->max_terminal_length()) {
+ return nullptr;
+ }
+ int terminal_index;
+ if (const char* terminal_match = FindTerminal(
+ input_iterator, rules_set->terminals()->data(),
+ terminal_rules->terminal_offsets()->data(),
+ terminal_rules->terminal_offsets()->size(), &terminal_index)) {
+ *terminal = StringPiece(terminal_match, terminal->length());
+ return rules_set->lhs_set()->Get(
+ terminal_rules->lhs_set_index()->Get(terminal_index));
+ }
+ return nullptr;
+}
+
+// Finds unary rules matches.
+const RulesSet_::LhsSet* FindUnaryRulesMatches(const RulesSet* rules_set,
+ const RulesSet_::Rules* rules,
+ const Nonterm nonterminal) {
+ if (!rules->unary_rules()) {
+ return nullptr;
+ }
+ if (const RulesSet_::Rules_::UnaryRulesEntry* entry =
+ rules->unary_rules()->LookupByKey(nonterminal)) {
+ return rules_set->lhs_set()->Get(entry->value());
+ }
+ return nullptr;
+}
+
+// Finds binary rules matches.
+const RulesSet_::LhsSet* FindBinaryRulesMatches(
+ const RulesSet* rules_set, const RulesSet_::Rules* rules,
+ const TwoNonterms nonterminals) {
+ if (!rules->binary_rules()) {
+ return nullptr;
+ }
+
+ // Lookup in rules hash table.
+ const uint32 bucket_index =
+ BinaryRuleHasher()(nonterminals) % rules->binary_rules()->size();
+
+ // Get hash table bucket.
+ if (const RulesSet_::Rules_::BinaryRuleTableBucket* bucket =
+ rules->binary_rules()->Get(bucket_index)) {
+ if (bucket->rules() == nullptr) {
+ return nullptr;
+ }
+
+ // Check all entries in the chain.
+ for (const RulesSet_::Rules_::BinaryRule* rule : *bucket->rules()) {
+ if (rule->rhs_first() == nonterminals.first &&
+ rule->rhs_second() == nonterminals.second) {
+ return rules_set->lhs_set()->Get(rule->lhs_set_index());
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+inline void GetLhs(const RulesSet* rules_set, const int lhs_entry,
+ Nonterm* nonterminal, CallbackId* callback, uint64* param,
+ int8* max_whitespace_gap) {
+ if (lhs_entry > 0) {
+ // Direct encoding of the nonterminal.
+ *nonterminal = lhs_entry;
+ *callback = kNoCallback;
+ *param = 0;
+ *max_whitespace_gap = -1;
+ } else {
+ const RulesSet_::Lhs* lhs = rules_set->lhs()->Get(-lhs_entry);
+ *nonterminal = lhs->nonterminal();
+ *callback = lhs->callback_id();
+ *param = lhs->callback_param();
+ *max_whitespace_gap = lhs->max_whitespace_gap();
+ }
+}
+
+} // namespace
+
+void Matcher::Reset() {
+ state_ = STATE_DEFAULT;
+ arena_.Reset();
+ pending_items_ = nullptr;
+ pending_exclusion_items_ = nullptr;
+ std::fill(chart_.begin(), chart_.end(), nullptr);
+ last_end_ = std::numeric_limits<int>().lowest();
+}
+
+void Matcher::Finish() {
+ // Check any pending items.
+ ProcessPendingExclusionMatches();
+}
+
+void Matcher::QueueForProcessing(Match* item) {
+ // Push element to the front.
+ item->next = pending_items_;
+ pending_items_ = item;
+}
+
+void Matcher::QueueForPostCheck(ExclusionMatch* item) {
+ // Push element to the front.
+ item->next = pending_exclusion_items_;
+ pending_exclusion_items_ = item;
+}
+
+void Matcher::AddTerminal(const CodepointSpan codepoint_span,
+ const int match_offset, StringPiece terminal) {
+ TC3_CHECK_GE(codepoint_span.second, last_end_);
+
+ // Finish any pending post-checks.
+ if (codepoint_span.second > last_end_) {
+ ProcessPendingExclusionMatches();
+ }
+
+ last_end_ = codepoint_span.second;
+ for (const RulesSet_::Rules* shard : rules_shards_) {
+ // Try case-sensitive matches.
+ if (const RulesSet_::LhsSet* lhs_set =
+ FindTerminalMatches(ByteIterator(terminal), rules_,
+ shard->terminal_rules(), &terminal)) {
+ // `terminal` points now into the rules string pool, providing a
+ // stable reference.
+ ExecuteLhsSet(
+ codepoint_span, match_offset,
+ /*whitespace_gap=*/(codepoint_span.first - match_offset),
+ [terminal](Match* match) {
+ match->terminal = terminal.data();
+ match->rhs2 = nullptr;
+ },
+ lhs_set, delegate_);
+ }
+
+ // Try case-insensitive matches.
+ if (const RulesSet_::LhsSet* lhs_set = FindTerminalMatches(
+ LowercasingByteIterator(&unilib_, terminal), rules_,
+ shard->lowercase_terminal_rules(), &terminal)) {
+ // `terminal` points now into the rules string pool, providing a
+ // stable reference.
+ ExecuteLhsSet(
+ codepoint_span, match_offset,
+ /*whitespace_gap=*/(codepoint_span.first - match_offset),
+ [terminal](Match* match) {
+ match->terminal = terminal.data();
+ match->rhs2 = nullptr;
+ },
+ lhs_set, delegate_);
+ }
+ }
+ ProcessPendingSet();
+}
+
+void Matcher::AddMatch(Match* match) {
+ TC3_CHECK_GE(match->codepoint_span.second, last_end_);
+
+ // Finish any pending post-checks.
+ if (match->codepoint_span.second > last_end_) {
+ ProcessPendingExclusionMatches();
+ }
+
+ last_end_ = match->codepoint_span.second;
+ QueueForProcessing(match);
+ ProcessPendingSet();
+}
+
+void Matcher::ExecuteLhsSet(const CodepointSpan codepoint_span,
+ const int match_offset_bytes,
+ const int whitespace_gap,
+ const std::function<void(Match*)>& initializer,
+ const RulesSet_::LhsSet* lhs_set,
+ CallbackDelegate* delegate) {
+ TC3_CHECK(lhs_set);
+ Match* match = nullptr;
+ Nonterm prev_lhs = kUnassignedNonterm;
+ for (const int32 lhs_entry : *lhs_set->lhs()) {
+ Nonterm lhs;
+ CallbackId callback_id;
+ uint64 callback_param;
+ int8 max_whitespace_gap;
+ GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param,
+ &max_whitespace_gap);
+
+ // Check that the allowed whitespace gap limit is followed.
+ if (max_whitespace_gap >= 0 && whitespace_gap > max_whitespace_gap) {
+ continue;
+ }
+
+ // Handle default callbacks.
+ switch (static_cast<DefaultCallback>(callback_id)) {
+ case DefaultCallback::kSetType: {
+ Match* typed_match = AllocateAndInitMatch<Match>(lhs, codepoint_span,
+ match_offset_bytes);
+ initializer(typed_match);
+ typed_match->type = callback_param;
+ QueueForProcessing(typed_match);
+ continue;
+ }
+ case DefaultCallback::kAssertion: {
+ AssertionMatch* assertion_match = AllocateAndInitMatch<AssertionMatch>(
+ lhs, codepoint_span, match_offset_bytes);
+ initializer(assertion_match);
+ assertion_match->type = Match::kAssertionMatch;
+ assertion_match->negative = (callback_param != 0);
+ QueueForProcessing(assertion_match);
+ continue;
+ }
+ case DefaultCallback::kMapping: {
+ MappingMatch* mapping_match = AllocateAndInitMatch<MappingMatch>(
+ lhs, codepoint_span, match_offset_bytes);
+ initializer(mapping_match);
+ mapping_match->type = Match::kMappingMatch;
+ mapping_match->id = callback_param;
+ QueueForProcessing(mapping_match);
+ continue;
+ }
+ case DefaultCallback::kExclusion: {
+ // We can only check the exclusion once all matches up to this position
+ // have been processed. Schedule and post check later.
+ ExclusionMatch* exclusion_match = AllocateAndInitMatch<ExclusionMatch>(
+ lhs, codepoint_span, match_offset_bytes);
+ initializer(exclusion_match);
+ exclusion_match->exclusion_nonterm = callback_param;
+ QueueForPostCheck(exclusion_match);
+ continue;
+ }
+ default:
+ break;
+ }
+
+ if (callback_id != kNoCallback && rules_->callback() != nullptr) {
+ const RulesSet_::CallbackEntry* callback_info =
+ rules_->callback()->LookupByKey(callback_id);
+ if (callback_info && callback_info->value().is_filter()) {
+ // Filter callback.
+ Match candidate;
+ candidate.Init(lhs, codepoint_span, match_offset_bytes);
+ initializer(&candidate);
+ delegate->MatchFound(&candidate, callback_id, callback_param, this);
+ continue;
+ }
+ }
+
+ if (prev_lhs != lhs) {
+ prev_lhs = lhs;
+ match =
+ AllocateAndInitMatch<Match>(lhs, codepoint_span, match_offset_bytes);
+ initializer(match);
+ QueueForProcessing(match);
+ }
+
+ if (callback_id != kNoCallback) {
+ // This is an output callback.
+ delegate->MatchFound(match, callback_id, callback_param, this);
+ }
+ }
+}
+
+void Matcher::ProcessPendingSet() {
+ // Avoid recursion caused by:
+ // ProcessPendingSet --> callback --> AddMatch --> ProcessPendingSet --> ...
+ if (state_ == STATE_PROCESSING) {
+ return;
+ }
+ state_ = STATE_PROCESSING;
+ while (pending_items_) {
+ // Process.
+ Match* item = pending_items_;
+ pending_items_ = pending_items_->next;
+
+ // Add it to the chart.
+ item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask];
+ chart_[item->codepoint_span.second & kChartHashTableBitmask] = item;
+
+ // Check unary rules that trigger.
+ for (const RulesSet_::Rules* shard : rules_shards_) {
+ if (const RulesSet_::LhsSet* lhs_set =
+ FindUnaryRulesMatches(rules_, shard, item->lhs)) {
+ ExecuteLhsSet(
+ item->codepoint_span, item->match_offset,
+ /*whitespace_gap=*/
+ (item->codepoint_span.first - item->match_offset),
+ [item](Match* match) {
+ match->rhs1 = nullptr;
+ match->rhs2 = item;
+ },
+ lhs_set, delegate_);
+ }
+ }
+
+ // Check binary rules that trigger.
+ // Lookup by begin.
+ Match* prev = chart_[item->match_offset & kChartHashTableBitmask];
+ // The chain of items is in decreasing `end` order.
+ // Find the ones that have prev->end == item->begin.
+ while (prev != nullptr &&
+ (prev->codepoint_span.second > item->match_offset)) {
+ prev = prev->next;
+ }
+ for (;
+ prev != nullptr && (prev->codepoint_span.second == item->match_offset);
+ prev = prev->next) {
+ for (const RulesSet_::Rules* shard : rules_shards_) {
+ if (const RulesSet_::LhsSet* lhs_set =
+ FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) {
+ ExecuteLhsSet(
+ /*codepoint_span=*/
+ {prev->codepoint_span.first, item->codepoint_span.second},
+ prev->match_offset,
+ /*whitespace_gap=*/
+ (item->codepoint_span.first -
+ item->match_offset), // Whitespace gap is the gap
+ // between the two parts.
+ [prev, item](Match* match) {
+ match->rhs1 = prev;
+ match->rhs2 = item;
+ },
+ lhs_set, delegate_);
+ }
+ }
+ }
+ }
+ state_ = STATE_DEFAULT;
+}
+
+void Matcher::ProcessPendingExclusionMatches() {
+ while (pending_exclusion_items_) {
+ ExclusionMatch* item = pending_exclusion_items_;
+ pending_exclusion_items_ = static_cast<ExclusionMatch*>(item->next);
+
+ // Check that the exclusion condition is fulfilled.
+ if (!ContainsMatch(item->exclusion_nonterm, item->codepoint_span)) {
+ AddMatch(item);
+ }
+ }
+}
+
+bool Matcher::ContainsMatch(const Nonterm nonterm,
+ const CodepointSpan& span) const {
+ // Lookup by end.
+ Match* match = chart_[span.second & kChartHashTableBitmask];
+ // The chain of items is in decreasing `end` order.
+ while (match != nullptr && match->codepoint_span.second > span.second) {
+ match = match->next;
+ }
+ while (match != nullptr && match->codepoint_span.second == span.second) {
+ if (match->lhs == nonterm && match->codepoint_span.first == span.first) {
+ return true;
+ }
+ match = match->next;
+ }
+ return false;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/matcher.h b/native/utils/grammar/matcher.h
new file mode 100644
index 0000000..47bac43
--- /dev/null
+++ b/native/utils/grammar/matcher.h
@@ -0,0 +1,246 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// A token matcher based on context-free grammars.
+//
+// A lexer passes token to the matcher: literal terminal strings and token
+// types. It passes tokens to the matcher by calling AddTerminal() and
+// AddMatch() for literal terminals and token types, respectively.
+// The lexer passes each token along with the [begin, end) position range
+// in which it occurs. So for an input string "Groundhog February 2, 2007", the
+// lexer would tell the matcher that:
+//
+// "Groundhog" occurs at [0, 9)
+// <space> occurs at [9, 10)
+// "February" occurs at [10, 18)
+// <space> occurs at [18, 19)
+// <string_of_digits> occurs at [19, 20)
+// "," occurs at [20, 21)
+// <space> occurs at [21, 22)
+// <string_of_digits> occurs at [22, 26)
+//
+// The lexer passes tokens to the matcher by calling AddTerminal() and
+// AddMatch() for literal terminals and token types, respectively.
+//
+// Although it is unnecessary for this example grammar, a lexer can
+// output multiple tokens for the same input range. So our lexer could
+// additionally output:
+// "2" occurs at [19, 20) // a second token for [19, 20)
+// "2007" occurs at [22, 26)
+// <syllable> occurs at [0, 6) // overlaps with (Groundhog [0, 9))
+// <syllable> occurs at [6, 9)
+// The only constraint on the lexer's output is that it has to pass tokens
+// to the matcher in left-to-right order, strictly speaking, their "end"
+// positions must be nondecreasing. (This constraint allows a more
+// efficient matching algorithm.) The "begin" positions can be in any
+// order.
+//
+// There are two kinds of supported callbacks:
+// (1) OUTPUT: Callbacks are the only output mechanism a matcher has. For each
+// "top-level" rule in your grammar, like the rule for <date> above -- something
+// you're trying to find instances of -- you use a callback which the matcher
+// will invoke every time it finds an instance of <date>.
+// (2) FILTERS:
+// Callbacks allow you to put extra conditions on when a grammar rule
+// applies. In the example grammar, the rule
+//
+// <day> ::= <string_of_digits> // must be between 1 and 31
+//
+// should only apply for *some* <string_of_digits> tokens, not others.
+// By using a filter callback on this rule, you can tell the matcher that
+// an instance of the rule's RHS is only *sometimes* considered an
+// instance of its LHS. The filter callback will get invoked whenever
+// the matcher finds an instance of <string_of_digits>. The callback can
+// look at the digits and decide whether they represent a number between
+// 1 and 31. If so, the callback calls Matcher::AddMatch() to tell the
+// matcher there's a <day> there. If not, the callback simply exits
+// without calling AddMatch().
+//
+// Technically, a FILTER callback can make any number of calls to
+// AddMatch() or even AddTerminal(). But the expected usage is to just
+// make zero or one call to AddMatch(). OUTPUT callbacks are not expected
+// to call either of these -- output callbacks are invoked merely as a
+// side-effect, not in order to decide whether a rule applies or not.
+//
+// In the above example, you would probably use three callbacks. Filter
+// callbacks on the rules for <day> and <year> would check the numeric
+// value of the <string_of_digits>. An output callback on the rule for
+// <date> would simply increment the counter of dates found on the page.
+//
+// Note that callbacks are attached to rules, not to nonterminals. You
+// could have two alternative rules for <date> and use a different
+// callback for each one.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
+
+#include <array>
+#include <functional>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/arena.h"
+#include "utils/grammar/callback-delegate.h"
+#include "utils/grammar/match.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3::grammar {
+
+class Matcher {
+ public:
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ const std::vector<const RulesSet_::Rules*> rules_shards,
+ CallbackDelegate* delegate)
+ : state_(STATE_DEFAULT),
+ unilib_(*unilib),
+ arena_(kBlocksize),
+ rules_(rules),
+ rules_shards_(rules_shards),
+ delegate_(delegate) {
+ TC3_CHECK(rules_ != nullptr);
+ Reset();
+ }
+ explicit Matcher(const UniLib* unilib, const RulesSet* rules,
+ CallbackDelegate* delegate)
+ : Matcher(unilib, rules, {}, delegate) {
+ rules_shards_.reserve(rules->rules()->size());
+ rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(),
+ rules->rules()->end());
+ }
+
+ // Resets the matcher.
+ void Reset();
+
+ // Finish the matching.
+ void Finish();
+
+ // Tells the matcher that the given terminal was found occupying position
+ // range [begin, end) in the input.
+ // The matcher may invoke callback functions before returning, if this
+ // terminal triggers any new matches for rules in the grammar.
+ // Calls to AddTerminal() and AddMatch() must be in left-to-right order,
+ // that is, the sequence of `end` values must be non-decreasing.
+ void AddTerminal(const CodepointSpan codepoint_span, const int match_offset,
+ StringPiece terminal);
+ void AddTerminal(const CodepointIndex begin, const CodepointIndex end,
+ StringPiece terminal) {
+ AddTerminal(CodepointSpan{begin, end}, begin, terminal);
+ }
+
+ // Adds a nonterminal match to the chart.
+ // This can be invoked by the lexer if the lexer needs to add nonterminals to
+ // the chart.
+ void AddMatch(Match* match);
+
+ // Allocates memory from an area for a new match.
+ // The `size` parameter is there to allow subclassing of the match object
+ // with additional fields.
+ Match* AllocateMatch(const size_t size) {
+ return reinterpret_cast<Match*>(arena_.Alloc(size));
+ }
+
+ template <typename T>
+ T* AllocateMatch() {
+ return reinterpret_cast<T*>(arena_.Alloc(sizeof(T)));
+ }
+
+ template <typename T, typename... Args>
+ T* AllocateAndInitMatch(Args... args) {
+ T* match = AllocateMatch<T>();
+ match->Init(args...);
+ return match;
+ }
+
+ // Returns the current number of bytes allocated for all match objects.
+ size_t ArenaSize() const { return arena_.status().bytes_allocated(); }
+
+ private:
+ static constexpr int kBlocksize = 16 << 10;
+
+ // The state of the matcher.
+ enum State {
+ // The matcher is in the default state.
+ STATE_DEFAULT = 0,
+
+ // The matcher is currently processing queued match items.
+ STATE_PROCESSING = 1,
+ };
+ State state_;
+
+ // Process matches from lhs set.
+ void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset,
+ const int whitespace_gap,
+ const std::function<void(Match*)>& initializer,
+ const RulesSet_::LhsSet* lhs_set,
+ CallbackDelegate* delegate);
+
+ // Queues a newly created match item.
+ void QueueForProcessing(Match* item);
+
+ // Queues a match item for later post checking of the exclusion condition.
+ // For exclusions we need to check that the `item->excluded_nonterminal`
+ // doesn't match the same span. As we cannot know which matches have already
+ // been added, we queue the item for later post checking - once all matches
+ // up to `item->codepoint_span.second` have been added.
+ void QueueForPostCheck(ExclusionMatch* item);
+
+ // Adds pending items to the chart, possibly generating new matches as a
+ // result.
+ void ProcessPendingSet();
+
+ // Returns whether the chart contains a match for a given nonterminal.
+ bool ContainsMatch(const Nonterm nonterm, const CodepointSpan& span) const;
+
+ // Checks all pending exclusion matches that their exclusion condition is
+ // fulfilled.
+ void ProcessPendingExclusionMatches();
+
+ UniLib unilib_;
+
+ // Memory arena for match allocation.
+ UnsafeArena arena_;
+
+ // The end position of the most recent match or terminal, for sanity
+ // checking.
+ int last_end_;
+
+ // Rules.
+ const RulesSet* rules_;
+
+ // The set of items pending to be added to the chart as a singly-linked list.
+ Match* pending_items_;
+
+ // The set of items pending to be post-checked as a singly-linked list.
+ ExclusionMatch* pending_exclusion_items_;
+
+ // The chart data structure: a hashtable containing all matches, indexed by
+ // their end positions.
+ static constexpr int kChartHashTableNumBuckets = 1 << 8;
+ static constexpr int kChartHashTableBitmask = kChartHashTableNumBuckets - 1;
+ std::array<Match*, kChartHashTableNumBuckets> chart_;
+
+ // The active rule shards.
+ std::vector<const RulesSet_::Rules*> rules_shards_;
+
+ // The callback handler.
+ CallbackDelegate* delegate_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_
diff --git a/native/utils/grammar/rules-utils.cc b/native/utils/grammar/rules-utils.cc
new file mode 100644
index 0000000..56c928a
--- /dev/null
+++ b/native/utils/grammar/rules-utils.cc
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/rules-utils.h"
+
+namespace libtextclassifier3::grammar {
+
+std::vector<std::vector<Locale>> ParseRulesLocales(const RulesSet* rules) {
+ if (rules == nullptr || rules->rules() == nullptr) {
+ return {};
+ }
+ std::vector<std::vector<Locale>> locales(rules->rules()->size());
+ for (int i = 0; i < rules->rules()->size(); i++) {
+ const grammar::RulesSet_::Rules* rules_shard = rules->rules()->Get(i);
+ if (rules_shard->locale() == nullptr) {
+ continue;
+ }
+ for (const LanguageTag* tag : *rules_shard->locale()) {
+ locales[i].push_back(Locale::FromLanguageTag(tag));
+ }
+ }
+ return locales;
+}
+
+std::vector<const grammar::RulesSet_::Rules*> SelectLocaleMatchingShards(
+ const RulesSet* rules,
+ const std::vector<std::vector<Locale>>& shard_locales,
+ const std::vector<Locale>& locales) {
+ std::vector<const grammar::RulesSet_::Rules*> shards;
+ if (rules->rules() == nullptr) {
+ return shards;
+ }
+ for (int i = 0; i < shard_locales.size(); i++) {
+ if (shard_locales[i].empty() ||
+ Locale::IsAnyLocaleSupported(locales,
+ /*supported_locales=*/shard_locales[i],
+ /*default_value=*/false)) {
+ shards.push_back(rules->rules()->Get(i));
+ }
+ }
+ return shards;
+}
+
+std::vector<Derivation> DeduplicateDerivations(
+ const std::vector<Derivation>& derivations) {
+ std::vector<Derivation> sorted_candidates = derivations;
+ std::stable_sort(
+ sorted_candidates.begin(), sorted_candidates.end(),
+ [](const Derivation& a, const Derivation& b) {
+ // Sort by id.
+ if (a.rule_id != b.rule_id) {
+ return a.rule_id < b.rule_id;
+ }
+
+ // Sort by increasing start.
+ if (a.match->codepoint_span.first != b.match->codepoint_span.first) {
+ return a.match->codepoint_span.first < b.match->codepoint_span.first;
+ }
+
+ // Sort by decreasing end.
+ return a.match->codepoint_span.second > b.match->codepoint_span.second;
+ });
+
+ // Deduplicate by overlap.
+ std::vector<Derivation> result;
+ for (int i = 0; i < sorted_candidates.size(); i++) {
+ const Derivation& candidate = sorted_candidates[i];
+ bool eliminated = false;
+
+ // Due to the sorting above, the candidate can only be completely
+ // intersected by a match before it in the sorted order.
+ for (int j = i - 1; j >= 0; j--) {
+ if (sorted_candidates[j].rule_id != candidate.rule_id) {
+ break;
+ }
+ if (sorted_candidates[j].match->codepoint_span.first <=
+ candidate.match->codepoint_span.first &&
+ sorted_candidates[j].match->codepoint_span.second >=
+ candidate.match->codepoint_span.second) {
+ eliminated = true;
+ break;
+ }
+ }
+
+ if (!eliminated) {
+ result.push_back(candidate);
+ }
+ }
+ return result;
+}
+
+bool VerifyAssertions(const Match* match) {
+ bool result = true;
+ grammar::Traverse(match, [&result](const Match* node) {
+ if (node->type != Match::kAssertionMatch) {
+ // Only validation if all checks so far passed.
+ return result;
+ }
+
+ // Positive assertions are by definition fulfilled,
+ // fail if the assertion is negative.
+ if (static_cast<const AssertionMatch*>(node)->negative) {
+ result = false;
+ }
+ return result;
+ });
+ return result;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules-utils.h b/native/utils/grammar/rules-utils.h
new file mode 100644
index 0000000..e6ac541
--- /dev/null
+++ b/native/utils/grammar/rules-utils.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Auxiliary methods for using rules.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "utils/grammar/match.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/i18n/locale.h"
+
+namespace libtextclassifier3::grammar {
+
+// Parses the locales of each rules shard.
+std::vector<std::vector<Locale>> ParseRulesLocales(const RulesSet* rules);
+
+// Selects rules shards that match on any locale.
+std::vector<const grammar::RulesSet_::Rules*> SelectLocaleMatchingShards(
+ const RulesSet* rules,
+ const std::vector<std::vector<Locale>>& shard_locales,
+ const std::vector<Locale>& locales);
+
+// Deduplicates rule derivations by containing overlap.
+// The grammar system can output multiple candidates for optional parts.
+// For example if a rule has an optional suffix, we
+// will get two rule derivations when the suffix is present: one with and one
+// without the suffix. We therefore deduplicate by containing overlap, viz. from
+// two candidates we keep the longer one if it completely contains the shorter.
+struct Derivation {
+ const Match* match;
+ int64 rule_id;
+};
+std::vector<Derivation> DeduplicateDerivations(
+ const std::vector<Derivation>& derivations);
+
+// Checks that all assertions of a match tree are fulfilled.
+bool VerifyAssertions(const Match* match);
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_
diff --git a/native/utils/grammar/rules-utils_test.cc b/native/utils/grammar/rules-utils_test.cc
new file mode 100644
index 0000000..6391be1
--- /dev/null
+++ b/native/utils/grammar/rules-utils_test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/rules-utils.h"
+
+#include <vector>
+
+#include "utils/grammar/match.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using testing::ElementsAre;
+using testing::Value;
+
+// Create test match object.
+Match CreateMatch(const CodepointIndex begin, const CodepointIndex end) {
+ Match match;
+ match.Init(0, CodepointSpan{begin, end},
+ /*arg_match_offset=*/begin);
+ return match;
+}
+
+MATCHER_P(IsDerivation, candidate, "") {
+ return Value(arg.rule_id, candidate.rule_id) &&
+ Value(arg.match, candidate.match);
+}
+
+TEST(UtilsTest, DeduplicatesMatches) {
+ // Overlapping matches from the same rule.
+ Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
+ const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
+ {&matches[1], /*rule_id=*/0},
+ {&matches[2], /*rule_id=*/0}};
+
+ // Keep longest.
+ EXPECT_THAT(DeduplicateDerivations(candidates),
+ ElementsAre(IsDerivation(candidates[2])));
+}
+
+TEST(UtilsTest, DeduplicatesMatchesPerRule) {
+ // Overlapping matches from different rules.
+ Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(0, 2)};
+ const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
+ {&matches[1], /*rule_id=*/0},
+ {&matches[2], /*rule_id=*/0},
+ {&matches[0], /*rule_id=*/1}};
+
+ // Keep longest for rule 0, but also keep match from rule 1.
+ EXPECT_THAT(
+ DeduplicateDerivations(candidates),
+ ElementsAre(IsDerivation(candidates[2]), IsDerivation(candidates[3])));
+}
+
+TEST(UtilsTest, KeepNonoverlapping) {
+ // Non-overlapping matches.
+ Match matches[] = {CreateMatch(0, 1), CreateMatch(1, 2), CreateMatch(2, 3)};
+ const std::vector<Derivation> candidates = {{&matches[0], /*rule_id=*/0},
+ {&matches[1], /*rule_id=*/0},
+ {&matches[2], /*rule_id=*/0}};
+
+ // Keep all matches.
+ EXPECT_THAT(
+ DeduplicateDerivations(candidates),
+ ElementsAre(IsDerivation(candidates[0]), IsDerivation(candidates[1]),
+ IsDerivation(candidates[2])));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
new file mode 100755
index 0000000..8052c11
--- /dev/null
+++ b/native/utils/grammar/rules.fbs
@@ -0,0 +1,215 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/i18n/language-tag.fbs";
+include "utils/zlib/buffer.fbs";
+
+// The terminal rules map as sorted strings table.
+// The sorted terminal strings table is represented as offsets into the
+// global strings pool, this allows to save memory between localized
+// rules sets.
+namespace libtextclassifier3.grammar.RulesSet_.Rules_;
+table TerminalRulesMap {
+ // The offsets into the terminals pool.
+ terminal_offsets:[uint];
+
+ // The lhs set associated with a terminal rule.
+ // This is an offset into the (deduplicated) global `lhs_set` vector.
+ lhs_set_index:[uint];
+
+ // Bounds the lengths of the terminal strings for quick early lookup
+ // abort.
+ min_terminal_length:int;
+
+ max_terminal_length:int;
+}
+
+namespace libtextclassifier3.grammar.RulesSet_.Rules_;
+struct UnaryRulesEntry {
+ key:uint (key);
+ value:uint;
+}
+
+// One key, value pair entry in the binary rules hash map.
+// The key is a pair of nonterminals and the value the index of the lhs set.
+namespace libtextclassifier3.grammar.RulesSet_.Rules_;
+struct BinaryRule {
+ // The two rhs nonterminals.
+ rhs_first:uint;
+
+ rhs_second:uint;
+
+ // The lhs set associated with this binary rule.
+ // This is an offset into the (deduplicated) global `lhs_set` vector.
+ lhs_set_index:uint;
+}
+
+// One bucket in the binary rule hash map that contains all entries for a
+// given hash value.
+namespace libtextclassifier3.grammar.RulesSet_.Rules_;
+table BinaryRuleTableBucket {
+ rules:[BinaryRule];
+}
+
+namespace libtextclassifier3.grammar.RulesSet_;
+table Rules {
+ // The locale this rule set applies to.
+ locale:[LanguageTag];
+
+ terminal_rules:Rules_.TerminalRulesMap;
+ lowercase_terminal_rules:Rules_.TerminalRulesMap;
+
+ // The unary rules map.
+ // This is a map from a nonterminal to an lhs set index into the
+ // (deduplicated) global `lhs_set` vector.
+ unary_rules:[Rules_.UnaryRulesEntry];
+
+ // The binary rules (hash) map.
+ // This is a map from nonterminal pair to an lhs set index into the
+ // (deduplicated) global `lhs_set` vector.
+ binary_rules:[Rules_.BinaryRuleTableBucket];
+}
+
+// A set of lhs nonterminals associated with a rule match.
+// Most commonly, that is just the id of the lhs nonterminal of the rule that
+// is triggered, in this case `lhs` is set to the id of the nonterminal.
+// If a callback needs to be triggered, lhs is the (negated) index into the
+// `lhs` vector below that specifies additionally to the nonterminal, also the
+// callback and parameter to call.
+namespace libtextclassifier3.grammar.RulesSet_;
+table LhsSet {
+ lhs:[int];
+}
+
+namespace libtextclassifier3.grammar.RulesSet_;
+struct Lhs {
+ // The lhs nonterminal.
+ nonterminal:uint;
+
+ // The id of the callback to trigger.
+ callback_id:uint;
+
+ // A parameter to pass when invoking the callback.
+ callback_param:ulong;
+
+ // The maximum amount of whitespace allowed between the two parts.
+ // A value of -1 allows for unbounded whitespace.
+ max_whitespace_gap:byte;
+}
+
+namespace libtextclassifier3.grammar.RulesSet_.Nonterminals_;
+table AnnotationNtEntry {
+ key:string (key, shared);
+ value:int;
+}
+
+// Usage of pre-defined non-terminals that the lexer can generate if used by
+// the grammar.
+namespace libtextclassifier3.grammar.RulesSet_;
+table Nonterminals {
+ // Id of the nonterminal indicating the start of input.
+ start_nt:int;
+
+ // Id of the nonterminal indicating the end of input.
+ end_nt:int;
+
+ // Id of the nonterminal indicating a token.
+ token_nt:int;
+
+ // Id of the nonterminal indicating a string of digits.
+ digits_nt:int;
+
+ // `n_digits_nt[k]` is the id of the nonterminal indicating a string of
+ // `k` digits.
+ n_digits_nt:[int];
+
+ // Id of the nonterminal indicating a word or token boundary.
+ wordbreak_nt:int;
+
+ // Id of the nonterminal indicating an uppercase token.
+ uppercase_token_nt:int;
+
+ // Predefined nonterminals for annotations.
+ // Maps annotation/collection names to non-terminal ids.
+ annotation_nt:[Nonterminals_.AnnotationNtEntry];
+}
+
+// Callback information.
+namespace libtextclassifier3.grammar.RulesSet_;
+struct Callback {
+ // Whether the callback is a filter.
+ is_filter:bool;
+}
+
+namespace libtextclassifier3.grammar.RulesSet_;
+struct CallbackEntry {
+ key:uint (key);
+ value:Callback;
+}
+
+namespace libtextclassifier3.grammar.RulesSet_.DebugInformation_;
+table NonterminalNamesEntry {
+ key:int (key);
+ value:string (shared);
+}
+
+// Debug information for e.g. printing parse trees and show match
+// information.
+namespace libtextclassifier3.grammar.RulesSet_;
+table DebugInformation {
+ nonterminal_names:[DebugInformation_.NonterminalNamesEntry];
+}
+
+// Regex annotators.
+namespace libtextclassifier3.grammar.RulesSet_;
+table RegexAnnotator {
+ // The pattern to run.
+ pattern:string (shared);
+
+ compressed_pattern:CompressedBuffer;
+
+ // The nonterminal to trigger.
+ nonterminal:uint;
+}
+
+// Context free grammar rules representation.
+// Rules are represented in (mostly) Chomsky Normal Form, where all rules are
+// of the following form, either:
+// * <nonterm> ::= term
+// * <nonterm> ::= <nonterm>
+// * <nonterm> ::= <nonterm> <nonterm>
+// The `terminals`, `unary_rules` and `binary_rules` maps below represent
+// these sets of rules.
+namespace libtextclassifier3.grammar;
+table RulesSet {
+ rules:[RulesSet_.Rules];
+ lhs_set:[RulesSet_.LhsSet];
+ lhs:[RulesSet_.Lhs];
+
+ // Terminals string pool.
+ // The strings are zero-byte delimited and offset indexed by
+ // `terminal_offsets` in the terminals rules map.
+ terminals:string (shared);
+
+ nonterminals:RulesSet_.Nonterminals;
+ callback:[RulesSet_.CallbackEntry];
+ debug_information:RulesSet_.DebugInformation;
+ regex_annotator:[RulesSet_.RegexAnnotator];
+
+ // If true, will compile the regexes only on first use.
+ lazy_regex_compilation:bool;
+}
+
diff --git a/native/utils/grammar/types.h b/native/utils/grammar/types.h
new file mode 100644
index 0000000..a79532b
--- /dev/null
+++ b/native/utils/grammar/types.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Common definitions used in the grammar system.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TYPES_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TYPES_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3::grammar {
+
+// A nonterminal identifier.
+typedef uint32 Nonterm;
+
+// This special Nonterm value is never used as a real Nonterm, but used as
+// a standin of an unassigned or unspecified nonterminal.
+const Nonterm kUnassignedNonterm = 0;
+
+typedef int32 CallbackId; // `kNoCallback` is reserved for "no callback"
+enum class DefaultCallback : CallbackId {
+ kSetType = -1,
+ kAssertion = -2,
+ kMapping = -3,
+ kExclusion = -4,
+ kRootRule = 1,
+};
+
+// Special CallbackId indicating that there's no callback associated with a
+// rule.
+const int32 kNoCallback = 0;
+
+// A pair of nonterminals.
+using TwoNonterms = std::pair<Nonterm, Nonterm>;
+
+static uint32 hash_int32(uint32 a) {
+ a = (a ^ 61) ^ (a >> 16);
+ a = a + (a << 3);
+ a = a ^ (a >> 4);
+ a = a * 0x27d4eb2d;
+ a = a ^ (a >> 15);
+ return a;
+}
+
+struct BinaryRuleHasher {
+ inline uint64 operator()(const TwoNonterms& x) const {
+ // the hash_int32 maps a int to a random int, then treat two ints as a
+ // rational number, then use cantor pairing function to calculate the
+ // order of rational number.
+ uint32 t1 = hash_int32(x.first);
+ uint32 t2 = hash_int32(x.second);
+ uint64 t = t1 + t2;
+ t *= (t + 1);
+ t >>= 1;
+ return t + t1;
+ }
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TYPES_H_
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
new file mode 100644
index 0000000..ce074b8
--- /dev/null
+++ b/native/utils/grammar/utils/ir.cc
@@ -0,0 +1,490 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/ir.h"
+
+#include "utils/strings/append.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+constexpr size_t kMaxHashTableSize = 100;
+
+template <typename T>
+void SortForBinarySearchLookup(T* entries) {
+ std::sort(entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a->key < b->key; });
+}
+
+template <typename T>
+void SortStructsForBinarySearchLookup(T* entries) {
+ std::sort(entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a.key() < b.key(); });
+}
+
+bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
+ return (lhs.nonterminal == other.nonterminal() &&
+ lhs.callback.id == other.callback_id() &&
+ lhs.callback.param == other.callback_param() &&
+ lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
+}
+
+bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
+ const std::vector<RulesSet_::Lhs>& candidates) {
+ // Simple case: direct encoding of the nonterminal.
+ if (lhs_entry > 0) {
+ return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
+ lhs.preconditions.max_whitespace_gap == -1);
+ }
+
+ // Entry is index into callback lookup.
+ return IsSameLhs(lhs, candidates[-lhs_entry]);
+}
+
+bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
+ const RulesSet_::LhsSetT& candidate,
+ const std::vector<RulesSet_::Lhs>& candidates) {
+ if (lhs_set.size() != candidate.lhs.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < lhs_set.size(); i++) {
+ // Check that entries are the same.
+ if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
+ return false;
+ }
+ }
+
+ return false;
+}
+
+Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
+ Ir::LhsSet sorted_lhs = lhs_set;
+ std::sort(sorted_lhs.begin(), sorted_lhs.end(),
+ [](const Ir::Lhs& a, const Ir::Lhs& b) {
+ return std::tie(a.nonterminal, a.callback.id, a.callback.param,
+ a.preconditions.max_whitespace_gap) <
+ std::tie(b.nonterminal, b.callback.id, b.callback.param,
+ b.preconditions.max_whitespace_gap);
+ });
+ return lhs_set;
+}
+
+// Adds a new lhs match set to the output.
+// Reuses the same set, if it was previously observed.
+int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
+ Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
+ // Check whether we can reuse an entry.
+ const int output_size = rules_set->lhs_set.size();
+ for (int i = 0; i < output_size; i++) {
+ if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
+ return i;
+ }
+ }
+
+ // Add new entry.
+ rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
+ RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
+ for (const Ir::Lhs& lhs : lhs_set) {
+ // Simple case: No callback and no special requirements, we directly encode
+ // the nonterminal.
+ if (lhs.callback.id == kNoCallback &&
+ lhs.preconditions.max_whitespace_gap < 0) {
+ serialized_lhs_set->lhs.push_back(lhs.nonterminal);
+ } else {
+ // Check whether we can reuse a callback entry.
+ const int lhs_size = rules_set->lhs.size();
+ bool found_entry = false;
+ for (int i = 0; i < lhs_size; i++) {
+ if (IsSameLhs(lhs, rules_set->lhs[i])) {
+ found_entry = true;
+ serialized_lhs_set->lhs.push_back(-i);
+ break;
+ }
+ }
+
+ // We could reuse an existing entry.
+ if (found_entry) {
+ continue;
+ }
+
+ // Add a new one.
+ rules_set->lhs.push_back(
+ RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
+ lhs.preconditions.max_whitespace_gap));
+ serialized_lhs_set->lhs.push_back(-lhs_size);
+ }
+ }
+ return output_size;
+}
+
+// Serializes a unary rules table.
+void SerializeUnaryRulesShard(
+ const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
+ RulesSetT* rules_set, RulesSet_::RulesT* rules) {
+ for (const auto& it : unary_rules) {
+ rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
+ it.first, AddLhsSet(it.second, rules_set)));
+ }
+ SortStructsForBinarySearchLookup(&rules->unary_rules);
+}
+
+// // Serializes a binary rules table.
+void SerializeBinaryRulesShard(
+ const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
+ binary_rules,
+ RulesSetT* rules_set, RulesSet_::RulesT* rules) {
+ const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
+ for (int i = 0; i < num_buckets; i++) {
+ rules->binary_rules.emplace_back(
+ new RulesSet_::Rules_::BinaryRuleTableBucketT());
+ }
+
+ // Serialize the table.
+ BinaryRuleHasher hash;
+ for (const auto& it : binary_rules) {
+ const TwoNonterms key = it.first;
+ uint32 bucket_index = hash(key) % num_buckets;
+
+ // Add entry to bucket chain list.
+ rules->binary_rules[bucket_index]->rules.push_back(
+ RulesSet_::Rules_::BinaryRule(key.first, key.second,
+ AddLhsSet(it.second, rules_set)));
+ }
+}
+
+} // namespace
+
+Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
+ const int lhs_set_size = lhs_set->size();
+ Nonterm shareable_nonterm = lhs.nonterminal;
+ for (int i = 0; i < lhs_set_size; i++) {
+ Lhs* candidate = &lhs_set->at(i);
+
+ // Exact match, just reuse rule.
+ if (lhs == *candidate) {
+ return candidate->nonterminal;
+ }
+
+ // Cannot reuse unshareable ids.
+ if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
+ nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
+ continue;
+ }
+
+ // Cannot reuse id if the preconditions are different.
+ if (!(lhs.preconditions == candidate->preconditions)) {
+ continue;
+ }
+
+ // If either callback is a filter, we can't share as we must always run
+ // both filters.
+ if ((lhs.callback.id != kNoCallback &&
+ filters_.find(lhs.callback.id) != filters_.end()) ||
+ (candidate->callback.id != kNoCallback &&
+ filters_.find(candidate->callback.id) != filters_.end())) {
+ continue;
+ }
+
+ // If the nonterminal is already defined, it must match for sharing.
+ if (lhs.nonterminal != kUnassignedNonterm &&
+ lhs.nonterminal != candidate->nonterminal) {
+ continue;
+ }
+
+ // Check whether the callbacks match.
+ if (lhs.callback == candidate->callback) {
+ return candidate->nonterminal;
+ }
+
+ // We can reuse if one of the output callbacks is not used.
+ if (lhs.callback.id == kNoCallback) {
+ return candidate->nonterminal;
+ } else if (candidate->callback.id == kNoCallback) {
+ // Old entry has no output callback, which is redundant now.
+ candidate->callback = lhs.callback;
+ return candidate->nonterminal;
+ }
+
+ // We can share the nonterminal, but we need to
+ // add a new output callback. Defer this as we might find a shareable
+ // nonterminal first.
+ shareable_nonterm = candidate->nonterminal;
+ }
+
+ // We didn't find a redundant entry, so create a new one.
+ shareable_nonterm = DefineNonterminal(shareable_nonterm);
+ lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
+ return shareable_nonterm;
+}
+
+Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
+ const bool case_sensitive, const int shard) {
+ TC3_CHECK_LT(shard, shards_.size());
+ if (case_sensitive) {
+ return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
+ } else {
+ return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
+ }
+}
+
+Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
+ const int shard) {
+ // Add a new unary rule.
+ if (rhs.size() == 1) {
+ return Add(lhs, rhs.front(), shard);
+ }
+
+ // Add a chain of (rhs.size() - 1) binary rules.
+ Nonterm prev = rhs.front();
+ for (int i = 1; i < rhs.size() - 1; i++) {
+ prev = Add(kUnassignedNonterm, prev, rhs[i], shard);
+ }
+ return Add(lhs, prev, rhs.back(), shard);
+}
+
+Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
+ lhs = DefineNonterminal(lhs);
+ regex_rules_.emplace_back(regex_pattern, lhs);
+ return lhs;
+}
+
+void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
+ annotations_.emplace_back(annotation, lhs);
+}
+
+// Serializes the terminal rules table.
+void Ir::SerializeTerminalRules(
+ RulesSetT* rules_set,
+ std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
+ // Use common pool for all terminals.
+ struct TerminalEntry {
+ std::string terminal;
+ int set_index;
+ int index;
+ Ir::LhsSet lhs_set;
+ };
+ std::vector<TerminalEntry> terminal_rules;
+
+ // Merge all terminals into a common pool.
+ // We want to use one common pool, but still need to track which set they
+ // belong to.
+ std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
+ terminal_rules_sets;
+ std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
+ terminal_rules_sets.reserve(2 * shards_.size());
+ rules_maps.reserve(terminal_rules_sets.size());
+ for (int i = 0; i < shards_.size(); i++) {
+ terminal_rules_sets.push_back(&shards_[i].terminal_rules);
+ terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
+ rules_shards->at(i)->terminal_rules.reset(
+ new RulesSet_::Rules_::TerminalRulesMapT());
+ rules_shards->at(i)->lowercase_terminal_rules.reset(
+ new RulesSet_::Rules_::TerminalRulesMapT());
+ rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
+ rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
+ }
+ for (int i = 0; i < terminal_rules_sets.size(); i++) {
+ for (const auto& it : *terminal_rules_sets[i]) {
+ terminal_rules.push_back(
+ TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
+ }
+ }
+ std::sort(terminal_rules.begin(), terminal_rules.end(),
+ [](const TerminalEntry& a, const TerminalEntry& b) {
+ return a.terminal < b.terminal;
+ });
+
+ // Index the entries in sorted order.
+ std::vector<int> index(terminal_rules_sets.size(), 0);
+ for (int i = 0; i < terminal_rules.size(); i++) {
+ terminal_rules[i].index = index[terminal_rules[i].set_index]++;
+ }
+
+ // We store the terminal strings sorted into a buffer and keep offsets into
+ // that buffer. In this way, we don't need extra space for terminals that are
+ // suffixes of others.
+
+ // Find terminals that are a suffix of others, O(n^2) algorithm.
+ constexpr int kInvalidIndex = -1;
+ std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
+ for (int i = 0; i < terminal_rules.size(); i++) {
+ const StringPiece terminal(terminal_rules[i].terminal);
+
+ // Check whether the ith terminal is a suffix of another.
+ for (int j = 0; j < terminal_rules.size(); j++) {
+ if (i == j) {
+ continue;
+ }
+ if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
+ // If both terminals are the same keep the first.
+ // This avoids cyclic dependencies.
+ // This can happen if multiple shards use same terminals, such as
+ // punctuation.
+ if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
+ continue;
+ }
+ suffix[i] = j;
+ break;
+ }
+ }
+ }
+
+ rules_set->terminals = "";
+
+ for (int i = 0; i < terminal_rules_sets.size(); i++) {
+ rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
+ rules_maps[i]->max_terminal_length = 0;
+ rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
+ }
+
+ for (int i = 0; i < terminal_rules.size(); i++) {
+ const TerminalEntry& entry = terminal_rules[i];
+
+ // Update bounds.
+ rules_maps[entry.set_index]->min_terminal_length =
+ std::min(rules_maps[entry.set_index]->min_terminal_length,
+ static_cast<int>(entry.terminal.size()));
+ rules_maps[entry.set_index]->max_terminal_length =
+ std::max(rules_maps[entry.set_index]->max_terminal_length,
+ static_cast<int>(entry.terminal.size()));
+
+ // Only include terminals that are not suffixes of others.
+ if (suffix[i] != kInvalidIndex) {
+ continue;
+ }
+
+ rules_maps[entry.set_index]->terminal_offsets[entry.index] =
+ rules_set->terminals.length();
+ rules_set->terminals += entry.terminal + '\0';
+ }
+
+ // Store just an offset into the existing terminal data for the terminals
+ // that are suffixes of others.
+ for (int i = 0; i < terminal_rules.size(); i++) {
+ int canonical_index = i;
+ if (suffix[canonical_index] == kInvalidIndex) {
+ continue;
+ }
+
+ // Find the overlapping string that was included in the data.
+ while (suffix[canonical_index] != kInvalidIndex) {
+ canonical_index = suffix[canonical_index];
+ }
+
+ const TerminalEntry& entry = terminal_rules[i];
+ const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
+
+ // The offset is the offset of the overlapping string and the offset within
+ // that string.
+ rules_maps[entry.set_index]->terminal_offsets[entry.index] =
+ rules_maps[canonical_entry.set_index]
+ ->terminal_offsets[canonical_entry.index] +
+ (canonical_entry.terminal.length() - entry.terminal.length());
+ }
+
+ for (const TerminalEntry& entry : terminal_rules) {
+ rules_maps[entry.set_index]->lhs_set_index.push_back(
+ AddLhsSet(entry.lhs_set, rules_set));
+ }
+}
+
+void Ir::Serialize(const bool include_debug_information,
+ RulesSetT* output) const {
+ // Set callback information.
+ for (const CallbackId filter_callback_id : filters_) {
+ output->callback.push_back(RulesSet_::CallbackEntry(
+ filter_callback_id, RulesSet_::Callback(/*is_filter=*/true)));
+ }
+ SortStructsForBinarySearchLookup(&output->callback);
+
+ // Add information about predefined nonterminal classes.
+ output->nonterminals.reset(new RulesSet_::NonterminalsT);
+ output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
+ output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
+ output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
+ output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
+ output->nonterminals->uppercase_token_nt =
+ GetNonterminalForName(kUppercaseTokenNonterm);
+ output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
+ for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
+ if (const Nonterm n_digits_nt =
+ GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
+ output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
+ output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
+ }
+ }
+ for (const auto& [annotation, annotation_nt] : annotations_) {
+ output->nonterminals->annotation_nt.emplace_back(
+ new RulesSet_::Nonterminals_::AnnotationNtEntryT);
+ output->nonterminals->annotation_nt.back()->key = annotation;
+ output->nonterminals->annotation_nt.back()->value = annotation_nt;
+ }
+ SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
+
+ if (include_debug_information) {
+ output->debug_information.reset(new RulesSet_::DebugInformationT);
+ // Keep original non-terminal names.
+ for (const auto& it : nonterminal_names_) {
+ output->debug_information->nonterminal_names.emplace_back(
+ new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
+ output->debug_information->nonterminal_names.back()->key = it.first;
+ output->debug_information->nonterminal_names.back()->value = it.second;
+ }
+ SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
+ }
+
+ // Add regex rules.
+ std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
+ for (auto [pattern, lhs] : regex_rules_) {
+ output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
+ output->regex_annotator.back()->compressed_pattern.reset(
+ new CompressedBufferT);
+ compressor->Compress(
+ pattern, output->regex_annotator.back()->compressed_pattern.get());
+ output->regex_annotator.back()->nonterminal = lhs;
+ }
+
+ // Serialize the unary and binary rules.
+ for (const RulesShard& shard : shards_) {
+ output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
+ RulesSet_::RulesT* rules = output->rules.back().get();
+ // Serialize the unary rules.
+ SerializeUnaryRulesShard(shard.unary_rules, output, rules);
+
+ // Serialize the binary rules.
+ SerializeBinaryRulesShard(shard.binary_rules, output, rules);
+ }
+
+ // Serialize the terminal rules.
+ // We keep the rules separate by shard but merge the actual terminals into
+ // one shared string pool to most effectively exploit reuse.
+ SerializeTerminalRules(output, &output->rules);
+}
+
+std::string Ir::SerializeAsFlatbuffer(
+ const bool include_debug_information) const {
+ RulesSetT output;
+ Serialize(include_debug_information, &output);
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RulesSet::Pack(builder, &output));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/ir.h b/native/utils/grammar/utils/ir.h
new file mode 100644
index 0000000..b05b87f
--- /dev/null
+++ b/native/utils/grammar/utils/ir.h
@@ -0,0 +1,236 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+
+namespace libtextclassifier3::grammar {
+
+// Pre-defined nonterminal classes that the lexer can handle.
+constexpr const char* kStartNonterm = "<^>";
+constexpr const char* kEndNonterm = "<$>";
+constexpr const char* kWordBreakNonterm = "<\b>";
+constexpr const char* kTokenNonterm = "<token>";
+constexpr const char* kUppercaseTokenNonterm = "<uppercase_token>";
+constexpr const char* kDigitsNonterm = "<digits>";
+constexpr const char* kNDigitsNonterm = "<%d_digits>";
+constexpr const int kMaxNDigitsNontermLength = 20;
+
+// Low-level intermediate rules representation.
+// In this representation, nonterminals are specified simply as integers
+// (Nonterms), rather than strings which is more efficient.
+// Rule set optimizations are done on this representation.
+//
+// Rules are represented in (mostly) Chomsky Normal Form, where all rules are
+// of the following form, either:
+// * <nonterm> ::= term
+// * <nonterm> ::= <nonterm>
+// * <nonterm> ::= <nonterm> <nonterm>
+class Ir {
+ public:
+ // A rule callback as a callback id and parameter pair.
+ struct Callback {
+ bool operator==(const Callback& other) const {
+ return std::tie(id, param) == std::tie(other.id, other.param);
+ }
+
+ CallbackId id = kNoCallback;
+ int64 param = 0;
+ };
+
+ // Constraints for triggering a rule.
+ struct Preconditions {
+ bool operator==(const Preconditions& other) const {
+ return max_whitespace_gap == other.max_whitespace_gap;
+ }
+
+ // The maximum allowed whitespace between parts of the rule.
+ // The default of -1 allows for unbounded whitespace.
+ int8 max_whitespace_gap = -1;
+ };
+
+ struct Lhs {
+ bool operator==(const Lhs& other) const {
+ return std::tie(nonterminal, callback, preconditions) ==
+ std::tie(other.nonterminal, other.callback, other.preconditions);
+ }
+
+ Nonterm nonterminal = kUnassignedNonterm;
+ Callback callback;
+ Preconditions preconditions;
+ };
+ using LhsSet = std::vector<Lhs>;
+
+ // A rules shard.
+ struct RulesShard {
+ // Terminal rules.
+ std::unordered_map<std::string, LhsSet> terminal_rules;
+ std::unordered_map<std::string, LhsSet> lowercase_terminal_rules;
+
+ // Unary rules.
+ std::unordered_map<Nonterm, LhsSet> unary_rules;
+
+ // Binary rules.
+ std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules;
+ };
+
+ explicit Ir(const std::unordered_set<CallbackId>& filters = {},
+ const int num_shards = 1)
+ : num_nonterminals_(0), filters_(filters), shards_(num_shards) {}
+
+ // Adds a new non-terminal.
+ Nonterm AddNonterminal(const std::string& name = "") {
+ const Nonterm nonterminal = ++num_nonterminals_;
+ if (!name.empty()) {
+ // Record debug information.
+ nonterminal_names_[nonterminal] = name;
+ nonterminal_ids_[name] = nonterminal;
+ }
+ return nonterminal;
+ }
+
+ // Defines a nonterminal if not yet defined.
+ Nonterm DefineNonterminal(Nonterm nonterminal) {
+ return (nonterminal != kUnassignedNonterm) ? nonterminal : AddNonterminal();
+ }
+
+ // Defines a new non-terminal that cannot be shared internally.
+ Nonterm AddUnshareableNonterminal(const std::string& name = "") {
+ const Nonterm nonterminal = AddNonterminal(name);
+ nonshareable_.insert(nonterminal);
+ return nonterminal;
+ }
+
+ // Gets the non-terminal for a given name, if it was previously defined.
+ Nonterm GetNonterminalForName(const std::string& name) const {
+ const auto it = nonterminal_ids_.find(name);
+ if (it == nonterminal_ids_.end()) {
+ return kUnassignedNonterm;
+ }
+ return it->second;
+ }
+
+ // Adds a terminal rule <lhs> ::= terminal.
+ Nonterm Add(const Lhs& lhs, const std::string& terminal,
+ bool case_sensitive = false, int shard = 0);
+ Nonterm Add(const Nonterm lhs, const std::string& terminal,
+ bool case_sensitive = false, int shard = 0) {
+ return Add(Lhs{lhs}, terminal, case_sensitive, shard);
+ }
+
+ // Adds a unary rule <lhs> ::= <rhs>.
+ Nonterm Add(const Lhs& lhs, Nonterm rhs, int shard = 0) {
+ return AddRule(lhs, rhs, &shards_[shard].unary_rules);
+ }
+ Nonterm Add(Nonterm lhs, Nonterm rhs, int shard = 0) {
+ return Add(Lhs{lhs}, rhs, shard);
+ }
+
+ // Adds a binary rule <lhs> ::= <rhs_1> <rhs_2>.
+ Nonterm Add(const Lhs& lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
+ return AddRule(lhs, {rhs_1, rhs_2}, &shards_[shard].binary_rules);
+ }
+ Nonterm Add(Nonterm lhs, Nonterm rhs_1, Nonterm rhs_2, int shard = 0) {
+ return Add(Lhs{lhs}, rhs_1, rhs_2, shard);
+ }
+
+ // Adds a rule <lhs> ::= <rhs_1> <rhs_2> ... <rhs_k>
+ //
+ // If k > 2, we internally create a series of Nonterms representing prefixes
+ // of the full rhs.
+ // <temp_1> ::= <RHS_1> <RHS_2>
+ // <temp_2> ::= <temp_1> <RHS_3>
+ // ...
+ // <LHS> ::= <temp_(k-1)> <RHS_k>
+ Nonterm Add(const Lhs& lhs, const std::vector<Nonterm>& rhs, int shard = 0);
+ Nonterm Add(Nonterm lhs, const std::vector<Nonterm>& rhs, int shard = 0) {
+ return Add(Lhs{lhs}, rhs, shard);
+ }
+
+ // Adds a regex rule <lhs> ::= <regex_pattern>.
+ Nonterm AddRegex(Nonterm lhs, const std::string& regex_pattern);
+
+ // Adds a definition for a nonterminal provided by a text annotation.
+ void AddAnnotation(Nonterm lhs, const std::string& annotation);
+
+ // Serializes a rule set in the intermediate representation into the
+ // memory mappable inference format.
+ void Serialize(bool include_debug_information, RulesSetT* output) const;
+
+ std::string SerializeAsFlatbuffer(
+ bool include_debug_information = false) const;
+
+ const std::vector<RulesShard>& shards() const { return shards_; }
+
+ private:
+ template <typename R, typename H>
+ Nonterm AddRule(const Lhs& lhs, const R& rhs,
+ std::unordered_map<R, LhsSet, H>* rules) {
+ const auto it = rules->find(rhs);
+
+ // Rhs was not yet used.
+ if (it == rules->end()) {
+ const Nonterm nonterminal = DefineNonterminal(lhs.nonterminal);
+ rules->insert(it,
+ {rhs, {Lhs{nonterminal, lhs.callback, lhs.preconditions}}});
+ return nonterminal;
+ }
+
+ return AddToSet(lhs, &it->second);
+ }
+
+ // Adds a new callback to an lhs set, potentially sharing nonterminal ids and
+ // existing callbacks.
+ Nonterm AddToSet(const Lhs& lhs, LhsSet* lhs_set);
+
+ // Serializes the sharded terminal rules.
+ void SerializeTerminalRules(
+ RulesSetT* rules_set,
+ std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const;
+
+ // The defined non-terminals.
+ Nonterm num_nonterminals_;
+ std::unordered_set<Nonterm> nonshareable_;
+
+ // The set of callbacks that should be treated as filters.
+ std::unordered_set<CallbackId> filters_;
+
+ // The sharded rules.
+ std::vector<RulesShard> shards_;
+
+ // The regex rules.
+ std::vector<std::pair<std::string, Nonterm>> regex_rules_;
+
+ // Mapping from annotation name to nonterminal.
+ std::vector<std::pair<std::string, Nonterm>> annotations_;
+
+ // Debug information.
+ std::unordered_map<Nonterm, std::string> nonterminal_names_;
+ std::unordered_map<std::string, Nonterm> nonterminal_ids_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_IR_H_
diff --git a/native/utils/grammar/utils/ir_test.cc b/native/utils/grammar/utils/ir_test.cc
new file mode 100644
index 0000000..d2438dd
--- /dev/null
+++ b/native/utils/grammar/utils/ir_test.cc
@@ -0,0 +1,238 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/ir.h"
+
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::Eq;
+using ::testing::IsEmpty;
+using ::testing::Ne;
+using ::testing::SizeIs;
+
+TEST(IrTest, HandlesSharingWithTerminalRules) {
+ Ir ir;
+
+ // <t1> ::= the
+ const Nonterm t1 = ir.Add(kUnassignedNonterm, "the");
+
+ // <t2> ::= quick
+ const Nonterm t2 = ir.Add(kUnassignedNonterm, "quick");
+
+ // <t3> ::= quick -- should share with <t2>
+ const Nonterm t3 = ir.Add(kUnassignedNonterm, "quick");
+
+ // <t4> ::= quick -- specify unshareable <t4>
+ // <t4> ::= brown
+ const Nonterm t4_unshareable = ir.AddUnshareableNonterminal();
+ ir.Add(t4_unshareable, "quick");
+ ir.Add(t4_unshareable, "brown");
+
+ // <t5> ::= brown -- should not be shared with <t4>
+ const Nonterm t5 = ir.Add(kUnassignedNonterm, "brown");
+
+ // <t6> ::= brown -- specify unshareable <t6>
+ const Nonterm t6_unshareable = ir.AddUnshareableNonterminal();
+ ir.Add(t6_unshareable, "brown");
+
+ // <t7> ::= brown -- should share with <t5>
+ const Nonterm t7 = ir.Add(kUnassignedNonterm, "brown");
+
+ EXPECT_THAT(t1, Ne(kUnassignedNonterm));
+ EXPECT_THAT(t2, Ne(kUnassignedNonterm));
+ EXPECT_THAT(t1, Ne(t2));
+ EXPECT_THAT(t2, Eq(t3));
+ EXPECT_THAT(t4_unshareable, Ne(kUnassignedNonterm));
+ EXPECT_THAT(t4_unshareable, Ne(t3));
+ EXPECT_THAT(t4_unshareable, Ne(t5));
+ EXPECT_THAT(t6_unshareable, Ne(kUnassignedNonterm));
+ EXPECT_THAT(t6_unshareable, Ne(t4_unshareable));
+ EXPECT_THAT(t6_unshareable, Ne(t5));
+ EXPECT_THAT(t7, Eq(t5));
+}
+
+TEST(IrTest, HandlesSharingWithNonterminalRules) {
+ Ir ir;
+
+ // Setup a few terminal rules.
+ const std::vector<Nonterm> rhs = {
+ ir.Add(kUnassignedNonterm, "the"), ir.Add(kUnassignedNonterm, "quick"),
+ ir.Add(kUnassignedNonterm, "brown"), ir.Add(kUnassignedNonterm, "fox")};
+
+ // Check for proper sharing using nonterminal rules.
+ for (int rhs_length = 1; rhs_length <= rhs.size(); rhs_length++) {
+ std::vector<Nonterm> rhs_truncated = rhs;
+ rhs_truncated.resize(rhs_length);
+ const Nonterm nt_u = ir.AddUnshareableNonterminal();
+ ir.Add(nt_u, rhs_truncated);
+ const Nonterm nt_1 = ir.Add(kUnassignedNonterm, rhs_truncated);
+ const Nonterm nt_2 = ir.Add(kUnassignedNonterm, rhs_truncated);
+
+ EXPECT_THAT(nt_1, Eq(nt_2));
+ EXPECT_THAT(nt_1, Ne(nt_u));
+ }
+}
+
+TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) {
+ // Test sharing in the presence of callbacks.
+ constexpr CallbackId kOutput1 = 1;
+ constexpr CallbackId kOutput2 = 2;
+ constexpr CallbackId kFilter1 = 3;
+ constexpr CallbackId kFilter2 = 4;
+ Ir ir(/*filters=*/{kFilter1, kFilter2});
+
+ const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
+ const Nonterm x2 =
+ ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput1, 0}}, "hello");
+ const Nonterm x3 =
+ ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter1, 0}}, "hello");
+ const Nonterm x4 =
+ ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
+ const Nonterm x5 =
+ ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter2, 0}}, "hello");
+
+ // Duplicate entry.
+ const Nonterm x6 =
+ ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
+
+ EXPECT_THAT(x2, Eq(x1));
+ EXPECT_THAT(x3, Ne(x1));
+ EXPECT_THAT(x4, Eq(x1));
+ EXPECT_THAT(x5, Ne(x1));
+ EXPECT_THAT(x5, Ne(x3));
+ EXPECT_THAT(x6, Ne(x3));
+}
+
+TEST(IrTest, HandlesSharingWithCallbacksWithDifferentParameters) {
+ // Test sharing in the presence of callbacks.
+ constexpr CallbackId kOutput = 1;
+ constexpr CallbackId kFilter = 2;
+ Ir ir(/*filters=*/{kFilter});
+
+ const Nonterm x1 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 0}}, "world");
+ const Nonterm x2 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput, 1}}, "world");
+ const Nonterm x3 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 0}}, "world");
+ const Nonterm x4 = ir.Add(Ir::Lhs{kUnassignedNonterm, {kFilter, 1}}, "world");
+
+ EXPECT_THAT(x2, Eq(x1));
+ EXPECT_THAT(x3, Ne(x1));
+ EXPECT_THAT(x4, Ne(x1));
+ EXPECT_THAT(x4, Ne(x3));
+}
+
+TEST(IrTest, SerializesRulesToFlatbufferFormat) {
+ constexpr CallbackId kOutput = 1;
+ Ir ir;
+ const Nonterm verb = ir.AddUnshareableNonterminal();
+ ir.Add(verb, "buy");
+ ir.Add(Ir::Lhs{verb, {kOutput}}, "bring");
+ ir.Add(verb, "upbring");
+ ir.Add(verb, "remind");
+ const Nonterm set_reminder = ir.AddUnshareableNonterminal();
+ ir.Add(set_reminder,
+ std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
+ ir.Add(kUnassignedNonterm, "me"),
+ ir.Add(kUnassignedNonterm, "to"), verb});
+ const Nonterm action = ir.AddUnshareableNonterminal();
+ ir.Add(action, set_reminder);
+ RulesSetT rules;
+ ir.Serialize(/*include_debug_information=*/false, &rules);
+
+ EXPECT_THAT(rules.rules, SizeIs(1));
+
+ // Only one rule uses a callback, the rest will be encoded directly.
+ EXPECT_THAT(rules.lhs, SizeIs(1));
+ EXPECT_THAT(rules.lhs.front().callback_id(), kOutput);
+
+ // 6 distinct terminals: "buy", "upbring", "bring", "remind", "me" and "to".
+ EXPECT_THAT(rules.rules.front()->lowercase_terminal_rules->terminal_offsets,
+ SizeIs(6));
+ EXPECT_THAT(rules.rules.front()->terminal_rules->terminal_offsets, IsEmpty());
+
+ // As "bring" is a suffix of "upbring" it is expected to be suffix merged in
+ // the string pool
+ EXPECT_THAT(rules.terminals,
+ Eq(std::string("buy\0me\0remind\0to\0upbring\0", 25)));
+
+ EXPECT_THAT(rules.rules.front()->binary_rules, SizeIs(3));
+
+ // One unary rule: <action> ::= <set_reminder>
+ EXPECT_THAT(rules.rules.front()->unary_rules, SizeIs(1));
+}
+
+TEST(IrTest, HandlesRulesSharding) {
+ Ir ir(/*filters=*/{}, /*num_shards=*/2);
+ const Nonterm verb = ir.AddUnshareableNonterminal();
+ const Nonterm set_reminder = ir.AddUnshareableNonterminal();
+
+ // Shard 0: en
+ ir.Add(verb, "buy");
+ ir.Add(verb, "bring");
+ ir.Add(verb, "remind");
+ ir.Add(set_reminder,
+ std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
+ ir.Add(kUnassignedNonterm, "me"),
+ ir.Add(kUnassignedNonterm, "to"), verb});
+
+ // Shard 1: de
+ ir.Add(verb, "kaufen", /*case_sensitive=*/false, /*shard=*/1);
+ ir.Add(verb, "bringen", /*case_sensitive=*/false, /*shard=*/1);
+ ir.Add(verb, "erinnern", /*case_sensitive=*/false, /*shard=*/1);
+ ir.Add(set_reminder,
+ std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "erinnere",
+ /*case_sensitive=*/false, /*shard=*/1),
+ ir.Add(kUnassignedNonterm, "mich",
+ /*case_sensitive=*/false, /*shard=*/1),
+ ir.Add(kUnassignedNonterm, "zu",
+ /*case_sensitive=*/false, /*shard=*/1),
+ verb},
+ /*shard=*/1);
+
+ // Test that terminal strings are correctly merged into the shared
+ // string pool.
+ RulesSetT rules;
+ ir.Serialize(/*include_debug_information=*/false, &rules);
+
+ EXPECT_THAT(rules.rules, SizeIs(2));
+
+ // 5 distinct terminals: "buy", "bring", "remind", "me" and "to".
+ EXPECT_THAT(rules.rules[0]->lowercase_terminal_rules->terminal_offsets,
+ SizeIs(5));
+ EXPECT_THAT(rules.rules[0]->terminal_rules->terminal_offsets, IsEmpty());
+
+ // 6 distinct terminals: "kaufen", "bringen", "erinnern", "erinnere", "mich"
+ // and "zu".
+ EXPECT_THAT(rules.rules[1]->lowercase_terminal_rules->terminal_offsets,
+ SizeIs(6));
+ EXPECT_THAT(rules.rules[1]->terminal_rules->terminal_offsets, IsEmpty());
+
+ EXPECT_THAT(rules.terminals,
+ Eq(std::string("bring\0bringen\0buy\0erinnere\0erinnern\0kaufen\0"
+ "me\0mich\0remind\0to\0zu\0",
+ 64)));
+
+ EXPECT_THAT(rules.rules[0]->binary_rules, SizeIs(3));
+ EXPECT_THAT(rules.rules[1]->binary_rules, SizeIs(3));
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
new file mode 100644
index 0000000..d6e4b76
--- /dev/null
+++ b/native/utils/grammar/utils/rules.cc
@@ -0,0 +1,472 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/rules.h"
+
+#include <set>
+
+#include "utils/grammar/utils/ir.h"
+#include "utils/strings/append.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Returns whether a nonterminal is a pre-defined one.
+bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
+ if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
+ nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
+ nonterminal_name == kWordBreakNonterm) {
+ return true;
+ }
+ for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
+ if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
+// assigned.
+Nonterm GetAssignedIdForNonterminal(
+ const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
+ const auto it = assignment.find(nonterminal);
+ if (it == assignment.end()) {
+ return kUnassignedNonterm;
+ }
+ return it->second;
+}
+
+// Checks whether all the nonterminals in the rhs of a rule have already been
+// assigned Nonterm values.
+bool IsRhsAssigned(const Rules::Rule& rule,
+ const std::unordered_map<int, Nonterm>& nonterminals) {
+ for (const Rules::RhsElement& element : rule.rhs) {
+ // Terminals are always considered assigned, check only for non-terminals.
+ if (element.is_terminal) {
+ continue;
+ }
+ if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
+ kUnassignedNonterm) {
+ return false;
+ }
+ }
+
+ // Check that all parts of an exclusion are defined.
+ if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
+ if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
+ kUnassignedNonterm) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Lowers a single high-level rule down into the intermediate representation.
+void LowerRule(const int lhs_index, const Rules::Rule& rule,
+ std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
+ const CallbackId callback = rule.callback;
+ int64 callback_param = rule.callback_param;
+
+ // Resolve id of excluded nonterminal in exclusion rules.
+ if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
+ callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
+ TC3_CHECK_NE(callback_param, kUnassignedNonterm);
+ }
+
+ // Special case for terminal rules.
+ if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
+ (*nonterminals)[lhs_index] =
+ ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
+ /*callback=*/{callback, callback_param},
+ /*preconditions=*/{rule.max_whitespace_gap}},
+ rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
+ return;
+ }
+
+ // Nonterminal rules.
+ std::vector<Nonterm> rhs_nonterms;
+ for (const Rules::RhsElement& element : rule.rhs) {
+ if (element.is_terminal) {
+ rhs_nonterms.push_back(ir->Add(Ir::Lhs{kUnassignedNonterm},
+ element.terminal, rule.case_sensitive,
+ rule.shard));
+ } else {
+ Nonterm nonterminal_id =
+ GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
+ TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
+ rhs_nonterms.push_back(nonterminal_id);
+ }
+ }
+ (*nonterminals)[lhs_index] =
+ ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
+ /*callback=*/{callback, callback_param},
+ /*preconditions=*/{rule.max_whitespace_gap}},
+ rhs_nonterms, rule.shard);
+}
+// Check whether this component is a non-terminal.
+bool IsNonterminal(StringPiece rhs_component) {
+ return rhs_component[0] == '<' &&
+ rhs_component[rhs_component.size() - 1] == '>';
+}
+
+// Sanity check for common typos -- '<' or '>' in a terminal.
+void ValidateTerminal(StringPiece rhs_component) {
+ TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains a question mark.";
+}
+
+} // namespace
+
+int Rules::AddNonterminal(const std::string& nonterminal_name) {
+ std::string key = nonterminal_name;
+ auto alias_it = nonterminal_alias_.find(key);
+ if (alias_it != nonterminal_alias_.end()) {
+ key = alias_it->second;
+ }
+ auto it = nonterminal_names_.find(key);
+ if (it != nonterminal_names_.end()) {
+ return it->second;
+ }
+ const int index = nonterminals_.size();
+ nonterminals_.push_back(NontermInfo{key});
+ nonterminal_names_.insert(it, {key, index});
+ return index;
+}
+
+int Rules::AddNewNonterminal() {
+ const int index = nonterminals_.size();
+ nonterminals_.push_back(NontermInfo{});
+ return index;
+}
+
+void Rules::AddAlias(const std::string& nonterminal_name,
+ const std::string& alias) {
+ TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
+ .first->second,
+ nonterminal_name)
+ << "Cannot redefine alias: " << alias;
+}
+
+// Defines a nonterminal for an externally provided annotation.
+int Rules::AddAnnotation(const std::string& annotation_name) {
+ auto [it, inserted] =
+ annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
+ if (inserted) {
+ nonterminals_.push_back(NontermInfo{});
+ }
+ return it->second;
+}
+
+void Rules::BindAnnotation(const std::string& nonterminal_name,
+ const std::string& annotation_name) {
+ auto [_, inserted] = annotation_nonterminals_.insert(
+ {annotation_name, AddNonterminal(nonterminal_name)});
+ TC3_CHECK(inserted);
+}
+
+bool Rules::IsNonterminalOfName(const RhsElement& element,
+ const std::string& nonterminal) const {
+ if (element.is_terminal) {
+ return false;
+ }
+ return (nonterminals_[element.nonterminal].name == nonterminal);
+}
+
+// Note: For k optional components this creates 2^k rules, but it would be
+// possible to be smarter about this and only use 2k rules instead.
+// However that might be slower as it requires an extra rule firing at match
+// time for every omitted optional element.
+void Rules::ExpandOptionals(
+ const int lhs, const std::vector<RhsElement>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
+ std::vector<int>::const_iterator optional_element_indices,
+ std::vector<int>::const_iterator optional_element_indices_end,
+ std::vector<bool>* omit_these) {
+ if (optional_element_indices == optional_element_indices_end) {
+ // Nothing is optional, so just generate a rule.
+ Rule r;
+ for (uint32 i = 0; i < rhs.size(); i++) {
+ if (!omit_these->at(i)) {
+ r.rhs.push_back(rhs[i]);
+ }
+ }
+ r.callback = callback;
+ r.callback_param = callback_param;
+ r.max_whitespace_gap = max_whitespace_gap;
+ r.case_sensitive = case_sensitive;
+ r.shard = shard;
+ nonterminals_[lhs].rules.push_back(rules_.size());
+ rules_.push_back(r);
+ return;
+ }
+
+ const int next_optional_part = *optional_element_indices;
+ ++optional_element_indices;
+
+ // Recursive call 1: The optional part is omitted.
+ (*omit_these)[next_optional_part] = true;
+ ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
+ case_sensitive, shard, optional_element_indices,
+ optional_element_indices_end, omit_these);
+
+ // Recursive call 2: The optional part is required.
+ (*omit_these)[next_optional_part] = false;
+ ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
+ case_sensitive, shard, optional_element_indices,
+ optional_element_indices_end, omit_these);
+}
+
+std::vector<Rules::RhsElement> Rules::ResolveAnchors(
+ const std::vector<RhsElement>& rhs) const {
+ if (rhs.size() <= 2) {
+ return rhs;
+ }
+ auto begin = rhs.begin();
+ auto end = rhs.end();
+ if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
+ IsNonterminalOfName(rhs[1], kFiller)) {
+ // Skip start anchor and filler.
+ begin += 2;
+ }
+ if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
+ IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
+ // Skip filler and end anchor.
+ end -= 2;
+ }
+ return std::vector<Rules::RhsElement>(begin, end);
+}
+
+std::vector<Rules::RhsElement> Rules::ResolveFillers(
+ const std::vector<RhsElement>& rhs) {
+ std::vector<RhsElement> result;
+ for (int i = 0; i < rhs.size();) {
+ if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
+ rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
+ result.push_back(rhs[i]);
+ i++;
+ continue;
+ }
+
+ // We have the case:
+ // <a> <filler>
+ // rewrite as:
+ // <a_with_tokens> ::= <a>
+ // <a_with_tokens> ::= <a_with_tokens> <token>
+ const int with_tokens_nonterminal = AddNewNonterminal();
+ const RhsElement token(AddNonterminal(kTokenNonterm),
+ /*is_optional=*/false);
+ if (rhs[i + 1].is_optional) {
+ // <a_with_tokens> ::= <a>
+ Add(with_tokens_nonterminal, {rhs[i]});
+ } else {
+ // <a_with_tokens> ::= <a> <token>
+ Add(with_tokens_nonterminal, {rhs[i], token});
+ }
+ // <a_with_tokens> ::= <a_with_tokens> <token>
+ const RhsElement with_tokens(with_tokens_nonterminal,
+ /*is_optional=*/false);
+ Add(with_tokens_nonterminal, {with_tokens, token});
+ result.push_back(with_tokens);
+ i += 2;
+ }
+ return result;
+}
+
+std::vector<Rules::RhsElement> Rules::OptimizeRhs(
+ const std::vector<RhsElement>& rhs) {
+ return ResolveFillers(ResolveAnchors(rhs));
+}
+
+void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive,
+ const int shard) {
+ // Resolve anchors and fillers.
+ const std::vector optimized_rhs = OptimizeRhs(rhs);
+
+ std::vector<int> optional_element_indices;
+ TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
+ << "Rhs must contain at least one non-optional element.";
+ for (int i = 0; i < optimized_rhs.size(); i++) {
+ if (optimized_rhs[i].is_optional) {
+ optional_element_indices.push_back(i);
+ }
+ }
+ std::vector<bool> omit_these(optimized_rhs.size(), false);
+ ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
+ max_whitespace_gap, case_sensitive, shard,
+ optional_element_indices.begin(),
+ optional_element_indices.end(), &omit_these);
+}
+
+void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive,
+ const int shard) {
+ TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
+ TC3_CHECK(!IsPredefinedNonterminal(lhs));
+ std::vector<RhsElement> rhs_elements;
+ rhs_elements.reserve(rhs.size());
+ for (StringPiece rhs_component : rhs) {
+ // Check whether this component is optional.
+ bool is_optional = false;
+ if (rhs_component[rhs_component.size() - 1] == '?') {
+ rhs_component.RemoveSuffix(1);
+ is_optional = true;
+ }
+ // Check whether this component is a non-terminal.
+ if (IsNonterminal(rhs_component)) {
+ rhs_elements.push_back(
+ RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
+ } else {
+ // A terminal.
+ // Sanity check for common typos -- '<' or '>' in a terminal.
+ ValidateTerminal(rhs_component);
+ rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
+ }
+ }
+ Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
+ max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddWithExclusion(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const std::string& excluded_nonterminal,
+ const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kExclusion),
+ /*callback_param=*/AddNonterminal(excluded_nonterminal),
+ max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddAssertion(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const bool negative, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kAssertion),
+ /*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddValueMapping(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const int64 value, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
+ AddRegex(AddNonterminal(lhs), regex_pattern);
+}
+
+void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
+ nonterminals_[lhs].regex_rules.push_back(regex_rules_.size());
+ regex_rules_.push_back(regex_pattern);
+}
+
+Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
+ Ir rules(filters_, num_shards_);
+ std::unordered_map<int, Nonterm> nonterminal_ids;
+
+ // Pending rules to process.
+ std::set<std::pair<int, int>> scheduled_rules;
+
+ // Define all used predefined nonterminals.
+ for (const auto& it : nonterminal_names_) {
+ if (IsPredefinedNonterminal(it.first) ||
+ predefined_nonterminals.find(it.first) !=
+ predefined_nonterminals.end()) {
+ nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
+ }
+ }
+
+ // Assign (unmergeable) Nonterm values to any nonterminals that have
+ // multiple rules or that have a filter callback on some rule.
+ for (int i = 0; i < nonterminals_.size(); i++) {
+ const NontermInfo& nonterminal = nonterminals_[i];
+ bool unmergeable =
+ (nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
+ !nonterminal.regex_rules.empty());
+ for (const int rule_index : nonterminal.rules) {
+ const Rule& rule = rules_[rule_index];
+
+ // Schedule rule.
+ scheduled_rules.insert({i, rule_index});
+
+ if (rule.callback != kNoCallback &&
+ filters_.find(rule.callback) != filters_.end()) {
+ unmergeable = true;
+ }
+ }
+
+ if (unmergeable) {
+ // Define unique nonterminal id.
+ nonterminal_ids[i] = rules.AddUnshareableNonterminal(nonterminal.name);
+ } else {
+ nonterminal_ids[i] = rules.AddNonterminal(nonterminal.name);
+ }
+
+ // Define regex rules.
+ for (const int regex_rule : nonterminal.regex_rules) {
+ rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
+ }
+ }
+
+ // Define annotations.
+ for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
+ rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
+ }
+
+ // Now, keep adding eligible rules (rules whose rhs is completely assigned)
+ // until we can't make any more progress.
+ // Note: The following code is quadratic in the worst case.
+ // This seems fine as this will only run as part of the compilation of the
+ // grammar rules during model assembly.
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ for (auto nt_and_rule = scheduled_rules.begin();
+ nt_and_rule != scheduled_rules.end();) {
+ const Rule& rule = rules_[nt_and_rule->second];
+ if (IsRhsAssigned(rule, nonterminal_ids)) {
+ // Compile the rule.
+ LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
+ &rules);
+ scheduled_rules.erase(
+ nt_and_rule++); // Iterator is advanced before erase.
+ changed = true;
+ break;
+ } else {
+ nt_and_rule++;
+ }
+ }
+ }
+ TC3_CHECK(scheduled_rules.empty());
+ return rules;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
new file mode 100644
index 0000000..5a2cbc2
--- /dev/null
+++ b/native/utils/grammar/utils/rules.h
@@ -0,0 +1,225 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for pre-processing, creating and testing context free
+// grammars.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_RULES_H_
+#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_RULES_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "utils/grammar/types.h"
+#include "utils/grammar/utils/ir.h"
+
+namespace libtextclassifier3::grammar {
+
+// Special nonterminals.
+constexpr const char* kFiller = "<filler>";
+
+// All rules for a grammar will be collected in a rules object.
+//
+// Rules r;
+// CallbackId date_output_callback = 1;
+// CallbackId day_filter_callback = 2; r.DefineFilter(day_filter_callback);
+// CallbackId year_filter_callback = 3; r.DefineFilter(year_filter_callback);
+// r.Add("<date>", {"<monthname>", "<day>", <year>"},
+// date_output_callback);
+// r.Add("<monthname>", {"January"});
+// ...
+// r.Add("<monthname>", {"December"});
+// r.Add("<day>", {"<string_of_digits>"}, day_filter_callback);
+// r.Add("<year>", {"<string_of_digits>"}, year_filter_callback);
+//
+// The Add() method adds a rule with a given lhs, rhs, and (optionally)
+// callback. The rhs is just a list of terminals and nonterminals. Anything
+// surrounded in angle brackets is considered a nonterminal. A "?" can follow
+// any element of the RHS, like this:
+//
+// r.Add("<date>", {"<monthname>", "<day>?", ",?", "<year>"});
+//
+// This indicates that the <day> and "," parts of the rhs are optional.
+// (This is just notational shorthand for adding a bunch of rules.)
+//
+// Once you're done adding rules and callbacks to the Rules object,
+// call r.Finalize() on it. This lowers the rule set into an internal
+// representation.
+class Rules {
+ public:
+ explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {}
+
+ // Represents one item in a right-hand side, a single terminal or nonterminal.
+ struct RhsElement {
+ RhsElement() {}
+ explicit RhsElement(const std::string& terminal, const bool is_optional)
+ : is_terminal(true), terminal(terminal), is_optional(is_optional) {}
+ explicit RhsElement(const int nonterminal, const bool is_optional)
+ : is_terminal(false),
+ nonterminal(nonterminal),
+ is_optional(is_optional) {}
+ bool is_terminal;
+ std::string terminal;
+ int nonterminal;
+ bool is_optional;
+ };
+
+ // Represents the right-hand side, and possibly callback, of one rule.
+ struct Rule {
+ std::vector<RhsElement> rhs;
+ CallbackId callback = kNoCallback;
+ int64 callback_param = 0;
+ int8 max_whitespace_gap = -1;
+ bool case_sensitive = false;
+ int shard = 0;
+ };
+
+ struct NontermInfo {
+ // The name of the non-terminal, if defined.
+ std::string name;
+
+ // Whether the nonterminal is provided via an annotation.
+ bool from_annotation = false;
+
+ // Rules that have this non-terminal as the lhs.
+ std::vector<int> rules;
+
+ // Regex rules that have this non-terminal as the lhs.
+ std::vector<int> regex_rules;
+ };
+
+ // Adds a rule `lhs ::= rhs` with the given callback id and parameter.
+ // Note: Nonterminal names are in angle brackets and cannot contain
+ // whitespace. The `rhs` is a list of components, each of which is either:
+ // * A nonterminal name (in angle brackets)
+ // * A terminal
+ // optionally followed by a `?` which indicates that the component is
+ // optional. The `rhs` must contain at least one non-optional component.
+ void Add(const std::string& lhs, const std::vector<std::string>& rhs,
+ const CallbackId callback = kNoCallback,
+ const int64 callback_param = 0, int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
+
+ // Adds a rule `lhs ::= rhs` with the given callback id and parameter.
+ // The `rhs` must contain at least one non-optional component.
+ void Add(int lhs, const std::vector<RhsElement>& rhs,
+ CallbackId callback = kNoCallback, int64 callback_param = 0,
+ int8 max_whitespace_gap = -1, bool case_sensitive = false,
+ int shard = 0);
+
+ // Adds a rule `lhs ::= rhs` with exclusion.
+ // The rule only matches, if `excluded_nonterminal` doesn't match the same
+ // span.
+ void AddWithExclusion(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const std::string& excluded_nonterminal,
+ int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
+
+ // Adds an assertion callback.
+ void AddAssertion(const std::string& lhs, const std::vector<std::string>& rhs,
+ bool negative = true, int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
+
+ // Adds a mapping callback.
+ void AddValueMapping(const std::string& lhs,
+ const std::vector<std::string>& rhs, int64 value,
+ int8 max_whitespace_gap = -1,
+ bool case_sensitive = false, int shard = 0);
+
+ // Adds a regex rule.
+ void AddRegex(const std::string& lhs, const std::string& regex_pattern);
+ void AddRegex(int lhs, const std::string& regex_pattern);
+
+ // Creates a nonterminal with the given name, if one doesn't already exist.
+ int AddNonterminal(const std::string& nonterminal_name);
+
+ // Creates a new nonterminal.
+ int AddNewNonterminal();
+
+ // Defines a nonterminal for an externally provided annotation.
+ int AddAnnotation(const std::string& annotation_name);
+
+ // Defines a nonterminal for an externally provided annotation.
+ void BindAnnotation(const std::string& nonterminal_name,
+ const std::string& annotation_name);
+
+ // Adds an alias for a nonterminal. This is a separate name for the same
+ // nonterminal.
+ void AddAlias(const std::string& nonterminal_name, const std::string& alias);
+
+ // Defines a new filter id.
+ void DefineFilter(const CallbackId filter_id) { filters_.insert(filter_id); }
+
+ // Lowers the rule set into the intermediate representation.
+ // Treats nonterminals given by the argument `predefined_nonterminals` as
+ // defined externally. This allows to define rules that are dependent on
+ // non-terminals produced by e.g. existing text annotations and that will be
+ // fed to the matcher by the lexer.
+ Ir Finalize(const std::set<std::string>& predefined_nonterminals = {}) const;
+
+ private:
+ void ExpandOptionals(
+ int lhs, const std::vector<RhsElement>& rhs, CallbackId callback,
+ int64 callback_param, int8 max_whitespace_gap, bool case_sensitive,
+ int shard, std::vector<int>::const_iterator optional_element_indices,
+ std::vector<int>::const_iterator optional_element_indices_end,
+ std::vector<bool>* omit_these);
+
+ // Applies optimizations to the right hand side of a rule.
+ std::vector<RhsElement> OptimizeRhs(const std::vector<RhsElement>& rhs);
+
+ // Removes start and end anchors in case they are followed (respectively
+ // preceded) by unbounded filler.
+ std::vector<RhsElement> ResolveAnchors(
+ const std::vector<RhsElement>& rhs) const;
+
+ // Rewrites fillers in a rule.
+ // Fillers in a rule such as `lhs ::= <a> <filler> <b>` could be lowered as
+ // <tokens> ::= <token>
+ // <tokens> ::= <tokens> <token>
+ // This has the disadvantage that it will produce a match for each possible
+ // span in the text, which is quadratic in the number of tokens.
+ // It can be more efficiently written as:
+ // `lhs ::= <a_with_tokens> <b>` with
+ // `<a_with_tokens> ::= <a>`
+ // `<a_with_tokens> ::= <a_with_tokens> <token>`
+ // In this each occurrence of `<a>` can start a sequence of tokens.
+ std::vector<RhsElement> ResolveFillers(const std::vector<RhsElement>& rhs);
+
+ // Checks whether an element denotes a specific nonterminal.
+ bool IsNonterminalOfName(const RhsElement& element,
+ const std::string& nonterminal) const;
+
+ const int num_shards_;
+
+ // Non-terminal to id map.
+ std::unordered_map<std::string, int> nonterminal_names_;
+ std::vector<NontermInfo> nonterminals_;
+ std::unordered_map<std::string, std::string> nonterminal_alias_;
+ std::unordered_map<std::string, int> annotation_nonterminals_;
+
+ // Rules.
+ std::vector<Rule> rules_;
+ std::vector<std::string> regex_rules_;
+
+ // Ids of callbacks that should be treated as filters.
+ std::unordered_set<CallbackId> filters_;
+};
+
+} // namespace libtextclassifier3::grammar
+
+#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_UTILS_RULES_H_
diff --git a/native/utils/grammar/utils/rules_test.cc b/native/utils/grammar/utils/rules_test.cc
new file mode 100644
index 0000000..6761118
--- /dev/null
+++ b/native/utils/grammar/utils/rules_test.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/grammar/utils/rules.h"
+
+#include "utils/grammar/rules_generated.h"
+#include "utils/grammar/utils/ir.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+
+TEST(SerializeRulesTest, HandlesSimpleRuleSet) {
+ Rules rules;
+
+ rules.Add("<verb>", {"buy"});
+ rules.Add("<verb>", {"bring"});
+ rules.Add("<verb>", {"remind"});
+ rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
+ rules.Add("<action>", {"<reminder>"});
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_THAT(frozen_rules.lhs, IsEmpty());
+ EXPECT_EQ(frozen_rules.terminals,
+ std::string("bring\0buy\0me\0remind\0to\0", 23));
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(3));
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
+}
+
+TEST(SerializeRulesTest, HandlesRulesSetWithCallbacks) {
+ Rules rules;
+ const CallbackId output = 1;
+ const CallbackId filter = 2;
+ rules.DefineFilter(filter);
+
+ rules.Add("<verb>", {"buy"});
+ rules.Add("<verb>", {"bring"}, output, 0);
+ rules.Add("<verb>", {"remind"}, output, 0);
+ rules.Add("<reminder>", {"remind", "me", "to", "<verb>"});
+ rules.Add("<action>", {"<reminder>"}, filter, 0);
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals,
+ std::string("bring\0buy\0me\0remind\0to\0", 23));
+
+ // We have two identical output calls and one filter call in the rule set
+ // definition above.
+ EXPECT_THAT(frozen_rules.lhs, SizeIs(2));
+
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(3));
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, SizeIs(1));
+}
+
+TEST(SerializeRulesTest, HandlesRulesWithWhitespaceGapLimits) {
+ Rules rules;
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
+ /*max_whitespace_gap=*/0);
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("aa\0lx\0", 6));
+ EXPECT_THAT(frozen_rules.lhs, SizeIs(1));
+}
+
+TEST(SerializeRulesTest, HandlesCaseSensitiveTerminals) {
+ Rules rules;
+ rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/true);
+ rules.Add("<iata>", {"AA"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/true);
+ rules.Add("<iata>", {"dl"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false);
+ rules.Add("<flight>", {"<iata>", "<4_digits>"}, kNoCallback, 0,
+ /*max_whitespace_gap=*/0);
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("AA\0LX\0dl\0", 9));
+ EXPECT_THAT(frozen_rules.lhs, SizeIs(1));
+}
+
+TEST(SerializeRulesTest, HandlesMultipleShards) {
+ Rules rules(/*num_shards=*/2);
+ rules.Add("<iata>", {"LX"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/true, /*shard=*/0);
+ rules.Add("<iata>", {"aa"}, kNoCallback, 0, /*max_whitespace_gap=*/-1,
+ /*case_sensitive=*/false, /*shard=*/1);
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(2));
+ EXPECT_EQ(frozen_rules.terminals, std::string("LX\0aa\0", 6));
+}
+
+TEST(SerializeRulesTest, HandlesRegexRules) {
+ Rules rules;
+ rules.AddRegex("<code>", "[A-Z]+");
+ rules.AddRegex("<numbers>", "\\d+");
+ RulesSetT frozen_rules;
+ rules.Finalize().Serialize(/*include_debug_information=*/false,
+ &frozen_rules);
+ EXPECT_THAT(frozen_rules.regex_annotator, SizeIs(2));
+}
+
+TEST(SerializeRulesTest, HandlesAlias) {
+ Rules rules;
+ rules.Add("<iata>", {"lx"});
+ rules.Add("<iata>", {"aa"});
+ rules.Add("<flight>", {"<iata>", "<4_digits>"});
+ rules.AddAlias("<flight_number>", "<flight>");
+
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("aa\0lx\0", 6));
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(1));
+
+ // Only alias, no rule.
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, IsEmpty());
+
+ EXPECT_THAT(frozen_rules.lhs, IsEmpty());
+}
+
+TEST(SerializeRulesTest, ResolvesAnchorsAndFillers) {
+ Rules rules;
+ rules.Add("<code>",
+ {"<^>", "<filler>", "this", "is", "a", "test", "<filler>", "<$>"});
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_EQ(frozen_rules.terminals, std::string("a\0test\0this\0", 12));
+
+ // Expect removal of anchors and fillers in this case.
+ // The rule above is equivalent to: <code> ::= this is a test, binarized into
+ // <tmp_0> ::= this is
+ // <tmp_1> ::= <tmp_0> a
+ // <code> ::= <tmp_1> test
+ EXPECT_THAT(frozen_rules.rules.front()->binary_rules, SizeIs(3));
+
+ EXPECT_THAT(frozen_rules.rules.front()->unary_rules, IsEmpty());
+ EXPECT_THAT(frozen_rules.lhs, IsEmpty());
+}
+
+TEST(SerializeRulesTest, HandlesAnnotations) {
+ Rules rules;
+ rules.AddAnnotation("phone");
+ rules.AddAnnotation("url");
+ rules.AddAnnotation("tracking_number");
+ const Ir ir = rules.Finalize();
+ RulesSetT frozen_rules;
+ ir.Serialize(/*include_debug_information=*/false, &frozen_rules);
+
+ EXPECT_THAT(frozen_rules.rules, SizeIs(1));
+ EXPECT_THAT(frozen_rules.nonterminals->annotation_nt, SizeIs(3));
+ EXPECT_EQ(frozen_rules.nonterminals->annotation_nt[0]->key, "phone");
+ EXPECT_EQ(frozen_rules.nonterminals->annotation_nt[1]->key,
+ "tracking_number");
+ EXPECT_EQ(frozen_rules.nonterminals->annotation_nt[2]->key, "url");
+}
+
+} // namespace
+} // namespace libtextclassifier3::grammar
diff --git a/utils/hash/farmhash.cc b/native/utils/hash/farmhash.cc
similarity index 100%
rename from utils/hash/farmhash.cc
rename to native/utils/hash/farmhash.cc
diff --git a/utils/hash/farmhash.h b/native/utils/hash/farmhash.h
similarity index 100%
rename from utils/hash/farmhash.h
rename to native/utils/hash/farmhash.h
diff --git a/native/utils/i18n/language-tag.fbs b/native/utils/i18n/language-tag.fbs
new file mode 100755
index 0000000..a2e1077
--- /dev/null
+++ b/native/utils/i18n/language-tag.fbs
@@ -0,0 +1,24 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// BCP 47 tag.
+namespace libtextclassifier3;
+table LanguageTag {
+ language:string (shared);
+ script:string (shared);
+ region:string (shared);
+}
+
diff --git a/native/utils/i18n/locale.cc b/native/utils/i18n/locale.cc
new file mode 100644
index 0000000..d5a1109
--- /dev/null
+++ b/native/utils/i18n/locale.cc
@@ -0,0 +1,220 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale.h"
+
+#include "utils/strings/split.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+constexpr const char* kAnyMatch = "*";
+
+// BCP 47 code for "Undetermined Language".
+constexpr const char* kUnknownLanguageCode = "und";
+
+bool CheckLanguage(StringPiece language) {
+ if (language.size() == 1 && language.data()[0] == '*') {
+ return true;
+ }
+
+ if (language.size() != 2 && language.size() != 3) {
+ return false;
+ }
+
+ // Needs to be all lowercase.
+ for (int i = 0; i < language.size(); ++i) {
+ if (!std::islower(language[i])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool CheckScript(StringPiece script) {
+ if (script.size() != 4) {
+ return false;
+ }
+
+ if (!std::isupper(script[0])) {
+ return false;
+ }
+
+ // Needs to be all lowercase.
+ for (int i = 1; i < script.size(); ++i) {
+ if (!std::islower(script[i])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool CheckRegion(StringPiece region) {
+ if (region.size() == 2) {
+ return std::isupper(region[0]) && std::isupper(region[1]);
+ } else if (region.size() == 3) {
+ return std::isdigit(region[0]) && std::isdigit(region[1]) &&
+ std::isdigit(region[2]);
+ } else {
+ return false;
+ }
+}
+
+} // namespace
+
+Locale Locale::FromBCP47(const std::string& locale_tag) {
+ std::vector<StringPiece> parts = strings::Split(locale_tag, '-');
+ if (parts.empty()) {
+ return Locale::Invalid();
+ }
+
+ auto parts_it = parts.begin();
+ StringPiece language = *parts_it;
+ if (!CheckLanguage(language)) {
+ return Locale::Invalid();
+ }
+ ++parts_it;
+
+ StringPiece script;
+ if (parts_it != parts.end()) {
+ script = *parts_it;
+ if (!CheckScript(script)) {
+ script = "";
+ } else {
+ ++parts_it;
+ }
+ }
+
+ StringPiece region;
+ if (parts_it != parts.end()) {
+ region = *parts_it;
+ if (!CheckRegion(region)) {
+ region = "";
+ } else {
+ ++parts_it;
+ }
+ }
+
+ // NOTE: We don't parse the rest of the BCP47 tag here even if specified.
+
+ return Locale(language.ToString(), script.ToString(), region.ToString());
+}
+
+Locale Locale::FromLanguageTag(const LanguageTag* language_tag) {
+ if (language_tag == nullptr || language_tag->language() == nullptr) {
+ return Locale::Invalid();
+ }
+
+ StringPiece language = language_tag->language()->c_str();
+ if (!CheckLanguage(language)) {
+ return Locale::Invalid();
+ }
+
+ StringPiece script;
+ if (language_tag->script() != nullptr) {
+ script = language_tag->script()->c_str();
+ if (!CheckScript(script)) {
+ script = "";
+ }
+ }
+
+ StringPiece region;
+ if (language_tag->region() != nullptr) {
+ region = language_tag->region()->c_str();
+ if (!CheckRegion(region)) {
+ region = "";
+ }
+ }
+ return Locale(language.ToString(), script.ToString(), region.ToString());
+}
+
+bool Locale::IsUnknown() const {
+ return is_valid_ && language_ == kUnknownLanguageCode;
+}
+
+bool Locale::IsLocaleSupported(const Locale& locale,
+ const std::vector<Locale>& supported_locales,
+ bool default_value) {
+ if (!locale.IsValid()) {
+ return false;
+ }
+ if (locale.IsUnknown()) {
+ return default_value;
+ }
+ for (const Locale& supported_locale : supported_locales) {
+ if (!supported_locale.IsValid()) {
+ continue;
+ }
+ const bool language_matches =
+ supported_locale.Language().empty() ||
+ supported_locale.Language() == kAnyMatch ||
+ supported_locale.Language() == locale.Language();
+ const bool script_matches = supported_locale.Script().empty() ||
+ supported_locale.Script() == kAnyMatch ||
+ locale.Script().empty() ||
+ supported_locale.Script() == locale.Script();
+ const bool region_matches = supported_locale.Region().empty() ||
+ supported_locale.Region() == kAnyMatch ||
+ locale.Region().empty() ||
+ supported_locale.Region() == locale.Region();
+ if (language_matches && script_matches && region_matches) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Locale::IsAnyLocaleSupported(const std::vector<Locale>& locales,
+ const std::vector<Locale>& supported_locales,
+ bool default_value) {
+ if (locales.empty()) {
+ return default_value;
+ }
+ if (supported_locales.empty()) {
+ return default_value;
+ }
+ for (const Locale& locale : locales) {
+ if (IsLocaleSupported(locale, supported_locales, default_value)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Locale& locale) {
+ return stream << "Locale(language=" << locale.Language()
+ << ", script=" << locale.Script()
+ << ", region=" << locale.Region()
+ << ", is_valid=" << locale.IsValid()
+ << ", is_unknown=" << locale.IsUnknown() << ")";
+}
+
+bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales) {
+ for (const auto& locale_str : strings::Split(locales_list, ',')) {
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ TC3_LOG(ERROR) << "Invalid locale " << locale_str.ToString();
+ return false;
+ }
+ locales->push_back(locale);
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/i18n/locale.h b/native/utils/i18n/locale.h
new file mode 100644
index 0000000..308846d
--- /dev/null
+++ b/native/utils/i18n/locale.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
+#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/i18n/language-tag_generated.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+class Locale {
+ public:
+ // Constructs the object from a valid BCP47 tag. If the tag is invalid,
+ // an object is created that gives false when IsInvalid() is called.
+ static Locale FromBCP47(const std::string& locale_tag);
+
+ // Constructs the object from a flatbuffer language tag.
+ static Locale FromLanguageTag(const LanguageTag* language_tag);
+
+ // Creates a prototypical invalid locale object.
+ static Locale Invalid() {
+ Locale locale(/*language=*/"", /*script=*/"", /*region=*/"");
+ locale.is_valid_ = false;
+ return locale;
+ }
+
+ std::string Language() const { return language_; }
+
+ std::string Script() const { return script_; }
+
+ std::string Region() const { return region_; }
+
+ bool IsValid() const { return is_valid_; }
+ bool IsUnknown() const;
+
+ // Returns whether any of the given locales is supported by any of the
+ // supported locales. Returns default value if the given 'locales' list, or
+ // 'supported_locales' list is empty or an unknown locale is found.
+ // Locale::FromBCP47("*") means any locale.
+ static bool IsAnyLocaleSupported(const std::vector<Locale>& locales,
+ const std::vector<Locale>& supported_locales,
+ bool default_value);
+
+ private:
+ Locale(const std::string& language, const std::string& script,
+ const std::string& region)
+ : language_(language),
+ script_(script),
+ region_(region),
+ is_valid_(true) {}
+
+ static bool IsLocaleSupported(const Locale& locale,
+ const std::vector<Locale>& supported_locales,
+ bool default_value);
+
+ std::string language_;
+ std::string script_;
+ std::string region_;
+ bool is_valid_;
+};
+
+// Pretty-printing function for Locale.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Locale& locale);
+
+// Parses a comma-separated list of BCP47 tags.
+bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
diff --git a/utils/i18n/locale_test.cc b/native/utils/i18n/locale_test.cc
similarity index 100%
rename from utils/i18n/locale_test.cc
rename to native/utils/i18n/locale_test.cc
diff --git a/native/utils/intents/intent-config.fbs b/native/utils/intents/intent-config.fbs
new file mode 100755
index 0000000..672eb9d
--- /dev/null
+++ b/native/utils/intents/intent-config.fbs
@@ -0,0 +1,201 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/zlib/buffer.fbs";
+
+// The type of variable to fetch.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorVariableType : int {
+ INVALID_VARIABLE = 0,
+
+ // The raw text that was classified.
+ RAW_TEXT = 1,
+
+ // Text as a URL with explicit protocol. If no protocol was specified, http
+ // is prepended.
+ URL_TEXT = 2,
+
+ // The raw text, but URL encoded.
+ URL_ENCODED_TEXT = 3,
+
+ // For dates/times: the instant of the event in UTC millis.
+ EVENT_TIME_MS_UTC = 4,
+
+ // For dates/times: the start of the event in UTC millis.
+ EVENT_START_MS_UTC = 5,
+
+ // For dates/times: the end of the event in UTC millis.
+ EVENT_END_MS_UTC = 6,
+
+ // Name of the package that's running the classifier.
+ PACKAGE_NAME = 7,
+}
+
+// Enumerates the possible extra types for the simple intent generator.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorExtraType : int {
+ INVALID_EXTRA_TYPE = 0,
+ STRING = 1,
+ // Use string_ field.
+
+ BOOL = 2,
+ // Use bool_ field.
+
+ VARIABLE_AS_LONG = 3,
+ // Use int32_ field for the variable index.
+}
+
+// Enumerates the possible condition types for the simple intent generator.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorConditionType : int {
+ INVALID_CONDITION_TYPE = 0,
+
+ // Queries the UserManager for the given boolean restriction. The condition
+ // passes if the result is of getBoolean is false. The name of the
+ // restriction to check is in the string_ field.
+ USER_RESTRICTION_NOT_SET = 1,
+
+ // Checks that the parsed event start time is at least a give number of
+ // milliseconds in the future. (Only valid if there is a parsed event
+ // time) The offset is stored in the int64_ field.
+ EVENT_START_IN_FUTURE_MS = 2,
+}
+
+// Describes how intents for the various entity types should be generated on
+// Android. This is distributed through the model, but not used by
+// libtextclassifier yet - rather, it's passed to the calling Java code, which
+// implements the Intent generation logic.
+namespace libtextclassifier3;
+table AndroidIntentFactoryOptions {
+ entity:[AndroidIntentFactoryEntityOptions];
+}
+
+// Describes how intents should be generated for a particular entity type.
+namespace libtextclassifier3;
+table AndroidIntentFactoryEntityOptions {
+ // The entity type as defined by one of the TextClassifier ENTITY_TYPE
+ // constants. (e.g. "address", "phone", etc.)
+ entity_type:string (shared);
+
+ // List of generators for all the different types of intents that should
+ // be made available for the entity type.
+ generator:[AndroidIntentGeneratorOptions];
+}
+
+// Configures a single Android Intent generator.
+namespace libtextclassifier3;
+table AndroidIntentGeneratorOptions {
+ // Strings for UI elements.
+ strings:[AndroidIntentGeneratorStrings];
+
+ // Generator specific configuration.
+ simple:AndroidSimpleIntentGeneratorOptions;
+}
+
+// Language dependent configuration for an Android Intent generator.
+namespace libtextclassifier3;
+table AndroidIntentGeneratorStrings {
+ // BCP 47 tag for the supported locale. Note that because of API level
+ // restrictions, this must /not/ use wildcards. To e.g. match all English
+ // locales, use only "en" and not "en_*". Reference the java.util.Locale
+ // constructor for details.
+ language_tag:string (shared);
+
+ // Title shown for the action (see RemoteAction.getTitle).
+ title:string (shared);
+
+ // Description shown for the action (see
+ // RemoteAction.getContentDescription).
+ description:string (shared);
+}
+
+// An extra to set on a simple intent generator Intent.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorExtra {
+ // The name of the extra to set.
+ name:string (shared);
+
+ // The type of the extra to set.
+ type:AndroidSimpleIntentGeneratorExtraType;
+
+ string_:string (shared);
+ bool_:bool;
+ int32_:int;
+}
+
+// A condition that needs to be fulfilled for an Intent to get generated.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorCondition {
+ type:AndroidSimpleIntentGeneratorConditionType;
+ string_:string (shared);
+ int32_:int;
+ int64_:long;
+}
+
+// Configures an intent generator where the logic is simple to be expressed with
+// basic rules - which covers the vast majority of use cases and is analogous
+// to Android Actions.
+// Most strings (action, data, type, ...) may contain variable references. To
+// use them, the generator must first declare all the variables it wishes to use
+// in the variables field. The values then become available as numbered
+// arguments (using the normal java.util.Formatter syntax) in the order they
+// were specified.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorOptions {
+ // The action to set on the Intent (see Intent.setAction). Supports variables.
+ action:string (shared);
+
+ // The data to set on the Intent (see Intent.setData). Supports variables.
+ data:string (shared);
+
+ // The type to set on the Intent (see Intent.setType). Supports variables.
+ type:string (shared);
+
+ // The list of all the extras to add to the Intent.
+ extra:[AndroidSimpleIntentGeneratorExtra];
+
+ // The list of all the variables that become available for substitution in
+ // the action, data, type and extra strings. To e.g. set a field to the value
+ // of the first variable, use "%0$s".
+ variable:[AndroidSimpleIntentGeneratorVariableType];
+
+ // The list of all conditions that need to be fulfilled for Intent generation.
+ condition:[AndroidSimpleIntentGeneratorCondition];
+}
+
+// Describes how intents should be generated for a particular entity type.
+namespace libtextclassifier3.IntentFactoryModel_;
+table IntentGenerator {
+ // The type of the intent generator, e.g. the entity type as defined by
+ // on the TextClassifier ENTITY_TYPE constants e.g. "address", "phone", etc.
+ type:string (shared);
+
+ // The template generator lua code, either as text source or precompiled
+ // bytecode.
+ lua_template_generator:[ubyte];
+
+ compressed_lua_template_generator:CompressedBuffer;
+}
+
+// Describes how intents for the various entity types should be generated.
+namespace libtextclassifier3;
+table IntentFactoryModel {
+ generator:[IntentFactoryModel_.IntentGenerator];
+
+ // Whether to precompile the generators when loading.
+ precompile_generators:bool = false;
+}
+
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
new file mode 100644
index 0000000..4cb3e40
--- /dev/null
+++ b/native/utils/intents/intent-generator.cc
@@ -0,0 +1,938 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/intent-generator.h"
+
+#include <vector>
+
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/base/statusor.h"
+#include "utils/hash/farmhash.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
+#include "utils/java/string_utils.h"
+#include "utils/lua-utils.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/strings/substitute.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/variant.h"
+#include "utils/zlib/zlib.h"
+#include "flatbuffers/reflection_generated.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
+static constexpr const char* kHashKey = "hash";
+static constexpr const char* kUrlSchemaKey = "url_schema";
+static constexpr const char* kUrlHostKey = "url_host";
+static constexpr const char* kUrlEncodeKey = "urlencode";
+static constexpr const char* kPackageNameKey = "package_name";
+static constexpr const char* kDeviceLocaleKey = "device_locales";
+static constexpr const char* kFormatKey = "format";
+
+// An Android specific Lua environment with JNI backed callbacks.
+class JniLuaEnvironment : public LuaEnvironment {
+ public:
+ JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales);
+ // Environment setup.
+ bool Initialize();
+
+ // Runs an intent generator snippet.
+ bool RunIntentGenerator(const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions);
+
+ protected:
+ virtual void SetupExternalHook();
+
+ int HandleExternalCallback();
+ int HandleAndroidCallback();
+ int HandleUserRestrictionsCallback();
+ int HandleUrlEncode();
+ int HandleUrlSchema();
+ int HandleHash();
+ int HandleFormat();
+ int HandleAndroidStringResources();
+ int HandleUrlHost();
+
+ // Checks and retrieves string resources from the model.
+ bool LookupModelStringResource() const;
+
+ // Reads and create a RemoteAction result from Lua.
+ RemoteActionTemplate ReadRemoteActionTemplateResult() const;
+
+ // Reads the extras from the Lua result.
+ std::map<std::string, Variant> ReadExtras() const;
+
+ // Retrieves user manager if not previously done.
+ bool RetrieveUserManager();
+
+ // Retrieves system resources if not previously done.
+ bool RetrieveSystemResources();
+
+ // Parse the url string by using Uri.parse from Java.
+ StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
+
+ // Read remote action templates from lua generator.
+ int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
+
+ const Resources& resources_;
+ JNIEnv* jenv_;
+ const JniCache* jni_cache_;
+ const jobject context_;
+ std::vector<Locale> device_locales_;
+
+ ScopedGlobalRef<jobject> usermanager_;
+ // Whether we previously attempted to retrieve the UserManager before.
+ bool usermanager_retrieved_;
+
+ ScopedGlobalRef<jobject> system_resources_;
+ // Whether we previously attempted to retrieve the system resources.
+ bool system_resources_resources_retrieved_;
+
+ // Cached JNI references for Java strings `string` and `android`.
+ ScopedGlobalRef<jstring> string_;
+ ScopedGlobalRef<jstring> android_;
+};
+
+JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
+ const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales)
+ : resources_(resources),
+ jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
+ jni_cache_(jni_cache),
+ context_(context),
+ device_locales_(device_locales),
+ usermanager_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ usermanager_retrieved_(false),
+ system_resources_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ system_resources_resources_retrieved_(false),
+ string_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ android_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
+
+bool JniLuaEnvironment::Initialize() {
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
+ JniHelper::NewStringUTF(jenv_, "string"));
+ string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
+ JniHelper::NewStringUTF(jenv_, "android"));
+ android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
+ if (string_ == nullptr || android_ == nullptr) {
+ TC3_LOG(ERROR) << "Could not allocate constant strings references.";
+ return false;
+ }
+ return (RunProtected([this] {
+ LoadDefaultLibraries();
+ SetupExternalHook();
+ lua_setglobal(state_, "external");
+ return LUA_OK;
+ }) == LUA_OK);
+}
+
+void JniLuaEnvironment::SetupExternalHook() {
+ // This exposes an `external` object with the following fields:
+ // * entity: the bundle with all information about a classification.
+ // * android: callbacks into specific android provided methods.
+ // * android.user_restrictions: callbacks to check user permissions.
+ // * android.R: callbacks to retrieve string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleExternalCallback);
+
+ // android
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidCallback);
+ {
+ // android.user_restrictions
+ PushLazyObject(&JniLuaEnvironment::HandleUserRestrictionsCallback);
+ lua_setfield(state_, /*idx=*/-2, "user_restrictions");
+
+ // android.R
+ // Callback to access android string resources.
+ PushLazyObject(&JniLuaEnvironment::HandleAndroidStringResources);
+ lua_setfield(state_, /*idx=*/-2, "R");
+ }
+ lua_setfield(state_, /*idx=*/-2, "android");
+}
+
+int JniLuaEnvironment::HandleExternalCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kHashKey)) {
+ PushFunction(&JniLuaEnvironment::HandleHash);
+ return 1;
+ } else if (key.Equals(kFormatKey)) {
+ PushFunction(&JniLuaEnvironment::HandleFormat);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined external access " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleAndroidCallback() {
+ const StringPiece key = ReadString(kIndexStackTop);
+ if (key.Equals(kDeviceLocaleKey)) {
+ // Provide the locale as table with the individual fields set.
+ lua_newtable(state_);
+ for (int i = 0; i < device_locales_.size(); i++) {
+ // Adjust index to 1-based indexing for Lua.
+ lua_pushinteger(state_, i + 1);
+ lua_newtable(state_);
+ PushString(device_locales_[i].Language());
+ lua_setfield(state_, -2, "language");
+ PushString(device_locales_[i].Region());
+ lua_setfield(state_, -2, "region");
+ PushString(device_locales_[i].Script());
+ lua_setfield(state_, -2, "script");
+ lua_settable(state_, /*idx=*/-3);
+ }
+ return 1;
+ } else if (key.Equals(kPackageNameKey)) {
+ if (context_ == nullptr) {
+ TC3_LOG(ERROR) << "Context invalid.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, context_, jni_cache_->context_get_package_name);
+
+ if (!status_or_package_name_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Context.getPackageName";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<std::string> status_or_package_name_std_str =
+ ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
+ if (!status_or_package_name_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_package_name_std_str.ValueOrDie());
+ return 1;
+ } else if (key.Equals(kUrlEncodeKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlEncode);
+ return 1;
+ } else if (key.Equals(kUrlHostKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlHost);
+ return 1;
+ } else if (key.Equals(kUrlSchemaKey)) {
+ PushFunction(&JniLuaEnvironment::HandleUrlSchema);
+ return 1;
+ } else {
+ TC3_LOG(ERROR) << "Undefined android reference " << key;
+ lua_error(state_);
+ return 0;
+ }
+}
+
+int JniLuaEnvironment::HandleUserRestrictionsCallback() {
+ if (jni_cache_->usermanager_class == nullptr ||
+ jni_cache_->usermanager_get_user_restrictions == nullptr) {
+ // UserManager is only available for API level >= 17 and
+ // getUserRestrictions only for API level >= 18, so we just return false
+ // normally here.
+ lua_pushboolean(state_, false);
+ return 1;
+ }
+
+ // Get user manager if not previously retrieved.
+ if (!RetrieveUserManager()) {
+ TC3_LOG(ERROR) << "Error retrieving user manager.";
+ lua_error(state_);
+ return 0;
+ }
+
+ StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
+ JniHelper::CallObjectMethod(
+ jenv_, usermanager_.get(),
+ jni_cache_->usermanager_get_user_restrictions);
+ if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
+ TC3_LOG(ERROR) << "Error calling getUserRestrictions";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StringPiece key_str = ReadString(kIndexStackTop);
+ if (key_str.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_key =
+ jni_cache_->ConvertToJavaString(key_str);
+ if (!status_or_key.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
+ jenv_, status_or_bundle.ValueOrDie().get(),
+ jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
+ if (!status_or_permission.ok()) {
+ TC3_LOG(ERROR) << "Error getting bundle value";
+ lua_pushboolean(state_, false);
+ } else {
+ lua_pushboolean(state_, status_or_permission.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlEncode() {
+ const StringPiece input = ReadString(/*index=*/1);
+ if (input.empty()) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ lua_error(state_);
+ return 0;
+ }
+
+ // Call Java URL encoder.
+ const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
+ jni_cache_->ConvertToJavaString(input);
+ if (!status_or_input_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
+ JniHelper::CallStaticObjectMethod<jstring>(
+ jenv_, jni_cache_->urlencoder_class.get(),
+ jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
+ jni_cache_->string_utf8.get());
+
+ if (!status_or_encoded_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<std::string> status_or_encoded_std_str =
+ ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
+ if (!status_or_encoded_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_encoded_std_str.ValueOrDie());
+ return 1;
+}
+
+StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
+ StringPiece url) const {
+ if (url.empty()) {
+ return {Status::UNKNOWN};
+ }
+
+ // Call to Java URI parser.
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
+ jni_cache_->ConvertToJavaString(url));
+
+ // Try to parse uri and get scheme.
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> uri,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
+ jni_cache_->uri_parse,
+ status_or_url_str.ValueOrDie().get()));
+ if (uri == nullptr) {
+ TC3_LOG(ERROR) << "Error calling Uri.parse";
+ return {Status::UNKNOWN};
+ }
+ return uri;
+}
+
+int JniLuaEnvironment::HandleUrlSchema() {
+ StringPiece url = ReadString(/*index=*/1);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_scheme);
+ if (!status_or_scheme_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getScheme";
+ lua_error(state_);
+ return 0;
+ }
+ if (status_or_scheme_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_scheme_std_str =
+ ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
+ if (!status_or_scheme_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_scheme_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleUrlHost() {
+ const StringPiece url = ReadString(kIndexStackTop);
+
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+
+ const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_host);
+ if (!status_or_host_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getHost";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_host_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ const StatusOr<std::string> status_or_host_std_str =
+ ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
+ if (!status_or_host_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_host_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+int JniLuaEnvironment::HandleHash() {
+ const StringPiece input = ReadString(kIndexStackTop);
+ lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
+ return 1;
+}
+
+int JniLuaEnvironment::HandleFormat() {
+ const int num_args = lua_gettop(state_);
+ std::vector<StringPiece> args(num_args - 1);
+ for (int i = 0; i < num_args - 1; i++) {
+ args[i] = ReadString(/*index=*/i + 2);
+ }
+ PushString(strings::Substitute(ReadString(/*index=*/1), args));
+ return 1;
+}
+
+bool JniLuaEnvironment::LookupModelStringResource() const {
+ // Handle only lookup by name.
+ if (lua_type(state_, kIndexStackTop) != LUA_TSTRING) {
+ return false;
+ }
+
+ const StringPiece resource_name = ReadString(kIndexStackTop);
+ std::string resource_content;
+ if (!resources_.GetResourceContent(device_locales_, resource_name,
+ &resource_content)) {
+ // Resource cannot be provided by the model.
+ return false;
+ }
+
+ PushString(resource_content);
+ return true;
+}
+
+int JniLuaEnvironment::HandleAndroidStringResources() {
+ // Check whether the requested resource can be served from the model data.
+ if (LookupModelStringResource()) {
+ return 1;
+ }
+
+ // Get system resources if not previously retrieved.
+ if (!RetrieveSystemResources()) {
+ TC3_LOG(ERROR) << "Error retrieving system resources.";
+ lua_error(state_);
+ return 0;
+ }
+
+ int resource_id;
+ switch (lua_type(state_, kIndexStackTop)) {
+ case LUA_TNUMBER:
+ resource_id = Read<int>(/*index=*/kIndexStackTop);
+ break;
+ case LUA_TSTRING: {
+ const StringPiece resource_name_str = ReadString(kIndexStackTop);
+ if (resource_name_str.empty()) {
+ TC3_LOG(ERROR) << "No resource name provided.";
+ lua_error(state_);
+ return 0;
+ }
+ const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
+ jni_cache_->ConvertToJavaString(resource_name_str);
+ if (!status_or_resource_name.ok()) {
+ TC3_LOG(ERROR) << "Invalid resource name.";
+ lua_error(state_);
+ return 0;
+ }
+ StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
+ jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
+ status_or_resource_name.ValueOrDie().get(), string_.get(),
+ android_.get());
+ if (!status_or_resource_id.ok()) {
+ TC3_LOG(ERROR) << "Error calling getIdentifier.";
+ lua_error(state_);
+ return 0;
+ }
+ resource_id = status_or_resource_id.ValueOrDie();
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
+ lua_error(state_);
+ return 0;
+ }
+ if (resource_id == 0) {
+ TC3_LOG(ERROR) << "Resource not found.";
+ lua_pushnil(state_);
+ return 1;
+ }
+ StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
+ JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
+ jni_cache_->resources_get_string,
+ resource_id);
+ if (!status_or_resource_str.ok()) {
+ TC3_LOG(ERROR) << "Error calling getString.";
+ lua_error(state_);
+ return 0;
+ }
+
+ if (status_or_resource_str.ValueOrDie() == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ StatusOr<std::string> status_or_resource_std_str =
+ ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
+ if (!status_or_resource_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_resource_std_str.ValueOrDie());
+ }
+ return 1;
+}
+
+bool JniLuaEnvironment::RetrieveSystemResources() {
+ if (system_resources_resources_retrieved_) {
+ return (system_resources_ != nullptr);
+ }
+ system_resources_resources_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
+ JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->resources_class.get(),
+ jni_cache_->resources_get_system));
+ system_resources_ =
+ MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
+ return (system_resources_ != nullptr);
+}
+
+bool JniLuaEnvironment::RetrieveUserManager() {
+ if (context_ == nullptr) {
+ return false;
+ }
+ if (usermanager_retrieved_) {
+ return (usermanager_ != nullptr);
+ }
+ usermanager_retrieved_ = true;
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
+ JniHelper::NewStringUTF(jenv_, "user"));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const ScopedLocalRef<jobject> usermanager_ref,
+ JniHelper::CallObjectMethod(jenv_, context_,
+ jni_cache_->context_get_system_service,
+ service.get()));
+
+ usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
+ return (usermanager_ != nullptr);
+}
+
+RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() const {
+ RemoteActionTemplate result;
+ // Read intent template.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("title_without_entity")) {
+ result.title_without_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("title_with_entity")) {
+ result.title_with_entity = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description")) {
+ result.description = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("description_with_app_name")) {
+ result.description_with_app_name =
+ Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("action")) {
+ result.action = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("data")) {
+ result.data = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("type")) {
+ result.type = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("flags")) {
+ result.flags = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("package_name")) {
+ result.package_name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("request_code")) {
+ result.request_code = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("category")) {
+ result.category = ReadVector<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("extra")) {
+ result.extra = ReadExtras();
+ } else {
+ TC3_LOG(INFO) << "Unknown entry: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ lua_pop(state_, 1);
+ return result;
+}
+
+std::map<std::string, Variant> JniLuaEnvironment::ReadExtras() const {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected extras table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::map<std::string, Variant> extras;
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ // Each entry is a table specifying name and value.
+ // The value is specified via a type specific field as Lua doesn't allow
+ // to easily distinguish between different number types.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table for an extra, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ std::string name;
+ Variant value;
+
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ const StringPiece key = ReadString(/*index=*/-2);
+ if (key.Equals("name")) {
+ name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals("int_value")) {
+ value = Variant(Read<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("long_value")) {
+ value = Variant(Read<int64>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_value")) {
+ value = Variant(Read<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("bool_value")) {
+ value = Variant(Read<bool>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_value")) {
+ value = Variant(Read<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("string_array_value")) {
+ value = Variant(ReadVector<std::string>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("float_array_value")) {
+ value = Variant(ReadVector<float>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("int_array_value")) {
+ value = Variant(ReadVector<int>(/*index=*/kIndexStackTop));
+ } else if (key.Equals("named_variant_array_value")) {
+ value = Variant(ReadExtras());
+ } else {
+ TC3_LOG(INFO) << "Unknown extra field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ if (!name.empty()) {
+ extras[name] = value;
+ } else {
+ TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
+ }
+ lua_pop(state_, 1);
+ }
+ return extras;
+}
+
+int JniLuaEnvironment::ReadRemoteActionTemplates(
+ std::vector<RemoteActionTemplate>* result) {
+ // Read result.
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Unexpected result for snippet: "
+ << lua_type(state_, kIndexStackTop);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ // Read remote action templates array.
+ lua_pushnil(state_);
+ while (Next(/*index=*/-2)) {
+ if (lua_type(state_, kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected intent table, got: "
+ << lua_type(state_, kIndexStackTop);
+ lua_pop(state_, 1);
+ continue;
+ }
+ result->push_back(ReadRemoteActionTemplateResult());
+ }
+ lua_pop(state_, /*n=*/1);
+ return LUA_OK;
+}
+
+bool JniLuaEnvironment::RunIntentGenerator(
+ const std::string& generator_snippet,
+ std::vector<RemoteActionTemplate>* remote_actions) {
+ int status;
+ status = luaL_loadbuffer(state_, generator_snippet.data(),
+ generator_snippet.size(),
+ /*name=*/nullptr);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
+ return false;
+ }
+ status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
+ return false;
+ }
+ if (RunProtected(
+ [this, remote_actions] {
+ return ReadRemoteActionTemplates(remote_actions);
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read results.";
+ return false;
+ }
+ // Check that we correctly cleaned-up the state.
+ const int stack_size = lua_gettop(state_);
+ if (stack_size > 0) {
+ TC3_LOG(ERROR) << "Unexpected stack size.";
+ lua_settop(state_, 0);
+ return false;
+ }
+ return true;
+}
+
+// Lua environment for classfication result intent generation.
+class AnnotatorJniEnvironment : public JniLuaEnvironment {
+ public:
+ AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
+ const jobject context,
+ const std::vector<Locale>& device_locales,
+ const std::string& entity_text,
+ const ClassificationResult& classification,
+ const int64 reference_time_ms_utc,
+ const reflection::Schema* entity_data_schema)
+ : JniLuaEnvironment(resources, jni_cache, context, device_locales),
+ entity_text_(entity_text),
+ classification_(classification),
+ reference_time_ms_utc_(reference_time_ms_utc),
+ entity_data_schema_(entity_data_schema) {}
+
+ protected:
+ void SetupExternalHook() override {
+ JniLuaEnvironment::SetupExternalHook();
+ lua_pushinteger(state_, reference_time_ms_utc_);
+ lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
+
+ PushAnnotation(classification_, entity_text_, entity_data_schema_);
+ lua_setfield(state_, /*idx=*/-2, "entity");
+ }
+
+ const std::string& entity_text_;
+ const ClassificationResult& classification_;
+ const int64 reference_time_ms_utc_;
+
+ // Reflection schema data.
+ const reflection::Schema* const entity_data_schema_;
+};
+
+// Lua environment for actions intent generation.
+class ActionsJniLuaEnvironment : public JniLuaEnvironment {
+ public:
+ ActionsJniLuaEnvironment(
+ const Resources& resources, const JniCache* jni_cache,
+ const jobject context, const std::vector<Locale>& device_locales,
+ const Conversation& conversation, const ActionSuggestion& action,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema)
+ : JniLuaEnvironment(resources, jni_cache, context, device_locales),
+ conversation_(conversation),
+ action_(action),
+ actions_entity_data_schema_(actions_entity_data_schema),
+ annotations_entity_data_schema_(annotations_entity_data_schema) {}
+
+ protected:
+ void SetupExternalHook() override {
+ JniLuaEnvironment::SetupExternalHook();
+ PushConversation(&conversation_.messages, annotations_entity_data_schema_);
+ lua_setfield(state_, /*idx=*/-2, "conversation");
+
+ PushAction(action_, actions_entity_data_schema_,
+ annotations_entity_data_schema_);
+ lua_setfield(state_, /*idx=*/-2, "entity");
+ }
+
+ const Conversation& conversation_;
+ const ActionSuggestion& action_;
+ const reflection::Schema* actions_entity_data_schema_;
+ const reflection::Schema* annotations_entity_data_schema_;
+};
+
+} // namespace
+
+std::unique_ptr<IntentGenerator> IntentGenerator::Create(
+ const IntentFactoryModel* options, const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache) {
+ std::unique_ptr<IntentGenerator> intent_generator(
+ new IntentGenerator(options, resources, jni_cache));
+
+ if (options == nullptr || options->generator() == nullptr) {
+ TC3_LOG(ERROR) << "No intent generator options.";
+ return nullptr;
+ }
+
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return nullptr;
+ }
+
+ for (const IntentFactoryModel_::IntentGenerator* generator :
+ *options->generator()) {
+ std::string lua_template_generator;
+ if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
+ generator->lua_template_generator(),
+ generator->compressed_lua_template_generator(),
+ &lua_template_generator)) {
+ TC3_LOG(ERROR) << "Could not decompress generator template.";
+ return nullptr;
+ }
+
+ std::string lua_code = lua_template_generator;
+ if (options->precompile_generators()) {
+ if (!Compile(lua_template_generator, &lua_code)) {
+ TC3_LOG(ERROR) << "Could not precompile generator template.";
+ return nullptr;
+ }
+ }
+
+ intent_generator->generators_[generator->type()->str()] = lua_code;
+ }
+
+ return intent_generator;
+}
+
+std::vector<Locale> IntentGenerator::ParseDeviceLocales(
+ const jstring device_locales) const {
+ if (device_locales == nullptr) {
+ TC3_LOG(ERROR) << "No locales provided.";
+ return {};
+ }
+ ScopedStringChars locales_str =
+ GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
+ if (locales_str == nullptr) {
+ TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
+ return {};
+ }
+ std::vector<Locale> locales;
+ if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
+ &locales)) {
+ TC3_LOG(ERROR) << "Cannot parse locales.";
+ return {};
+ }
+ return locales;
+}
+
+bool IntentGenerator::GenerateIntents(
+ const jstring device_locales, const ClassificationResult& classification,
+ const int64 reference_time_ms_utc, const std::string& text,
+ const CodepointSpan selection_indices, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const {
+ if (options_ == nullptr) {
+ return false;
+ }
+
+ // Retrieve generator for specified entity.
+ auto it = generators_.find(classification.collection);
+ if (it == generators_.end()) {
+ TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
+ return true;
+ }
+
+ const std::string entity_text =
+ UTF8ToUnicodeText(text, /*do_copy=*/false)
+ .UTF8Substring(selection_indices.first, selection_indices.second);
+
+ std::unique_ptr<AnnotatorJniEnvironment> interpreter(
+ new AnnotatorJniEnvironment(
+ resources_, jni_cache_.get(), context,
+ ParseDeviceLocales(device_locales), entity_text, classification,
+ reference_time_ms_utc, annotations_entity_data_schema));
+
+ if (!interpreter->Initialize()) {
+ TC3_LOG(ERROR) << "Could not create Lua interpreter.";
+ return false;
+ }
+
+ return interpreter->RunIntentGenerator(it->second, remote_actions);
+}
+
+bool IntentGenerator::GenerateIntents(
+ const jstring device_locales, const ActionSuggestion& action,
+ const Conversation& conversation, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const reflection::Schema* actions_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const {
+ if (options_ == nullptr) {
+ return false;
+ }
+
+ // Retrieve generator for specified action.
+ auto it = generators_.find(action.type);
+ if (it == generators_.end()) {
+ return true;
+ }
+
+ std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
+ new ActionsJniLuaEnvironment(
+ resources_, jni_cache_.get(), context,
+ ParseDeviceLocales(device_locales), conversation, action,
+ actions_entity_data_schema, annotations_entity_data_schema));
+
+ if (!interpreter->Initialize()) {
+ TC3_LOG(ERROR) << "Could not create Lua interpreter.";
+ return false;
+ }
+
+ return interpreter->RunIntentGenerator(it->second, remote_actions);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h
new file mode 100644
index 0000000..2a45191
--- /dev/null
+++ b/native/utils/intents/intent-generator.h
@@ -0,0 +1,71 @@
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+
+#include <jni.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/i18n/locale.h"
+#include "utils/intents/intent-config_generated.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-cache.h"
+#include "utils/resources.h"
+#include "utils/resources_generated.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Helper class to generate Android intents for text classifier results.
+class IntentGenerator {
+ public:
+ static std::unique_ptr<IntentGenerator> Create(
+ const IntentFactoryModel* options, const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache);
+
+ // Generates intents for a classification result.
+ // Returns true, if the intent generator snippets could be successfully run,
+ // returns false otherwise.
+ bool GenerateIntents(const jstring device_locales,
+ const ClassificationResult& classification,
+ const int64 reference_time_ms_utc,
+ const std::string& text,
+ const CodepointSpan selection_indices,
+ const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const;
+
+ // Generates intents for an action suggestion.
+ // Returns true, if the intent generator snippets could be successfully run,
+ // returns false otherwise.
+ bool GenerateIntents(const jstring device_locales,
+ const ActionSuggestion& action,
+ const Conversation& conversation, const jobject context,
+ const reflection::Schema* annotations_entity_data_schema,
+ const reflection::Schema* actions_entity_data_schema,
+ std::vector<RemoteActionTemplate>* remote_actions) const;
+
+ private:
+ IntentGenerator(const IntentFactoryModel* options,
+ const ResourcePool* resources,
+ const std::shared_ptr<JniCache>& jni_cache)
+ : options_(options),
+ resources_(Resources(resources)),
+ jni_cache_(jni_cache) {}
+
+ std::vector<Locale> ParseDeviceLocales(const jstring device_locales) const;
+
+ const IntentFactoryModel* options_;
+ const Resources resources_;
+ std::shared_ptr<JniCache> jni_cache_;
+ std::map<std::string, std::string> generators_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
new file mode 100644
index 0000000..051d078
--- /dev/null
+++ b/native/utils/intents/jni.cc
@@ -0,0 +1,354 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/jni.h"
+
+#include <memory>
+
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+
+// The macros below are intended to reduce the boilerplate and avoid
+// easily introduced copy/paste errors.
+#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
+#define TC3_GET_CLASS(FIELD, NAME) \
+ { \
+ StatusOr<ScopedLocalRef<jclass>> status_or_clazz = \
+ JniHelper::FindClass(env, NAME); \
+ handler->FIELD = MakeGlobalRef(status_or_clazz.ValueOrDie().release(), \
+ env, jni_cache->jvm); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
+ }
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
+ TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
+
+std::unique_ptr<RemoteActionTemplatesHandler>
+RemoteActionTemplatesHandler::Create(
+ const std::shared_ptr<JniCache>& jni_cache) {
+ JNIEnv* env = jni_cache->GetEnv();
+ if (env == nullptr) {
+ return nullptr;
+ }
+
+ std::unique_ptr<RemoteActionTemplatesHandler> handler(
+ new RemoteActionTemplatesHandler(jni_cache));
+
+ TC3_GET_CLASS(integer_class_, "java/lang/Integer");
+ TC3_GET_METHOD(integer_class_, integer_init_, "<init>", "(I)V");
+
+ TC3_GET_CLASS(remote_action_template_class_,
+ TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR);
+ TC3_GET_METHOD(
+ remote_action_template_class_, remote_action_template_init_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "Integer;[Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR ";Ljava/lang/Integer;)V");
+
+ TC3_GET_CLASS(named_variant_class_,
+ TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR);
+
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_int_, "<init>",
+ "(Ljava/lang/String;I)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_long_, "<init>",
+ "(Ljava/lang/String;J)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_float_, "<init>",
+ "(Ljava/lang/String;F)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_double_, "<init>",
+ "(Ljava/lang/String;D)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_bool_, "<init>",
+ "(Ljava/lang/String;Z)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_string_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_string_array_,
+ "<init>", "(Ljava/lang/String;[Ljava/lang/String;)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_float_array_,
+ "<init>", "(Ljava/lang/String;[F)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_int_array_, "<init>",
+ "(Ljava/lang/String;[I)V");
+ TC3_GET_METHOD(
+ named_variant_class_, named_variant_from_named_variant_array_, "<init>",
+ "(Ljava/lang/String;[L" TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR
+ ";)V");
+ return handler;
+}
+
+StatusOr<ScopedLocalRef<jstring>> RemoteActionTemplatesHandler::AsUTF8String(
+ const Optional<std::string>& optional) const {
+ if (!optional.has_value()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+ return jni_cache_->ConvertToJavaString(optional.value());
+}
+
+StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsInteger(
+ const Optional<int>& optional) const {
+ if (!optional.has_value()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(jni_cache_->GetEnv(), integer_class_.get(),
+ integer_init_, optional.value()));
+
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::AsStringArray(
+ const std::vector<std::string>& values) const {
+ if (values.empty()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> result,
+ JniHelper::NewObjectArray(jni_cache_->GetEnv(), values.size(),
+ jni_cache_->string_class.get(), nullptr));
+
+ for (int k = 0; k < values.size(); k++) {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> value_str,
+ jni_cache_->ConvertToJavaString(values[k]));
+ jni_cache_->GetEnv()->SetObjectArrayElement(result.get(), k,
+ value_str.get());
+ }
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jfloatArray>>
+RemoteActionTemplatesHandler::AsFloatArray(
+ const std::vector<float>& values) const {
+ if (values.empty()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jfloatArray> result,
+ JniHelper::NewFloatArray(jni_cache_->GetEnv(), values.size()));
+
+ jni_cache_->GetEnv()->SetFloatArrayRegion(result.get(), /*start=*/0,
+ /*len=*/values.size(),
+ &(values[0]));
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jintArray>> RemoteActionTemplatesHandler::AsIntArray(
+ const std::vector<int>& values) const {
+ if (values.empty()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jintArray> result,
+ JniHelper::NewIntArray(jni_cache_->GetEnv(), values.size()));
+
+ jni_cache_->GetEnv()->SetIntArrayRegion(result.get(), /*start=*/0,
+ /*len=*/values.size(), &(values[0]));
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
+ const std::string& name_str, const Variant& value) const {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> name,
+ jni_cache_->ConvertToJavaString(name_str));
+
+ JNIEnv* env = jni_cache_->GetEnv();
+ switch (value.GetType()) {
+ case Variant::TYPE_INT_VALUE:
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_int_, name.get(),
+ value.Value<int>());
+
+ case Variant::TYPE_INT64_VALUE:
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_long_, name.get(),
+ value.Value<int64>());
+
+ case Variant::TYPE_FLOAT_VALUE:
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_float_, name.get(),
+ value.Value<float>());
+
+ case Variant::TYPE_DOUBLE_VALUE:
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_double_, name.get(),
+ value.Value<double>());
+
+ case Variant::TYPE_BOOL_VALUE:
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_bool_, name.get(),
+ value.Value<bool>());
+
+ case Variant::TYPE_STRING_VALUE: {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> value_jstring,
+ jni_cache_->ConvertToJavaString(value.ConstRefValue<std::string>()));
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_string_, name.get(),
+ value_jstring.get());
+ }
+
+ case Variant::TYPE_STRING_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> value_jstring_array,
+ AsStringArray(value.ConstRefValue<std::vector<std::string>>()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_string_array_, name.get(),
+ value_jstring_array.get());
+ }
+
+ case Variant::TYPE_FLOAT_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jfloatArray> value_jfloat_array,
+ AsFloatArray(value.ConstRefValue<std::vector<float>>()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_float_array_, name.get(),
+ value_jfloat_array.get());
+ }
+
+ case Variant::TYPE_INT_VECTOR_VALUE: {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jintArray> value_jint_array,
+ AsIntArray(value.ConstRefValue<std::vector<int>>()));
+
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_int_array_, name.get(),
+ value_jint_array.get());
+ }
+
+ case Variant::TYPE_STRING_VARIANT_MAP_VALUE: {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> value_jobect_array,
+ AsNamedVariantArray(
+ value.ConstRefValue<std::map<std::string, Variant>>()));
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_named_variant_array_,
+ name.get(), value_jobect_array.get());
+ }
+
+ case Variant::TYPE_EMPTY:
+ return {Status::UNKNOWN};
+
+ default:
+ TC3_LOG(ERROR) << "Unsupported NamedVariant type: " << value.GetType();
+ return {Status::UNKNOWN};
+ }
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::AsNamedVariantArray(
+ const std::map<std::string, Variant>& values) const {
+ JNIEnv* env = jni_cache_->GetEnv();
+ if (values.empty()) {
+ return {{nullptr, env}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> result,
+ JniHelper::NewObjectArray(jni_cache_->GetEnv(), values.size(),
+ named_variant_class_.get(), nullptr));
+ int element_index = 0;
+ for (const auto& key_value_pair : values) {
+ if (!key_value_pair.second.HasValue()) {
+ element_index++;
+ continue;
+ }
+ TC3_ASSIGN_OR_RETURN(
+ StatusOr<ScopedLocalRef<jobject>> named_extra,
+ AsNamedVariant(key_value_pair.first, key_value_pair.second));
+ env->SetObjectArrayElement(result.get(), element_index,
+ named_extra.ValueOrDie().get());
+ element_index++;
+ }
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions) const {
+ JNIEnv* env = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, remote_actions.size(),
+ remote_action_template_class_.get(), nullptr));
+
+ for (int i = 0; i < remote_actions.size(); i++) {
+ const RemoteActionTemplate& remote_action = remote_actions[i];
+
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> title_without_entity,
+ AsUTF8String(remote_action.title_without_entity));
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> title_with_entity,
+ AsUTF8String(remote_action.title_with_entity));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> description,
+ AsUTF8String(remote_action.description));
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> description_with_app_name,
+ AsUTF8String(remote_action.description_with_app_name));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> action,
+ AsUTF8String(remote_action.action));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> data,
+ AsUTF8String(remote_action.data));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> type,
+ AsUTF8String(remote_action.type));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobject>> flags,
+ AsInteger(remote_action.flags));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobjectArray>> category,
+ AsStringArray(remote_action.category));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> package,
+ AsUTF8String(remote_action.package_name));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobjectArray>> extra,
+ AsNamedVariantArray(remote_action.extra));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobject>> request_code,
+ AsInteger(remote_action.request_code));
+
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(
+ env, remote_action_template_class_.get(),
+ remote_action_template_init_,
+ title_without_entity.ValueOrDie().get(),
+ title_with_entity.ValueOrDie().get(),
+ description.ValueOrDie().get(),
+ description_with_app_name.ValueOrDie().get(),
+ action.ValueOrDie().get(), data.ValueOrDie().get(),
+ type.ValueOrDie().get(), flags.ValueOrDie().get(),
+ category.ValueOrDie().get(), package.ValueOrDie().get(),
+ extra.ValueOrDie().get(), request_code.ValueOrDie().get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
+ }
+ return results;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
+ const reflection::Schema* entity_data_schema,
+ const std::string& serialized_entity_data) const {
+ ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = entity_data_builder.NewRoot();
+ buffer->MergeFromSerializedFlatbuffer(serialized_entity_data);
+ std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
+ return AsNamedVariantArray(entity_data_map);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/jni.h b/native/utils/intents/jni.h
new file mode 100644
index 0000000..ada2631
--- /dev/null
+++ b/native/utils/intents/jni.h
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+
+#include <jni.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/base/statusor.h"
+#include "utils/flatbuffers.h"
+#include "utils/intents/remote-action-template.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/optional.h"
+#include "utils/variant.h"
+
+#ifndef TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME RemoteActionTemplate
+#endif
+
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME)
+
+#ifndef TC3_NAMED_VARIANT_CLASS_NAME
+#define TC3_NAMED_VARIANT_CLASS_NAME NamedVariant
+#endif
+
+#define TC3_NAMED_VARIANT_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_NAMED_VARIANT_CLASS_NAME)
+
+namespace libtextclassifier3 {
+
+// A helper class to create RemoteActionTemplate object from model results.
+class RemoteActionTemplatesHandler {
+ public:
+ static std::unique_ptr<RemoteActionTemplatesHandler> Create(
+ const std::shared_ptr<JniCache>& jni_cache);
+
+ StatusOr<ScopedLocalRef<jstring>> AsUTF8String(
+ const Optional<std::string>& optional) const;
+ StatusOr<ScopedLocalRef<jobject>> AsInteger(
+ const Optional<int>& optional) const;
+ StatusOr<ScopedLocalRef<jobjectArray>> AsStringArray(
+ const std::vector<std::string>& values) const;
+ StatusOr<ScopedLocalRef<jfloatArray>> AsFloatArray(
+ const std::vector<float>& values) const;
+ StatusOr<ScopedLocalRef<jintArray>> AsIntArray(
+ const std::vector<int>& values) const;
+ StatusOr<ScopedLocalRef<jobject>> AsNamedVariant(const std::string& name,
+ const Variant& value) const;
+ StatusOr<ScopedLocalRef<jobjectArray>> AsNamedVariantArray(
+ const std::map<std::string, Variant>& values) const;
+
+ StatusOr<ScopedLocalRef<jobjectArray>> RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions) const;
+
+ StatusOr<ScopedLocalRef<jobjectArray>> EntityDataAsNamedVariantArray(
+ const reflection::Schema* entity_data_schema,
+ const std::string& serialized_entity_data) const;
+
+ private:
+ explicit RemoteActionTemplatesHandler(
+ const std::shared_ptr<JniCache>& jni_cache)
+ : jni_cache_(jni_cache),
+ integer_class_(nullptr, jni_cache->jvm),
+ remote_action_template_class_(nullptr, jni_cache->jvm),
+ named_variant_class_(nullptr, jni_cache->jvm) {}
+
+ std::shared_ptr<JniCache> jni_cache_;
+
+ // java.lang.Integer
+ ScopedGlobalRef<jclass> integer_class_;
+ jmethodID integer_init_ = nullptr;
+
+ // RemoteActionTemplate
+ ScopedGlobalRef<jclass> remote_action_template_class_;
+ jmethodID remote_action_template_init_ = nullptr;
+
+ // NamedVariant
+ ScopedGlobalRef<jclass> named_variant_class_;
+ jmethodID named_variant_from_int_ = nullptr;
+ jmethodID named_variant_from_long_ = nullptr;
+ jmethodID named_variant_from_float_ = nullptr;
+ jmethodID named_variant_from_double_ = nullptr;
+ jmethodID named_variant_from_bool_ = nullptr;
+ jmethodID named_variant_from_string_ = nullptr;
+ jmethodID named_variant_from_string_array_ = nullptr;
+ jmethodID named_variant_from_float_array_ = nullptr;
+ jmethodID named_variant_from_int_array_ = nullptr;
+ jmethodID named_variant_from_named_variant_array_ = nullptr;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
diff --git a/native/utils/intents/remote-action-template.h b/native/utils/intents/remote-action-template.h
new file mode 100644
index 0000000..4aaf6ba
--- /dev/null
+++ b/native/utils/intents/remote-action-template.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_REMOTE_ACTION_TEMPLATE_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_REMOTE_ACTION_TEMPLATE_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "utils/optional.h"
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+// A template with parameters for an Android remote action.
+struct RemoteActionTemplate {
+ // Title shown for the action (see: RemoteAction.getTitle).
+ Optional<std::string> title_without_entity;
+
+ // Title with entity for the action. It is not guaranteed that the client
+ // will use this, so title should be always given and general enough.
+ Optional<std::string> title_with_entity;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription).
+ Optional<std::string> description;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription)
+ // when app name is available. Caller is expected to replace the placeholder
+ // by the name of the app that is going to handle the action.
+ Optional<std::string> description_with_app_name;
+
+ // The action to set on the Intent (see: Intent.setAction).
+ Optional<std::string> action;
+
+ // The data to set on the Intent (see: Intent.setData).
+ Optional<std::string> data;
+
+ // The type to set on the Intent (see: Intent.setType).
+ Optional<std::string> type;
+
+ // Flags for launching the Intent (see: Intent.setFlags).
+ Optional<int> flags;
+
+ // Categories to set on the Intent (see: Intent.addCategory).
+ std::vector<std::string> category;
+
+ // Explicit application package to set on the Intent (see: Intent.setPackage).
+ Optional<std::string> package_name;
+
+ // The list of all the extras to add to the Intent.
+ std::map<std::string, Variant> extra;
+
+ // Private request code to use for the Intent.
+ Optional<int> request_code;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_REMOTE_ACTION_TEMPLATE_H_
diff --git a/native/utils/intents/zlib-utils.cc b/native/utils/intents/zlib-utils.cc
new file mode 100644
index 0000000..78489cc
--- /dev/null
+++ b/native/utils/intents/zlib-utils.cc
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/zlib-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+bool CompressIntentModel(IntentFactoryModelT* intent_model) {
+ std::unique_ptr<ZlibCompressor> intent_zlib_compressor =
+ ZlibCompressor::Instance();
+ for (auto& generator : intent_model->generator) {
+ generator->compressed_lua_template_generator.reset(new CompressedBufferT);
+ intent_zlib_compressor->Compress(
+ std::string(reinterpret_cast<const char*>(
+ generator->lua_template_generator.data()),
+ generator->lua_template_generator.size()),
+ generator->compressed_lua_template_generator.get());
+ generator->lua_template_generator.clear();
+ }
+ return true;
+}
+
+bool DecompressIntentModel(IntentFactoryModelT* intent_model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ for (std::unique_ptr<IntentFactoryModel_::IntentGeneratorT>& generator :
+ intent_model->generator) {
+ if (generator->compressed_lua_template_generator == nullptr) {
+ continue;
+ }
+
+ std::string lua_template_generator;
+ if (!zlib_decompressor->MaybeDecompress(
+ generator->compressed_lua_template_generator.get(),
+ &lua_template_generator)) {
+ TC3_LOG(ERROR) << "Cannot decompress intent template.";
+ return false;
+ }
+ generator->lua_template_generator = std::vector<uint8_t>(
+ lua_template_generator.begin(), lua_template_generator.end());
+
+ generator->compressed_lua_template_generator.reset(nullptr);
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/intents/zlib-utils.h b/native/utils/intents/zlib-utils.h
new file mode 100644
index 0000000..b9a370f
--- /dev/null
+++ b/native/utils/intents/zlib-utils.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
+
+#include "utils/intents/intent-config_generated.h"
+
+namespace libtextclassifier3 {
+
+bool CompressIntentModel(IntentFactoryModelT* intent_model);
+bool DecompressIntentModel(IntentFactoryModelT* intent_model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
new file mode 100644
index 0000000..e0829b7
--- /dev/null
+++ b/native/utils/java/jni-base.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/java/jni-base.h"
+
+#include "utils/base/status.h"
+#include "utils/java/string_utils.h"
+
+namespace libtextclassifier3 {
+
+bool EnsureLocalCapacity(JNIEnv* env, int capacity) {
+ return env->EnsureLocalCapacity(capacity) == JNI_OK;
+}
+
+bool JniExceptionCheckAndClear(JNIEnv* env) {
+ TC3_CHECK(env != nullptr);
+ const bool result = env->ExceptionCheck();
+ if (result) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ }
+ return result;
+}
+
+StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str) {
+ std::string result;
+ if (!JStringToUtf8String(env, str, &result)) {
+ return {Status::UNKNOWN};
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
new file mode 100644
index 0000000..c7b04e6
--- /dev/null
+++ b/native/utils/java/jni-base.h
@@ -0,0 +1,219 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
+#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
+
+#include <jni.h>
+
+#include <string>
+
+#include "utils/base/statusor.h"
+
+// When we use a macro as an argument for a macro, an additional level of
+// indirection is needed, if the macro argument is used with # or ##.
+#define TC3_ADD_QUOTES_HELPER(TOKEN) #TOKEN
+#define TC3_ADD_QUOTES(TOKEN) TC3_ADD_QUOTES_HELPER(TOKEN)
+
+#ifndef TC3_PACKAGE_NAME
+#define TC3_PACKAGE_NAME com_google_android_textclassifier
+#endif
+
+#ifndef TC3_PACKAGE_PATH
+#define TC3_PACKAGE_PATH \
+ "com/google/android/textclassifier/"
+#endif
+
+#define TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \
+ Java_##package_name##_##class_name##_##method_name
+
+#define TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \
+ method_name) \
+ JNIEXPORT return_type JNICALL TC3_JNI_METHOD_NAME_INTERNAL( \
+ package_name, class_name, method_name)
+
+// The indirection is needed to correctly expand the TC3_PACKAGE_NAME macro.
+// See the explanation near TC3_ADD_QUOTES macro.
+#define TC3_JNI_METHOD2(return_type, package_name, class_name, method_name) \
+ TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name)
+
+#define TC3_JNI_METHOD(return_type, class_name, method_name) \
+ TC3_JNI_METHOD2(return_type, TC3_PACKAGE_NAME, class_name, method_name)
+
+#define TC3_JNI_METHOD_NAME2(package_name, class_name, method_name) \
+ TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name)
+
+#define TC3_JNI_METHOD_NAME(class_name, method_name) \
+ TC3_JNI_METHOD_NAME2(TC3_PACKAGE_NAME, class_name, method_name)
+
+namespace libtextclassifier3 {
+
+// Returns true if the requested capacity is available.
+bool EnsureLocalCapacity(JNIEnv* env, int capacity);
+
+// Returns true if there was an exception. Also it clears the exception.
+bool JniExceptionCheckAndClear(JNIEnv* env);
+
+StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str);
+
+// A deleter to be used with std::unique_ptr to delete JNI global references.
+class GlobalRefDeleter {
+ public:
+ explicit GlobalRefDeleter(JavaVM* jvm) : jvm_(jvm) {}
+
+ GlobalRefDeleter(const GlobalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedGlobalRef.
+ GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) {
+ TC3_CHECK_EQ(jvm_, rhs.jvm_);
+ return *this;
+ }
+
+ // The delete operator.
+ void operator()(jobject object) const {
+ JNIEnv* env;
+ if (object != nullptr && jvm_ != nullptr &&
+ JNI_OK ==
+ jvm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_4)) {
+ env->DeleteGlobalRef(object);
+ }
+ }
+
+ private:
+ // The jvm_ stashed to use for deletion.
+ JavaVM* const jvm_;
+};
+
+// A deleter to be used with std::unique_ptr to delete JNI local references.
+class LocalRefDeleter {
+ public:
+ explicit LocalRefDeleter(JNIEnv* env)
+ : env_(env) {} // NOLINT(runtime/explicit)
+
+ LocalRefDeleter(const LocalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedLocalRef.
+ LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
+ env_ = rhs.env_;
+ return *this;
+ }
+
+ // The delete operator.
+ void operator()(jobject object) const {
+ if (env_) {
+ env_->DeleteLocalRef(object);
+ }
+ }
+
+ private:
+ // The env_ stashed to use for deletion. Thread-local, don't share!
+ JNIEnv* env_;
+};
+
+// A smart pointer that deletes a reference when it goes out of scope.
+//
+// Note that this class is not thread-safe since it caches JNIEnv in
+// the deleter. Do not use the same jobject across different threads.
+template <typename T, typename Env, typename Deleter>
+class ScopedRef {
+ public:
+ ScopedRef() : ptr_(nullptr, Deleter(nullptr)) {}
+ ScopedRef(T value, Env* env) : ptr_(value, Deleter(env)) {}
+
+ T get() const { return ptr_.get(); }
+
+ T release() { return ptr_.release(); }
+
+ bool operator!() const { return !ptr_; }
+
+ bool operator==(void* value) const { return ptr_.get() == value; }
+
+ explicit operator bool() const { return ptr_ != nullptr; }
+
+ void reset(T value, Env* env) {
+ ptr_.reset(value);
+ ptr_.get_deleter() = Deleter(env);
+ }
+
+ private:
+ std::unique_ptr<typename std::remove_pointer<T>::type, Deleter> ptr_;
+};
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator==(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() == y.get();
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator==(const ScopedRef<T, Env, Deleter>& x, std::nullptr_t) {
+ return x.get() == nullptr;
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator==(std::nullptr_t, const ScopedRef<T, Env, Deleter>& x) {
+ return nullptr == x.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator!=(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() != y.get();
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator!=(const ScopedRef<T, Env, Deleter>& x, std::nullptr_t) {
+ return x.get() != nullptr;
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator!=(std::nullptr_t, const ScopedRef<T, Env, Deleter>& x) {
+ return nullptr != x.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator<(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() < y.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator>(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() > y.get();
+}
+
+// A smart pointer that deletes a JNI global reference when it goes out
+// of scope. Usage is:
+// ScopedGlobalRef<jobject> scoped_global(env->JniFunction(), jvm);
+template <typename T>
+using ScopedGlobalRef = ScopedRef<T, JavaVM, GlobalRefDeleter>;
+
+// Ditto, but usage is:
+// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
+template <typename T>
+using ScopedLocalRef = ScopedRef<T, JNIEnv, LocalRefDeleter>;
+
+// A helper to create global references.
+template <typename T>
+ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) {
+ const jobject global_object = env->NewGlobalRef(object);
+ return ScopedGlobalRef<T>(reinterpret_cast<T>(global_object), jvm);
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/native/utils/java/jni-cache.cc b/native/utils/java/jni-cache.cc
new file mode 100644
index 0000000..0be769d
--- /dev/null
+++ b/native/utils/java/jni-cache.cc
@@ -0,0 +1,315 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/java/jni-cache.h"
+
+#include "utils/base/logging.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+
+JniCache::JniCache(JavaVM* jvm)
+ : jvm(jvm),
+ string_class(nullptr, jvm),
+ string_utf8(nullptr, jvm),
+ pattern_class(nullptr, jvm),
+ matcher_class(nullptr, jvm),
+ locale_class(nullptr, jvm),
+ locale_us(nullptr, jvm),
+ breakiterator_class(nullptr, jvm),
+ integer_class(nullptr, jvm),
+ calendar_class(nullptr, jvm),
+ timezone_class(nullptr, jvm),
+ urlencoder_class(nullptr, jvm)
+#ifdef __ANDROID__
+ ,
+ context_class(nullptr, jvm),
+ uri_class(nullptr, jvm),
+ usermanager_class(nullptr, jvm),
+ bundle_class(nullptr, jvm),
+ resources_class(nullptr, jvm)
+#endif
+{
+}
+
+// The macros below are intended to reduce the boilerplate in Create and avoid
+// easily introduced copy/paste errors.
+#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
+#define TC3_CHECK_JNI_RESULT(RESULT) TC3_CHECK(RESULT)
+
+#define TC3_GET_CLASS_OR_RETURN_NULL(FIELD, NAME) \
+ { \
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jclass> clazz, \
+ JniHelper::FindClass(env, NAME)); \
+ result->FIELD##_class = MakeGlobalRef(clazz.get(), env, jvm); \
+ if (result->FIELD##_class == nullptr) { \
+ TC3_LOG(ERROR) << "Error finding class: " << NAME; \
+ return nullptr; \
+ } \
+ }
+
+#define TC3_GET_OPTIONAL_CLASS(FIELD, NAME) \
+ { \
+ StatusOr<ScopedLocalRef<jclass>> status_or_class = \
+ JniHelper::FindClass(env, NAME); \
+ if (status_or_class.ok()) { \
+ result->FIELD##_class = MakeGlobalRef( \
+ std::move(status_or_class).ValueOrDie().get(), env, jvm); \
+ } \
+ }
+
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ result->CLASS##_##FIELD = \
+ env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
+ TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
+ << "Error finding method: " << NAME;
+
+#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ if (result->CLASS##_class != nullptr) { \
+ result->CLASS##_##FIELD = \
+ env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
+ env->ExceptionClear(); \
+ }
+
+#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ if (result->CLASS##_class != nullptr) { \
+ result->CLASS##_##FIELD = \
+ env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
+ env->ExceptionClear(); \
+ }
+
+#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ result->CLASS##_##FIELD = \
+ env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
+ TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
+ << "Error finding method: " << NAME;
+
+#define TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(CLASS, FIELD, NAME, \
+ SIGNATURE) \
+ { \
+ const jfieldID CLASS##_##FIELD##_field = \
+ env->GetStaticFieldID(result->CLASS##_class.get(), NAME, SIGNATURE); \
+ TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
+ << "Error finding field id: " << NAME; \
+ TC3_ASSIGN_OR_RETURN_NULL( \
+ ScopedLocalRef<jobject> static_object, \
+ JniHelper::GetStaticObjectField(env, result->CLASS##_class.get(), \
+ CLASS##_##FIELD##_field)); \
+ result->CLASS##_##FIELD = MakeGlobalRef(static_object.get(), env, jvm); \
+ if (result->CLASS##_##FIELD == nullptr) { \
+ TC3_LOG(ERROR) << "Error finding field: " << NAME; \
+ return nullptr; \
+ } \
+ }
+
+#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
+ const jfieldID CLASS##_##FIELD##_field = \
+ env->GetStaticFieldID(result->CLASS##_class.get(), NAME, "I"); \
+ TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
+ << "Error finding field id: " << NAME; \
+ result->CLASS##_##FIELD = env->GetStaticIntField( \
+ result->CLASS##_class.get(), CLASS##_##FIELD##_field); \
+ TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
+ << "Error finding field: " << NAME;
+
+std::unique_ptr<JniCache> JniCache::Create(JNIEnv* env) {
+ if (env == nullptr) {
+ return nullptr;
+ }
+ JavaVM* jvm = nullptr;
+ if (JNI_OK != env->GetJavaVM(&jvm) || jvm == nullptr) {
+ return nullptr;
+ }
+ std::unique_ptr<JniCache> result(new JniCache(jvm));
+
+ // String
+ TC3_GET_CLASS_OR_RETURN_NULL(string, "java/lang/String");
+ TC3_GET_METHOD(string, init_bytes_charset, "<init>",
+ "([BLjava/lang/String;)V");
+ TC3_GET_METHOD(string, code_point_count, "codePointCount", "(II)I");
+ TC3_GET_METHOD(string, length, "length", "()I");
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> result_string,
+ JniHelper::NewStringUTF(env, "UTF-8"));
+ result->string_utf8 = MakeGlobalRef(result_string.get(), env, jvm);
+ TC3_CHECK_JNI_PTR(result->string_utf8);
+
+ // Pattern
+ TC3_GET_CLASS_OR_RETURN_NULL(pattern, "java/util/regex/Pattern");
+ TC3_GET_STATIC_METHOD(pattern, compile, "compile",
+ "(Ljava/lang/String;)Ljava/util/regex/Pattern;");
+ TC3_GET_METHOD(pattern, matcher, "matcher",
+ "(Ljava/lang/CharSequence;)Ljava/util/regex/Matcher;");
+
+ // Matcher
+ TC3_GET_CLASS_OR_RETURN_NULL(matcher, "java/util/regex/Matcher");
+ TC3_GET_METHOD(matcher, matches, "matches", "()Z");
+ TC3_GET_METHOD(matcher, find, "find", "()Z");
+ TC3_GET_METHOD(matcher, reset, "reset", "()Ljava/util/regex/Matcher;");
+ TC3_GET_METHOD(matcher, start_idx, "start", "(I)I");
+ TC3_GET_METHOD(matcher, end_idx, "end", "(I)I");
+ TC3_GET_METHOD(matcher, group, "group", "()Ljava/lang/String;");
+ TC3_GET_METHOD(matcher, group_idx, "group", "(I)Ljava/lang/String;");
+
+ // Locale
+ TC3_GET_CLASS_OR_RETURN_NULL(locale, "java/util/Locale");
+ TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL(locale, us, "US",
+ "Ljava/util/Locale;");
+ TC3_GET_METHOD(locale, init_string, "<init>", "(Ljava/lang/String;)V");
+ TC3_GET_OPTIONAL_STATIC_METHOD(locale, for_language_tag, "forLanguageTag",
+ "(Ljava/lang/String;)Ljava/util/Locale;");
+
+ // BreakIterator
+ TC3_GET_CLASS_OR_RETURN_NULL(breakiterator, "java/text/BreakIterator");
+ TC3_GET_STATIC_METHOD(breakiterator, getwordinstance, "getWordInstance",
+ "(Ljava/util/Locale;)Ljava/text/BreakIterator;");
+ TC3_GET_METHOD(breakiterator, settext, "setText", "(Ljava/lang/String;)V");
+ TC3_GET_METHOD(breakiterator, next, "next", "()I");
+
+ // Integer
+ TC3_GET_CLASS_OR_RETURN_NULL(integer, "java/lang/Integer");
+ TC3_GET_STATIC_METHOD(integer, parse_int, "parseInt",
+ "(Ljava/lang/String;)I");
+
+ // Calendar.
+ TC3_GET_CLASS_OR_RETURN_NULL(calendar, "java/util/Calendar");
+ TC3_GET_STATIC_METHOD(
+ calendar, get_instance, "getInstance",
+ "(Ljava/util/TimeZone;Ljava/util/Locale;)Ljava/util/Calendar;");
+ TC3_GET_METHOD(calendar, get_first_day_of_week, "getFirstDayOfWeek", "()I");
+ TC3_GET_METHOD(calendar, get_time_in_millis, "getTimeInMillis", "()J");
+ TC3_GET_METHOD(calendar, set_time_in_millis, "setTimeInMillis", "(J)V");
+ TC3_GET_METHOD(calendar, add, "add", "(II)V");
+ TC3_GET_METHOD(calendar, get, "get", "(I)I");
+ TC3_GET_METHOD(calendar, set, "set", "(II)V");
+ TC3_GET_STATIC_INT_FIELD(calendar, zone_offset, "ZONE_OFFSET");
+ TC3_GET_STATIC_INT_FIELD(calendar, dst_offset, "DST_OFFSET");
+ TC3_GET_STATIC_INT_FIELD(calendar, year, "YEAR");
+ TC3_GET_STATIC_INT_FIELD(calendar, month, "MONTH");
+ TC3_GET_STATIC_INT_FIELD(calendar, day_of_year, "DAY_OF_YEAR");
+ TC3_GET_STATIC_INT_FIELD(calendar, day_of_month, "DAY_OF_MONTH");
+ TC3_GET_STATIC_INT_FIELD(calendar, day_of_week, "DAY_OF_WEEK");
+ TC3_GET_STATIC_INT_FIELD(calendar, hour_of_day, "HOUR_OF_DAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, minute, "MINUTE");
+ TC3_GET_STATIC_INT_FIELD(calendar, second, "SECOND");
+ TC3_GET_STATIC_INT_FIELD(calendar, millisecond, "MILLISECOND");
+ TC3_GET_STATIC_INT_FIELD(calendar, sunday, "SUNDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, monday, "MONDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, tuesday, "TUESDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, wednesday, "WEDNESDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, thursday, "THURSDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, friday, "FRIDAY");
+ TC3_GET_STATIC_INT_FIELD(calendar, saturday, "SATURDAY");
+
+ // TimeZone.
+ TC3_GET_CLASS_OR_RETURN_NULL(timezone, "java/util/TimeZone");
+ TC3_GET_STATIC_METHOD(timezone, get_timezone, "getTimeZone",
+ "(Ljava/lang/String;)Ljava/util/TimeZone;");
+
+ // URLEncoder.
+ TC3_GET_CLASS_OR_RETURN_NULL(urlencoder, "java/net/URLEncoder");
+ TC3_GET_STATIC_METHOD(
+ urlencoder, encode, "encode",
+ "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");
+
+#ifdef __ANDROID__
+ // Context.
+ TC3_GET_CLASS_OR_RETURN_NULL(context, "android/content/Context");
+ TC3_GET_METHOD(context, get_package_name, "getPackageName",
+ "()Ljava/lang/String;");
+ TC3_GET_METHOD(context, get_system_service, "getSystemService",
+ "(Ljava/lang/String;)Ljava/lang/Object;");
+
+ // Uri.
+ TC3_GET_CLASS_OR_RETURN_NULL(uri, "android/net/Uri");
+ TC3_GET_STATIC_METHOD(uri, parse, "parse",
+ "(Ljava/lang/String;)Landroid/net/Uri;");
+ TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;");
+ TC3_GET_METHOD(uri, get_host, "getHost", "()Ljava/lang/String;");
+
+ // UserManager.
+ TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager");
+ TC3_GET_OPTIONAL_METHOD(usermanager, get_user_restrictions,
+ "getUserRestrictions", "()Landroid/os/Bundle;");
+
+ // Bundle.
+ TC3_GET_CLASS_OR_RETURN_NULL(bundle, "android/os/Bundle");
+ TC3_GET_METHOD(bundle, get_boolean, "getBoolean", "(Ljava/lang/String;)Z");
+
+ // String resources.
+ TC3_GET_CLASS_OR_RETURN_NULL(resources, "android/content/res/Resources");
+ TC3_GET_STATIC_METHOD(resources, get_system, "getSystem",
+ "()Landroid/content/res/Resources;");
+ TC3_GET_METHOD(resources, get_identifier, "getIdentifier",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I");
+ TC3_GET_METHOD(resources, get_string, "getString", "(I)Ljava/lang/String;");
+#endif
+
+ return result;
+}
+
+#undef TC3_GET_STATIC_INT_FIELD
+#undef TC3_GET_STATIC_OBJECT_FIELD_OR_RETURN_NULL
+#undef TC3_GET_STATIC_METHOD
+#undef TC3_GET_METHOD
+#undef TC3_GET_CLASS_OR_RETURN_NULL
+#undef TC3_GET_OPTIONAL_CLASS
+#undef TC3_CHECK_JNI_PTR
+
+JNIEnv* JniCache::GetEnv() const {
+ void* env;
+ if (JNI_OK == jvm->GetEnv(&env, JNI_VERSION_1_4)) {
+ return reinterpret_cast<JNIEnv*>(env);
+ } else {
+ TC3_LOG(ERROR) << "JavaICU UniLib used on unattached thread";
+ return nullptr;
+ }
+}
+
+bool JniCache::ExceptionCheckAndClear() const {
+ return JniExceptionCheckAndClear(GetEnv());
+}
+
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
+ const char* utf8_text, const int utf8_text_size_bytes) const {
+ // Create java byte array.
+ JNIEnv* jenv = GetEnv();
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> text_java_utf8,
+ JniHelper::NewByteArray(jenv, utf8_text_size_bytes));
+
+ jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
+ reinterpret_cast<const jbyte*>(utf8_text));
+
+ // Create the string with a UTF-8 charset.
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> result,
+ JniHelper::NewObject<jstring>(
+ jenv, string_class.get(), string_init_bytes_charset,
+ text_java_utf8.get(), string_utf8.get()));
+
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
+ StringPiece utf8_text) const {
+ return ConvertToJavaString(utf8_text.data(), utf8_text.size());
+}
+
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
+ const UnicodeText& text) const {
+ return ConvertToJavaString(text.data(), text.size_bytes());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-cache.h b/native/utils/java/jni-cache.h
new file mode 100644
index 0000000..ab48419
--- /dev/null
+++ b/native/utils/java/jni-cache.h
@@ -0,0 +1,153 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
+#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
+
+#include <jni.h>
+
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// A helper class to cache class and method pointers for calls from JNI to Java.
+// (for implementations such as Java ICU that need to make calls from C++ to
+// Java)
+struct JniCache {
+ static std::unique_ptr<JniCache> Create(JNIEnv* env);
+
+ JNIEnv* GetEnv() const;
+ bool ExceptionCheckAndClear() const;
+
+ JavaVM* jvm = nullptr;
+
+ // java.lang.String
+ ScopedGlobalRef<jclass> string_class;
+ jmethodID string_init_bytes_charset = nullptr;
+ jmethodID string_code_point_count = nullptr;
+ jmethodID string_length = nullptr;
+ ScopedGlobalRef<jstring> string_utf8;
+
+ // java.util.regex.Pattern
+ ScopedGlobalRef<jclass> pattern_class;
+ jmethodID pattern_compile = nullptr;
+ jmethodID pattern_matcher = nullptr;
+
+ // java.util.regex.Matcher
+ ScopedGlobalRef<jclass> matcher_class;
+ jmethodID matcher_matches = nullptr;
+ jmethodID matcher_find = nullptr;
+ jmethodID matcher_reset = nullptr;
+ jmethodID matcher_start_idx = nullptr;
+ jmethodID matcher_end_idx = nullptr;
+ jmethodID matcher_group = nullptr;
+ jmethodID matcher_group_idx = nullptr;
+
+ // java.util.Locale
+ ScopedGlobalRef<jclass> locale_class;
+ ScopedGlobalRef<jobject> locale_us;
+ jmethodID locale_init_string = nullptr;
+ jmethodID locale_for_language_tag = nullptr;
+
+ // java.text.BreakIterator
+ ScopedGlobalRef<jclass> breakiterator_class;
+ jmethodID breakiterator_getwordinstance = nullptr;
+ jmethodID breakiterator_settext = nullptr;
+ jmethodID breakiterator_next = nullptr;
+
+ // java.lang.Integer
+ ScopedGlobalRef<jclass> integer_class;
+ jmethodID integer_parse_int = nullptr;
+
+ // java.util.Calendar
+ ScopedGlobalRef<jclass> calendar_class;
+ jmethodID calendar_get_instance = nullptr;
+ jmethodID calendar_get_first_day_of_week = nullptr;
+ jmethodID calendar_get_time_in_millis = nullptr;
+ jmethodID calendar_set_time_in_millis = nullptr;
+ jmethodID calendar_add = nullptr;
+ jmethodID calendar_get = nullptr;
+ jmethodID calendar_set = nullptr;
+ jint calendar_zone_offset;
+ jint calendar_dst_offset;
+ jint calendar_year;
+ jint calendar_month;
+ jint calendar_day_of_year;
+ jint calendar_day_of_month;
+ jint calendar_day_of_week;
+ jint calendar_hour_of_day;
+ jint calendar_minute;
+ jint calendar_second;
+ jint calendar_millisecond;
+ jint calendar_sunday;
+ jint calendar_monday;
+ jint calendar_tuesday;
+ jint calendar_wednesday;
+ jint calendar_thursday;
+ jint calendar_friday;
+ jint calendar_saturday;
+
+ // java.util.TimeZone
+ ScopedGlobalRef<jclass> timezone_class;
+ jmethodID timezone_get_timezone = nullptr;
+
+ // java.net.URLEncoder
+ ScopedGlobalRef<jclass> urlencoder_class;
+ jmethodID urlencoder_encode = nullptr;
+
+ // android.content.Context
+ ScopedGlobalRef<jclass> context_class;
+ jmethodID context_get_package_name = nullptr;
+ jmethodID context_get_system_service = nullptr;
+
+ // android.net.Uri
+ ScopedGlobalRef<jclass> uri_class;
+ jmethodID uri_parse = nullptr;
+ jmethodID uri_get_scheme = nullptr;
+ jmethodID uri_get_host = nullptr;
+
+ // android.os.UserManager
+ ScopedGlobalRef<jclass> usermanager_class;
+ jmethodID usermanager_get_user_restrictions = nullptr;
+
+ // android.os.Bundle
+ ScopedGlobalRef<jclass> bundle_class;
+ jmethodID bundle_get_boolean = nullptr;
+
+ // android.content.res.Resources
+ ScopedGlobalRef<jclass> resources_class;
+ jmethodID resources_get_system = nullptr;
+ jmethodID resources_get_identifier = nullptr;
+ jmethodID resources_get_string = nullptr;
+
+ // Helper to convert lib3 UnicodeText to Java strings.
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
+ const char* utf8_text, const int utf8_text_size_bytes) const;
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
+ StringPiece utf8_text) const;
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
+ const UnicodeText& text) const;
+
+ private:
+ explicit JniCache(JavaVM* jvm);
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
diff --git a/native/utils/java/jni-helper.cc b/native/utils/java/jni-helper.cc
new file mode 100644
index 0000000..d1677e4
--- /dev/null
+++ b/native/utils/java/jni-helper.cc
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+
+StatusOr<ScopedLocalRef<jclass>> JniHelper::FindClass(JNIEnv* env,
+ const char* class_name) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jclass> result(env->FindClass(class_name), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jmethodID> JniHelper::GetMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* return_type) {
+ jmethodID result = env->GetMethodID(clazz, method_name, return_type);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jobject>> JniHelper::GetStaticObjectField(
+ JNIEnv* env, jclass class_name, jfieldID field_id) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jobject> result(
+ env->GetStaticObjectField(class_name, field_id), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jbyteArray>> JniHelper::NewByteArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jbyteArray> result(env->NewByteArray(length), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+Status JniHelper::CallVoidMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ env->CallVoidMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+StatusOr<bool> JniHelper::CallBooleanMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ bool result = env->CallBooleanMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jint result = env->CallIntMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<int64> JniHelper::CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jlong result = env->CallLongMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<float> JniHelper::CallFloatMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jfloat result = env->CallFloatMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<double> JniHelper::CallDoubleMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jdouble result = env->CallDoubleMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jintArray>> JniHelper::NewIntArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jintArray> result(env->NewIntArray(length), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jfloatArray>> JniHelper::NewFloatArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jfloatArray> result(env->NewFloatArray(length), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+Status JniHelper::SetObjectArrayElement(JNIEnv* env, jobjectArray array,
+ jsize index, jobject val) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ env->SetObjectArrayElement(array, index, val);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>> JniHelper::NewObjectArray(
+ JNIEnv* env, jsize length, jclass element_class, jobject initial_element) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jobjectArray> result(
+ env->NewObjectArray(length, element_class, initial_element), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+StatusOr<jsize> JniHelper::GetArrayLength(JNIEnv* env,
+ jarray jinput_fragments) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ jsize result = env->GetArrayLength(jinput_fragments);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jstring>> JniHelper::NewStringUTF(JNIEnv* env,
+ const char* bytes) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jstring> result(env->NewStringUTF(bytes), env);
+ TC3_NO_EXCEPTION_OR_RETURN;
+ TC3_NOT_NULL_OR_RETURN;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
new file mode 100644
index 0000000..55d4696
--- /dev/null
+++ b/native/utils/java/jni-helper.h
@@ -0,0 +1,158 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility class that provides similar calls like JNIEnv, but performs
+// additional checks on them, so that it's harder to use them incorrectly.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
+#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
+
+#include <jni.h>
+
+#include <string>
+
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+
+#define TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN \
+ if (!EnsureLocalCapacity(env, 1)) { \
+ TC3_LOG(ERROR) << "EnsureLocalCapacity(1) failed."; \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_NO_EXCEPTION_OR_RETURN \
+ if (JniExceptionCheckAndClear(env)) { \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_NOT_NULL_OR_RETURN \
+ if (result == nullptr) { \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD( \
+ METHOD_NAME, RETURN_TYPE, INPUT_TYPE, POST_CHECK) \
+ template <typename T = RETURN_TYPE> \
+ static StatusOr<ScopedLocalRef<T>> METHOD_NAME( \
+ JNIEnv* env, INPUT_TYPE object, jmethodID method_id, ...) { \
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN; \
+ \
+ va_list args; \
+ va_start(args, method_id); \
+ ScopedLocalRef<T> result( \
+ reinterpret_cast<T>(env->METHOD_NAME##V(object, method_id, args)), \
+ env); \
+ POST_CHECK \
+ va_end(args); \
+ \
+ TC3_NO_EXCEPTION_OR_RETURN; \
+ return result; \
+ }
+
+#define TC3_JNI_NO_CHECK \
+ {}
+
+namespace libtextclassifier3 {
+
+class JniHelper {
+ public:
+ // Misc methods.
+ static StatusOr<ScopedLocalRef<jclass>> FindClass(JNIEnv* env,
+ const char* class_name);
+
+ template <typename T = jobject>
+ static StatusOr<ScopedLocalRef<T>> GetObjectArrayElement(JNIEnv* env,
+ jobjectArray array,
+ jsize index);
+ static StatusOr<jmethodID> GetMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* return_type);
+
+ static StatusOr<ScopedLocalRef<jobject>> GetStaticObjectField(
+ JNIEnv* env, jclass class_name, jfieldID field_id);
+
+ // New* methods.
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(NewObject, jobject, jclass,
+ TC3_NOT_NULL_OR_RETURN);
+ static StatusOr<ScopedLocalRef<jobjectArray>> NewObjectArray(
+ JNIEnv* env, jsize length, jclass element_class,
+ jobject initial_element = nullptr);
+ static StatusOr<ScopedLocalRef<jbyteArray>> NewByteArray(JNIEnv* env,
+ jsize length);
+ static StatusOr<ScopedLocalRef<jintArray>> NewIntArray(JNIEnv* env,
+ jsize length);
+ static StatusOr<ScopedLocalRef<jstring>> NewStringUTF(JNIEnv* env,
+ const char* bytes);
+ static StatusOr<ScopedLocalRef<jfloatArray>> NewFloatArray(JNIEnv* env,
+ jsize length);
+
+ static StatusOr<jsize> GetArrayLength(JNIEnv* env, jarray jinput_fragments);
+
+ static Status SetObjectArrayElement(JNIEnv* env, jobjectArray array,
+ jsize index, jobject val);
+
+ // Call* methods.
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallObjectMethod, jobject,
+ jobject, TC3_JNI_NO_CHECK);
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallStaticObjectMethod,
+ jobject, jclass,
+ TC3_JNI_NO_CHECK);
+ static Status CallVoidMethod(JNIEnv* env, jobject object, jmethodID method_id,
+ ...);
+ static StatusOr<bool> CallBooleanMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int32> CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int64> CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<float> CallFloatMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<double> CallDoubleMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+
+ template <class T>
+ static StatusOr<T> CallStaticIntMethod(JNIEnv* env, jclass clazz,
+ jmethodID method_id, ...);
+};
+
+template <typename T>
+StatusOr<ScopedLocalRef<T>> JniHelper::GetObjectArrayElement(JNIEnv* env,
+ jobjectArray array,
+ jsize index) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<T> result(
+ reinterpret_cast<T>(env->GetObjectArrayElement(array, index)), env);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+template <class T>
+StatusOr<T> JniHelper::CallStaticIntMethod(JNIEnv* env, jclass clazz,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jint result = env->CallStaticIntMethodV(clazz, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
diff --git a/native/utils/java/string_utils.cc b/native/utils/java/string_utils.cc
new file mode 100644
index 0000000..ca518a0
--- /dev/null
+++ b/native/utils/java/string_utils.cc
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/java/string_utils.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+ std::string* result) {
+ jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+ if (array_bytes == nullptr) {
+ return false;
+ }
+
+ const int array_length = env->GetArrayLength(array);
+ *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
+
+ env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+
+ return true;
+}
+
+bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
+ std::string* result) {
+ if (jstr == nullptr) {
+ *result = std::string();
+ return true;
+ }
+
+ jclass string_class = env->FindClass("java/lang/String");
+ if (!string_class) {
+ TC3_LOG(ERROR) << "Can't find String class";
+ return false;
+ }
+
+ jmethodID get_bytes_id =
+ env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
+
+ jstring encoding = env->NewStringUTF("UTF-8");
+
+ jbyteArray array = reinterpret_cast<jbyteArray>(
+ env->CallObjectMethod(jstr, get_bytes_id, encoding));
+
+ JByteArrayToString(env, array, result);
+
+ // Release the array.
+ env->DeleteLocalRef(array);
+ env->DeleteLocalRef(string_class);
+ env->DeleteLocalRef(encoding);
+
+ return true;
+}
+
+ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
+ jboolean* is_copy) {
+ return ScopedStringChars(env->GetStringUTFChars(string, is_copy),
+ StringCharsReleaser(env, string));
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/java/string_utils.h b/native/utils/java/string_utils.h
similarity index 100%
rename from utils/java/string_utils.h
rename to native/utils/java/string_utils.h
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
new file mode 100644
index 0000000..d6fe2c4
--- /dev/null
+++ b/native/utils/lua-utils.cc
@@ -0,0 +1,674 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/lua-utils.h"
+
+// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
+#ifndef TC3_AOSP
+#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
+ {LUA_TABLIBNAME, luaopen_table},
+ {LUA_STRLIBNAME, luaopen_string},
+ {LUA_BITLIBNAME, luaopen_bit32},
+ {LUA_MATHLIBNAME, luaopen_math},
+ {nullptr, nullptr}};
+
+static constexpr const char kTextKey[] = "text";
+static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc";
+static constexpr const char kGranularityKey[] = "granularity";
+static constexpr const char kCollectionKey[] = "collection";
+static constexpr const char kNameKey[] = "name";
+static constexpr const char kScoreKey[] = "score";
+static constexpr const char kPriorityScoreKey[] = "priority_score";
+static constexpr const char kTypeKey[] = "type";
+static constexpr const char kResponseTextKey[] = "response_text";
+static constexpr const char kAnnotationKey[] = "annotation";
+static constexpr const char kSpanKey[] = "span";
+static constexpr const char kMessageKey[] = "message";
+static constexpr const char kBeginKey[] = "begin";
+static constexpr const char kEndKey[] = "end";
+static constexpr const char kClassificationKey[] = "classification";
+static constexpr const char kSerializedEntity[] = "serialized_entity";
+static constexpr const char kEntityKey[] = "entity";
+
+// Implementation of a lua_Writer that appends the data to a string.
+int LuaStringWriter(lua_State* state, const void* data, size_t size,
+ void* result) {
+ std::string* const result_string = static_cast<std::string*>(result);
+ result_string->insert(result_string->size(), static_cast<const char*>(data),
+ size);
+ return LUA_OK;
+}
+
+} // namespace
+
+LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
+
+LuaEnvironment::~LuaEnvironment() {
+ if (state_ != nullptr) {
+ lua_close(state_);
+ }
+}
+
+void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type,
+ const flatbuffers::Table* table) const {
+ PushLazyObject(
+ std::bind(&LuaEnvironment::GetField, this, schema, type, table));
+}
+
+int LuaEnvironment::GetField(const reflection::Schema* schema,
+ const reflection::Object* type,
+ const flatbuffers::Table* table) const {
+ const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop);
+ const reflection::Field* field = type->fields()->LookupByKey(field_name);
+ if (field == nullptr) {
+ lua_error(state_);
+ return 0;
+ }
+ // Provide primitive fields directly.
+ const reflection::BaseType field_type = field->type()->base_type();
+ switch (field_type) {
+ case reflection::Bool:
+ case reflection::UByte:
+ Push(table->GetField<uint8>(field->offset(), field->default_integer()));
+ break;
+ case reflection::Byte:
+ Push(table->GetField<int8>(field->offset(), field->default_integer()));
+ break;
+ case reflection::Int:
+ Push(table->GetField<int32>(field->offset(), field->default_integer()));
+ break;
+ case reflection::UInt:
+ Push(table->GetField<uint32>(field->offset(), field->default_integer()));
+ break;
+ case reflection::Short:
+ Push(table->GetField<int16>(field->offset(), field->default_integer()));
+ break;
+ case reflection::UShort:
+ Push(table->GetField<uint16>(field->offset(), field->default_integer()));
+ break;
+ case reflection::Long:
+ Push(table->GetField<int64>(field->offset(), field->default_integer()));
+ break;
+ case reflection::ULong:
+ Push(table->GetField<uint64>(field->offset(), field->default_integer()));
+ break;
+ case reflection::Float:
+ Push(table->GetField<float>(field->offset(), field->default_real()));
+ break;
+ case reflection::Double:
+ Push(table->GetField<double>(field->offset(), field->default_real()));
+ break;
+ case reflection::String: {
+ Push(table->GetPointer<const flatbuffers::String*>(field->offset()));
+ break;
+ }
+ case reflection::Obj: {
+ const flatbuffers::Table* field_table =
+ table->GetPointer<const flatbuffers::Table*>(field->offset());
+ if (field_table == nullptr) {
+ // Field was not set in entity data.
+ return 0;
+ }
+ const reflection::Object* field_type =
+ schema->objects()->Get(field->type()->index());
+ PushFlatbuffer(schema, field_type, field_table);
+ break;
+ }
+ case reflection::Vector: {
+ const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector =
+ table->GetPointer<
+ const flatbuffers::Vector<flatbuffers::Offset<void>>*>(
+ field->offset());
+ if (field_vector == nullptr) {
+ // Repeated field was not set in flatbuffer.
+ PushEmptyVector();
+ break;
+ }
+ switch (field->type()->element()) {
+ case reflection::Bool:
+ case reflection::UByte:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<uint8>*>(
+ field->offset()));
+ break;
+ case reflection::Byte:
+ PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>(
+ field->offset()));
+ break;
+ case reflection::Int:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<int32>*>(
+ field->offset()));
+ break;
+ case reflection::UInt:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<uint32>*>(
+ field->offset()));
+ break;
+ case reflection::Short:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<int16>*>(
+ field->offset()));
+ break;
+ case reflection::UShort:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<uint16>*>(
+ field->offset()));
+ break;
+ case reflection::Long:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<int64>*>(
+ field->offset()));
+ break;
+ case reflection::ULong:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<uint64>*>(
+ field->offset()));
+ break;
+ case reflection::Float:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<float>*>(
+ field->offset()));
+ break;
+ case reflection::Double:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<double>*>(
+ field->offset()));
+ break;
+ case reflection::String:
+ PushRepeatedField(
+ table->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<flatbuffers::String>>*>(field->offset()));
+ break;
+ case reflection::Obj:
+ PushRepeatedFlatbufferField(
+ schema, schema->objects()->Get(field->type()->index()),
+ table->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<flatbuffers::Table>>*>(field->offset()));
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported repeated type: "
+ << field->type()->element();
+ lua_error(state_);
+ return 0;
+ }
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << field_type;
+ lua_error(state_);
+ return 0;
+ }
+ return 1;
+}
+
+int LuaEnvironment::ReadFlatbuffer(const int index,
+ ReflectiveFlatbuffer* buffer) const {
+ if (buffer == nullptr) {
+ TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected table, got: "
+ << lua_type(state_, /*idx=*/kIndexStackTop);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ lua_pushnil(state_);
+ while (Next(index - 1)) {
+ const StringPiece key = ReadString(/*index=*/index - 1);
+ const reflection::Field* field = buffer->GetFieldOrNull(key);
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Unknown field: " << key;
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ switch (field->type()->base_type()) {
+ case reflection::Obj:
+ ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field));
+ break;
+ case reflection::Bool:
+ buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::Byte:
+ buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::UByte:
+ buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::Int:
+ buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::UInt:
+ buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::Long:
+ buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::ULong:
+ buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::Float:
+ buffer->Set(field, Read<float>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::Double:
+ buffer->Set(field, Read<double>(/*index=*/kIndexStackTop));
+ break;
+ case reflection::String: {
+ buffer->Set(field, ReadString(/*index=*/kIndexStackTop));
+ break;
+ }
+ case reflection::Vector: {
+ // Read repeated field.
+ switch (field->type()->element()) {
+ case reflection::Bool:
+ ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Byte:
+ ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::UByte:
+ ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Int:
+ ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::UInt:
+ ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Long:
+ ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::ULong:
+ ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Float:
+ ReadRepeatedField<float>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Double:
+ ReadRepeatedField<double>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::String:
+ ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ case reflection::Obj:
+ ReadRepeatedField<ReflectiveFlatbuffer>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unsupported repeated field type: "
+ << field->type()->element();
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ break;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ lua_pop(state_, 1);
+ }
+ return LUA_OK;
+}
+
+void LuaEnvironment::LoadDefaultLibraries() {
+ for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) {
+ luaL_requiref(state_, lib->name, lib->func, 1);
+ lua_pop(state_, 1); // Remove lib.
+ }
+}
+
+StringPiece LuaEnvironment::ReadString(const int index) const {
+ size_t length = 0;
+ const char* data = lua_tolstring(state_, index, &length);
+ return StringPiece(data, length);
+}
+
+void LuaEnvironment::PushString(const StringPiece str) const {
+ lua_pushlstring(state_, str.data(), str.size());
+}
+
+bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const {
+ if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not compile lua snippet: "
+ << ReadString(/*index=*/kIndexStackTop);
+ lua_pop(state_, 1);
+ return false;
+ }
+ if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
+ lua_pop(state_, 1);
+ return false;
+ }
+ lua_pop(state_, 1);
+ return true;
+}
+
+void LuaEnvironment::PushAnnotation(
+ const ClassificationResult& classification,
+ const reflection::Schema* entity_data_schema) const {
+ if (entity_data_schema == nullptr ||
+ classification.serialized_entity_data.empty()) {
+ // Empty table.
+ lua_newtable(state_);
+ } else {
+ PushFlatbuffer(entity_data_schema,
+ flatbuffers::GetRoot<flatbuffers::Table>(
+ classification.serialized_entity_data.data()));
+ }
+ Push(classification.datetime_parse_result.time_ms_utc);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey);
+ Push(classification.datetime_parse_result.granularity);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey);
+ Push(classification.collection);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey);
+ Push(classification.score);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
+ Push(classification.serialized_entity_data);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity);
+}
+
+void LuaEnvironment::PushAnnotation(
+ const ClassificationResult& classification, StringPiece text,
+ const reflection::Schema* entity_data_schema) const {
+ PushAnnotation(classification, entity_data_schema);
+ Push(text);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey);
+}
+
+void LuaEnvironment::PushAnnotation(
+ const ActionSuggestionAnnotation& annotation,
+ const reflection::Schema* entity_data_schema) const {
+ PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema);
+ PushString(annotation.name);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey);
+ {
+ lua_newtable(state_);
+ Push(annotation.span.message_index);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey);
+ Push(annotation.span.span.first);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
+ Push(annotation.span.span.second);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
+ }
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
+}
+
+void LuaEnvironment::PushAnnotatedSpan(
+ const AnnotatedSpan& annotated_span,
+ const reflection::Schema* entity_data_schema) const {
+ lua_newtable(state_);
+ {
+ lua_newtable(state_);
+ Push(annotated_span.span.first);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
+ Push(annotated_span.span.second);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
+ }
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
+ PushAnnotations(&annotated_span.classification, entity_data_schema);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey);
+}
+
+void LuaEnvironment::PushAnnotatedSpans(
+ const std::vector<AnnotatedSpan>* annotated_spans,
+ const reflection::Schema* entity_data_schema) const {
+ PushIterator(annotated_spans ? annotated_spans->size() : 0,
+ [this, annotated_spans, entity_data_schema](const int64 index) {
+ PushAnnotatedSpan(annotated_spans->at(index),
+ entity_data_schema);
+ return 1;
+ });
+}
+
+MessageTextSpan LuaEnvironment::ReadSpan() const {
+ MessageTextSpan span;
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
+ if (key.Equals(kMessageKey)) {
+ span.message_index = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kBeginKey)) {
+ span.span.first = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kEndKey)) {
+ span.span.second = Read<int>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kTextKey)) {
+ span.text = Read<std::string>(/*index=*/kIndexStackTop);
+ } else {
+ TC3_LOG(INFO) << "Unknown span field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ return span;
+}
+
+int LuaEnvironment::ReadAnnotations(
+ const reflection::Schema* entity_data_schema,
+ std::vector<ActionSuggestionAnnotation>* annotations) const {
+ if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected annotations table, got: "
+ << lua_type(state_, /*idx=*/kIndexStackTop);
+ lua_pop(state_, 1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+
+ // Read actions.
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected annotation table, got: "
+ << lua_type(state_, /*idx=*/kIndexStackTop);
+ lua_pop(state_, 1);
+ continue;
+ }
+ annotations->push_back(ReadAnnotation(entity_data_schema));
+ lua_pop(state_, 1);
+ }
+ return LUA_OK;
+}
+
+ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation(
+ const reflection::Schema* entity_data_schema) const {
+ ActionSuggestionAnnotation annotation;
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
+ if (key.Equals(kNameKey)) {
+ annotation.name = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kSpanKey)) {
+ annotation.span = ReadSpan();
+ } else if (key.Equals(kEntityKey)) {
+ annotation.entity = ReadClassificationResult(entity_data_schema);
+ } else {
+ TC3_LOG(ERROR) << "Unknown annotation field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ return annotation;
+}
+
+ClassificationResult LuaEnvironment::ReadClassificationResult(
+ const reflection::Schema* entity_data_schema) const {
+ ClassificationResult classification;
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
+ if (key.Equals(kCollectionKey)) {
+ classification.collection = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kScoreKey)) {
+ classification.score = Read<float>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kTimeUsecKey)) {
+ classification.datetime_parse_result.time_ms_utc =
+ Read<int64>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kGranularityKey)) {
+ classification.datetime_parse_result.granularity =
+ static_cast<DatetimeGranularity>(
+ lua_tonumber(state_, /*idx=*/kIndexStackTop));
+ } else if (key.Equals(kSerializedEntity)) {
+ classification.serialized_entity_data =
+ Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kEntityKey)) {
+ auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
+ ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
+ classification.serialized_entity_data = buffer->Serialize();
+ } else {
+ TC3_LOG(INFO) << "Unknown classification result field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ return classification;
+}
+
+void LuaEnvironment::PushAction(
+ const ActionSuggestion& action,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const {
+ if (actions_entity_data_schema == nullptr ||
+ action.serialized_entity_data.empty()) {
+ // Empty table.
+ lua_newtable(state_);
+ } else {
+ PushFlatbuffer(actions_entity_data_schema,
+ flatbuffers::GetRoot<flatbuffers::Table>(
+ action.serialized_entity_data.data()));
+ }
+ PushString(action.type);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey);
+ PushString(action.response_text);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey);
+ Push(action.score);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
+ Push(action.priority_score);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey);
+ PushAnnotations(&action.annotations, annotations_entity_data_schema);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey);
+}
+
+void LuaEnvironment::PushActions(
+ const std::vector<ActionSuggestion>* actions,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const {
+ PushIterator(actions ? actions->size() : 0,
+ [this, actions, actions_entity_data_schema,
+ annotations_entity_data_schema](const int64 index) {
+ PushAction(actions->at(index), actions_entity_data_schema,
+ annotations_entity_data_schema);
+ return 1;
+ });
+}
+
+ActionSuggestion LuaEnvironment::ReadAction(
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const {
+ ActionSuggestion action;
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
+ if (key.Equals(kResponseTextKey)) {
+ action.response_text = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kTypeKey)) {
+ action.type = Read<std::string>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kScoreKey)) {
+ action.score = Read<float>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kPriorityScoreKey)) {
+ action.priority_score = Read<float>(/*index=*/kIndexStackTop);
+ } else if (key.Equals(kAnnotationKey)) {
+ ReadAnnotations(actions_entity_data_schema, &action.annotations);
+ } else if (key.Equals(kEntityKey)) {
+ auto buffer =
+ ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
+ ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
+ action.serialized_entity_data = buffer->Serialize();
+ } else {
+ TC3_LOG(INFO) << "Unknown action field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ return action;
+}
+
+int LuaEnvironment::ReadActions(
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const {
+ // Read actions.
+ lua_pushnil(state_);
+ while (Next(/*index=*/kIndexStackTop - 1)) {
+ if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected action table, got: "
+ << lua_type(state_, /*idx=*/kIndexStackTop);
+ lua_pop(state_, 1);
+ continue;
+ }
+ actions->push_back(
+ ReadAction(actions_entity_data_schema, annotations_entity_data_schema));
+ lua_pop(state_, /*n=*/1);
+ }
+ lua_pop(state_, /*n=*/1);
+
+ return LUA_OK;
+}
+
+void LuaEnvironment::PushConversation(
+ const std::vector<ConversationMessage>* conversation,
+ const reflection::Schema* annotations_entity_data_schema) const {
+ PushIterator(
+ conversation ? conversation->size() : 0,
+ [this, conversation, annotations_entity_data_schema](const int64 index) {
+ const ConversationMessage& message = conversation->at(index);
+ lua_newtable(state_);
+ Push(message.user_id);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id");
+ Push(message.text);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text");
+ Push(message.reference_time_ms_utc);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc");
+ Push(message.reference_timezone);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone");
+ PushAnnotatedSpans(&message.annotations,
+ annotations_entity_data_schema);
+ lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation");
+ return 1;
+ });
+}
+
+bool Compile(StringPiece snippet, std::string* bytecode) {
+ return LuaEnvironment().Compile(snippet, bytecode);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
new file mode 100644
index 0000000..b01471a
--- /dev/null
+++ b/native/utils/lua-utils.h
@@ -0,0 +1,623 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+
+#include <vector>
+
+#include "actions/types.h"
+#include "annotator/types.h"
+#include "utils/flatbuffers.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+#include "flatbuffers/reflection_generated.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+static constexpr const char kLengthKey[] = "__len";
+static constexpr const char kPairsKey[] = "__pairs";
+static constexpr const char kIndexKey[] = "__index";
+static constexpr const char kGcKey[] = "__gc";
+static constexpr const char kNextKey[] = "__next";
+
+static constexpr const int kIndexStackTop = -1;
+
+// Casts to the lua user data type.
+template <typename T>
+void* AsUserData(const T* value) {
+ return static_cast<void*>(const_cast<T*>(value));
+}
+template <typename T>
+void* AsUserData(const T value) {
+ return reinterpret_cast<void*>(value);
+}
+
+// Retrieves up-values.
+template <typename T>
+T FromUpValue(const int index, lua_State* state) {
+ return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
+}
+
+class LuaEnvironment {
+ public:
+ virtual ~LuaEnvironment();
+ LuaEnvironment();
+
+ // Compile a lua snippet into binary bytecode.
+ // NOTE: The compiled bytecode might not be compatible across Lua versions
+ // and platforms.
+ bool Compile(StringPiece snippet, std::string* bytecode) const;
+
+ // Loads default libraries.
+ void LoadDefaultLibraries();
+
+ // Provides a callback to Lua.
+ template <typename T>
+ void PushFunction(int (T::*handler)()) {
+ PushFunction(std::bind(handler, static_cast<T*>(this)));
+ }
+
+ template <typename F>
+ void PushFunction(const F& func) const {
+ // Copy closure to the lua stack.
+ new (lua_newuserdata(state_, sizeof(func))) F(func);
+
+ // Register garbage collection callback.
+ lua_newtable(state_);
+ lua_pushcfunction(state_, &ReleaseFunction<F>);
+ lua_setfield(state_, -2, kGcKey);
+ lua_setmetatable(state_, -2);
+
+ // Push dispatch.
+ lua_pushcclosure(state_, &CallFunction<F>, 1);
+ }
+
+ // Sets up a named table that calls back whenever a member is accessed.
+ // This allows to lazily provide required information to the script.
+ template <typename T>
+ void PushLazyObject(int (T::*handler)()) {
+ PushLazyObject(std::bind(handler, static_cast<T*>(this)));
+ }
+
+ template <typename F>
+ void PushLazyObject(const F& func) const {
+ lua_newtable(state_);
+ lua_newtable(state_);
+ PushFunction(func);
+ lua_setfield(state_, -2, kIndexKey);
+ lua_setmetatable(state_, -2);
+ }
+
+ void Push(const int64 value) const { lua_pushinteger(state_, value); }
+ void Push(const uint64 value) const { lua_pushinteger(state_, value); }
+ void Push(const int32 value) const { lua_pushinteger(state_, value); }
+ void Push(const uint32 value) const { lua_pushinteger(state_, value); }
+ void Push(const int16 value) const { lua_pushinteger(state_, value); }
+ void Push(const uint16 value) const { lua_pushinteger(state_, value); }
+ void Push(const int8 value) const { lua_pushinteger(state_, value); }
+ void Push(const uint8 value) const { lua_pushinteger(state_, value); }
+ void Push(const float value) const { lua_pushnumber(state_, value); }
+ void Push(const double value) const { lua_pushnumber(state_, value); }
+ void Push(const bool value) const { lua_pushboolean(state_, value); }
+ void Push(const StringPiece value) const { PushString(value); }
+ void Push(const flatbuffers::String* value) const {
+ if (value == nullptr) {
+ PushString("");
+ } else {
+ PushString(StringPiece(value->c_str(), value->size()));
+ }
+ }
+
+ template <typename T>
+ T Read(const int index = -1) const;
+
+ template <>
+ int64 Read<int64>(const int index) const {
+ return static_cast<int64>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ uint64 Read<uint64>(const int index) const {
+ return static_cast<uint64>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ int32 Read<int32>(const int index) const {
+ return static_cast<int32>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ uint32 Read<uint32>(const int index) const {
+ return static_cast<uint32>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ int16 Read<int16>(const int index) const {
+ return static_cast<int16>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ uint16 Read<uint16>(const int index) const {
+ return static_cast<uint16>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ int8 Read<int8>(const int index) const {
+ return static_cast<int8>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ uint8 Read<uint8>(const int index) const {
+ return static_cast<uint8>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ float Read<float>(const int index) const {
+ return static_cast<float>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ double Read<double>(const int index) const {
+ return static_cast<double>(lua_tonumber(state_, /*idx=*/index));
+ }
+
+ template <>
+ bool Read<bool>(const int index) const {
+ return lua_toboolean(state_, /*idx=*/index);
+ }
+
+ template <>
+ StringPiece Read<StringPiece>(const int index) const {
+ return ReadString(index);
+ }
+
+ template <>
+ std::string Read<std::string>(const int index) const {
+ return ReadString(index).ToString();
+ }
+
+ // Reads a string from the stack.
+ StringPiece ReadString(int index) const;
+
+ // Pushes a string to the stack.
+ void PushString(const StringPiece str) const;
+
+ // Pushes a flatbuffer to the stack.
+ void PushFlatbuffer(const reflection::Schema* schema,
+ const flatbuffers::Table* table) const {
+ PushFlatbuffer(schema, schema->root_table(), table);
+ }
+
+ // Reads a flatbuffer from the stack.
+ int ReadFlatbuffer(int index, ReflectiveFlatbuffer* buffer) const;
+
+ // Pushes an iterator.
+ template <typename ItemCallback, typename KeyCallback>
+ void PushIterator(const int length, const ItemCallback& item_callback,
+ const KeyCallback& key_callback) const {
+ lua_newtable(state_);
+ CreateIteratorMetatable(length, item_callback);
+ PushFunction([this, length, item_callback, key_callback]() {
+ return Iterator::Dispatch(this, length, item_callback, key_callback);
+ });
+ lua_setfield(state_, -2, kIndexKey);
+ lua_setmetatable(state_, -2);
+ }
+
+ template <typename ItemCallback>
+ void PushIterator(const int length, const ItemCallback& item_callback) const {
+ lua_newtable(state_);
+ CreateIteratorMetatable(length, item_callback);
+ PushFunction([this, length, item_callback]() {
+ return Iterator::Dispatch(this, length, item_callback);
+ });
+ lua_setfield(state_, -2, kIndexKey);
+ lua_setmetatable(state_, -2);
+ }
+
+ template <typename ItemCallback>
+ void CreateIteratorMetatable(const int length,
+ const ItemCallback& item_callback) const {
+ lua_newtable(state_);
+ PushFunction([this, length]() { return Iterator::Length(this, length); });
+ lua_setfield(state_, -2, kLengthKey);
+ PushFunction([this, length, item_callback]() {
+ return Iterator::IterItems(this, length, item_callback);
+ });
+ lua_setfield(state_, -2, kPairsKey);
+ PushFunction([this, length, item_callback]() {
+ return Iterator::Next(this, length, item_callback);
+ });
+ lua_setfield(state_, -2, kNextKey);
+ }
+
+ template <typename T>
+ void PushVectorIterator(const std::vector<T>* items) const {
+ PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
+ this->Push(items->at(pos));
+ return 1;
+ });
+ }
+
+ template <typename T>
+ void PushVector(const std::vector<T>& items) const {
+ lua_newtable(state_);
+ for (int i = 0; i < items.size(); i++) {
+ // Key: index, 1-based.
+ Push(i + 1);
+
+ // Value.
+ Push(items[i]);
+ lua_settable(state_, /*idx=*/-3);
+ }
+ }
+
+ void PushEmptyVector() const { lua_newtable(state_); }
+
+ template <typename T>
+ std::vector<T> ReadVector(const int index = -1) const {
+ std::vector<T> result;
+ if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table, got: "
+ << lua_type(state_, /*idx=*/kIndexStackTop);
+ lua_pop(state_, 1);
+ return {};
+ }
+ lua_pushnil(state_);
+ while (Next(index - 1)) {
+ result.push_back(Read<T>(/*index=*/kIndexStackTop));
+ lua_pop(state_, 1);
+ }
+ return result;
+ }
+
+ // Runs a closure in protected mode.
+ // `func`: closure to run in protected mode.
+ // `num_lua_args`: number of arguments from the lua stack to process.
+ // `num_results`: number of result values pushed on the stack.
+ template <typename F>
+ int RunProtected(const F& func, const int num_args = 0,
+ const int num_results = 0) const {
+ PushFunction(func);
+ // Put the closure before the arguments on the stack.
+ if (num_args > 0) {
+ lua_insert(state_, -(1 + num_args));
+ }
+ return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
+ }
+
+ // Auxiliary methods to handle model results.
+ // Provides an annotation to lua.
+ void PushAnnotation(const ClassificationResult& classification,
+ const reflection::Schema* entity_data_schema) const;
+ void PushAnnotation(const ClassificationResult& classification,
+ StringPiece text,
+ const reflection::Schema* entity_data_schema) const;
+ void PushAnnotation(const ActionSuggestionAnnotation& annotation,
+ const reflection::Schema* entity_data_schema) const;
+
+ template <typename Annotation>
+ void PushAnnotations(const std::vector<Annotation>* annotations,
+ const reflection::Schema* entity_data_schema) const {
+ PushIterator(
+ annotations ? annotations->size() : 0,
+ [this, annotations, entity_data_schema](const int64 index) {
+ PushAnnotation(annotations->at(index), entity_data_schema);
+ return 1;
+ },
+ [this, annotations, entity_data_schema](StringPiece name) {
+ if (const Annotation* annotation =
+ GetAnnotationByName(*annotations, name)) {
+ PushAnnotation(*annotation, entity_data_schema);
+ return 1;
+ } else {
+ return 0;
+ }
+ });
+ }
+
+ // Pushes a span to the lua stack.
+ void PushAnnotatedSpan(const AnnotatedSpan& annotated_span,
+ const reflection::Schema* entity_data_schema) const;
+ void PushAnnotatedSpans(const std::vector<AnnotatedSpan>* annotated_spans,
+ const reflection::Schema* entity_data_schema) const;
+
+ // Reads a message text span from lua.
+ MessageTextSpan ReadSpan() const;
+
+ ActionSuggestionAnnotation ReadAnnotation(
+ const reflection::Schema* entity_data_schema) const;
+ int ReadAnnotations(
+ const reflection::Schema* entity_data_schema,
+ std::vector<ActionSuggestionAnnotation>* annotations) const;
+ ClassificationResult ReadClassificationResult(
+ const reflection::Schema* entity_data_schema) const;
+
+ // Provides an action to lua.
+ void PushAction(
+ const ActionSuggestion& action,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const;
+
+ void PushActions(
+ const std::vector<ActionSuggestion>* actions,
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const;
+
+ ActionSuggestion ReadAction(
+ const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema) const;
+
+ int ReadActions(const reflection::Schema* actions_entity_data_schema,
+ const reflection::Schema* annotations_entity_data_schema,
+ std::vector<ActionSuggestion>* actions) const;
+
+ // Conversation message iterator.
+ void PushConversation(
+ const std::vector<ConversationMessage>* conversation,
+ const reflection::Schema* annotations_entity_data_schema) const;
+
+ lua_State* state() const { return state_; }
+
+ protected:
+ // Wrapper for handling iteration over containers.
+ class Iterator {
+ public:
+ // Starts a new key-value pair iterator.
+ template <typename ItemCallback>
+ static int IterItems(const LuaEnvironment* env, const int length,
+ const ItemCallback& callback) {
+ env->PushFunction([env, callback, length, pos = 0]() mutable {
+ if (pos >= length) {
+ lua_pushnil(env->state());
+ return 1;
+ }
+
+ // Push key.
+ lua_pushinteger(env->state(), pos + 1);
+
+ // Push item.
+ return 1 + callback(pos++);
+ });
+ return 1; // Num. results.
+ }
+
+ // Gets the next element.
+ template <typename ItemCallback>
+ static int Next(const LuaEnvironment* env, const int length,
+ const ItemCallback& item_callback) {
+ int64 pos = lua_isnil(env->state(), /*idx=*/kIndexStackTop)
+ ? 0
+ : env->Read<int64>(/*index=*/kIndexStackTop);
+ if (pos < length) {
+ // Push next key.
+ lua_pushinteger(env->state(), pos + 1);
+
+ // Push item.
+ return 1 + item_callback(pos);
+ } else {
+ lua_pushnil(env->state());
+ return 1;
+ }
+ }
+
+ // Returns the length of the container the iterator processes.
+ static int Length(const LuaEnvironment* env, const int length) {
+ lua_pushinteger(env->state(), length);
+ return 1; // Num. results.
+ }
+
+ // Handles item queries to the iterator.
+ // Elements of the container can either be queried by name or index.
+ // Dispatch will check how an element is accessed and
+ // calls `key_callback` for access by name and `item_callback` for access by
+ // index.
+ template <typename ItemCallback, typename KeyCallback>
+ static int Dispatch(const LuaEnvironment* env, const int length,
+ const ItemCallback& item_callback,
+ const KeyCallback& key_callback) {
+ switch (lua_type(env->state(), kIndexStackTop)) {
+ case LUA_TNUMBER: {
+ // Lua is one based, so adjust the index here.
+ const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
+ if (index < 0 || index >= length) {
+ TC3_LOG(ERROR) << "Invalid index: " << index;
+ lua_error(env->state());
+ return 0;
+ }
+ return item_callback(index);
+ }
+ case LUA_TSTRING: {
+ return key_callback(env->ReadString(kIndexStackTop));
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected access type: "
+ << lua_type(env->state(), kIndexStackTop);
+ lua_error(env->state());
+ return 0;
+ }
+ }
+
+ template <typename ItemCallback>
+ static int Dispatch(const LuaEnvironment* env, const int length,
+ const ItemCallback& item_callback) {
+ switch (lua_type(env->state(), kIndexStackTop)) {
+ case LUA_TNUMBER: {
+ // Lua is one based, so adjust the index here.
+ const int64 index = env->Read<int64>(/*index=*/kIndexStackTop) - 1;
+ if (index < 0 || index >= length) {
+ TC3_LOG(ERROR) << "Invalid index: " << index;
+ lua_error(env->state());
+ return 0;
+ }
+ return item_callback(index);
+ }
+ default:
+ TC3_LOG(ERROR) << "Unexpected access type: "
+ << lua_type(env->state(), kIndexStackTop);
+ lua_error(env->state());
+ return 0;
+ }
+ }
+ };
+
+ // Calls the deconstructor from a previously pushed function.
+ template <typename T>
+ static int ReleaseFunction(lua_State* state) {
+ static_cast<T*>(lua_touserdata(state, 1))->~T();
+ return 0;
+ }
+
+ template <typename T>
+ static int CallFunction(lua_State* state) {
+ return (*static_cast<T*>(lua_touserdata(state, lua_upvalueindex(1))))();
+ }
+
+ // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
+ void PushFlatbuffer(const reflection::Schema* schema,
+ const reflection::Object* type,
+ const flatbuffers::Table* table) const;
+ int GetField(const reflection::Schema* schema, const reflection::Object* type,
+ const flatbuffers::Table* table) const;
+
+ // Reads a repeated field from lua.
+ template <typename T>
+ void ReadRepeatedField(const int index, RepeatedField* result) const {
+ for (const auto& element : ReadVector<T>(index)) {
+ result->Add(element);
+ }
+ }
+
+ template <>
+ void ReadRepeatedField<ReflectiveFlatbuffer>(const int index,
+ RepeatedField* result) const {
+ lua_pushnil(state_);
+ while (Next(index - 1)) {
+ ReadFlatbuffer(index, result->Add());
+ lua_pop(state_, 1);
+ }
+ }
+
+ // Pushes a repeated field to the lua stack.
+ template <typename T>
+ void PushRepeatedField(const flatbuffers::Vector<T>* items) const {
+ PushIterator(items ? items->size() : 0, [this, items](const int64 pos) {
+ Push(items->Get(pos));
+ return 1; // Num. results.
+ });
+ }
+
+ void PushRepeatedFlatbufferField(
+ const reflection::Schema* schema, const reflection::Object* type,
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>* items)
+ const {
+ PushIterator(items ? items->size() : 0,
+ [this, schema, type, items](const int64 pos) {
+ PushFlatbuffer(schema, type, items->Get(pos));
+ return 1; // Num. results.
+ });
+ }
+
+ // Overloads Lua next function to use __next key on the metatable.
+ // This allows us to treat lua objects and lazy objects provided by our
+ // callbacks uniformly.
+ int Next(int index) const {
+ // Check whether the (meta)table of this object has an associated "__next"
+ // entry. This means, we registered our own callback. So we explicitly call
+ // that.
+ if (luaL_getmetafield(state_, index, kNextKey)) {
+ // Callback is now on top of the stack, so adjust relative indices by 1.
+ if (index < 0) {
+ index--;
+ }
+
+ // Copy the reference to the table.
+ lua_pushvalue(state_, index);
+
+ // Move the key to top to have it as second argument for the callback.
+ // Copy the key to the top.
+ lua_pushvalue(state_, -3);
+
+ // Remove the copy of the key.
+ lua_remove(state_, -4);
+
+ // Call the callback with (key and table as arguments).
+ lua_pcall(state_, /*nargs=*/2 /* table, key */,
+ /*nresults=*/2 /* key, item */, 0);
+
+ // Next returned nil, it's the end.
+ if (lua_isnil(state_, kIndexStackTop)) {
+ // Remove nil value.
+ // Results will be padded to `nresults` specified above, so we need
+ // to remove two elements here.
+ lua_pop(state_, 2);
+ return 0;
+ }
+
+ return 2; // Num. results.
+ } else if (lua_istable(state_, index)) {
+ return lua_next(state_, index);
+ }
+
+ // Remove the key.
+ lua_pop(state_, 1);
+ return 0;
+ }
+
+ static const ClassificationResult* GetAnnotationByName(
+ const std::vector<ClassificationResult>& annotations, StringPiece name) {
+ // Lookup annotation by collection.
+ for (const ClassificationResult& annotation : annotations) {
+ if (name.Equals(annotation.collection)) {
+ return &annotation;
+ }
+ }
+ TC3_LOG(ERROR) << "No annotation with collection: " << name << " found.";
+ return nullptr;
+ }
+
+ static const ActionSuggestionAnnotation* GetAnnotationByName(
+ const std::vector<ActionSuggestionAnnotation>& annotations,
+ StringPiece name) {
+ // Lookup annotation by name.
+ for (const ActionSuggestionAnnotation& annotation : annotations) {
+ if (name.Equals(annotation.name)) {
+ return &annotation;
+ }
+ }
+ TC3_LOG(ERROR) << "No annotation with name: " << name << " found.";
+ return nullptr;
+ }
+
+ lua_State* state_;
+}; // namespace libtextclassifier3
+
+bool Compile(StringPiece snippet, std::string* bytecode);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
diff --git a/native/utils/lua-utils_test.cc b/native/utils/lua-utils_test.cc
new file mode 100644
index 0000000..8c9f8de
--- /dev/null
+++ b/native/utils/lua-utils_test.cc
@@ -0,0 +1,333 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/lua-utils.h"
+
+#include <string>
+
+#include "utils/flatbuffers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+using testing::Eq;
+using testing::FloatEq;
+
+std::string TestFlatbufferSchema() {
+ // Creates a test schema for flatbuffer passing tests.
+ // Cannot use the object oriented API here as that is not available for the
+ // reflection schema.
+ flatbuffers::FlatBufferBuilder schema_builder;
+ std::vector<flatbuffers::Offset<reflection::Field>> fields = {
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("float_field"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Float),
+ /*id=*/0,
+ /*offset=*/4),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("nested_field"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Obj,
+ /*element=*/reflection::None,
+ /*index=*/0 /* self */),
+ /*id=*/1,
+ /*offset=*/6),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("repeated_nested_field"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Vector,
+ /*element=*/reflection::Obj,
+ /*index=*/0 /* self */),
+ /*id=*/2,
+ /*offset=*/8),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("repeated_string_field"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::Vector,
+ /*element=*/reflection::String),
+ /*id=*/3,
+ /*offset=*/10),
+ reflection::CreateField(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("string_field"),
+ /*type=*/
+ reflection::CreateType(schema_builder,
+ /*base_type=*/reflection::String),
+ /*id=*/4,
+ /*offset=*/12)};
+
+ std::vector<flatbuffers::Offset<reflection::Enum>> enums;
+ std::vector<flatbuffers::Offset<reflection::Object>> objects = {
+ reflection::CreateObject(
+ schema_builder,
+ /*name=*/schema_builder.CreateString("TestData"),
+ /*fields=*/
+ schema_builder.CreateVectorOfSortedTables(&fields))};
+ schema_builder.Finish(reflection::CreateSchema(
+ schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
+ schema_builder.CreateVectorOfSortedTables(&enums),
+ /*(unused) file_ident=*/0,
+ /*(unused) file_ext=*/0,
+ /*root_table*/ objects[0]));
+ return std::string(
+ reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
+ schema_builder.GetSize());
+}
+
+class LuaUtilsTest : public testing::Test, protected LuaEnvironment {
+ protected:
+ LuaUtilsTest()
+ : serialized_flatbuffer_schema_(TestFlatbufferSchema()),
+ schema_(flatbuffers::GetRoot<reflection::Schema>(
+ serialized_flatbuffer_schema_.data())),
+ flatbuffer_builder_(schema_) {
+ EXPECT_THAT(RunProtected([this] {
+ LoadDefaultLibraries();
+ return LUA_OK;
+ }),
+ Eq(LUA_OK));
+ }
+
+ void RunScript(StringPiece script) {
+ EXPECT_THAT(luaL_loadbuffer(state_, script.data(), script.size(),
+ /*name=*/nullptr),
+ Eq(LUA_OK));
+ EXPECT_THAT(
+ lua_pcall(state_, /*nargs=*/0, /*num_results=*/1, /*errfunc=*/0),
+ Eq(LUA_OK));
+ }
+
+ const std::string serialized_flatbuffer_schema_;
+ const reflection::Schema* schema_;
+ ReflectiveFlatbufferBuilder flatbuffer_builder_;
+};
+
+TEST_F(LuaUtilsTest, HandlesVectors) {
+ {
+ PushVector(std::vector<int64>{1, 2, 3, 4, 5});
+ EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
+ }
+ {
+ PushVector(std::vector<std::string>{"hello", "there"});
+ EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
+ }
+ {
+ PushVector(std::vector<bool>{true, true, false});
+ EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
+ }
+}
+
+TEST_F(LuaUtilsTest, HandlesVectorIterators) {
+ {
+ const std::vector<int64> elements = {1, 2, 3, 4, 5};
+ PushVectorIterator(&elements);
+ EXPECT_THAT(ReadVector<int64>(), ElementsAre(1, 2, 3, 4, 5));
+ }
+ {
+ const std::vector<std::string> elements = {"hello", "there"};
+ PushVectorIterator(&elements);
+ EXPECT_THAT(ReadVector<std::string>(), ElementsAre("hello", "there"));
+ }
+ {
+ const std::vector<bool> elements = {true, true, false};
+ PushVectorIterator(&elements);
+ EXPECT_THAT(ReadVector<bool>(), ElementsAre(true, true, false));
+ }
+}
+
+TEST_F(LuaUtilsTest, ReadsFlatbufferResults) {
+ // Setup.
+ RunScript(R"lua(
+ return {
+ float_field = 42.1,
+ string_field = "hello there",
+
+ -- Nested field.
+ nested_field = {
+ float_field = 64,
+ string_field = "hello nested",
+ },
+
+ -- Repeated fields.
+ repeated_string_field = { "a", "bold", "one" },
+ repeated_nested_field = {
+ { string_field = "a" },
+ { string_field = "b" },
+ { repeated_string_field = { "nested", "nested2" } },
+ },
+ }
+ )lua");
+
+ // Read the flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ ReadFlatbuffer(/*index=*/-1, buffer.get());
+ const std::string serialized_buffer = buffer->Serialize();
+
+ // Check fields. As we do not have flatbuffer compiled generated code for the
+ // ad hoc generated test schema, we have to read by manually using field
+ // offsets.
+ const flatbuffers::Table* flatbuffer_data =
+ flatbuffers::GetRoot<flatbuffers::Table>(serialized_buffer.data());
+ EXPECT_THAT(flatbuffer_data->GetField<float>(/*field=*/4, /*defaultval=*/0),
+ FloatEq(42.1));
+ EXPECT_THAT(
+ flatbuffer_data->GetPointer<const flatbuffers::String*>(/*field=*/12)
+ ->str(),
+ "hello there");
+
+ // Read the nested field.
+ const flatbuffers::Table* nested_field =
+ flatbuffer_data->GetPointer<const flatbuffers::Table*>(/*field=*/6);
+ EXPECT_THAT(nested_field->GetField<float>(/*field=*/4, /*defaultval=*/0),
+ FloatEq(64));
+ EXPECT_THAT(
+ nested_field->GetPointer<const flatbuffers::String*>(/*field=*/12)->str(),
+ "hello nested");
+
+ // Read the repeated string field.
+ auto repeated_strings = flatbuffer_data->GetPointer<
+ flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ /*field=*/10);
+ EXPECT_THAT(repeated_strings->size(), Eq(3));
+ EXPECT_THAT(repeated_strings->GetAsString(0)->str(), Eq("a"));
+ EXPECT_THAT(repeated_strings->GetAsString(1)->str(), Eq("bold"));
+ EXPECT_THAT(repeated_strings->GetAsString(2)->str(), Eq("one"));
+
+ // Read the repeated nested field.
+ auto repeated_nested_fields = flatbuffer_data->GetPointer<
+ flatbuffers::Vector<flatbuffers::Offset<flatbuffers::Table>>*>(
+ /*field=*/8);
+ EXPECT_THAT(repeated_nested_fields->size(), Eq(3));
+ EXPECT_THAT(repeated_nested_fields->Get(0)
+ ->GetPointer<const flatbuffers::String*>(/*field=*/12)
+ ->str(),
+ "a");
+ EXPECT_THAT(repeated_nested_fields->Get(1)
+ ->GetPointer<const flatbuffers::String*>(/*field=*/12)
+ ->str(),
+ "b");
+}
+
+TEST_F(LuaUtilsTest, HandlesSimpleFlatbufferFields) {
+ // Create test flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ buffer->Set("float_field", 42.f);
+ const std::string serialized_buffer = buffer->Serialize();
+ PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
+ lua_setglobal(state_, "arg");
+
+ // Setup.
+ RunScript(R"lua(
+ return arg.float_field
+ )lua");
+
+ EXPECT_THAT(Read<float>(), FloatEq(42));
+}
+
+TEST_F(LuaUtilsTest, HandlesRepeatedFlatbufferFields) {
+ // Create test flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ RepeatedField* repeated_field = buffer->Repeated("repeated_string_field");
+ repeated_field->Add("this");
+ repeated_field->Add("is");
+ repeated_field->Add("a");
+ repeated_field->Add("test");
+ const std::string serialized_buffer = buffer->Serialize();
+ PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
+ lua_setglobal(state_, "arg");
+
+ // Return flatbuffer repeated field as vector.
+ RunScript(R"lua(
+ return arg.repeated_string_field
+ )lua");
+
+ EXPECT_THAT(ReadVector<std::string>(),
+ ElementsAre("this", "is", "a", "test"));
+}
+
+TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
+ // Create test flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field");
+ repeated_field->Add()->Set("string_field", "hello");
+ repeated_field->Add()->Set("string_field", "my");
+ ReflectiveFlatbuffer* nested = repeated_field->Add();
+ nested->Set("string_field", "old");
+ RepeatedField* nested_repeated = nested->Repeated("repeated_string_field");
+ nested_repeated->Add("friend");
+ nested_repeated->Add("how");
+ nested_repeated->Add("are");
+ repeated_field->Add()->Set("string_field", "you?");
+ const std::string serialized_buffer = buffer->Serialize();
+ PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
+ lua_setglobal(state_, "arg");
+
+ RunScript(R"lua(
+ result = {}
+ for _, nested in pairs(arg.repeated_nested_field) do
+ result[#result + 1] = nested.string_field
+ for _, nested_string in pairs(nested.repeated_string_field) do
+ result[#result + 1] = nested_string
+ end
+ end
+ return result
+ )lua");
+
+ EXPECT_THAT(
+ ReadVector<std::string>(),
+ ElementsAre("hello", "my", "old", "friend", "how", "are", "you?"));
+}
+
+TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) {
+ // The first flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
+ buffer->Set("string_field", "first");
+ const std::string serialized_buffer = buffer->Serialize();
+ PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer.data()));
+ lua_setglobal(state_, "arg");
+ // The second flatbuffer.
+ std::unique_ptr<ReflectiveFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
+ buffer2->Set("string_field", "second");
+ const std::string serialized_buffer2 = buffer2->Serialize();
+ PushFlatbuffer(schema_, flatbuffers::GetRoot<flatbuffers::Table>(
+ serialized_buffer2.data()));
+ lua_setglobal(state_, "arg2");
+
+ RunScript(R"lua(
+ return {arg.string_field, arg2.string_field}
+ )lua");
+
+ EXPECT_THAT(ReadVector<std::string>(), ElementsAre("first", "second"));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/math/fastexp.cc b/native/utils/math/fastexp.cc
similarity index 100%
rename from utils/math/fastexp.cc
rename to native/utils/math/fastexp.cc
diff --git a/native/utils/math/fastexp.h b/native/utils/math/fastexp.h
new file mode 100644
index 0000000..8128627
--- /dev/null
+++ b/native/utils/math/fastexp.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Fast approximation for exp.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
+#define LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
+
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+#include "utils/base/casts.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+class FastMathClass {
+ private:
+ static constexpr int kBits = 7;
+ static constexpr int kMask1 = (1 << kBits) - 1;
+ static constexpr int kMask2 = 0xFF << kBits;
+ static constexpr float kLogBase2OfE = 1.44269504088896340736f;
+
+ struct Table {
+ int32 exp1[1 << kBits];
+ };
+
+ public:
+ float VeryFastExp2(float f) const {
+ TC3_DCHECK_LE(fabs(f), 126);
+ const float g = f + (127 + (1 << (23 - kBits)));
+ const int32 x = bit_cast<int32>(g);
+ int32 ret = ((x & kMask2) << (23 - kBits))
+ | cache_.exp1[x & kMask1];
+ return bit_cast<float>(ret);
+ }
+
+ float VeryFastExp(float f) const {
+ return VeryFastExp2(f * kLogBase2OfE);
+ }
+
+ private:
+ static const Table cache_;
+};
+
+extern FastMathClass FastMathInstance;
+
+inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
diff --git a/utils/math/softmax.cc b/native/utils/math/softmax.cc
similarity index 100%
rename from utils/math/softmax.cc
rename to native/utils/math/softmax.cc
diff --git a/utils/math/softmax.h b/native/utils/math/softmax.h
similarity index 100%
rename from utils/math/softmax.h
rename to native/utils/math/softmax.h
diff --git a/utils/memory/mmap.cc b/native/utils/memory/mmap.cc
similarity index 100%
rename from utils/memory/mmap.cc
rename to native/utils/memory/mmap.cc
diff --git a/native/utils/memory/mmap.h b/native/utils/memory/mmap.h
new file mode 100644
index 0000000..974cc02
--- /dev/null
+++ b/native/utils/memory/mmap.h
@@ -0,0 +1,141 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
+#define LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
+
+#include <stddef.h>
+
+#include <string>
+
+#include "utils/base/integral_types.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Handle for a memory area where a file has been mmapped.
+//
+// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
+// it using Unmap(). Just like a pointer, it is passed around by value (see
+// signature of MmapFile and Unmap; fortunately, it's a small class, so there
+// shouldn't be any significant performance penalty) and its usage is not
+// necessarily scoped (that's why the destructor is not performing the unmap).
+//
+// Note: on program termination, each still unmapped file is automatically
+// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
+// are ok keeping that file in memory the whole time).
+class MmapHandle {
+ public:
+ MmapHandle(void *start, size_t num_bytes, void *unmap_addr = nullptr)
+ : start_(start), num_bytes_(num_bytes), unmap_addr_(unmap_addr) {}
+
+ // Returns start address for the memory area where a file has been mmapped.
+ void *start() const { return start_; }
+
+ // Returns address to use for munmap call. If unmap_addr was not specified
+ // the start address is used.
+ void *unmap_addr() const {
+ if (unmap_addr_ != nullptr) {
+ return unmap_addr_;
+ } else {
+ return start_;
+ }
+ }
+
+ // Returns number of bytes of the memory area from start().
+ size_t num_bytes() const { return num_bytes_; }
+
+ // Shortcut to simplify checking success of MmapFile(). See usage example
+ // from the doc of that function.
+ bool ok() const { return start() != nullptr; }
+
+ // Returns a StringPiece pointing to the same underlying bytes.
+ StringPiece to_stringpiece() const {
+ return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
+ }
+
+ private:
+ // See doc for start(). Not owned.
+ void *const start_;
+
+ // See doc for num_bytes().
+ const size_t num_bytes_;
+
+ // Address to use for unmapping.
+ void *const unmap_addr_;
+};
+
+// Maps the full content of a file in memory (using mmap).
+//
+// When done using the file content, one can unmap using Unmap(). Otherwise,
+// all mapped files are unmapped when the program terminates.
+//
+// Sample usage:
+//
+// MmapHandle mmap_handle = MmapFile(filename);
+// TC3_DCHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
+//
+// ... use data from addresses
+// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
+//
+// Unmap(mmap_handle); // Unmap logs errors internally.
+//
+// Note: one can read *and* write the num_bytes bytes from start, but those
+// writes are not propagated to the underlying file, nor to other processes that
+// may have mmapped that file (all changes are local to current process).
+MmapHandle MmapFile(const std::string &filename);
+
+// Like MmapFile(const std::string &filename), but uses a file descriptor.
+MmapHandle MmapFile(int fd);
+
+// Maps a segment of a file to memory. File is given by a file descriptor, and
+// offset (relative to the beginning of the file) and size specify the segment
+// to be mapped. NOTE: Internally, we align the offset for the call to mmap
+// system call to be a multiple of page size, so offset does NOT have to be a
+// multiply of the page size.
+MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size);
+
+// Unmaps a file mapped using MmapFile. Returns true on success, false
+// otherwise.
+bool Unmap(MmapHandle mmap_handle);
+
+// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
+// destruction.
+class ScopedMmap {
+ public:
+ explicit ScopedMmap(const std::string &filename)
+ : handle_(MmapFile(filename)) {}
+
+ explicit ScopedMmap(int fd) : handle_(MmapFile(fd)) {}
+
+ ScopedMmap(int fd, int segment_offset, int segment_size)
+ : handle_(MmapFile(fd, segment_offset, segment_size)) {}
+
+ ~ScopedMmap() {
+ if (handle_.ok()) {
+ Unmap(handle_);
+ }
+ }
+
+ const MmapHandle &handle() const { return handle_; }
+
+ private:
+ MmapHandle handle_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
diff --git a/native/utils/normalization.cc b/native/utils/normalization.cc
new file mode 100644
index 0000000..f9623f7
--- /dev/null
+++ b/native/utils/normalization.cc
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/normalization.h"
+
+#include "utils/base/logging.h"
+#include "utils/normalization_generated.h"
+
+namespace libtextclassifier3 {
+
+UnicodeText NormalizeText(const UniLib& unilib,
+ const NormalizationOptions* normalization_options,
+ const UnicodeText& text) {
+ return NormalizeTextCodepointWise(
+ unilib, normalization_options->codepointwise_normalization(), text);
+}
+
+UnicodeText NormalizeTextCodepointWise(const UniLib& unilib,
+ const uint32 codepointwise_ops,
+ const UnicodeText& text) {
+ // Sanity check.
+ TC3_CHECK(!((codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE) &&
+ (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE)));
+
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ // Skip whitespace.
+ if ((codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE) &&
+ unilib.IsWhitespace(codepoint)) {
+ continue;
+ }
+
+ // Skip punctuation.
+ if ((codepointwise_ops &
+ NormalizationOptions_::
+ CodepointwiseNormalizationOp_DROP_PUNCTUATION) &&
+ unilib.IsPunctuation(codepoint)) {
+ continue;
+ }
+
+ int32 normalized_codepoint = codepoint;
+
+ // Lower case.
+ if (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE) {
+ normalized_codepoint = unilib.ToLower(normalized_codepoint);
+
+ // Upper case.
+ } else if (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE) {
+ normalized_codepoint = unilib.ToUpper(normalized_codepoint);
+ }
+
+ result.push_back(normalized_codepoint);
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/normalization.fbs b/native/utils/normalization.fbs
new file mode 100755
index 0000000..4d43f10
--- /dev/null
+++ b/native/utils/normalization.fbs
@@ -0,0 +1,40 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// The possible codepoint wise normalization operations.
+namespace libtextclassifier3.NormalizationOptions_;
+enum CodepointwiseNormalizationOp : int {
+ NONE = 0,
+
+ // Lower-case the string.
+ LOWERCASE = 1,
+
+ // Upper-case the string.
+ UPPERCASE = 4,
+
+ // Remove whitespace.
+ DROP_WHITESPACE = 8,
+
+ // Remove punctuation.
+ DROP_PUNCTUATION = 16,
+}
+
+namespace libtextclassifier3;
+table NormalizationOptions {
+ // Codepoint wise normalizations to apply, represents a bit field.
+ codepointwise_normalization:uint;
+}
+
diff --git a/native/utils/normalization.h b/native/utils/normalization.h
new file mode 100644
index 0000000..ff00783
--- /dev/null
+++ b/native/utils/normalization.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Methods for string normalization.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/normalization_generated.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Normalizes a text according to the options.
+UnicodeText NormalizeText(const UniLib& unilib,
+ const NormalizationOptions* normalization_options,
+ const UnicodeText& text);
+
+// Normalizes a text codepoint wise by applying each codepoint wise op in
+// `codepointwise_ops` that is interpreted as a set of
+// `CodepointwiseNormalizationOp`.
+UnicodeText NormalizeTextCodepointWise(const UniLib& unilib,
+ const uint32 codepointwise_ops,
+ const UnicodeText& text);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
diff --git a/native/utils/normalization_test.cc b/native/utils/normalization_test.cc
new file mode 100644
index 0000000..1f731c7
--- /dev/null
+++ b/native/utils/normalization_test.cc
@@ -0,0 +1,121 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/normalization.h"
+
+#include <string>
+
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::Eq;
+
+class NormalizationTest : public testing::Test {
+ protected:
+ NormalizationTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+
+ std::string NormalizeTextCodepointWise(const std::string& text,
+ const int32 codepointwise_ops) {
+ return libtextclassifier3::NormalizeTextCodepointWise(
+ unilib_, codepointwise_ops,
+ UTF8ToUnicodeText(text, /*do_copy=*/false))
+ .ToUTF8String();
+ }
+
+ UniLib unilib_;
+};
+
+TEST_F(NormalizationTest, ReturnsIdenticalStringWhenNoNormalization) {
+ EXPECT_THAT(NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_NONE),
+ Eq("Never gonna let you down."));
+}
+
+#if !defined(TC3_UNILIB_DUMMY)
+TEST_F(NormalizationTest, DropsWhitespace) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never\tgonna\t\tlet\tyou\tdown.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never\u2003gonna\u2003let\u2003you\u2003down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+}
+
+TEST_F(NormalizationTest, DropsPunctuation) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("Never gonna let you down"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("αʹ Σημεῖόν ἐστιν οὗ μέρος οὐθέν"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "978—3—16—148410—0",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("9783161484100"));
+}
+
+TEST_F(NormalizationTest, LowercasesUnicodeText) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
+ Eq("αʹ. σημεῖόν ἐστιν, οὗ μέρος οὐθέν."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
+ Eq("αʹ.σημεῖόνἐστιν,οὗμέροςοὐθέν."));
+}
+
+TEST_F(NormalizationTest, UppercasesUnicodeText) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Κανένας άνθρωπος δεν ξέρει",
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
+ Eq("ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Κανένας άνθρωπος δεν ξέρει",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
+ Eq("ΚΑΝΈΝΑΣΆΝΘΡΩΠΟΣΔΕΝΞΈΡΕΙ"));
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/optional.h b/native/utils/optional.h
similarity index 100%
rename from utils/optional.h
rename to native/utils/optional.h
diff --git a/native/utils/regex-match.cc b/native/utils/regex-match.cc
new file mode 100644
index 0000000..13c881f
--- /dev/null
+++ b/native/utils/regex-match.cc
@@ -0,0 +1,184 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/regex-match.h"
+
+#include <memory>
+
+#include "annotator/types.h"
+
+#ifndef TC3_DISABLE_LUA
+#include "utils/lua-utils.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+#ifndef TC3_DISABLE_LUA
+// Provide a lua environment for running regex match post verification.
+// It sets up and exposes the match data as well as the context.
+class LuaVerifier : public LuaEnvironment {
+ public:
+ static std::unique_ptr<LuaVerifier> Create(
+ const std::string& context, const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher);
+
+ bool Verify(bool* result);
+
+ private:
+ explicit LuaVerifier(const std::string& context,
+ const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher)
+ : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
+ bool Initialize();
+
+ // Provides details of a capturing group to lua.
+ int GetCapturingGroup();
+
+ const std::string& context_;
+ const std::string& verifier_code_;
+ const UniLib::RegexMatcher* matcher_;
+};
+
+bool LuaVerifier::Initialize() {
+ // Run protected to not lua panic in case of setup failure.
+ return RunProtected([this] {
+ LoadDefaultLibraries();
+
+ // Expose context of the match as `context` global variable.
+ PushString(context_);
+ lua_setglobal(state_, "context");
+
+ // Expose match array as `match` global variable.
+ // Each entry `match[i]` exposes the ith capturing group as:
+ // * `begin`: span start
+ // * `end`: span end
+ // * `text`: the text
+ PushLazyObject(&LuaVerifier::GetCapturingGroup);
+ lua_setglobal(state_, "match");
+ return LUA_OK;
+ }) == LUA_OK;
+}
+
+std::unique_ptr<LuaVerifier> LuaVerifier::Create(
+ const std::string& context, const std::string& verifier_code,
+ const UniLib::RegexMatcher* matcher) {
+ auto verifier = std::unique_ptr<LuaVerifier>(
+ new LuaVerifier(context, verifier_code, matcher));
+ if (!verifier->Initialize()) {
+ TC3_LOG(ERROR) << "Could not initialize lua environment.";
+ return nullptr;
+ }
+ return verifier;
+}
+
+int LuaVerifier::GetCapturingGroup() {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
+ TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_error(state_);
+ return 0;
+ }
+ const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan span = {matcher_->Start(group_id, &status),
+ matcher_->End(group_id, &status)};
+ std::string text = matcher_->Group(group_id, &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError) {
+ TC3_LOG(ERROR) << "Could not extract span from capturing group.";
+ lua_error(state_);
+ return 0;
+ }
+ lua_newtable(state_);
+ lua_pushinteger(state_, span.first);
+ lua_setfield(state_, /*idx=*/-2, "begin");
+ lua_pushinteger(state_, span.second);
+ lua_setfield(state_, /*idx=*/-2, "end");
+ PushString(text);
+ lua_setfield(state_, /*idx=*/-2, "text");
+ return 1;
+}
+
+bool LuaVerifier::Verify(bool* result) {
+ if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
+ /*name=*/nullptr) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not load verifier snippet.";
+ return false;
+ }
+
+ if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not run verifier snippet.";
+ return false;
+ }
+
+ if (RunProtected(
+ [this, result] {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
+ TC3_LOG(ERROR) << "Unexpected verification result type: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
+ *result = lua_toboolean(state_, /*idx=*/-1);
+ return LUA_OK;
+ },
+ /*num_args=*/1) != LUA_OK) {
+ TC3_LOG(ERROR) << "Could not read lua result.";
+ return false;
+ }
+ return true;
+}
+#endif // TC3_DISABLE_LUA
+
+} // namespace
+
+Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
+ const int group_id) {
+ int status = UniLib::RegexMatcher::kNoError;
+ std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
+ if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
+ return Optional<std::string>();
+ }
+ return Optional<std::string>(group_text);
+}
+
+bool VerifyMatch(const std::string& context,
+ const UniLib::RegexMatcher* matcher,
+ const std::string& lua_verifier_code) {
+ bool status = false;
+#ifndef TC3_DISABLE_LUA
+ auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
+ if (verifier == nullptr) {
+ TC3_LOG(ERROR) << "Could not create verifier.";
+ return false;
+ }
+ if (!verifier->Verify(&status)) {
+ TC3_LOG(ERROR) << "Could not create verifier.";
+ return false;
+ }
+#endif // TC3_DISABLE_LUA
+ return status;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/regex-match.h b/native/utils/regex-match.h
new file mode 100644
index 0000000..1466b86
--- /dev/null
+++ b/native/utils/regex-match.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
+#define LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
+
+#include "utils/optional.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Returns text of a capturing group if the capturing group was fulfilled in
+// the regex match.
+Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
+ const int group_id);
+
+// Post-checks a regular expression match with a lua verifier script.
+// The verifier can access:
+// * `context`: The context as a string.
+// * `match`: The groups of the regex match as an array, each group gives
+// * `begin`: span start
+// * `end`: span end
+// * `text`: the text
+// The verifier is expected to return a boolean, indicating whether the
+// verification succeeded or not.
+// Returns true if the verification was successful, false if not.
+bool VerifyMatch(const std::string& context,
+ const UniLib::RegexMatcher* matcher,
+ const std::string& lua_verifier_code);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
diff --git a/native/utils/regex-match_test.cc b/native/utils/regex-match_test.cc
new file mode 100644
index 0000000..c45fb29
--- /dev/null
+++ b/native/utils/regex-match_test.cc
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/regex-match.h"
+
+#include <memory>
+
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class RegexMatchTest : public testing::Test {
+ protected:
+ RegexMatchTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+#ifdef TC3_UNILIB_ICU
+#ifndef TC3_DISABLE_LUA
+TEST_F(RegexMatchTest, HandlesSimpleVerification) {
+ EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
+}
+#endif // TC3_DISABLE_LUA
+
+#ifndef TC3_DISABLE_LUA
+TEST_F(RegexMatchTest, HandlesCustomVerification) {
+ UnicodeText pattern = UTF8ToUnicodeText("(\\d{16})",
+ /*do_copy=*/true);
+ UnicodeText message = UTF8ToUnicodeText("cc: 4012888888881881",
+ /*do_copy=*/true);
+ const std::string verifier = R"(
+function luhn(candidate)
+ local sum = 0
+ local num_digits = string.len(candidate)
+ local parity = num_digits % 2
+ for pos = 1,num_digits do
+ d = tonumber(string.sub(candidate, pos, pos))
+ if pos % 2 ~= parity then
+ d = d * 2
+ end
+ if d > 9 then
+ d = d - 9
+ end
+ sum = sum + d
+ end
+ return (sum % 10) == 0
+end
+return luhn(match[1].text);
+ )";
+ const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib_.CreateRegexPattern(pattern);
+ ASSERT_TRUE(regex_pattern != nullptr);
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern->Matcher(message);
+ ASSERT_TRUE(matcher != nullptr);
+ int status = UniLib::RegexMatcher::kNoError;
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+
+ EXPECT_TRUE(VerifyMatch(message.ToUTF8String(), matcher.get(), verifier));
+}
+#endif // TC3_DISABLE_LUA
+
+TEST_F(RegexMatchTest, RetrievesMatchGroupTest) {
+ UnicodeText pattern =
+ UTF8ToUnicodeText("never gonna (?:give (you) up|let (you) down)",
+ /*do_copy=*/true);
+ const std::unique_ptr<UniLib::RegexPattern> regex_pattern =
+ unilib_.CreateRegexPattern(pattern);
+ ASSERT_TRUE(regex_pattern != nullptr);
+ UnicodeText message =
+ UTF8ToUnicodeText("never gonna give you up - never gonna let you down");
+ const std::unique_ptr<UniLib::RegexMatcher> matcher =
+ regex_pattern->Matcher(message);
+ ASSERT_TRUE(matcher != nullptr);
+ int status = UniLib::RegexMatcher::kNoError;
+
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
+ testing::Eq("never gonna give you up"));
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 1).value(),
+ testing::Eq("you"));
+ EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 2).has_value());
+
+ ASSERT_TRUE(matcher->Find(&status) &&
+ status == UniLib::RegexMatcher::kNoError);
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 0).value(),
+ testing::Eq("never gonna let you down"));
+ EXPECT_FALSE(GetCapturingGroupText(matcher.get(), 1).has_value());
+ EXPECT_THAT(GetCapturingGroupText(matcher.get(), 2).value(),
+ testing::Eq("you"));
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/resources.cc b/native/utils/resources.cc
new file mode 100644
index 0000000..2ae2def
--- /dev/null
+++ b/native/utils/resources.cc
@@ -0,0 +1,248 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/resources.h"
+
+#include "utils/base/logging.h"
+#include "utils/zlib/buffer_generated.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+namespace {
+bool isWildcardMatch(const flatbuffers::String* left,
+ const std::string& right) {
+ return (left == nullptr || right.empty());
+}
+
+bool isExactMatch(const flatbuffers::String* left, const std::string& right) {
+ if (left == nullptr) {
+ return right.empty();
+ }
+ return left->str() == right;
+}
+
+} // namespace
+
+int Resources::LocaleMatch(const Locale& locale,
+ const LanguageTag* entry_locale) const {
+ int match = LOCALE_NO_MATCH;
+ if (isExactMatch(entry_locale->language(), locale.Language())) {
+ match |= LOCALE_LANGUAGE_MATCH;
+ } else if (isWildcardMatch(entry_locale->language(), locale.Language())) {
+ match |= LOCALE_LANGUAGE_WILDCARD_MATCH;
+ }
+
+ if (isExactMatch(entry_locale->script(), locale.Script())) {
+ match |= LOCALE_SCRIPT_MATCH;
+ } else if (isWildcardMatch(entry_locale->script(), locale.Script())) {
+ match |= LOCALE_SCRIPT_WILDCARD_MATCH;
+ }
+
+ if (isExactMatch(entry_locale->region(), locale.Region())) {
+ match |= LOCALE_REGION_MATCH;
+ } else if (isWildcardMatch(entry_locale->region(), locale.Region())) {
+ match |= LOCALE_REGION_WILDCARD_MATCH;
+ }
+
+ return match;
+}
+
+const ResourceEntry* Resources::FindResource(
+ const StringPiece resource_name) const {
+ if (resources_ == nullptr || resources_->resource_entry() == nullptr) {
+ TC3_LOG(ERROR) << "No resources defined.";
+ return nullptr;
+ }
+ const ResourceEntry* entry =
+ resources_->resource_entry()->LookupByKey(resource_name.data());
+ if (entry == nullptr) {
+ TC3_LOG(ERROR) << "Resource " << resource_name.ToString() << " not found";
+ return nullptr;
+ }
+ return entry;
+}
+
+int Resources::BestResourceForLocales(
+ const ResourceEntry* resource, const std::vector<Locale>& locales) const {
+ // Find best match based on locale.
+ int resource_id = -1;
+ int locale_match = LOCALE_NO_MATCH;
+ const auto* resources = resource->resource();
+ for (int user_locale = 0; user_locale < locales.size(); user_locale++) {
+ if (!locales[user_locale].IsValid()) {
+ continue;
+ }
+ for (int i = 0; i < resources->size(); i++) {
+ for (const int locale_id : *resources->Get(i)->locale()) {
+ const int candidate_match = LocaleMatch(
+ locales[user_locale], resources_->locale()->Get(locale_id));
+
+ // Only consider if at least the language matches.
+ if ((candidate_match & LOCALE_LANGUAGE_MATCH) == 0 &&
+ (candidate_match & LOCALE_LANGUAGE_WILDCARD_MATCH) == 0) {
+ continue;
+ }
+
+ if (candidate_match > locale_match) {
+ locale_match = candidate_match;
+ resource_id = i;
+ }
+ }
+ }
+
+ // If the language matches exactly, we are already finished.
+ // We found an exact language match.
+ if (locale_match & LOCALE_LANGUAGE_MATCH) {
+ return resource_id;
+ }
+ }
+ return resource_id;
+}
+
+bool Resources::GetResourceContent(const std::vector<Locale>& locales,
+ const StringPiece resource_name,
+ std::string* result) const {
+ const ResourceEntry* entry = FindResource(resource_name);
+ if (entry == nullptr || entry->resource() == nullptr) {
+ return false;
+ }
+
+ int resource_id = BestResourceForLocales(entry, locales);
+ if (resource_id < 0) {
+ return false;
+ }
+ const auto* resource = entry->resource()->Get(resource_id);
+ if (resource->content() != nullptr) {
+ *result = resource->content()->str();
+ return true;
+ } else if (resource->compressed_content() != nullptr) {
+ std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(
+ resources_->compression_dictionary()->data(),
+ resources_->compression_dictionary()->size());
+ if (decompressor != nullptr &&
+ decompressor->MaybeDecompress(resource->compressed_content(), result)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool CompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary,
+ const int dictionary_sample_every) {
+ std::vector<unsigned char> dictionary;
+ if (build_compression_dictionary) {
+ {
+ // Build up a compression dictionary.
+ std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
+ int i = 0;
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->content.empty()) {
+ continue;
+ }
+ i++;
+
+ // Use a sample of the entries to build up a custom compression
+ // dictionary. Using all entries will generally not give a benefit
+ // for small data sizes, so we subsample here.
+ if (i % dictionary_sample_every != 0) {
+ continue;
+ }
+ CompressedBufferT compressed_content;
+ compressor->Compress(resource->content, &compressed_content);
+ }
+ }
+ compressor->GetDictionary(&dictionary);
+ resources->compression_dictionary.assign(
+ dictionary.data(), dictionary.data() + dictionary.size());
+ }
+ }
+
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->content.empty()) {
+ continue;
+ }
+ // Try compressing the data.
+ std::unique_ptr<ZlibCompressor> compressor =
+ build_compression_dictionary
+ ? ZlibCompressor::Instance(dictionary.data(), dictionary.size())
+ : ZlibCompressor::Instance();
+ if (!compressor) {
+ TC3_LOG(ERROR) << "Cannot create zlib compressor.";
+ return false;
+ }
+
+ CompressedBufferT compressed_content;
+ compressor->Compress(resource->content, &compressed_content);
+
+ // Only keep compressed version if smaller.
+ if (compressed_content.uncompressed_size >
+ compressed_content.buffer.size()) {
+ resource->content.clear();
+ resource->compressed_content.reset(new CompressedBufferT);
+ *resource->compressed_content = compressed_content;
+ }
+ }
+ }
+ return true;
+}
+
+std::string CompressSerializedResources(const std::string& resources,
+ const int dictionary_sample_every) {
+ std::unique_ptr<ResourcePoolT> unpacked_resources(
+ flatbuffers::GetRoot<ResourcePool>(resources.data())->UnPack());
+ TC3_CHECK(unpacked_resources != nullptr);
+ TC3_CHECK(
+ CompressResources(unpacked_resources.get(), dictionary_sample_every));
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(ResourcePool::Pack(builder, unpacked_resources.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+bool DecompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary) {
+ std::vector<unsigned char> dictionary;
+
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->compressed_content == nullptr) {
+ continue;
+ }
+
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ build_compression_dictionary
+ ? ZlibDecompressor::Instance(dictionary.data(), dictionary.size())
+ : ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ if (!zlib_decompressor->MaybeDecompress(
+ resource->compressed_content.get(), &resource->content)) {
+ TC3_LOG(ERROR) << "Cannot decompress resource.";
+ return false;
+ }
+ resource->compressed_content.reset(nullptr);
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/resources.fbs b/native/utils/resources.fbs
new file mode 100755
index 0000000..aae57cf
--- /dev/null
+++ b/native/utils/resources.fbs
@@ -0,0 +1,39 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/i18n/language-tag.fbs";
+include "utils/zlib/buffer.fbs";
+
+namespace libtextclassifier3;
+table Resource {
+ locale:[int];
+ content:string (shared);
+ compressed_content:CompressedBuffer;
+}
+
+namespace libtextclassifier3;
+table ResourceEntry {
+ name:string (key, shared);
+ resource:[Resource];
+}
+
+namespace libtextclassifier3;
+table ResourcePool {
+ locale:[LanguageTag];
+ resource_entry:[ResourceEntry];
+ compression_dictionary:[ubyte];
+}
+
diff --git a/native/utils/resources.h b/native/utils/resources.h
new file mode 100644
index 0000000..96f9683
--- /dev/null
+++ b/native/utils/resources.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
+#define LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
+
+#include <vector>
+
+#include "utils/i18n/language-tag_generated.h"
+#include "utils/i18n/locale.h"
+#include "utils/resources_generated.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Class for accessing localized model resources.
+class Resources {
+ public:
+ explicit Resources(const ResourcePool* resources) : resources_(resources) {}
+
+ // Returns the string value associated with the particular resource.
+ // `locales` are locales in preference order.
+ bool GetResourceContent(const std::vector<Locale>& locales,
+ const StringPiece resource_name,
+ std::string* result) const;
+
+ private:
+ // Match priorities: language > script > region with wildcard matches being
+ // weaker than an exact match.
+ // For a resource lookup, at least language needs to (weakly) match.
+ // c.f. developer.android.com/guide/topics/resources/multilingual-support
+ enum LocaleMatch {
+ LOCALE_NO_MATCH = 0,
+ LOCALE_REGION_WILDCARD_MATCH = 1 << 0,
+ LOCALE_REGION_MATCH = 1 << 1,
+ LOCALE_SCRIPT_WILDCARD_MATCH = 1 << 2,
+ LOCALE_SCRIPT_MATCH = 1 << 3,
+ LOCALE_LANGUAGE_WILDCARD_MATCH = 1 << 4,
+ LOCALE_LANGUAGE_MATCH = 1 << 5
+ };
+ int LocaleMatch(const Locale& locale, const LanguageTag* entry_locale) const;
+
+ // Finds a resource entry by name.
+ const ResourceEntry* FindResource(const StringPiece resource_name) const;
+
+ // Finds the best locale matching resource from a resource entry.
+ int BestResourceForLocales(const ResourceEntry* resource,
+ const std::vector<Locale>& locales) const;
+
+ const ResourcePool* resources_;
+};
+
+// Compresses resources in place.
+bool CompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary = false,
+ const int dictionary_sample_every = 1);
+std::string CompressSerializedResources(
+ const std::string& resources,
+ const bool build_compression_dictionary = false,
+ const int dictionary_sample_every = 1);
+
+bool DecompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary = false);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/utils/resources_test.cc b/native/utils/resources_test.cc
similarity index 100%
rename from utils/resources_test.cc
rename to native/utils/resources_test.cc
diff --git a/native/utils/sentencepiece/encoder.cc b/native/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..2754ab7
--- /dev/null
+++ b/native/utils/sentencepiece/encoder.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+bool Encoder::Encode(StringPiece normalized_text,
+ std::vector<int>* encoded_text) const {
+ const int len = normalized_text.size();
+ if (len <= 0) {
+ *encoded_text = {start_code_, end_code_};
+ return true;
+ }
+ // We use `previous_pos` to indicate whether a dynamic programming state was
+ // reachable.
+ std::vector<SegmentationEntry> segmentation(
+ len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+ /*num_pieces=*/0});
+ for (int i = 0; i < len; i++) {
+ // State couldn't be reached.
+ if (i > 0 && segmentation[i].previous_pos < 0) {
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ continue;
+ }
+ // Check whether we can use the unknown token.
+ if (unknown_code_ >= 0) {
+ const int pos = i + 1;
+ const float unknown_penalty = segmentation[i].score + unknown_score_;
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < unknown_penalty) {
+ // Merge multiple unknown tokens into one.
+ if (segmentation[i].piece_id == unknown_code_) {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/segmentation[i].previous_pos,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces};
+ } else {
+ segmentation[pos] = {/*score=*/unknown_penalty,
+ /*previous_pos=*/i,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ }
+ std::vector<StringSet::Match> matches;
+ if (!pieces_->FindAllPrefixMatches(normalized_text, &matches)) {
+ TC3_LOG(ERROR)
+ << "Couldn't successfully gather prefix sentence piece matches.";
+ return false;
+ }
+ for (const auto& match : matches) {
+ TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+ const int pos = i + match.match_length;
+ const float candidate_score = segmentation[i].score + scores_[match.id];
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < candidate_score) {
+ segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+ /*piece_id=*/match.id + encoding_offset_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ }
+ if (segmentation[len].num_pieces <= 0) {
+ *encoded_text = {start_code_, end_code_};
+ return true;
+ }
+ const int num_pieces = segmentation[len].num_pieces;
+ encoded_text->resize(num_pieces + 2);
+ (*encoded_text)[num_pieces + 1] = end_code_;
+ int pos = len;
+ for (int i = num_pieces; i > 0; i--) {
+ (*encoded_text)[i] = segmentation[pos].piece_id;
+ pos = segmentation[pos].previous_pos;
+ }
+ (*encoded_text)[0] = start_code_;
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/encoder.h b/native/utils/sentencepiece/encoder.h
new file mode 100644
index 0000000..304c4e5
--- /dev/null
+++ b/native/utils/sentencepiece/encoder.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into pieces such that the sum of the
+// scores of the pieces used is maximized.
+class Encoder {
+ public:
+ // pieces: the list of valid sentence pieces represented as a string set, e.g.
+ // a trie.
+ // num_pieces: the number of pieces in the trie.
+ // pieces_scores: the scores of the individual pieces.
+ // start_code: code that is used as encoding of the start of input.
+ // end_code: code that is used as encoding of the end of input.
+ // encoding_offset: value added to the sentence piece ids to make them
+ // not interesecting with start_code and end_code.
+ // unknown_code: code that is used for out-of-dictionary characters.
+ // unknown_score: the penality score associated with the unknown code.
+ Encoder(const StringSet* pieces, const int num_pieces,
+ const float* pieces_scores, int start_code = 0, int end_code = 1,
+ int encoding_offset = 2, int unknown_code = -1,
+ float unknown_score = 0.f)
+ : num_pieces_(num_pieces),
+ scores_(pieces_scores),
+ pieces_(pieces),
+ start_code_(start_code),
+ end_code_(end_code),
+ encoding_offset_(encoding_offset),
+ unknown_code_(unknown_code),
+ unknown_score_(unknown_score) {}
+
+ // Segment the input so that the total score of the pieces used is maximized.
+ // This is a simplified implementation of the general Viterbi algorithm,
+ // assuming independence between individual pieces.
+ bool Encode(StringPiece normalized_text,
+ std::vector<int>* encoded_text) const;
+
+ private:
+ // State in the dynamic programming algorithm.
+ struct SegmentationEntry {
+ // Accumulated score.
+ float score;
+
+ // Position before last piece.
+ int previous_pos;
+
+ // Last piece used.
+ int piece_id;
+
+ // Total number of pieces used.
+ int num_pieces;
+ };
+
+ const int num_pieces_;
+ const float* scores_;
+ const StringSet* pieces_;
+ const int start_code_;
+ const int end_code_;
+ const int encoding_offset_;
+ const int unknown_code_;
+ const int unknown_score_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/native/utils/sentencepiece/encoder_test.cc b/native/utils/sentencepiece/encoder_test.cc
new file mode 100644
index 0000000..740db35
--- /dev/null
+++ b/native/utils/sentencepiece/encoder_test.cc
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/sorted-strings-table.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(EncoderTest, SimpleTokenization) {
+ const char pieces_table[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<StringSet> pieces(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
+ const Encoder encoder(pieces.get(),
+ /*num_pieces=*/4, scores);
+
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 5, 1));
+ }
+
+ // Make probability of hello very low:
+ // hello gets now tokenized as hell + o.
+ scores[1] = -100.0;
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 2, 4, 5, 1));
+ }
+}
+
+TEST(EncoderTest, HandlesEdgeCases) {
+ const char pieces_table[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<StringSet> pieces(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
+ const Encoder encoder(pieces.get(),
+ /*num_pieces=*/4, scores);
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 2, 3, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 2, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+}
+
+TEST(EncoderTest, HandlesOutOfDictionary) {
+ const char pieces_table[] = "hell\0hello\0o\0there\0";
+ const uint32 offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<StringSet> pieces(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces_table, 18)));
+ const Encoder encoder(pieces.get(),
+ /*num_pieces=*/4, scores,
+ /*start_code=*/0, /*end_code=*/1,
+ /*encoding_offset=*/3, /*unknown_code=*/2,
+ /*unknown_score=*/-100.0);
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 3, 4, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 4, 3, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(0, 1));
+ }
+ {
+ std::vector<int> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
+ EXPECT_THAT(encoded_text,
+ ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
+ }
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/normalizer.cc b/native/utils/sentencepiece/normalizer.cc
new file mode 100644
index 0000000..4cee507
--- /dev/null
+++ b/native/utils/sentencepiece/normalizer.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/normalizer.h"
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+bool SentencePieceNormalizer::Normalize(StringPiece input,
+ std::string* normalized_input) const {
+ // Ignores heading space.
+ if (remove_extra_whitespaces_) {
+ while (!input.empty()) {
+ std::pair<StringPiece, int> suffix_and_length;
+ if (!NormalizePrefix(input, &suffix_and_length)) {
+ TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
+ return false;
+ }
+ if (suffix_and_length.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return false;
+ }
+ if (suffix_and_length.first.size() != 1 ||
+ suffix_and_length.first[0] != ' ') {
+ break;
+ }
+ input.RemovePrefix(suffix_and_length.second);
+ }
+ }
+
+ if (input.empty()) {
+ *normalized_input = "";
+ return true;
+ }
+
+ // Reserves the output buffer to avoid re-allocations.
+ const int kReservedSize = input.size() * 3;
+ normalized_input->reserve(kReservedSize);
+
+ // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
+ // if escape_whitespaces() is set (default = true).
+ const StringPiece kSpaceSymbol = "\xe2\x96\x81";
+
+ // Adds a space symbol as a prefix (default is true)
+ // With this prefix, "world" and "hello world" are converted into
+ // "_world" and "_hello_world", which help the trainer to extract
+ // "_world" as one symbol.
+ if (add_dummy_prefix_) {
+ if (escape_whitespaces_) {
+ normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ normalized_input->append(" ");
+ }
+ }
+
+ bool is_prev_space = remove_extra_whitespaces_;
+ while (!input.empty()) {
+ std::pair<StringPiece, int> p;
+ if (!NormalizePrefix(input, &p)) {
+ TC3_LOG(ERROR) << "Couldn't normalize string.";
+ return false;
+ }
+ if (p.second <= 0) {
+ TC3_LOG(ERROR) << "Consumed string is empty.";
+ return false;
+ }
+
+ StringPiece sp = p.first;
+
+ // Removes heading spaces in sentence piece,
+ // if the previous sentence piece ends with whitespace.
+ while (is_prev_space && ConsumePrefix(&sp, " ")) {
+ }
+
+ if (!sp.empty()) {
+ const char* data = sp.data();
+ for (int n = 0; n < sp.size(); ++n) {
+ if (escape_whitespaces_ && data[n] == ' ') {
+ normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
+ } else {
+ *normalized_input += data[n];
+ }
+ }
+ // Checks whether the last character of sp is whitespace.
+ is_prev_space = EndsWith(sp, " ");
+ }
+ input.RemovePrefix(p.second);
+ is_prev_space = is_prev_space && remove_extra_whitespaces_;
+ }
+
+ // Ignores tailing space.
+ if (remove_extra_whitespaces_) {
+ const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
+ while (EndsWith(*normalized_input, space)) {
+ const int length = normalized_input->size() - space.size();
+ normalized_input->resize(length);
+ }
+ }
+ return true;
+}
+
+bool SentencePieceNormalizer::NormalizePrefix(
+ StringPiece input, std::pair<StringPiece, int>* prefix) const {
+ if (input.empty()) return true;
+ StringSet::Match match;
+ if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
+ TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
+ return false;
+ }
+ const bool no_match = match.match_length <= 0;
+ if (no_match) {
+ const int char_length = ValidUTF8CharLength(input.data(), input.size());
+ if (char_length <= 0) {
+ // Found a malformed utf8.
+ // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
+ // which is a valid Unicode of three bytes in utf8,
+ // but here we only consume one byte.
+ static const char kReplacementChar[] = "\xEF\xBF\xBD";
+ prefix->first = StringPiece(kReplacementChar, 3);
+ prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD.
+ } else {
+ prefix->first = StringPiece(input.data(), char_length);
+ prefix->second = char_length;
+ }
+ } else {
+ if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
+ TC3_LOG(ERROR) << "Invalid entry in normalization table.";
+ return false;
+ }
+ prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
+ prefix->second = match.match_length;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/sentencepiece/normalizer.h b/native/utils/sentencepiece/normalizer.h
new file mode 100644
index 0000000..0dea60d
--- /dev/null
+++ b/native/utils/sentencepiece/normalizer.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
+
+#include <memory>
+#include <string>
+
+#include "utils/container/double-array-trie.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Normalizer implements a simple text normalizer with user-defined
+// string-to-string rules and leftmost longest matching.
+class SentencePieceNormalizer {
+ public:
+ // charsmap_trie and charsmap_normalized specify the normalization/replacement
+ // string-to-string rules in the following way:
+ // A match in the trie for a string will return the offset in
+ // charsmap_normalized that contains the replacement string.
+ //
+ // add_dummy_prefix: Whether to add dummy whitespace at the beginning of the
+ // text in order to treat "world" in "world" and "hello world" uniformly.
+ //
+ // remove_extra_whitespaces: Whether to remove leading, trailing and duplicate
+ // internal whitespace.
+ //
+ // escape_whitespaces: Whether to replace whitespace with a meta symbol.
+ SentencePieceNormalizer(const DoubleArrayTrie& charsmap_trie,
+ StringPiece charsmap_normalized,
+ bool add_dummy_prefix = true,
+ bool remove_extra_whitespaces = true,
+ bool escape_whitespaces = true)
+ : charsmap_trie_(charsmap_trie),
+ charsmap_normalized_(charsmap_normalized),
+ add_dummy_prefix_(add_dummy_prefix),
+ remove_extra_whitespaces_(remove_extra_whitespaces),
+ escape_whitespaces_(escape_whitespaces) {}
+
+ // Normalizes a plain utf8 string into an internal representation for
+ // Sentencepiece model.
+ bool Normalize(StringPiece input, std::string* normalized_input) const;
+
+ private:
+ // Normalizes the prefix of `input` and returns the pair of
+ // normalized prefix and the length of the prefix of `input` processed in the
+ // normalization.
+ bool NormalizePrefix(StringPiece input,
+ std::pair<StringPiece, int>* prefix) const;
+
+ // Internal trie for efficient longest prefix string matching.
+ DoubleArrayTrie charsmap_trie_;
+
+ // "\0" delimitered concatenated normalized strings.
+ // the value of `charsmap_trie_` stores offsets into this string.
+ StringPiece charsmap_normalized_;
+
+ const bool add_dummy_prefix_;
+ const bool remove_extra_whitespaces_;
+ const bool escape_whitespaces_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
diff --git a/native/utils/strings/append.cc b/native/utils/strings/append.cc
new file mode 100644
index 0000000..ec0346f
--- /dev/null
+++ b/native/utils/strings/append.cc
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/append.h"
+
+#include <stdarg.h>
+
+#include <cstring>
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace strings {
+
+void SStringAppendV(std::string *strp, int bufsize, const char *fmt,
+ va_list arglist) {
+ int capacity = bufsize;
+ if (capacity <= 0) {
+ va_list backup;
+ va_copy(backup, arglist);
+ capacity = vsnprintf(nullptr, 0, fmt, backup);
+ va_end(arglist);
+ }
+
+ size_t start = strp->size();
+ strp->resize(strp->size() + capacity + 1);
+
+ int written = vsnprintf(&(*strp)[start], capacity + 1, fmt, arglist);
+ va_end(arglist);
+ strp->resize(start + std::min(capacity, written));
+}
+
+void SStringAppendF(std::string *strp,
+ int bufsize,
+ const char *fmt, ...) {
+ va_list arglist;
+ va_start(arglist, fmt);
+ SStringAppendV(strp, bufsize, fmt, arglist);
+}
+
+std::string StringPrintf(const char* fmt, ...) {
+ std::string s;
+ va_list arglist;
+ va_start(arglist, fmt);
+ SStringAppendV(&s, 0, fmt, arglist);
+ return s;
+}
+
+std::string JoinStrings(const char *delim,
+ const std::vector<std::string> &vec) {
+ int delim_len = strlen(delim);
+
+ // Calc size.
+ int out_len = 0;
+ for (size_t i = 0; i < vec.size(); i++) {
+ out_len += vec[i].size() + delim_len;
+ }
+
+ // Write out.
+ std::string ret;
+ ret.reserve(out_len);
+ for (size_t i = 0; i < vec.size(); i++) {
+ ret.append(vec[i]);
+ ret.append(delim, delim_len);
+ }
+
+ // Strip last delimiter.
+ if (!ret.empty()) {
+ // Must be at least delim_len.
+ ret.resize(ret.size() - delim_len);
+ }
+ return ret;
+}
+
+} // namespace strings
+} // namespace libtextclassifier3
diff --git a/native/utils/strings/append.h b/native/utils/strings/append.h
new file mode 100644
index 0000000..4b4d0b0
--- /dev/null
+++ b/native/utils/strings/append.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
+
+#include <string>
+#include <vector>
+
+namespace libtextclassifier3 {
+namespace strings {
+
+// Append vsnprintf to strp. If bufsize hint is > 0 it is
+// used. Otherwise we compute the required bufsize (which is somewhat
+// expensive).
+void SStringAppendV(std::string *strp, int bufsize, const char *fmt,
+ va_list arglist);
+
+void SStringAppendF(std::string *strp, int bufsize, const char *fmt, ...)
+ __attribute__((format(printf, 3, 4)));
+
+std::string StringPrintf(const char *fmt, ...)
+ __attribute__((format(printf, 1, 2)));
+
+std::string JoinStrings(const char *delim, const std::vector<std::string> &vec);
+
+} // namespace strings
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_APPEND_H_
diff --git a/native/utils/strings/append_test.cc b/native/utils/strings/append_test.cc
new file mode 100644
index 0000000..8950761
--- /dev/null
+++ b/native/utils/strings/append_test.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/append.h"
+
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace strings {
+
+TEST(StringUtilTest, SStringAppendF) {
+ std::string str;
+ SStringAppendF(&str, 5, "%d %d", 0, 1);
+ EXPECT_EQ(str, "0 1");
+
+ SStringAppendF(&str, 1, "%d", 9);
+ EXPECT_EQ(str, "0 19");
+
+ SStringAppendF(&str, 1, "%d", 10);
+ EXPECT_EQ(str, "0 191");
+
+ str.clear();
+
+ SStringAppendF(&str, 5, "%d", 100);
+ EXPECT_EQ(str, "100");
+}
+
+TEST(StringUtilTest, SStringAppendFBufCalc) {
+ std::string str;
+ SStringAppendF(&str, 0, "%d %s %d", 1, "hello", 2);
+ EXPECT_EQ(str, "1 hello 2");
+}
+
+TEST(StringUtilTest, JoinStrings) {
+ std::vector<std::string> vec;
+ vec.push_back("1");
+ vec.push_back("2");
+ vec.push_back("3");
+
+ EXPECT_EQ("1,2,3", JoinStrings(",", vec));
+ EXPECT_EQ("123", JoinStrings("", vec));
+ EXPECT_EQ("1, 2, 3", JoinStrings(", ", vec));
+ EXPECT_EQ("", JoinStrings(",", std::vector<std::string>()));
+}
+
+} // namespace strings
+} // namespace libtextclassifier3
diff --git a/native/utils/strings/numbers.cc b/native/utils/strings/numbers.cc
new file mode 100644
index 0000000..39ff1fd
--- /dev/null
+++ b/native/utils/strings/numbers.cc
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/numbers.h"
+
+#ifdef COMPILER_MSVC
+#include <sstream>
+#endif // COMPILER_MSVC
+
+#include <stdlib.h>
+
+namespace libtextclassifier3 {
+
+// This conversion is only valid for numerical base is 10 (radix)
+bool ParseInt32(const char *c_str, int32 *value) {
+ char *temp;
+ // Short version of man strtol:
+ //
+ // strtol parses some optional whitespaces, an optional +/- sign, and next a
+ // succession of digits. If it finds some digits, it sets temp to point to
+ // the first character after that succession of digits and returns the parsed
+ // integer.
+ //
+ // If there were no digits at all, strtol() sets temp to be c_str (the start
+ // address) and returns 0.
+ // Explicitly setting this to base 10 as 0 means the base used is determined
+ // by the format which can cause problems
+ *value = strtol(c_str, &temp, 10); // NOLINT
+
+ // temp != c_str means that the input string contained at least one digit (see
+ // above). *temp == '\0' means the input string does not contain any random
+ // chars after the number.
+ return (temp != c_str) && (*temp == '\0');
+}
+
+// This conversion is only valid for numerical base is 10 (radix)
+bool ParseInt64(const char *c_str, int64 *value) {
+ char *temp;
+
+ // Explicitly setting this to base 10 as 0 means the base used is determined
+ // by the format which can cause problems
+ *value = strtoll(c_str, &temp, 10); // NOLINT
+
+ // See comments inside ParseInt32.
+ return (temp != c_str) && (*temp == '\0');
+}
+
+bool ParseDouble(const char *c_str, double *value) {
+ char *temp;
+ *value = strtod(c_str, &temp);
+
+ // See comments inside ParseInt32.
+ return (temp != c_str) && (*temp == '\0');
+}
+
+#ifdef COMPILER_MSVC
+std::string IntToString(int64 input) {
+ std::stringstream stream;
+ stream << input;
+ return stream.str();
+}
+#else
+std::string IntToString(int64 input) {
+ return std::to_string(input);
+}
+#endif // COMPILER_MSVC
+
+} // namespace libtextclassifier3
diff --git a/utils/strings/numbers.h b/native/utils/strings/numbers.h
similarity index 100%
rename from utils/strings/numbers.h
rename to native/utils/strings/numbers.h
diff --git a/native/utils/strings/numbers_test.cc b/native/utils/strings/numbers_test.cc
new file mode 100644
index 0000000..bf2f84a
--- /dev/null
+++ b/native/utils/strings/numbers_test.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/numbers.h"
+
+#include "utils/base/integral_types.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+void TestParseInt32(const char *c_str, bool expected_parsing_success,
+ int32 expected_parsed_value = 0) {
+ int32 parsed_value = 0;
+ EXPECT_EQ(expected_parsing_success, ParseInt32(c_str, &parsed_value));
+ if (expected_parsing_success) {
+ EXPECT_EQ(expected_parsed_value, parsed_value);
+ }
+}
+
+TEST(ParseInt32Test, Normal) {
+ TestParseInt32("2", true, 2);
+ TestParseInt32("-357", true, -357);
+ TestParseInt32("7", true, 7);
+ TestParseInt32("+7", true, 7);
+ TestParseInt32(" +7", true, 7);
+ TestParseInt32("-23", true, -23);
+ TestParseInt32(" -23", true, -23);
+ TestParseInt32("04", true, 4);
+ TestParseInt32("07", true, 7);
+ TestParseInt32("08", true, 8);
+ TestParseInt32("09", true, 9);
+}
+
+TEST(ParseInt32Test, ErrorCases) {
+ TestParseInt32("", false);
+ TestParseInt32(" ", false);
+ TestParseInt32("not-a-number", false);
+ TestParseInt32("123a", false);
+}
+
+void TestParseInt64(const char *c_str, bool expected_parsing_success,
+ int64 expected_parsed_value = 0) {
+ int64 parsed_value = 0;
+ EXPECT_EQ(expected_parsing_success, ParseInt64(c_str, &parsed_value));
+ if (expected_parsing_success) {
+ EXPECT_EQ(expected_parsed_value, parsed_value);
+ }
+}
+
+TEST(ParseInt64Test, Normal) {
+ TestParseInt64("2", true, 2);
+ TestParseInt64("-357", true, -357);
+ TestParseInt64("7", true, 7);
+ TestParseInt64("+7", true, 7);
+ TestParseInt64(" +7", true, 7);
+ TestParseInt64("-23", true, -23);
+ TestParseInt64(" -23", true, -23);
+ TestParseInt64("07", true, 7);
+ TestParseInt64("08", true, 8);
+}
+
+TEST(ParseInt64Test, ErrorCases) {
+ TestParseInt64("", false);
+ TestParseInt64(" ", false);
+ TestParseInt64("not-a-number", false);
+ TestParseInt64("23z", false);
+}
+
+void TestParseDouble(const char *c_str, bool expected_parsing_success,
+ double expected_parsed_value = 0.0) {
+ double parsed_value = 0.0;
+ EXPECT_EQ(expected_parsing_success, ParseDouble(c_str, &parsed_value));
+ if (expected_parsing_success) {
+ EXPECT_NEAR(expected_parsed_value, parsed_value, 0.00001);
+ }
+}
+
+TEST(ParseDoubleTest, Normal) {
+ TestParseDouble("2", true, 2.0);
+ TestParseDouble("-357.023", true, -357.023);
+ TestParseDouble("7.04", true, 7.04);
+ TestParseDouble("+7.2", true, 7.2);
+ TestParseDouble(" +7.236", true, 7.236);
+ TestParseDouble("-23.4", true, -23.4);
+ TestParseDouble(" -23.4", true, -23.4);
+}
+
+TEST(ParseDoubleTest, ErrorCases) {
+ TestParseDouble("", false);
+ TestParseDouble(" ", false);
+ TestParseDouble("not-a-number", false);
+ TestParseDouble("23.5a", false);
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/strings/split.cc b/native/utils/strings/split.cc
similarity index 100%
rename from utils/strings/split.cc
rename to native/utils/strings/split.cc
diff --git a/native/utils/strings/split.h b/native/utils/strings/split.h
new file mode 100644
index 0000000..98d066e
--- /dev/null
+++ b/native/utils/strings/split.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
+
+#include <string>
+#include <vector>
+
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace strings {
+
+std::vector<StringPiece> Split(const StringPiece &text, char delim);
+// Delete overload that takes r-value string, to avoid common pitfalls like:
+// Split(GetSomeTransientString())
+std::vector<StringPiece> Split(const std::string &&text, char delim) = delete;
+
+} // namespace strings
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
diff --git a/native/utils/strings/stringpiece.h b/native/utils/strings/stringpiece.h
new file mode 100644
index 0000000..6160ead
--- /dev/null
+++ b/native/utils/strings/stringpiece.h
@@ -0,0 +1,179 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
+
+#include <cstddef>
+#include <string>
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+// Read-only "view" of a piece of data. Does not own the underlying data.
+class StringPiece {
+ public:
+ static constexpr size_t npos = static_cast<size_t>(-1);
+
+ StringPiece() : StringPiece(nullptr, 0) {}
+
+ StringPiece(const char* str) // NOLINT(runtime/explicit)
+ : start_(str), size_(str == nullptr ? 0 : strlen(str)) {}
+
+ StringPiece(const char* start, size_t size) : start_(start), size_(size) {}
+
+ // Intentionally no "explicit" keyword: in function calls, we want strings to
+ // be converted to StringPiece implicitly.
+ StringPiece(const std::string& s) // NOLINT(runtime/explicit)
+ : StringPiece(s.data(), s.size()) {}
+
+ StringPiece(const std::string& s, int offset, int len)
+ : StringPiece(s.data() + offset, len) {}
+
+ char operator[](size_t i) const { return start_[i]; }
+
+ // Returns start address of underlying data.
+ const char* data() const { return start_; }
+
+ // Returns number of bytes of underlying data.
+ size_t size() const { return size_; }
+ size_t length() const { return size_; }
+
+ bool empty() const { return size_ == 0; }
+
+ // Returns a std::string containing a copy of the underlying data.
+ std::string ToString() const { return std::string(data(), size()); }
+
+ // Returns whether string ends with a given suffix.
+ bool EndsWith(StringPiece suffix) const {
+ return suffix.empty() || (size_ >= suffix.size() &&
+ memcmp(start_ + (size_ - suffix.size()),
+ suffix.data(), suffix.size()) == 0);
+ }
+
+ // Returns whether the string begins with a given prefix.
+ bool StartsWith(StringPiece prefix) const {
+ return prefix.empty() ||
+ (size_ >= prefix.size() &&
+ memcmp(start_, prefix.data(), prefix.size()) == 0);
+ }
+
+ bool Equals(StringPiece other) const {
+ return size() == other.size() && memcmp(start_, other.data(), size_) == 0;
+ }
+
+ // Removes the first `n` characters from the string piece. Note that the
+ // underlying string is not changed, only the view.
+ void RemovePrefix(int n) {
+ TC3_CHECK_LE(n, size_);
+ start_ += n;
+ size_ -= n;
+ }
+
+ // Removes the last `n` characters from the string piece. Note that the
+ // underlying string is not changed, only the view.
+ void RemoveSuffix(int n) {
+ TC3_CHECK_LE(n, size_);
+ size_ -= n;
+ }
+
+ // Finds the first occurrence of the substring `s` within the `StringPiece`,
+ // returning the position of the first character's match, or `npos` if no
+ // match was found.
+ // Here
+ // - c is the char to search for in the StringPiece
+ // - pos is the position at which to start the search.
+ size_t find(char c, size_t pos = 0) const noexcept {
+ if (empty() || pos >= size_) {
+ return npos;
+ }
+ const char* result =
+ static_cast<const char*>(memchr(start_ + pos, c, size_ - pos));
+ return result != nullptr ? result - start_ : npos;
+ }
+
+ size_t find(StringPiece s, size_t pos = 0) const noexcept {
+ if (empty() || pos >= size_) {
+ if (empty() && pos == 0 && s.empty()) {
+ return 0;
+ }
+ return npos;
+ }
+ const char* result = memmatch(start_ + pos, size_ - pos, s.start_, s.size_);
+ return result ? result - start_ : npos;
+ }
+
+ private:
+ const char* memmatch(const char* phaystack, size_t haylen,
+ const char* pneedle, size_t neelen) const {
+ if (0 == neelen) {
+ return phaystack; // Even if haylen is 0.
+ }
+ if (haylen < neelen) {
+ return nullptr;
+ }
+
+ const char* match;
+ const char* hayend = phaystack + haylen - neelen + 1;
+ while ((match = static_cast<const char*>(
+ memchr(phaystack, pneedle[0], hayend - phaystack)))) {
+ if (memcmp(match, pneedle, neelen) == 0) {
+ return match;
+ } else {
+ phaystack = match + 1;
+ }
+ }
+ return nullptr;
+ }
+
+ const char* start_; // Not owned.
+ size_t size_;
+};
+
+inline bool EndsWith(StringPiece text, StringPiece suffix) {
+ return text.EndsWith(suffix);
+}
+
+inline bool StartsWith(StringPiece text, StringPiece prefix) {
+ return text.StartsWith(prefix);
+}
+
+inline bool ConsumePrefix(StringPiece* text, StringPiece prefix) {
+ if (!text->StartsWith(prefix)) {
+ return false;
+ }
+ text->RemovePrefix(prefix.size());
+ return true;
+}
+
+inline bool ConsumeSuffix(StringPiece* text, StringPiece suffix) {
+ if (!text->EndsWith(suffix)) {
+ return false;
+ }
+ text->RemoveSuffix(suffix.size());
+ return true;
+}
+
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, StringPiece message) {
+ stream.message.append(message.data(), message.size());
+ return stream;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
diff --git a/native/utils/strings/stringpiece_test.cc b/native/utils/strings/stringpiece_test.cc
new file mode 100644
index 0000000..64808d3
--- /dev/null
+++ b/native/utils/strings/stringpiece_test.cc
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(StringPieceTest, EndsWith) {
+ EXPECT_TRUE(EndsWith("hello there!", "there!"));
+ EXPECT_TRUE(EndsWith("hello there!", "!"));
+ EXPECT_FALSE(EndsWith("hello there!", "there"));
+ EXPECT_FALSE(EndsWith("hello there!", " hello there!"));
+ EXPECT_TRUE(EndsWith("hello there!", ""));
+ EXPECT_FALSE(EndsWith("", "hello there!"));
+}
+
+TEST(StringPieceTest, StartsWith) {
+ EXPECT_TRUE(StartsWith("hello there!", "hello"));
+ EXPECT_TRUE(StartsWith("hello there!", "hello "));
+ EXPECT_FALSE(StartsWith("hello there!", "there!"));
+ EXPECT_FALSE(StartsWith("hello there!", " hello there! "));
+ EXPECT_TRUE(StartsWith("hello there!", ""));
+ EXPECT_FALSE(StartsWith("", "hello there!"));
+}
+
+TEST(StringPieceTest, ConsumePrefix) {
+ StringPiece str("hello there!");
+ EXPECT_TRUE(ConsumePrefix(&str, "hello "));
+ EXPECT_EQ(str.ToString(), "there!");
+ EXPECT_TRUE(ConsumePrefix(&str, "there"));
+ EXPECT_EQ(str.ToString(), "!");
+ EXPECT_FALSE(ConsumePrefix(&str, "!!"));
+ EXPECT_TRUE(ConsumePrefix(&str, ""));
+ EXPECT_TRUE(ConsumePrefix(&str, "!"));
+ EXPECT_EQ(str.ToString(), "");
+ EXPECT_TRUE(ConsumePrefix(&str, ""));
+ EXPECT_FALSE(ConsumePrefix(&str, "!"));
+}
+
+TEST(StringPieceTest, ConsumeSuffix) {
+ StringPiece str("hello there!");
+ EXPECT_TRUE(ConsumeSuffix(&str, "!"));
+ EXPECT_EQ(str.ToString(), "hello there");
+ EXPECT_TRUE(ConsumeSuffix(&str, " there"));
+ EXPECT_EQ(str.ToString(), "hello");
+ EXPECT_FALSE(ConsumeSuffix(&str, "!!"));
+ EXPECT_TRUE(ConsumeSuffix(&str, ""));
+ EXPECT_TRUE(ConsumeSuffix(&str, "hello"));
+ EXPECT_EQ(str.ToString(), "");
+ EXPECT_TRUE(ConsumeSuffix(&str, ""));
+ EXPECT_FALSE(ConsumeSuffix(&str, "!"));
+}
+
+TEST(StringPieceTest, Find) {
+ StringPiece str("<hello there!>");
+ EXPECT_EQ(str.find('<'), 0);
+ EXPECT_EQ(str.find('>'), str.length() - 1);
+ EXPECT_EQ(str.find('?'), StringPiece::npos);
+ EXPECT_EQ(str.find('<', str.length() - 1), StringPiece::npos);
+ EXPECT_EQ(str.find('<', 0), 0);
+ EXPECT_EQ(str.find('>', str.length() - 1), str.length() - 1);
+}
+
+TEST(StringPieceTest, FindStringPiece) {
+ StringPiece str("<foo bar baz!>");
+ EXPECT_EQ(str.find("foo"), 1);
+ EXPECT_EQ(str.find("bar"), 5);
+ EXPECT_EQ(str.find("baz"), 9);
+ EXPECT_EQ(str.find("qux"), StringPiece::npos);
+ EXPECT_EQ(str.find("?"), StringPiece::npos);
+ EXPECT_EQ(str.find(">"), str.length() - 1);
+ EXPECT_EQ(str.find("<", str.length() - 1), StringPiece::npos);
+ EXPECT_EQ(str.find("<", 0), 0);
+ EXPECT_EQ(str.find(">", str.length() - 1), str.length() - 1);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/strings/substitute.cc b/native/utils/strings/substitute.cc
similarity index 100%
rename from utils/strings/substitute.cc
rename to native/utils/strings/substitute.cc
diff --git a/utils/strings/substitute.h b/native/utils/strings/substitute.h
similarity index 100%
rename from utils/strings/substitute.h
rename to native/utils/strings/substitute.h
diff --git a/utils/strings/substitute_test.cc b/native/utils/strings/substitute_test.cc
similarity index 100%
rename from utils/strings/substitute_test.cc
rename to native/utils/strings/substitute_test.cc
diff --git a/native/utils/strings/utf8.cc b/native/utils/strings/utf8.cc
new file mode 100644
index 0000000..932e2a5
--- /dev/null
+++ b/native/utils/strings/utf8.cc
@@ -0,0 +1,125 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/utf8.h"
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+bool IsValidUTF8(const char *src, int size) {
+ for (int i = 0; i < size;) {
+ const int char_length = ValidUTF8CharLength(src + i, size - i);
+ if (char_length <= 0) {
+ return false;
+ }
+ i += char_length;
+ }
+ return true;
+}
+
+int ValidUTF8CharLength(const char *src, int size) {
+ // Unexpected trail byte.
+ if (IsTrailByte(src[0])) {
+ return -1;
+ }
+
+ const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[0]);
+ if (num_codepoint_bytes <= 0 || num_codepoint_bytes > size) {
+ return -1;
+ }
+
+ // Check that remaining bytes in the codepoint are trailing bytes.
+ for (int k = 1; k < num_codepoint_bytes; k++) {
+ if (!IsTrailByte(src[k])) {
+ return -1;
+ }
+ }
+
+ return num_codepoint_bytes;
+}
+
+int SafeTruncateLength(const char *str, int truncate_at) {
+ // Always want to truncate at the start of a character, so if
+ // it's in a middle, back up toward the start
+ while (IsTrailByte(str[truncate_at]) && (truncate_at > 0)) {
+ truncate_at--;
+ }
+ return truncate_at;
+}
+
+char32 ValidCharToRune(const char *str) {
+ TC3_DCHECK(!IsTrailByte(str[0]) && GetNumBytesForUTF8Char(str) > 0);
+
+ // Convert from UTF-8
+ unsigned char byte1 = static_cast<unsigned char>(str[0]);
+ if (byte1 < 0x80) {
+ // One character sequence: 00000 - 0007F.
+ return byte1;
+ }
+
+ unsigned char byte2 = static_cast<unsigned char>(str[1]);
+ if (byte1 < 0xE0) {
+ // Two character sequence: 00080 - 007FF.
+ return ((byte1 & 0x1F) << 6) | (byte2 & 0x3F);
+ }
+
+ unsigned char byte3 = static_cast<unsigned char>(str[2]);
+ if (byte1 < 0xF0) {
+ // Three character sequence: 00800 - 0FFFF.
+ return ((byte1 & 0x0F) << 12) | ((byte2 & 0x3F) << 6) | (byte3 & 0x3F);
+ }
+
+ unsigned char byte4 = static_cast<unsigned char>(str[3]);
+ // Four character sequence: 10000 - 1FFFF.
+ return ((byte1 & 0x07) << 18) | ((byte2 & 0x3F) << 12) |
+ ((byte3 & 0x3F) << 6) | (byte4 & 0x3F);
+}
+
+int ValidRuneToChar(const char32 rune, char *dest) {
+ // Convert to unsigned for range check.
+ uint32 c;
+
+ // 1 char 00-7F
+ c = rune;
+ if (c <= 0x7F) {
+ dest[0] = static_cast<char>(c);
+ return 1;
+ }
+
+ // 2 char 0080-07FF
+ if (c <= 0x07FF) {
+ dest[0] = 0xC0 | static_cast<char>(c >> 1 * 6);
+ dest[1] = 0x80 | (c & 0x3F);
+ return 2;
+ }
+
+ // 3 char 0800-FFFF
+ if (c <= 0xFFFF) {
+ dest[0] = 0xE0 | static_cast<char>(c >> 2 * 6);
+ dest[1] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[2] = 0x80 | (c & 0x3F);
+ return 3;
+ }
+
+ // 4 char 10000-1FFFFF
+ dest[0] = 0xF0 | static_cast<char>(c >> 3 * 6);
+ dest[1] = 0x80 | ((c >> 2 * 6) & 0x3F);
+ dest[2] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[3] = 0x80 | (c & 0x3F);
+ return 4;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/strings/utf8.h b/native/utils/strings/utf8.h
new file mode 100644
index 0000000..e871731
--- /dev/null
+++ b/native/utils/strings/utf8.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
+#define LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
+
+#include "utils/base/integral_types.h"
+
+namespace libtextclassifier3 {
+
+// Returns the length (number of bytes) of the Unicode code point starting at
+// src, based on inspecting just that one byte. Preconditions: src != NULL,
+// *src can be read.
+static inline int GetNumBytesForUTF8Char(const char *src) {
+ // On most platforms, char is unsigned by default, but iOS is an exception.
+ // The cast below makes sure we always interpret *src as an unsigned char.
+ return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
+ [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
+}
+
+// Returns true if this byte is a trailing UTF-8 byte (10xx xxxx)
+static inline bool IsTrailByte(char x) {
+ // return (x & 0xC0) == 0x80;
+ // Since trail bytes are always in [0x80, 0xBF], we can optimize:
+ return static_cast<signed char>(x) < -0x40;
+}
+
+// Returns true iff src points to a well-formed UTF-8 string.
+bool IsValidUTF8(const char *src, int size);
+
+// Returns byte length of the first valid codepoint in the string, otherwise -1
+// if pointing to an ill-formed UTF-8 character.
+int ValidUTF8CharLength(const char *src, int size);
+
+// Helper to ensure that strings are not truncated in the middle of
+// multi-byte UTF-8 characters.
+// Given a string, and a position at which to truncate, returns the
+// last position not after the provided cut point, that would truncate a
+// full character.
+int SafeTruncateLength(const char *str, int truncate_at);
+
+// Gets a unicode codepoint from a valid utf8 encoding.
+char32 ValidCharToRune(const char *str);
+
+// Converts a valid codepoint to utf8.
+// Returns the length of the encoding.
+int ValidRuneToChar(const char32 rune, char *dest);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
diff --git a/native/utils/strings/utf8_test.cc b/native/utils/strings/utf8_test.cc
new file mode 100644
index 0000000..28d971b
--- /dev/null
+++ b/native/utils/strings/utf8_test.cc
@@ -0,0 +1,79 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/strings/utf8.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(Utf8Test, ComputesUtf8LengthOfUnicodeCharacters) {
+ EXPECT_EQ(GetNumBytesForUTF8Char("\x00"), 1);
+ EXPECT_EQ(GetNumBytesForUTF8Char("h"), 1);
+ EXPECT_EQ(GetNumBytesForUTF8Char("😋"), 4);
+ EXPECT_EQ(GetNumBytesForUTF8Char("㍿"), 3);
+}
+
+TEST(Utf8Test, IsValidUTF8) {
+ EXPECT_TRUE(IsValidUTF8("1234😋hello", 13));
+ EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8));
+ EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26));
+ EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4));
+ // Too short (string is too short).
+ EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2));
+ // Too long (too many trailing bytes).
+ EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5));
+ // Too short (too few trailing bytes).
+ EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5));
+}
+
+TEST(Utf8Test, ValidUTF8CharLength) {
+ EXPECT_EQ(ValidUTF8CharLength("1234😋hello", 13), 1);
+ EXPECT_EQ(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), 3);
+ EXPECT_EQ(ValidUTF8CharLength("this is a test😋😋😋", 26), 1);
+ EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), 4);
+ // Too short (string is too short).
+ EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f", 2), -1);
+ // Too long (too many trailing bytes). First character is valid.
+ EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), 4);
+ // Too short (too few trailing bytes).
+ EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), -1);
+}
+
+TEST(Utf8Test, CorrectlyTruncatesStrings) {
+ EXPECT_EQ(SafeTruncateLength("FooBar", 3), 3);
+ EXPECT_EQ(SafeTruncateLength("früh", 3), 2);
+ EXPECT_EQ(SafeTruncateLength("مَمِمّمَّمِّ", 5), 4);
+}
+
+TEST(Utf8Test, CorrectlyConvertsFromUtf8) {
+ EXPECT_EQ(ValidCharToRune("a"), 97);
+ EXPECT_EQ(ValidCharToRune("\0"), 0);
+ EXPECT_EQ(ValidCharToRune("\u304A"), 0x304a);
+ EXPECT_EQ(ValidCharToRune("\xe3\x81\x8a"), 0x304a);
+}
+
+TEST(Utf8Test, CorrectlyConvertsToUtf8) {
+ char utf8_encoding[4];
+ EXPECT_EQ(ValidRuneToChar(97, utf8_encoding), 1);
+ EXPECT_EQ(ValidRuneToChar(0, utf8_encoding), 1);
+ EXPECT_EQ(ValidRuneToChar(0x304a, utf8_encoding), 3);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/tensor-view.cc b/native/utils/tensor-view.cc
similarity index 100%
rename from utils/tensor-view.cc
rename to native/utils/tensor-view.cc
diff --git a/utils/tensor-view.h b/native/utils/tensor-view.h
similarity index 100%
rename from utils/tensor-view.h
rename to native/utils/tensor-view.h
diff --git a/utils/tensor-view_test.cc b/native/utils/tensor-view_test.cc
similarity index 100%
rename from utils/tensor-view_test.cc
rename to native/utils/tensor-view_test.cc
diff --git a/native/utils/test-utils.cc b/native/utils/test-utils.cc
new file mode 100644
index 0000000..8996a4a
--- /dev/null
+++ b/native/utils/test-utils.cc
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/test-utils.h"
+
+#include <iterator>
+
+#include "utils/codepoint-range.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::Token;
+
+std::vector<Token> TokenizeOnSpace(const std::string& text) {
+ return TokenizeOnDelimiters(text, {' '});
+}
+
+std::vector<Token> TokenizeOnDelimiters(
+ const std::string& text, const std::unordered_set<char32>& delimiters) {
+ const UnicodeText unicode_text = UTF8ToUnicodeText(text, /*do_copy=*/false);
+
+ std::vector<Token> result;
+
+ int token_start_codepoint = 0;
+ auto token_start_it = unicode_text.begin();
+ int codepoint_idx = 0;
+
+ UnicodeText::const_iterator it;
+ for (it = unicode_text.begin(); it < unicode_text.end(); it++) {
+ if (delimiters.find(*it) != delimiters.end()) {
+ // Only add a token when the string is non-empty.
+ if (token_start_it != it) {
+ result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
+ token_start_codepoint, codepoint_idx});
+ }
+
+ token_start_codepoint = codepoint_idx + 1;
+ token_start_it = it;
+ token_start_it++;
+ }
+
+ codepoint_idx++;
+ }
+ // Only add a token when the string is non-empty.
+ if (token_start_it != it) {
+ result.push_back(Token{UnicodeText::UTF8Substring(token_start_it, it),
+ token_start_codepoint, codepoint_idx});
+ }
+
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/test-utils.h b/native/utils/test-utils.h
new file mode 100644
index 0000000..0e75190
--- /dev/null
+++ b/native/utils/test-utils.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for tests.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+
+#include <string>
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Returns a list of Tokens for a given input string, by tokenizing on space.
+std::vector<Token> TokenizeOnSpace(const std::string& text);
+
+// Returns a list of Tokens for a given input string, by tokenizing on the
+// given set of delimiter codepoints.
+std::vector<Token> TokenizeOnDelimiters(
+ const std::string& text, const std::unordered_set<char32>& delimiters);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/native/utils/test-utils_test.cc b/native/utils/test-utils_test.cc
new file mode 100644
index 0000000..bdaa285
--- /dev/null
+++ b/native/utils/test-utils_test.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/test-utils.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(TestUtilTest, TokenizeOnSpace) {
+ std::vector<Token> tokens =
+ TokenizeOnSpace("Where is Jörg Borg located? Maybe in Zürich ...");
+
+ EXPECT_EQ(tokens.size(), 9);
+
+ EXPECT_EQ(tokens[0].value, "Where");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 5);
+
+ EXPECT_EQ(tokens[1].value, "is");
+ EXPECT_EQ(tokens[1].start, 6);
+ EXPECT_EQ(tokens[1].end, 8);
+
+ EXPECT_EQ(tokens[2].value, "Jörg");
+ EXPECT_EQ(tokens[2].start, 9);
+ EXPECT_EQ(tokens[2].end, 13);
+
+ EXPECT_EQ(tokens[3].value, "Borg");
+ EXPECT_EQ(tokens[3].start, 14);
+ EXPECT_EQ(tokens[3].end, 18);
+
+ EXPECT_EQ(tokens[4].value, "located?");
+ EXPECT_EQ(tokens[4].start, 19);
+ EXPECT_EQ(tokens[4].end, 27);
+
+ EXPECT_EQ(tokens[5].value, "Maybe");
+ EXPECT_EQ(tokens[5].start, 28);
+ EXPECT_EQ(tokens[5].end, 33);
+
+ EXPECT_EQ(tokens[6].value, "in");
+ EXPECT_EQ(tokens[6].start, 34);
+ EXPECT_EQ(tokens[6].end, 36);
+
+ EXPECT_EQ(tokens[7].value, "Zürich");
+ EXPECT_EQ(tokens[7].start, 37);
+ EXPECT_EQ(tokens[7].end, 43);
+
+ EXPECT_EQ(tokens[8].value, "...");
+ EXPECT_EQ(tokens[8].start, 44);
+ EXPECT_EQ(tokens[8].end, 47);
+}
+
+TEST(TestUtilTest, TokenizeOnDelimiters) {
+ std::vector<Token> tokens = TokenizeOnDelimiters(
+ "This might be čomplíčateď?!: Oder?", {' ', '?', '!'});
+
+ EXPECT_EQ(tokens.size(), 6);
+
+ EXPECT_EQ(tokens[0].value, "This");
+ EXPECT_EQ(tokens[0].start, 0);
+ EXPECT_EQ(tokens[0].end, 4);
+
+ EXPECT_EQ(tokens[1].value, "might");
+ EXPECT_EQ(tokens[1].start, 7);
+ EXPECT_EQ(tokens[1].end, 12);
+
+ EXPECT_EQ(tokens[2].value, "be");
+ EXPECT_EQ(tokens[2].start, 13);
+ EXPECT_EQ(tokens[2].end, 15);
+
+ EXPECT_EQ(tokens[3].value, "čomplíčateď");
+ EXPECT_EQ(tokens[3].start, 16);
+ EXPECT_EQ(tokens[3].end, 27);
+
+ EXPECT_EQ(tokens[4].value, ":");
+ EXPECT_EQ(tokens[4].start, 29);
+ EXPECT_EQ(tokens[4].end, 30);
+
+ EXPECT_EQ(tokens[5].value, "Oder");
+ EXPECT_EQ(tokens[5].start, 31);
+ EXPECT_EQ(tokens[5].end, 35);
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
new file mode 100644
index 0000000..55faea5
--- /dev/null
+++ b/native/utils/tflite-model-executor.cc
@@ -0,0 +1,283 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite-model-executor.h"
+
+#include "utils/base/logging.h"
+#include "tensorflow/lite/kernels/register.h"
+
+// Forward declaration of custom TensorFlow Lite ops for registration.
+namespace tflite {
+namespace ops {
+namespace builtin {
+TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_CONCATENATION();
+TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_EQUAL();
+TfLiteRegistration* Register_FULLY_CONNECTED();
+TfLiteRegistration* Register_GREATER_EQUAL();
+TfLiteRegistration* Register_L2_NORMALIZATION();
+TfLiteRegistration* Register_MUL();
+TfLiteRegistration* Register_RESHAPE();
+TfLiteRegistration* Register_REDUCE_MAX();
+TfLiteRegistration* Register_REDUCE_ANY();
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Register_GATHER();
+TfLiteRegistration* Register_TRANSPOSE();
+TfLiteRegistration* Register_SUB();
+TfLiteRegistration* Register_DIV();
+TfLiteRegistration* Register_STRIDED_SLICE();
+TfLiteRegistration* Register_EXP();
+TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_SPLIT();
+TfLiteRegistration* Register_CAST();
+TfLiteRegistration* Register_MAXIMUM();
+TfLiteRegistration* Register_MINIMUM();
+TfLiteRegistration* Register_NEG();
+TfLiteRegistration* Register_SLICE();
+TfLiteRegistration* Register_LOG();
+TfLiteRegistration* Register_SUM();
+TfLiteRegistration* Register_PACK();
+TfLiteRegistration* Register_DEQUANTIZE();
+TfLiteRegistration* Register_MEAN();
+TfLiteRegistration* Register_LESS();
+TfLiteRegistration* Register_TILE();
+TfLiteRegistration* Register_SQUARED_DIFFERENCE();
+TfLiteRegistration* Register_RSQRT();
+TfLiteRegistration* Register_LOG_SOFTMAX();
+TfLiteRegistration* Register_WHERE();
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+
+#ifdef TC3_WITH_ACTIONS_OPS
+#include "utils/tflite/dist_diversification.h"
+#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/token_encoder.h"
+
+void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
+ resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
+ tflite::ops::builtin::Register_ADD(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
+ tflite::ops::builtin::Register_CONCATENATION(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
+ tflite::ops::builtin::Register_CONV_2D(),
+ /*min_version=*/1,
+ /*max_version=*/3);
+ resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
+ ::tflite::ops::builtin::Register_EQUAL());
+
+ resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
+ tflite::ops::builtin::Register_FULLY_CONNECTED(),
+ /*min_version=*/1,
+ /*max_version=*/4);
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
+ ::tflite::ops::builtin::Register_GREATER_EQUAL());
+ resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
+ tflite::ops::builtin::Register_L2_NORMALIZATION(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
+ tflite::ops::builtin::Register_MUL());
+ resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
+ tflite::ops::builtin::Register_RESHAPE());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
+ ::tflite::ops::builtin::Register_REDUCE_MAX());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
+ ::tflite::ops::builtin::Register_REDUCE_ANY());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
+ tflite::ops::builtin::Register_SOFTMAX(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
+ tflite::ops::builtin::Register_GATHER(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
+ tflite::ops::builtin::Register_TRANSPOSE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
+ tflite::ops::builtin::Register_SUB(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
+ tflite::ops::builtin::Register_DIV());
+ resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
+ tflite::ops::builtin::Register_STRIDED_SLICE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
+ tflite::ops::builtin::Register_EXP());
+ resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
+ tflite::ops::builtin::Register_TOPK_V2(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
+ tflite::ops::builtin::Register_SPLIT(),
+ /*min_version=*/1,
+ /*max_version=*/3);
+ resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
+ tflite::ops::builtin::Register_CAST());
+ resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
+ tflite::ops::builtin::Register_MAXIMUM(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
+ tflite::ops::builtin::Register_MINIMUM(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
+ tflite::ops::builtin::Register_NEG());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
+ tflite::ops::builtin::Register_SLICE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
+ tflite::ops::builtin::Register_LOG());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
+ tflite::ops::builtin::Register_SUM());
+ resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
+ tflite::ops::builtin::Register_PACK(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
+ tflite::ops::builtin::Register_DEQUANTIZE(),
+ /*min_version=*/1,
+ /*max_version=*/2);
+ resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
+ tflite::ops::builtin::Register_MEAN());
+ resolver->AddBuiltin(tflite::BuiltinOperator_LESS,
+ tflite::ops::builtin::Register_LESS());
+ resolver->AddBuiltin(tflite::BuiltinOperator_TILE,
+ tflite::ops::builtin::Register_TILE());
+ resolver->AddBuiltin(tflite::BuiltinOperator_SQUARED_DIFFERENCE,
+ tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
+ resolver->AddBuiltin(tflite::BuiltinOperator_RSQRT,
+ tflite::ops::builtin::Register_RSQRT());
+ resolver->AddBuiltin(tflite::BuiltinOperator_LOG_SOFTMAX,
+ tflite::ops::builtin::Register_LOG_SOFTMAX());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
+ ::tflite::ops::builtin::Register_WHERE());
+}
+#else
+void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
+ resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
+ tflite::ops::builtin::Register_FULLY_CONNECTED());
+}
+#endif // TC3_WITH_ACTIONS_OPS
+
+namespace libtextclassifier3 {
+
+inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
+#ifdef TC3_USE_SELECTIVE_REGISTRATION
+ std::unique_ptr<tflite::MutableOpResolver> resolver(
+ new tflite::MutableOpResolver);
+ RegisterSelectedOps(resolver.get());
+#else
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
+ new tflite::ops::builtin::BuiltinOpResolver);
+#endif
+#ifdef TC3_WITH_ACTIONS_OPS
+ resolver->AddCustom("DistanceDiversification",
+ tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
+ resolver->AddCustom("TextEncoder",
+ tflite::ops::custom::Register_TEXT_ENCODER());
+ resolver->AddCustom("TokenEncoder",
+ tflite::ops::custom::Register_TOKEN_ENCODER());
+#endif // TC3_WITH_ACTIONS_OPS
+ return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
+}
+
+std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
+ const tflite::Model* model_spec) {
+ std::unique_ptr<const tflite::FlatBufferModel> model(
+ tflite::FlatBufferModel::BuildFromModel(model_spec));
+ if (!model || !model->initialized()) {
+ TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
+ return nullptr;
+ }
+ return model;
+}
+
+std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
+ const tflite::Model* model =
+ flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
+ flatbuffers::Verifier verifier(model_spec_buffer->data(),
+ model_spec_buffer->size());
+ if (!model->Verify(verifier)) {
+ return nullptr;
+ }
+ return TfLiteModelFromModelSpec(model);
+}
+
+TfLiteModelExecutor::TfLiteModelExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model)
+ : model_(std::move(model)), resolver_(BuildOpResolver()) {}
+TfLiteModelExecutor::TfLiteModelExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model,
+ std::unique_ptr<tflite::OpResolver> resolver)
+ : model_(std::move(model)), resolver_(std::move(resolver)) {}
+
+std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
+ const {
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
+ return interpreter;
+}
+
+template <>
+void TfLiteModelExecutor::SetInput(const int input_index,
+ const std::vector<std::string>& input_data,
+ tflite::Interpreter* interpreter) const {
+ tflite::DynamicBuffer buf;
+ for (const std::string& s : input_data) {
+ buf.AddString(s.data(), s.length());
+ }
+ buf.WriteToTensorAsVector(
+ interpreter->tensor(interpreter->inputs()[input_index]));
+}
+
+template <>
+std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
+ const int output_index, const tflite::Interpreter* interpreter) const {
+ const TfLiteTensor* output_tensor =
+ interpreter->tensor(interpreter->outputs()[output_index]);
+ const int num_strings = tflite::GetStringCount(output_tensor);
+ std::vector<tflite::StringRef> output(num_strings);
+ for (int i = 0; i < num_strings; i++) {
+ output[i] = tflite::GetString(output_tensor, i);
+ }
+ return output;
+}
+
+template <>
+std::vector<std::string> TfLiteModelExecutor::Output(
+ const int output_index, const tflite::Interpreter* interpreter) const {
+ std::vector<std::string> output;
+ for (const tflite::StringRef& s :
+ Output<tflite::StringRef>(output_index, interpreter)) {
+ output.push_back(std::string(s.str, s.len));
+ }
+ return output;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/tflite-model-executor.h b/native/utils/tflite-model-executor.h
new file mode 100644
index 0000000..a4432ff
--- /dev/null
+++ b/native/utils/tflite-model-executor.h
@@ -0,0 +1,159 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains classes that can execute different models/parts of a model.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
+
+#include <cstdint>
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/tensor-view.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/op_resolver.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<tflite::OpResolver> BuildOpResolver();
+std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
+ const tflite::Model*);
+std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
+ const flatbuffers::Vector<uint8_t>*);
+
+// Executor for the text selection prediction and classification models.
+class TfLiteModelExecutor {
+ public:
+ static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
+ const tflite::Model* model_spec) {
+ auto model = TfLiteModelFromModelSpec(model_spec);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<TfLiteModelExecutor>(
+ new TfLiteModelExecutor(std::move(model)));
+ }
+
+ static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
+ auto model = TfLiteModelFromBuffer(model_spec_buffer);
+ if (!model) {
+ return nullptr;
+ }
+ return std::unique_ptr<TfLiteModelExecutor>(
+ new TfLiteModelExecutor(std::move(model)));
+ }
+
+ // Creates an Interpreter for the model that serves as a scratch-pad for the
+ // inference. The Interpreter is NOT thread-safe.
+ std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
+
+ template <typename T>
+ void SetInput(const int input_index, const TensorView<T>& input_data,
+ tflite::Interpreter* interpreter) const {
+ input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
+ input_data.size());
+ }
+
+ template <typename T>
+ void SetInput(const int input_index, const std::vector<T>& input_data,
+ tflite::Interpreter* interpreter) const {
+ std::copy(input_data.begin(), input_data.end(),
+ interpreter->typed_input_tensor<T>(input_index));
+ }
+
+ template <typename T>
+ void SetInput(const int input_index, const T input_value,
+ tflite::Interpreter* interpreter) const {
+ TfLiteTensor* input_tensor =
+ interpreter->tensor(interpreter->inputs()[input_index]);
+ switch (input_tensor->type) {
+ case kTfLiteFloat32:
+ *tflite::GetTensorData<float>(input_tensor) = input_value;
+ break;
+ case kTfLiteInt32:
+ *tflite::GetTensorData<int32_t>(input_tensor) = input_value;
+ break;
+ case kTfLiteUInt8:
+ *tflite::GetTensorData<uint8_t>(input_tensor) = input_value;
+ break;
+ case kTfLiteInt64:
+ *tflite::GetTensorData<int64_t>(input_tensor) = input_value;
+ break;
+ case kTfLiteBool:
+ *tflite::GetTensorData<bool>(input_tensor) = input_value;
+ break;
+ case kTfLiteInt16:
+ *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
+ break;
+ case kTfLiteInt8:
+ *tflite::GetTensorData<int8_t>(input_tensor) = input_value;
+ break;
+ default:
+ break;
+ }
+ }
+
+ template <typename T>
+ TensorView<T> OutputView(const int output_index,
+ const tflite::Interpreter* interpreter) const {
+ const TfLiteTensor* output_tensor =
+ interpreter->tensor(interpreter->outputs()[output_index]);
+ return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
+ std::vector<int>(output_tensor->dims->data,
+ output_tensor->dims->data +
+ output_tensor->dims->size));
+ }
+
+ template <typename T>
+ std::vector<T> Output(const int output_index,
+ const tflite::Interpreter* interpreter) const {
+ TensorView<T> output_view = OutputView<T>(output_index, interpreter);
+ return std::vector<T>(output_view.data(),
+ output_view.data() + output_view.size());
+ }
+
+ protected:
+ explicit TfLiteModelExecutor(
+ std::unique_ptr<const tflite::FlatBufferModel> model);
+ TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,
+ std::unique_ptr<tflite::OpResolver> resolver);
+
+ std::unique_ptr<const tflite::FlatBufferModel> model_;
+ std::unique_ptr<tflite::OpResolver> resolver_;
+};
+
+template <>
+void TfLiteModelExecutor::SetInput(const int input_index,
+ const std::vector<std::string>& input_data,
+ tflite::Interpreter* interpreter) const;
+
+template <>
+std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
+ const int output_index, const tflite::Interpreter* interpreter) const;
+
+template <>
+std::vector<std::string> TfLiteModelExecutor::Output(
+ const int output_index, const tflite::Interpreter* interpreter) const;
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
diff --git a/utils/tflite/dist_diversification.cc b/native/utils/tflite/dist_diversification.cc
similarity index 100%
rename from utils/tflite/dist_diversification.cc
rename to native/utils/tflite/dist_diversification.cc
diff --git a/utils/tflite/dist_diversification.h b/native/utils/tflite/dist_diversification.h
similarity index 100%
rename from utils/tflite/dist_diversification.h
rename to native/utils/tflite/dist_diversification.h
diff --git a/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc
similarity index 100%
rename from utils/tflite/encoder_common.cc
rename to native/utils/tflite/encoder_common.cc
diff --git a/utils/tflite/encoder_common.h b/native/utils/tflite/encoder_common.h
similarity index 100%
rename from utils/tflite/encoder_common.h
rename to native/utils/tflite/encoder_common.h
diff --git a/native/utils/tflite/text_encoder.cc b/native/utils/tflite/text_encoder.cc
new file mode 100644
index 0000000..78bb51a
--- /dev/null
+++ b/native/utils/tflite/text_encoder.cc
@@ -0,0 +1,299 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/text_encoder.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/double-array-trie.h"
+#include "utils/container/sorted-strings-table.h"
+#include "utils/sentencepiece/encoder.h"
+#include "utils/sentencepiece/normalizer.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/encoder_common.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+struct TextEncoderOp {
+ std::unique_ptr<SentencePieceNormalizer> normalizer;
+ std::unique_ptr<Encoder> encoder;
+ std::unique_ptr<StringSet> matcher;
+};
+
+// Input parameters for the op.
+// The conversation message as a (1, conversation length) string tensor.
+constexpr const int kInputTexts = 0;
+
+// The number of messages, the conversation length, int scalar.
+constexpr const int kInputNumInputs = 1;
+
+// Maximum output length of the encoding, int scalar.
+constexpr const int kInputMaxLength = 2;
+
+// Additional attributes to align to the sentence pieces, e.g. user ids per
+// message.
+constexpr const int kInputAttr = 3;
+
+// Output parameters for the op.
+// The text sentence piece encodings as ids, (1, max output length) int tensor.
+constexpr const int kOutputEncoded = 0;
+
+// Relative position of each sentence piece in the input text,
+// (1, max output length) int tensor.
+constexpr const int kOutputPosition = 1;
+
+// Output length after trimming to the maximum output length specified.
+// int scalar.
+constexpr const int kOutputLengths = 2;
+
+// Padded and sentence piece aligned provided attributes, e.g. user id per
+// sentence piece.
+constexpr const int kOutputAttr = 3;
+
+const char kTextEncoderConfigAttr[] = "text_encoder_config";
+
+// Initializes text encoder object from serialized options:
+// The options are a flexbuffers attribute map that contain the op config
+// with the key `text_encoder_config` as `TextEncoderConfig`.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ const flexbuffers::Map& attr_map =
+ flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
+ .AsMap();
+ const flexbuffers::Blob serialized_config =
+ attr_map[kTextEncoderConfigAttr].AsBlob();
+ const TextEncoderConfig* config =
+ flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
+
+ std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
+
+ // Create normalizer from options.
+ const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
+ config->normalization_charsmap()->Data());
+ const int charsmap_trie_nodes_length =
+ config->normalization_charsmap()->size() / sizeof(TrieNode);
+ encoder_op->normalizer.reset(new SentencePieceNormalizer(
+ DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
+ StringPiece(config->normalization_charsmap_values()->data(),
+ config->normalization_charsmap_values()->size()),
+ config->add_dummy_prefix(), config->remove_extra_whitespaces(),
+ config->escape_whitespaces()));
+
+ const int num_pieces = config->pieces_scores()->Length();
+
+ switch (config->matcher_type()) {
+ case SentencePieceMatcherType_MAPPED_TRIE: {
+ const TrieNode* pieces_trie_nodes =
+ reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+ const int pieces_trie_nodes_length =
+ config->pieces()->Length() / sizeof(TrieNode);
+ encoder_op->matcher.reset(
+ new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+ break;
+ }
+ case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+ encoder_op->matcher.reset(new SortedStringsTable(
+ num_pieces, config->pieces_offsets()->data(),
+ StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ break;
+ }
+ default: {
+ TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+ return nullptr;
+ }
+ }
+ encoder_op->encoder.reset(new Encoder(
+ encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+ config->start_code(), config->end_code(), config->encoding_offset(),
+ config->unknown_code(), config->unknown_score()));
+ return encoder_op.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<TextEncoderOp*>(buffer);
+}
+
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(max_output_length,
+ &context->tensors[node->outputs->data[kOutputEncoded]],
+ context));
+
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputPosition]], context));
+
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TF_LITE_ENSURE_OK(
+ context,
+ ResizeOutputTensor(
+ max_output_length,
+ &context->tensors[node->outputs->data[kOutputAttr + i]], context));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kBatchSize.
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTexts]];
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncoded]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateIntArray({kEncoderBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[kOutputAttr + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ if (node->user_data == nullptr) {
+ return kTfLiteError;
+ }
+ const TextEncoderOp* encoder_op =
+ reinterpret_cast<TextEncoderOp*>(node->user_data);
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTexts]];
+ const int num_strings = tflite::GetStringCount(&input_text);
+ // Check that the number of strings matches the length parameter.
+ const int num_strings_param =
+ context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
+ TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncoded]];
+ if (tflite::IsDynamicTensor(&output_encoded)) {
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kInputMaxLength]];
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPosition]];
+
+ std::vector<int> encoded_total;
+ std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
+ encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
+
+ for (int i = 0; i < num_strings; ++i) {
+ const auto& strref = tflite::GetString(&input_text, i);
+ std::string normalized;
+ TF_LITE_ENSURE(context,
+ encoder_op->normalizer->Normalize(
+ StringPiece(strref.str, strref.len), &normalized));
+ std::vector<int> encoded;
+ TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
+ encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+ encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); i++) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
+ }
+
+ const int num_skip = CopyDataToTensorAndPadOrTruncate(
+ max_output_length, encoded_total,
+ /*padding_value=*/encoded_total.back(), &output_encoded);
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengths]];
+ output_lengths.data.i32[0] = encoded_total.size() - num_skip;
+ CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
+ /*padding_value=*/max_encoded_position,
+ &output_positions);
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - kOutputAttr;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+ context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
+ num_skip, context,
+ &context->tensors[node->outputs->data[kOutputAttr + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER() {
+ static TfLiteRegistration registration = {
+ libtextclassifier3::Initialize, libtextclassifier3::Free,
+ libtextclassifier3::Prepare, libtextclassifier3::Eval};
+ return ®istration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/utils/tflite/text_encoder.h b/native/utils/tflite/text_encoder.h
similarity index 100%
rename from utils/tflite/text_encoder.h
rename to native/utils/tflite/text_encoder.h
diff --git a/utils/tflite/text_encoder_config.fbs b/native/utils/tflite/text_encoder_config.fbs
similarity index 100%
rename from utils/tflite/text_encoder_config.fbs
rename to native/utils/tflite/text_encoder_config.fbs
diff --git a/utils/tflite/token_encoder.cc b/native/utils/tflite/token_encoder.cc
similarity index 100%
rename from utils/tflite/token_encoder.cc
rename to native/utils/tflite/token_encoder.cc
diff --git a/utils/tflite/token_encoder.h b/native/utils/tflite/token_encoder.h
similarity index 100%
rename from utils/tflite/token_encoder.h
rename to native/utils/tflite/token_encoder.h
diff --git a/native/utils/token-feature-extractor.cc b/native/utils/token-feature-extractor.cc
new file mode 100644
index 0000000..ee915db
--- /dev/null
+++ b/native/utils/token-feature-extractor.cc
@@ -0,0 +1,310 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/token-feature-extractor.h"
+
+#include <cctype>
+#include <string>
+
+#include "utils/base/logging.h"
+#include "utils/hash/farmhash.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+namespace {
+
+std::string RemapTokenAscii(const std::string& token,
+ const TokenFeatureExtractorOptions& options) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ return token;
+ }
+
+ std::string copy = token;
+ for (int i = 0; i < token.size(); ++i) {
+ if (options.remap_digits && isdigit(copy[i])) {
+ copy[i] = '0';
+ }
+ if (options.lowercase_tokens) {
+ copy[i] = tolower(copy[i]);
+ }
+ }
+ return copy;
+}
+
+void RemapTokenUnicode(const std::string& token,
+ const TokenFeatureExtractorOptions& options,
+ const UniLib& unilib, UnicodeText* remapped) {
+ if (!options.remap_digits && !options.lowercase_tokens) {
+ // Leave remapped untouched.
+ return;
+ }
+
+ UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
+ remapped->clear();
+ for (auto it = word.begin(); it != word.end(); ++it) {
+ if (options.remap_digits && unilib.IsDigit(*it)) {
+ remapped->push_back('0');
+ } else if (options.lowercase_tokens) {
+ remapped->push_back(unilib.ToLower(*it));
+ } else {
+ remapped->push_back(*it);
+ }
+ }
+}
+
+} // namespace
+
+TokenFeatureExtractor::TokenFeatureExtractor(
+ const TokenFeatureExtractorOptions& options, const UniLib* unilib)
+ : options_(options), unilib_(*unilib) {
+ for (const std::string& pattern : options.regexp_features) {
+ regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
+ unilib_.CreateRegexPattern(UTF8ToUnicodeText(
+ pattern.c_str(), pattern.size(), /*do_copy=*/false))));
+ }
+}
+
+bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const {
+ if (!dense_features) {
+ return false;
+ }
+ if (sparse_features) {
+ *sparse_features = ExtractCharactergramFeatures(token);
+ }
+ *dense_features = ExtractDenseFeatures(token, is_in_span);
+ return true;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
+ const Token& token) const {
+ if (options_.unicode_aware_features) {
+ return ExtractCharactergramFeaturesUnicode(token);
+ } else {
+ return ExtractCharactergramFeaturesAscii(token);
+ }
+}
+
+std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
+ const Token& token, bool is_in_span) const {
+ std::vector<float> dense_features;
+
+ if (options_.extract_case_feature) {
+ if (options_.unicode_aware_features) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ if (!token.value.empty() && unilib_.IsUpper(*token_unicode.begin())) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ } else {
+ if (!token.value.empty() && isupper(*token.value.begin())) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ if (options_.extract_selection_mask_feature) {
+ if (is_in_span) {
+ dense_features.push_back(1.0);
+ } else {
+ if (options_.unicode_aware_features) {
+ dense_features.push_back(-1.0);
+ } else {
+ dense_features.push_back(0.0);
+ }
+ }
+ }
+
+ // Add regexp features.
+ if (!regex_patterns_.empty()) {
+ UnicodeText token_unicode =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ for (int i = 0; i < regex_patterns_.size(); ++i) {
+ if (!regex_patterns_[i].get()) {
+ dense_features.push_back(-1.0);
+ continue;
+ }
+ auto matcher = regex_patterns_[i]->Matcher(token_unicode);
+ int status;
+ if (matcher->Matches(&status)) {
+ dense_features.push_back(1.0);
+ } else {
+ dense_features.push_back(-1.0);
+ }
+ }
+ }
+
+ return dense_features;
+}
+
+int TokenFeatureExtractor::HashToken(StringPiece token) const {
+ if (options_.allowed_chargrams.empty()) {
+ return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
+ } else {
+ // Padding and out-of-vocabulary tokens have extra buckets reserved because
+ // they are special and important tokens, and we don't want them to share
+ // embedding with other charactergrams.
+ // TODO(zilka): Experimentally verify.
+ const int kNumExtraBuckets = 2;
+ const std::string token_string = token.ToString();
+ if (token_string == "<PAD>") {
+ return 1;
+ } else if (options_.allowed_chargrams.find(token_string) ==
+ options_.allowed_chargrams.end()) {
+ return 0; // Out-of-vocabulary.
+ } else {
+ return (tc3farmhash::Fingerprint64(token) %
+ (options_.num_buckets - kNumExtraBuckets)) +
+ kNumExtraBuckets;
+ }
+ }
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
+ const Token& token) const {
+ std::vector<int> result;
+ if (token.is_padding || token.value.empty()) {
+ result.push_back(HashToken("<PAD>"));
+ } else {
+ const std::string word = RemapTokenAscii(token.value, options_);
+
+ // Trim words that are over max_word_length characters.
+ const int max_word_length = options_.max_word_length;
+ std::string feature_word;
+ if (word.size() > max_word_length) {
+ feature_word =
+ "^" + word.substr(0, max_word_length / 2) + "\1" +
+ word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
+ "$";
+ } else {
+ // Add a prefix and suffix to the word.
+ feature_word = "^" + word + "$";
+ }
+
+ // Upper-bound the number of charactergram extracted to avoid resizing.
+ result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(
+ HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
+ /*len=*/chargram_order)));
+ }
+ }
+ }
+ }
+ }
+ return result;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
+ const Token& token) const {
+ std::vector<int> result;
+ if (token.is_padding || token.value.empty()) {
+ result.push_back(HashToken("<PAD>"));
+ } else {
+ UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ RemapTokenUnicode(token.value, options_, unilib_, &word);
+
+ // Trim the word if needed by finding a left-cut point and right-cut point.
+ auto left_cut = word.begin();
+ auto right_cut = word.end();
+ for (int i = 0; i < options_.max_word_length / 2; i++) {
+ if (left_cut < right_cut) {
+ ++left_cut;
+ }
+ if (left_cut < right_cut) {
+ --right_cut;
+ }
+ }
+
+ std::string feature_word;
+ if (left_cut == right_cut) {
+ feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
+ } else {
+ // clang-format off
+ feature_word = "^" +
+ word.UTF8Substring(word.begin(), left_cut) +
+ "\1" +
+ word.UTF8Substring(right_cut, word.end()) +
+ "$";
+ // clang-format on
+ }
+
+ const UnicodeText feature_word_unicode =
+ UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
+
+ // Upper-bound the number of charactergram extracted to avoid resizing.
+ result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+ if (options_.chargram_orders.empty()) {
+ result.push_back(HashToken(feature_word));
+ } else {
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ UnicodeText::const_iterator it_start = feature_word_unicode.begin();
+ UnicodeText::const_iterator it_end = feature_word_unicode.end();
+ if (chargram_order == 1) {
+ ++it_start;
+ --it_end;
+ }
+
+ UnicodeText::const_iterator it_chargram_start = it_start;
+ UnicodeText::const_iterator it_chargram_end = it_start;
+ bool chargram_is_complete = true;
+ for (int i = 0; i < chargram_order; ++i) {
+ if (it_chargram_end == it_end) {
+ chargram_is_complete = false;
+ break;
+ }
+ ++it_chargram_end;
+ }
+ if (!chargram_is_complete) {
+ continue;
+ }
+
+ for (; it_chargram_end <= it_end;
+ ++it_chargram_start, ++it_chargram_end) {
+ const int length_bytes =
+ it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
+ result.push_back(HashToken(
+ StringPiece(it_chargram_start.utf8_data(), length_bytes)));
+ }
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/token-feature-extractor.h b/native/utils/token-feature-extractor.h
new file mode 100644
index 0000000..b3f2f33
--- /dev/null
+++ b/native/utils/token-feature-extractor.h
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+struct TokenFeatureExtractorOptions {
+ // Number of buckets used for hashing charactergrams.
+ int num_buckets = 0;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ std::vector<int> chargram_orders;
+
+ // Whether to extract the token case feature.
+ bool extract_case_feature = false;
+
+ // If true, will use the unicode-aware functionality for extracting features.
+ bool unicode_aware_features = false;
+
+ // Whether to extract the selection mask feature.
+ bool extract_selection_mask_feature = false;
+
+ // Regexp features to extract.
+ std::vector<std::string> regexp_features;
+
+ // Whether to remap digits to a single number.
+ bool remap_digits = false;
+
+ // Whether to lowercase all tokens.
+ bool lowercase_tokens = false;
+
+ // Maximum length of a word.
+ int max_word_length = 20;
+
+ // List of allowed charactergrams. The extracted charactergrams are filtered
+ // using this list, and charactergrams that are not present are interpreted as
+ // out-of-vocabulary.
+ // If no allowed_chargrams are specified, all charactergrams are allowed.
+ std::unordered_set<std::string> allowed_chargrams;
+};
+
+class TokenFeatureExtractor {
+ public:
+ // Des not take ownership of unilib, which must refer to a valid unilib
+ // instance that outlives this feature extractor.
+ explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
+ const UniLib* unilib);
+
+ // Extracts both the sparse (charactergram) and the dense features from a
+ // token. is_in_span is a bool indicator whether the token is a part of the
+ // selection span (true) or not (false).
+ // The sparse_features output is optional. Fails and returns false if
+ // dense_fatures in a nullptr.
+ bool Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const;
+
+ // Extracts the sparse (charactergram) features from the token.
+ std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
+
+ // Extracts the dense features from the token. is_in_span is a bool indicator
+ // whether the token is a part of the selection span (true) or not (false).
+ std::vector<float> ExtractDenseFeatures(const Token& token,
+ bool is_in_span) const;
+
+ int DenseFeaturesCount() const {
+ int feature_count =
+ options_.extract_case_feature + options_.extract_selection_mask_feature;
+ feature_count += regex_patterns_.size();
+ return feature_count;
+ }
+
+ protected:
+ // Hashes given token to given number of buckets.
+ int HashToken(StringPiece token) const;
+
+ // Extracts the charactergram features from the token in a non-unicode-aware
+ // way.
+ std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const;
+
+ // Extracts the charactergram features from the token in a unicode-aware way.
+ std::vector<int> ExtractCharactergramFeaturesUnicode(
+ const Token& token) const;
+
+ private:
+ TokenFeatureExtractorOptions options_;
+ std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_;
+ const UniLib& unilib_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/native/utils/token-feature-extractor_test.cc b/native/utils/token-feature-extractor_test.cc
new file mode 100644
index 0000000..15a434c
--- /dev/null
+++ b/native/utils/token-feature-extractor_test.cc
@@ -0,0 +1,579 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/token-feature-extractor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class TokenFeatureExtractorTest : public ::testing::Test {
+ protected:
+ explicit TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+ UniLib unilib_;
+};
+
+class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
+ public:
+ using TokenFeatureExtractor::HashToken;
+ using TokenFeatureExtractor::TokenFeatureExtractor;
+};
+
+TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("e"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("o"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("He"),
+ extractor.HashToken("el"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("lo"),
+ extractor.HashToken("o$"),
+ extractor.HashToken("^He"),
+ extractor.HashToken("Hel"),
+ extractor.HashToken("ell"),
+ extractor.HashToken("llo"),
+ extractor.HashToken("lo$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^world!$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("H"),
+ extractor.HashToken("ě"),
+ extractor.HashToken("l"),
+ extractor.HashToken("l"),
+ extractor.HashToken("ó"),
+ extractor.HashToken("^H"),
+ extractor.HashToken("Hě"),
+ extractor.HashToken("ěl"),
+ extractor.HashToken("ll"),
+ extractor.HashToken("ló"),
+ extractor.HashToken("ó$"),
+ extractor.HashToken("^Hě"),
+ extractor.HashToken("Hěl"),
+ extractor.HashToken("ěll"),
+ extractor.HashToken("lló"),
+ extractor.HashToken("ló$")
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ extractor.HashToken("o"),
+ extractor.HashToken("r"),
+ extractor.HashToken("l"),
+ extractor.HashToken("d"),
+ extractor.HashToken("!"),
+ extractor.HashToken("^w"),
+ extractor.HashToken("wo"),
+ extractor.HashToken("or"),
+ extractor.HashToken("rl"),
+ extractor.HashToken("ld"),
+ extractor.HashToken("d!"),
+ extractor.HashToken("!$"),
+ extractor.HashToken("^wo"),
+ extractor.HashToken("wor"),
+ extractor.HashToken("orl"),
+ extractor.HashToken("rld"),
+ extractor.HashToken("ld!"),
+ extractor.HashToken("d!$"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ extractor.HashToken("^world!$"),
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+}
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = false;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
+}
+#endif
+
+TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = true;
+ options.unicode_aware_features = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features,
+ testing::Not(testing::ElementsAreArray(sparse_features2)));
+}
+
+TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
+ &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+
+ extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.lowercase_tokens = true;
+ options.unicode_aware_features = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
+
+ std::vector<int> sparse_features2;
+ extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
+ &dense_features);
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
+}
+#endif
+
+#ifdef TC3_TEST_ICU
+TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.remap_digits = false;
+ options.unicode_aware_features = false;
+ options.regexp_features.push_back("^[a-z]+$"); // all lower case.
+ options.regexp_features.push_back("^[0-9]+$"); // all digits.
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
+
+ dense_features.clear();
+ extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
+ &dense_features);
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
+}
+#endif
+
+TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{22};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ // Test that this runs. ASAN should catch problems.
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+ extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
+ &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
+ extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
+ // clang-format on
+ }));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = true;
+ options.extract_selection_mask_feature = true;
+
+ TestingTokenFeatureExtractor extractor_unicode(options, &unilib_);
+
+ options.unicode_aware_features = false;
+ TestingTokenFeatureExtractor extractor_ascii(options, &unilib_);
+
+ for (const std::string& input :
+ {"https://www.abcdefgh.com/in/xxxkkkvayio",
+ "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
+ "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
+ "x", "Hello", "Hey,", "Hi", ""}) {
+ std::vector<int> sparse_features_unicode;
+ std::vector<float> dense_features_unicode;
+ extractor_unicode.Extract(Token{input, 0, 0}, true,
+ &sparse_features_unicode,
+ &dense_features_unicode);
+
+ std::vector<int> sparse_features_ascii;
+ std::vector<float> dense_features_ascii;
+ extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
+ &dense_features_ascii);
+
+ EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
+ EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
+ }
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token(), false, &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ options.allowed_chargrams.insert("^H");
+ options.allowed_chargrams.insert("ll");
+ options.allowed_chargrams.insert("llo");
+ options.allowed_chargrams.insert("w");
+ options.allowed_chargrams.insert("!");
+ options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
+
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features,
+ testing::ElementsAreArray({
+ // clang-format off
+ 0,
+ extractor.HashToken("\xc4"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("^H"),
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("ll"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("llo"),
+ 0
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
+
+ sparse_features.clear();
+ dense_features.clear();
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
+ &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("w"),
+ 0,
+ 0,
+ 0,
+ 0,
+ extractor.HashToken("!"),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
+ EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
+}
+
+TEST_F(TokenFeatureExtractorTest, ExtractEmptyToken) {
+ TokenFeatureExtractorOptions options;
+ options.num_buckets = 1000;
+ options.chargram_orders = std::vector<int>{1, 2, 3};
+ options.extract_case_feature = true;
+ options.unicode_aware_features = false;
+ options.extract_selection_mask_feature = true;
+ TestingTokenFeatureExtractor extractor(options, &unilib_);
+
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ // Should not crash.
+ extractor.Extract(Token(), true, &sparse_features, &dense_features);
+
+ EXPECT_THAT(sparse_features, testing::ElementsAreArray({
+ // clang-format off
+ extractor.HashToken("<PAD>"),
+ // clang-format on
+ }));
+ EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
new file mode 100644
index 0000000..bd47592
--- /dev/null
+++ b/native/utils/tokenizer.cc
@@ -0,0 +1,342 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer.h"
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/strings/utf8.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+Tokenizer::Tokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens,
+ const bool preserve_floating_numbers)
+ : type_(type),
+ unilib_(unilib),
+ split_on_script_change_(split_on_script_change),
+ icu_preserve_whitespace_tokens_(icu_preserve_whitespace_tokens),
+ preserve_floating_numbers_(preserve_floating_numbers) {
+ for (const TokenizationCodepointRange* range : codepoint_ranges) {
+ codepoint_ranges_.emplace_back(range->UnPack());
+ }
+
+ std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
+ });
+
+ SortCodepointRanges(internal_tokenizer_codepoint_ranges,
+ &internal_tokenizer_codepoint_ranges_);
+}
+
+const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
+ int codepoint) const {
+ auto it = std::lower_bound(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
+ int codepoint) {
+ // This function compares range with the codepoint for the purpose of
+ // finding the first greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when range < codepoint;
+ // the first time it will return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is range.end <= codepoint
+ // here but when codepoint == range.end it means it's actually just
+ // outside of the range, thus the range is less than the codepoint.
+ return range->end <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
+ (*it)->end > codepoint) {
+ return it->get();
+ } else {
+ return nullptr;
+ }
+}
+
+void Tokenizer::GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const {
+ const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
+ if (range) {
+ *role = range->role;
+ *script = range->script_id;
+ } else {
+ *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ *script = kUnknownScript;
+ }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
+ UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
+ return Tokenize(text_unicode);
+}
+
+std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
+ switch (type_) {
+ case TokenizationType_INTERNAL_TOKENIZER:
+ return InternalTokenize(text_unicode);
+ case TokenizationType_ICU:
+ TC3_FALLTHROUGH_INTENDED;
+ case TokenizationType_MIXED: {
+ std::vector<Token> result;
+ if (!ICUTokenize(text_unicode, &result)) {
+ return {};
+ }
+ if (type_ == TokenizationType_MIXED) {
+ InternalRetokenize(text_unicode, &result);
+ }
+ return result;
+ }
+ case TokenizationType_LETTER_DIGIT: {
+ std::vector<Token> result;
+ if (!NumberTokenize(text_unicode, &result)) {
+ return {};
+ }
+ return result;
+ }
+ default:
+ TC3_LOG(ERROR) << "Unknown tokenization type specified. Using internal.";
+ return InternalTokenize(text_unicode);
+ }
+}
+
+void AppendCodepointToToken(UnicodeText::const_iterator it, Token* token) {
+ token->value += std::string(
+ it.utf8_data(), it.utf8_data() + GetNumBytesForUTF8Char(it.utf8_data()));
+}
+
+std::vector<Token> Tokenizer::InternalTokenize(
+ const UnicodeText& text_unicode) const {
+ std::vector<Token> result;
+ Token new_token("", 0, 0);
+ int codepoint_index = 0;
+
+ int last_script = kInvalidScript;
+ for (auto it = text_unicode.begin(); it != text_unicode.end();
+ ++it, ++codepoint_index) {
+ TokenizationCodepointRange_::Role role;
+ int script;
+ GetScriptAndRole(*it, &role, &script);
+
+ if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
+ (split_on_script_change_ && last_script != kInvalidScript &&
+ last_script != script)) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index, codepoint_index);
+ }
+ if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
+ new_token.end += 1;
+ AppendCodepointToToken(it, &new_token);
+ }
+ if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ }
+
+ last_script = script;
+ }
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+
+ return result;
+}
+
+void Tokenizer::TokenizeSubstring(const UnicodeText& unicode_text,
+ CodepointSpan span,
+ std::vector<Token>* result) const {
+ if (span.first < 0) {
+ // There is no span to tokenize.
+ return;
+ }
+
+ // Extract the substring.
+ UnicodeText text = UnicodeText::Substring(unicode_text, span.first,
+ span.second, /*do_copy=*/false);
+
+ // Run the tokenizer and update the token bounds to reflect the offset of the
+ // substring.
+ std::vector<Token> tokens = InternalTokenize(text);
+
+ // Avoids progressive capacity increases in the for loop.
+ result->reserve(result->size() + tokens.size());
+ for (Token& token : tokens) {
+ token.start += span.first;
+ token.end += span.first;
+ result->emplace_back(std::move(token));
+ }
+}
+
+void Tokenizer::InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const {
+ std::vector<Token> result;
+ CodepointSpan span(-1, -1);
+ for (Token& token : *tokens) {
+ const UnicodeText unicode_token_value =
+ UTF8ToUnicodeText(token.value, /*do_copy=*/false);
+ bool should_retokenize = true;
+ for (const int codepoint : unicode_token_value) {
+ if (!IsCodepointInRanges(codepoint,
+ internal_tokenizer_codepoint_ranges_)) {
+ should_retokenize = false;
+ break;
+ }
+ }
+
+ if (should_retokenize) {
+ if (span.first < 0) {
+ span.first = token.start;
+ }
+ span.second = token.end;
+ } else {
+ TokenizeSubstring(unicode_text, span, &result);
+ span.first = -1;
+ result.emplace_back(std::move(token));
+ }
+ }
+ TokenizeSubstring(unicode_text, span, &result);
+
+ *tokens = std::move(result);
+}
+
+bool Tokenizer::ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const {
+ std::unique_ptr<UniLib::BreakIterator> break_iterator =
+ unilib_->CreateBreakIterator(context_unicode);
+ if (!break_iterator) {
+ return false;
+ }
+ int last_break_index = 0;
+ int break_index = 0;
+ int last_unicode_index = 0;
+ int unicode_index = 0;
+ auto token_begin_it = context_unicode.begin();
+ while ((break_index = break_iterator->Next()) !=
+ UniLib::BreakIterator::kDone) {
+ const int token_length = break_index - last_break_index;
+ unicode_index = last_unicode_index + token_length;
+
+ auto token_end_it = token_begin_it;
+ std::advance(token_end_it, token_length);
+
+ // Determine if the whole token is whitespace.
+ bool is_whitespace = true;
+ for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
+ if (!unilib_->IsWhitespace(*char_it)) {
+ is_whitespace = false;
+ break;
+ }
+ }
+
+ const std::string token =
+ context_unicode.UTF8Substring(token_begin_it, token_end_it);
+
+ if (!is_whitespace || icu_preserve_whitespace_tokens_) {
+ result->push_back(Token(token, last_unicode_index, unicode_index,
+ /*is_padding=*/false, is_whitespace));
+ }
+
+ last_break_index = break_index;
+ last_unicode_index = unicode_index;
+ token_begin_it = token_end_it;
+ }
+
+ return true;
+}
+
+bool Tokenizer::NumberTokenize(const UnicodeText& text_unicode,
+ std::vector<Token>* result) const {
+ Token new_token("", 0, 0);
+ NumberTokenType current_token_type = NOT_SET;
+ int codepoint_index = 0;
+
+ auto PushToken = [&new_token, result]() {
+ if (!new_token.value.empty()) {
+ result->push_back(new_token);
+ }
+ };
+
+ auto MaybeResetTokenAndAddChar =
+ [&new_token, PushToken, ¤t_token_type](
+ int codepoint_index, NumberTokenType token_type,
+ UnicodeText::const_iterator it, bool is_whitespace = false) {
+ if (current_token_type != token_type) {
+ PushToken();
+ new_token = Token("", codepoint_index, codepoint_index,
+ /*is_padding=*/false, is_whitespace);
+ }
+ new_token.end += 1;
+ AppendCodepointToToken(it, &new_token);
+ current_token_type = token_type;
+ };
+
+ auto FinishTokenAndAddSeparator =
+ [&new_token, result, ¤t_token_type, PushToken](
+ int codepoint_index, UnicodeText::const_iterator it) {
+ PushToken();
+
+ result->emplace_back("", codepoint_index, codepoint_index + 1);
+ AppendCodepointToToken(it, &result->back());
+
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ current_token_type = NOT_SET;
+ };
+
+ for (auto it = text_unicode.begin(); it != text_unicode.end();
+ ++it, ++codepoint_index) {
+ if (unilib_->IsDigit(*it)) {
+ MaybeResetTokenAndAddChar(codepoint_index, NUMERICAL, it);
+ } else if (unilib_->IsLetter(*it)) {
+ MaybeResetTokenAndAddChar(codepoint_index, TERM, it);
+ } else if (unilib_->IsWhitespace(*it)) {
+ MaybeResetTokenAndAddChar(codepoint_index, WHITESPACE, it,
+ /*is_whitespace=*/true);
+ } else if (unilib_->IsDot(*it) && preserve_floating_numbers_) {
+ auto it_next = std::next(it);
+ if (current_token_type == NUMERICAL && it_next != text_unicode.end() &&
+ unilib_->IsDigit(*it_next)) {
+ new_token.end += 1;
+ AppendCodepointToToken(it, &new_token);
+ } else {
+ // If the current token is not a number or dot at the end or followed
+ // by a non digit => separate token
+ FinishTokenAndAddSeparator(codepoint_index, it);
+ }
+ } else {
+ FinishTokenAndAddSeparator(codepoint_index, it);
+ }
+ }
+ PushToken();
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.fbs b/native/utils/tokenizer.fbs
new file mode 100755
index 0000000..c0a3919
--- /dev/null
+++ b/native/utils/tokenizer.fbs
@@ -0,0 +1,73 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Controls the type of tokenization the model will use for the input text.
+namespace libtextclassifier3;
+enum TokenizationType : int {
+ INVALID_TOKENIZATION_TYPE = 0,
+
+ // Use the internal tokenizer for tokenization.
+ INTERNAL_TOKENIZER = 1,
+
+ // Use ICU for tokenization.
+ ICU = 2,
+
+ // First apply ICU tokenization. Then identify stretches of tokens
+ // consisting only of codepoints in internal_tokenizer_codepoint_ranges
+ // and re-tokenize them using the internal tokenizer.
+ MIXED = 3,
+
+ // Tokenizer parsing out numbers, words and separators.
+ LETTER_DIGIT = 4,
+}
+
+// Role of the codepoints in the range.
+namespace libtextclassifier3.TokenizationCodepointRange_;
+enum Role : int {
+ // Concatenates the codepoint to the current run of codepoints.
+ DEFAULT_ROLE = 0,
+
+ // Splits a run of codepoints before the current codepoint.
+ SPLIT_BEFORE = 1,
+
+ // Splits a run of codepoints after the current codepoint.
+ SPLIT_AFTER = 2,
+
+ // Each codepoint will be a separate token. Good e.g. for Chinese
+ // characters.
+ TOKEN_SEPARATOR = 3,
+
+ // Discards the codepoint.
+ DISCARD_CODEPOINT = 4,
+
+ // Common values:
+ // Splits on the characters and discards them. Good e.g. for the space
+ // character.
+ WHITESPACE_SEPARATOR = 7,
+}
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+namespace libtextclassifier3;
+table TokenizationCodepointRange {
+ start:int;
+ end:int;
+ role:TokenizationCodepointRange_.Role;
+
+ // Integer identifier of the script this range denotes. Negative values are
+ // reserved for Tokenizer's internal use.
+ script_id:int;
+}
+
diff --git a/native/utils/tokenizer.h b/native/utils/tokenizer.h
new file mode 100644
index 0000000..63b95d8
--- /dev/null
+++ b/native/utils/tokenizer.h
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/codepoint-range.h"
+#include "utils/tokenizer_generated.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+const int kInvalidScript = -1;
+const int kUnknownScript = -2;
+
+// Tokenizer splits the input string into a sequence of tokens, according to
+// the configuration.
+class Tokenizer {
+ public:
+ // `codepoint_ranges`: Codepoint ranges that determine how different
+ // codepoints are tokenized. The ranges must not overlap.
+ // `internal_tokenizer_codepoint_ranges`: Codepoint ranges that define which
+ // tokens should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // `split_on_script_change`: Whether to consider a change of codepoint script
+ // in a sequence of characters as a token boundary. If True, will treat
+ // script change as a token boundary.
+ // `icu_preserve_whitespace_tokens`: If true, will include empty tokens in the
+ // output (in the ICU tokenization mode).
+ // `preserve_floating_numbers`: If true (default), will keep dots between
+ // digits together, not making separate tokens (in the LETTER_DIGIT
+ // tokenization mode).
+ Tokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens,
+ const bool preserve_floating_numbers);
+
+ Tokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens)
+ : Tokenizer(type, unilib, codepoint_ranges,
+ internal_tokenizer_codepoint_ranges, split_on_script_change,
+ icu_preserve_whitespace_tokens,
+ /*preserve_floating_numbers=*/true) {}
+
+ Tokenizer(
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const bool split_on_script_change)
+ : Tokenizer(TokenizationType_INTERNAL_TOKENIZER, /*unilib=*/nullptr,
+ codepoint_ranges, /*internal_tokenizer_codepoint_ranges=*/{},
+ split_on_script_change,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/true) {}
+
+ // Describes the type of tokens used in the NumberTokenizer.
+ enum NumberTokenType {
+ INVALID_TOKEN_TYPE,
+ NUMERICAL,
+ TERM,
+ WHITESPACE,
+ SEPARATOR,
+ NOT_SET
+ };
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& text) const;
+
+ // Same as above but takes UnicodeText.
+ std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
+
+ protected:
+ // Finds the tokenization codepoint range config for given codepoint.
+ // Internally uses binary search so should be O(log(# of codepoint_ranges)).
+ const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
+
+ // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
+ // and kUnknownScript are assigned.
+ void GetScriptAndRole(char32 codepoint,
+ TokenizationCodepointRange_::Role* role,
+ int* script) const;
+
+ // Tokenizes a substring of the unicode string, appending the resulting tokens
+ // to the output vector. The resulting tokens have bounds relative to the full
+ // string. Does nothing if the start of the span is negative.
+ void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
+ std::vector<Token>* result) const;
+
+ std::vector<Token> InternalTokenize(const UnicodeText& text_unicode) const;
+
+ // Takes the result of ICU tokenization and retokenizes stretches of tokens
+ // made of a specific subset of characters using the internal tokenizer.
+ void InternalRetokenize(const UnicodeText& unicode_text,
+ std::vector<Token>* tokens) const;
+
+ // Tokenizes the input text using ICU tokenizer.
+ bool ICUTokenize(const UnicodeText& context_unicode,
+ std::vector<Token>* result) const;
+
+ // Tokenizes the input in number, word and separator tokens.
+ bool NumberTokenize(const UnicodeText& text_unicode,
+ std::vector<Token>* result) const;
+
+ private:
+ const TokenizationType type_;
+
+ const UniLib* unilib_;
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
+ codepoint_ranges_;
+
+ // Codepoint ranges that define which tokens (consisting of which codepoints)
+ // should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRangeStruct> internal_tokenizer_codepoint_ranges_;
+
+ // If true, tokens will be additionally split when the codepoint's script_id
+ // changes.
+ const bool split_on_script_change_;
+
+ const bool icu_preserve_whitespace_tokens_;
+ const bool preserve_floating_numbers_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
diff --git a/native/utils/tokenizer_test.cc b/native/utils/tokenizer_test.cc
new file mode 100644
index 0000000..f73f8f8
--- /dev/null
+++ b/native/utils/tokenizer_test.cc
@@ -0,0 +1,626 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenizer.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingTokenizer : public Tokenizer {
+ public:
+ TestingTokenizer(
+ const TokenizationType type, const UniLib* unilib,
+ const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
+ const std::vector<const CodepointRange*>&
+ internal_tokenizer_codepoint_ranges,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens,
+ const bool preserve_floating_numbers)
+ : Tokenizer(type, unilib, codepoint_ranges,
+ internal_tokenizer_codepoint_ranges, split_on_script_change,
+ icu_preserve_whitespace_tokens, preserve_floating_numbers) {}
+
+ using Tokenizer::FindTokenizationRange;
+};
+
+class TestingTokenizerProxy {
+ public:
+ TestingTokenizerProxy(
+ TokenizationType type,
+ const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
+ const std::vector<CodepointRangeT>& internal_codepoint_range_configs,
+ const bool split_on_script_change,
+ const bool icu_preserve_whitespace_tokens,
+ const bool preserve_floating_numbers)
+ : INIT_UNILIB_FOR_TESTING(unilib_) {
+ const int num_configs = codepoint_range_configs.size();
+ std::vector<const TokenizationCodepointRange*> configs_fb;
+ configs_fb.reserve(num_configs);
+ const int num_internal_configs = internal_codepoint_range_configs.size();
+ std::vector<const CodepointRange*> internal_configs_fb;
+ internal_configs_fb.reserve(num_internal_configs);
+ buffers_.reserve(num_configs + num_internal_configs);
+ for (int i = 0; i < num_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateTokenizationCodepointRange(
+ builder, &codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ configs_fb.push_back(flatbuffers::GetRoot<TokenizationCodepointRange>(
+ buffers_.back().data()));
+ }
+ for (int i = 0; i < num_internal_configs; i++) {
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(
+ CreateCodepointRange(builder, &internal_codepoint_range_configs[i]));
+ buffers_.push_back(builder.Release());
+ internal_configs_fb.push_back(
+ flatbuffers::GetRoot<CodepointRange>(buffers_.back().data()));
+ }
+ tokenizer_ = std::unique_ptr<TestingTokenizer>(new TestingTokenizer(
+ type, &unilib_, configs_fb, internal_configs_fb, split_on_script_change,
+ icu_preserve_whitespace_tokens, preserve_floating_numbers));
+ }
+
+ TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
+ const TokenizationCodepointRangeT* range =
+ tokenizer_->FindTokenizationRange(c);
+ if (range != nullptr) {
+ return range->role;
+ } else {
+ return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ }
+ }
+
+ std::vector<Token> Tokenize(const std::string& utf8_text) const {
+ return tokenizer_->Tokenize(utf8_text);
+ }
+
+ private:
+ UniLib unilib_;
+ std::vector<flatbuffers::DetachedBuffer> buffers_;
+ std::unique_ptr<TestingTokenizer> tokenizer_;
+};
+
+TEST(TokenizerTest, FindTokenizationRange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 10;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 1234;
+ config->end = 12345;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {}, /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+
+ // Test hits to the first group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit to the second group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
+ TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test hits to the third group.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
+ TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+
+ // Test a hit outside.
+ EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
+ TokenizationCodepointRange_::Role_DEFAULT_ROLE);
+}
+
+TEST(TokenizerTest, TokenizeOnSpace) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ // Space character.
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
+
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
+}
+
+TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Latin.
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ config->script_id = 1;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ config->script_id = 1;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/true,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
+ Token("전화", 7, 10), Token("(123)", 10, 15),
+ Token("456-789", 16, 23),
+ Token("웹사이트", 23, 28)}));
+} // namespace
+
+TEST(TokenizerTest, TokenizeComplex) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
+ // Latin - cyrilic.
+ // 0000..007F; Basic Latin
+ // 0080..00FF; Latin-1 Supplement
+ // 0100..017F; Latin Extended-A
+ // 0180..024F; Latin Extended-B
+ // 0250..02AF; IPA Extensions
+ // 02B0..02FF; Spacing Modifier Letters
+ // 0300..036F; Combining Diacritical Marks
+ // 0370..03FF; Greek and Coptic
+ // 0400..04FF; Cyrillic
+ // 0500..052F; Cyrillic Supplement
+ // 0530..058F; Armenian
+ // 0590..05FF; Hebrew
+ // 0600..06FF; Arabic
+ // 0700..074F; Syriac
+ // 0750..077F; Arabic Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 32;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 33;
+ config->end = 0x77F + 1;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ // CJK
+ // 2E80..2EFF; CJK Radicals Supplement
+ // 3000..303F; CJK Symbols and Punctuation
+ // 3040..309F; Hiragana
+ // 30A0..30FF; Katakana
+ // 3100..312F; Bopomofo
+ // 3130..318F; Hangul Compatibility Jamo
+ // 3190..319F; Kanbun
+ // 31A0..31BF; Bopomofo Extended
+ // 31C0..31EF; CJK Strokes
+ // 31F0..31FF; Katakana Phonetic Extensions
+ // 3200..32FF; Enclosed CJK Letters and Months
+ // 3300..33FF; CJK Compatibility
+ // 3400..4DBF; CJK Unified Ideographs Extension A
+ // 4DC0..4DFF; Yijing Hexagram Symbols
+ // 4E00..9FFF; CJK Unified Ideographs
+ // A000..A48F; Yi Syllables
+ // A490..A4CF; Yi Radicals
+ // A4D0..A4FF; Lisu
+ // A500..A63F; Vai
+ // F900..FAFF; CJK Compatibility Ideographs
+ // FE30..FE4F; CJK Compatibility Forms
+ // 20000..2A6DF; CJK Unified Ideographs Extension B
+ // 2A700..2B73F; CJK Unified Ideographs Extension C
+ // 2B740..2B81F; CJK Unified Ideographs Extension D
+ // 2B820..2CEAF; CJK Unified Ideographs Extension E
+ // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
+ // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2E80;
+ config->end = 0x2EFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x3000;
+ config->end = 0xA63F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xF900;
+ config->end = 0xFAFF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0xFE30;
+ config->end = 0xFE4F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x20000;
+ config->end = 0x2A6DF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2A700;
+ config->end = 0x2B73F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B740;
+ config->end = 0x2B81F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2B820;
+ config->end = 0x2CEAF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2CEB0;
+ config->end = 0x2EBEF + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x2F800;
+ config->end = 0x2FA1F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ // Thai.
+ // 0E00..0E7F; Thai
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0x0E00;
+ config->end = 0x0E7F + 1;
+ config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
+ {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens;
+
+ tokens = tokenizer.Tokenize(
+ "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
+ EXPECT_EQ(tokens.size(), 30);
+
+ tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
+ // clang-format off
+ EXPECT_THAT(
+ tokens,
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
+ // clang-format on
+}
+
+#if defined(TC3_TEST_ICU) || defined(__APPLE__)
+TEST(TokenizerTest, ICUTokenizeWithWhitespaces) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/true,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
+ // clang-format off
+ ASSERT_EQ(tokens,
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token(" ", 6, 7),
+ Token("สมเด็จ", 7, 13),
+ Token(" ", 13, 14),
+ Token("พระ", 14, 17),
+ Token(" ", 17, 18),
+ Token("ปร", 18, 20),
+ Token(" ", 20, 21),
+ Token("มิ", 21, 23)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, ICUTokenizePunctuation) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/true,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens =
+ tokenizer.Tokenize("The interval is: -(12, 138*)");
+ // clang-format off
+ ASSERT_EQ(
+ tokens,
+ std::vector<Token>({Token("The", 0, 3),
+ Token(" ", 3, 4),
+ Token("interval", 4, 12),
+ Token(" ", 12, 13),
+ Token("is", 13, 15),
+ Token(":", 15, 16),
+ Token(" ", 16, 17),
+ Token("-", 17, 18),
+ Token("(", 18, 19),
+ Token("12", 19, 21),
+ Token(",", 21, 22),
+ Token(" ", 22, 23),
+ Token("138", 23, 26),
+ Token("*", 26, 27),
+ Token(")", 27, 28)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, ICUTokenizeWithNumbers) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/true,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("3.1 3﹒2 3.3");
+ // clang-format off
+ ASSERT_EQ(tokens,
+ std::vector<Token>({Token("3.1", 0, 3),
+ Token(" ", 3, 4),
+ Token("3﹒2", 4, 7),
+ Token(" ", 7, 8),
+ Token("3.3", 8, 11)}));
+ // clang-format on
+}
+#endif
+
+#if defined(TC3_TEST_ICU)
+TEST(TokenizerTest, ICUTokenize) {
+ TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("พระบาทสมเด็จพระปรมิ");
+ // clang-format off
+ ASSERT_EQ(tokens,
+ std::vector<Token>({Token("พระบาท", 0, 6),
+ Token("สมเด็จ", 6, 12),
+ Token("พระ", 12, 15),
+ Token("ปร", 15, 17),
+ Token("มิ", 17, 19)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, MixedTokenize) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 32;
+ config->end = 33;
+ config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
+
+ std::vector<CodepointRangeT> internal_configs;
+ CodepointRangeT* interal_config;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 0;
+ interal_config->end = 128;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 128;
+ interal_config->end = 256;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 256;
+ interal_config->end = 384;
+
+ internal_configs.emplace_back();
+ interal_config = &internal_configs.back();
+ interal_config->start = 384;
+ interal_config->end = 592;
+
+ TestingTokenizerProxy tokenizer(TokenizationType_MIXED, configs,
+ internal_configs,
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+
+ std::vector<Token> tokens = tokenizer.Tokenize(
+ "こんにちはJapanese-ląnguagę text 你好世界 http://www.google.com/");
+ ASSERT_EQ(
+ tokens,
+ // clang-format off
+ std::vector<Token>({Token("こんにちは", 0, 5),
+ Token("Japanese-ląnguagę", 5, 22),
+ Token("text", 23, 27),
+ Token("你好", 28, 30),
+ Token("世界", 30, 32),
+ Token("http://www.google.com/", 33, 55)}));
+ // clang-format on
+}
+
+TEST(TokenizerTest, InternalTokenizeOnScriptChange) {
+ std::vector<TokenizationCodepointRangeT> configs;
+ TokenizationCodepointRangeT* config;
+
+ configs.emplace_back();
+ config = &configs.back();
+ config->start = 0;
+ config->end = 256;
+ config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
+
+ {
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
+ configs, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+
+ EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
+ }
+
+ {
+ TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
+ configs, {},
+ /*split_on_script_change=*/true,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
+ std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
+ Token("웹사이트", 7, 11)}));
+ }
+}
+#endif
+
+TEST(TokenizerTest, LetterDigitTokenize) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/true);
+ std::vector<Token> tokens = tokenizer.Tokenize("7% -3.14 68.9#? 7% $99 .18.");
+ ASSERT_EQ(tokens,
+ std::vector<Token>(
+ {Token("7", 0, 1), Token("%", 1, 2), Token(" ", 2, 3),
+ Token("-", 3, 4), Token("3.14", 4, 8), Token(" ", 8, 9),
+ Token("68.9", 9, 13), Token("#", 13, 14), Token("?", 14, 15),
+ Token(" ", 15, 16), Token("7", 16, 17), Token("%", 17, 18),
+ Token(" ", 18, 19), Token("$", 19, 20), Token("99", 20, 22),
+ Token(" ", 22, 23), Token(".", 23, 24), Token("18", 24, 26),
+ Token(".", 26, 27)}));
+}
+
+TEST(TokenizerTest, LetterDigitTokenizeUnicode) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/true);
+ std::vector<Token> tokens = tokenizer.Tokenize("2 pércént 3パーセント");
+ ASSERT_EQ(tokens, std::vector<Token>({Token("2", 0, 1), Token(" ", 1, 2),
+ Token("pércént", 2, 9),
+ Token(" ", 9, 10), Token("3", 10, 11),
+ Token("パーセント", 11, 16)}));
+}
+
+TEST(TokenizerTest, LetterDigitTokenizeWithDots) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/true);
+ std::vector<Token> tokens = tokenizer.Tokenize("3 3﹒2 3.3%");
+ ASSERT_EQ(tokens,
+ std::vector<Token>({Token("3", 0, 1), Token(" ", 1, 2),
+ Token("3﹒2", 2, 5), Token(" ", 5, 6),
+ Token("3.3", 6, 9), Token("%", 9, 10)}));
+}
+
+TEST(TokenizerTest, LetterDigitTokenizeDoNotPreserveFloatingNumbers) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("15.12.2019 january's 3.2");
+ ASSERT_EQ(tokens,
+ std::vector<Token>(
+ {Token("15", 0, 2), Token(".", 2, 3), Token("12", 3, 5),
+ Token(".", 5, 6), Token("2019", 6, 10), Token(" ", 10, 11),
+ Token("january", 11, 18), Token("'", 18, 19),
+ Token("s", 19, 20), Token(" ", 20, 21), Token("3", 21, 22),
+ Token(".", 22, 23), Token("2", 23, 24)}));
+}
+
+TEST(TokenizerTest, LetterDigitTokenizeStrangeStringFloatingNumbers) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("The+2345++the +íí+");
+ ASSERT_EQ(tokens,
+ std::vector<Token>({Token("The", 0, 3), Token("+", 3, 4),
+ Token("2345", 4, 8), Token("+", 8, 9),
+ Token("+", 9, 10), Token("the", 10, 13),
+ Token(" ", 13, 14), Token("+", 14, 15),
+ Token("íí", 15, 17), Token("+", 17, 18)}));
+}
+
+TEST(TokenizerTest, LetterDigitTokenizeWhitespcesInSameToken) {
+ TestingTokenizerProxy tokenizer(TokenizationType_LETTER_DIGIT, {}, {},
+ /*split_on_script_change=*/false,
+ /*icu_preserve_whitespace_tokens=*/false,
+ /*preserve_floating_numbers=*/false);
+ std::vector<Token> tokens = tokenizer.Tokenize("2 3 4 5");
+ ASSERT_EQ(tokens, std::vector<Token>({Token("2", 0, 1), Token(" ", 1, 2),
+ Token("3", 2, 3), Token(" ", 3, 5),
+ Token("4", 5, 6), Token(" ", 6, 9),
+ Token("5", 9, 10)}));
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc
new file mode 100644
index 0000000..7b56ce2
--- /dev/null
+++ b/native/utils/utf8/unicodetext.cc
@@ -0,0 +1,323 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unicodetext.h"
+
+#include <string.h>
+
+#include <algorithm>
+
+#include "utils/base/logging.h"
+#include "utils/strings/utf8.h"
+
+namespace libtextclassifier3 {
+
+// *************** Data representation **********
+// Note: the copy constructor is undefined.
+
+UnicodeText::Repr& UnicodeText::Repr::operator=(Repr&& src) {
+ if (ours_ && data_) delete[] data_;
+ data_ = src.data_;
+ size_ = src.size_;
+ capacity_ = src.capacity_;
+ ours_ = src.ours_;
+ src.ours_ = false;
+ return *this;
+}
+
+void UnicodeText::Repr::PointTo(const char* data, int size) {
+ if (ours_ && data_) delete[] data_; // If we owned the old buffer, free it.
+ data_ = const_cast<char*>(data);
+ size_ = size;
+ capacity_ = size;
+ ours_ = false;
+}
+
+void UnicodeText::Repr::Copy(const char* data, int size) {
+ resize(size);
+ memcpy(data_, data, size);
+}
+
+void UnicodeText::Repr::resize(int new_size) {
+ if (new_size == 0) {
+ clear();
+ } else {
+ if (!ours_ || new_size > capacity_) reserve(new_size);
+ // Clear the memory in the expanded part.
+ if (size_ < new_size) memset(data_ + size_, 0, new_size - size_);
+ size_ = new_size;
+ ours_ = true;
+ }
+}
+
+void UnicodeText::Repr::reserve(int new_capacity) {
+ // If there's already enough capacity, and we're an owner, do nothing.
+ if (capacity_ >= new_capacity && ours_) return;
+
+ // Otherwise, allocate a new buffer.
+ capacity_ = std::max(new_capacity, (3 * capacity_) / 2 + 20);
+ char* new_data = new char[capacity_];
+
+ // If there is an old buffer, copy it into the new buffer.
+ if (data_) {
+ memcpy(new_data, data_, size_);
+ if (ours_) delete[] data_; // If we owned the old buffer, free it.
+ }
+ data_ = new_data;
+ ours_ = true; // We own the new buffer.
+ // size_ is unchanged.
+}
+
+void UnicodeText::Repr::append(const char* bytes, int byte_length) {
+ reserve(size_ + byte_length);
+ memcpy(data_ + size_, bytes, byte_length);
+ size_ += byte_length;
+}
+
+void UnicodeText::Repr::clear() {
+ if (ours_) delete[] data_;
+ data_ = nullptr;
+ size_ = capacity_ = 0;
+ ours_ = true;
+}
+
+// *************** UnicodeText ******************
+
+UnicodeText::UnicodeText() {}
+
+UnicodeText::UnicodeText(const UnicodeText& src, bool do_copy) {
+ if (do_copy) {
+ Copy(src);
+ } else {
+ repr_.PointTo(src.repr_.data_, src.repr_.size_);
+ }
+}
+
+UnicodeText& UnicodeText::operator=(UnicodeText&& src) {
+ this->repr_ = std::move(src.repr_);
+ return *this;
+}
+
+UnicodeText& UnicodeText::Copy(const UnicodeText& src) {
+ repr_.Copy(src.repr_.data_, src.repr_.size_);
+ return *this;
+}
+
+UnicodeText& UnicodeText::PointToUTF8(const char* buffer, int byte_length) {
+ repr_.PointTo(buffer, byte_length);
+ return *this;
+}
+
+UnicodeText& UnicodeText::CopyUTF8(const char* buffer, int byte_length) {
+ repr_.Copy(buffer, byte_length);
+ return *this;
+}
+
+UnicodeText& UnicodeText::AppendUTF8(const char* utf8, int len) {
+ repr_.append(utf8, len);
+ return *this;
+}
+
+const char* UnicodeText::data() const { return repr_.data_; }
+
+int UnicodeText::size_bytes() const { return repr_.size_; }
+
+namespace {
+
+enum {
+ RuneError = 0xFFFD, // Decoding error in UTF.
+ RuneMax = 0x10FFFF, // Maximum rune value.
+};
+
+int runetochar(const char32 rune, char* dest) {
+ // Convert to unsigned for range check.
+ uint32 c;
+
+ // 1 char 00-7F
+ c = rune;
+ if (c <= 0x7F) {
+ dest[0] = static_cast<char>(c);
+ return 1;
+ }
+
+ // 2 char 0080-07FF
+ if (c <= 0x07FF) {
+ dest[0] = 0xC0 | static_cast<char>(c >> 1 * 6);
+ dest[1] = 0x80 | (c & 0x3F);
+ return 2;
+ }
+
+ // Range check
+ if (c > RuneMax) {
+ c = RuneError;
+ }
+
+ // 3 char 0800-FFFF
+ if (c <= 0xFFFF) {
+ dest[0] = 0xE0 | static_cast<char>(c >> 2 * 6);
+ dest[1] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[2] = 0x80 | (c & 0x3F);
+ return 3;
+ }
+
+ // 4 char 10000-1FFFFF
+ dest[0] = 0xF0 | static_cast<char>(c >> 3 * 6);
+ dest[1] = 0x80 | ((c >> 2 * 6) & 0x3F);
+ dest[2] = 0x80 | ((c >> 1 * 6) & 0x3F);
+ dest[3] = 0x80 | (c & 0x3F);
+ return 4;
+}
+
+} // namespace
+
+UnicodeText& UnicodeText::push_back(char32 ch) {
+ char str[4];
+ int char_len = runetochar(ch, str);
+ repr_.append(str, char_len);
+ return *this;
+}
+
+void UnicodeText::clear() { repr_.clear(); }
+
+int UnicodeText::size_codepoints() const {
+ return std::distance(begin(), end());
+}
+
+bool UnicodeText::empty() const { return size_bytes() == 0; }
+
+bool UnicodeText::is_valid() const {
+ return IsValidUTF8(repr_.data_, repr_.size_);
+}
+
+bool UnicodeText::operator==(const UnicodeText& other) const {
+ if (repr_.size_ != other.repr_.size_) {
+ return false;
+ }
+ return memcmp(repr_.data_, other.repr_.data_, repr_.size_) == 0;
+}
+
+std::string UnicodeText::ToUTF8String() const {
+ return UTF8Substring(begin(), end());
+}
+
+std::string UnicodeText::UTF8Substring(int begin_codepoint,
+ int end_codepoint) const {
+ auto span_begin = begin();
+ std::advance(span_begin, begin_codepoint);
+ auto span_end = span_begin;
+ std::advance(span_end, end_codepoint - begin_codepoint);
+ return UTF8Substring(span_begin, span_end);
+}
+
+std::string UnicodeText::UTF8Substring(const const_iterator& it_begin,
+ const const_iterator& it_end) {
+ return std::string(it_begin.it_, it_end.it_ - it_begin.it_);
+}
+
+UnicodeText UnicodeText::Substring(const UnicodeText& text, int begin_codepoint,
+ int end_codepoint, bool do_copy) {
+ auto it_begin = text.begin();
+ std::advance(it_begin, begin_codepoint);
+ auto it_end = text.begin();
+ std::advance(it_end, end_codepoint);
+
+ return Substring(it_begin, it_end, do_copy);
+}
+
+UnicodeText UnicodeText::Substring(const const_iterator& it_begin,
+ const const_iterator& it_end, bool do_copy) {
+ if (do_copy) {
+ UnicodeText result;
+ result.repr_.Copy(it_begin.it_, it_end.it_ - it_begin.it_);
+ return result;
+ } else {
+ UnicodeText result;
+ result.repr_.PointTo(it_begin.it_, it_end.it_ - it_begin.it_);
+ return result;
+ }
+}
+
+UnicodeText::~UnicodeText() {}
+
+// ******************* UnicodeText::const_iterator *********************
+
+// The implementation of const_iterator would be nicer if it
+// inherited from boost::iterator_facade
+// (http://boost.org/libs/iterator/doc/iterator_facade.html).
+
+UnicodeText::const_iterator::const_iterator() : it_(nullptr) {}
+
+UnicodeText::const_iterator& UnicodeText::const_iterator::operator=(
+ const const_iterator& other) {
+ if (&other != this) it_ = other.it_;
+ return *this;
+}
+
+UnicodeText::const_iterator UnicodeText::begin() const {
+ return const_iterator(repr_.data_);
+}
+
+UnicodeText::const_iterator UnicodeText::end() const {
+ return const_iterator(repr_.data_ + repr_.size_);
+}
+
+bool operator<(const UnicodeText::const_iterator& lhs,
+ const UnicodeText::const_iterator& rhs) {
+ return lhs.it_ < rhs.it_;
+}
+
+char32 UnicodeText::const_iterator::operator*() const {
+ // (We could call chartorune here, but that does some
+ // error-checking, and we're guaranteed that our data is valid
+ // UTF-8. Also, we expect this routine to be called very often. So
+ // for speed, we do the calculation ourselves.)
+ return ValidCharToRune(it_);
+}
+
+UnicodeText::const_iterator& UnicodeText::const_iterator::operator++() {
+ it_ += GetNumBytesForUTF8Char(it_);
+ return *this;
+}
+
+UnicodeText::const_iterator& UnicodeText::const_iterator::operator--() {
+ while (IsTrailByte(*--it_)) {
+ }
+ return *this;
+}
+
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy) {
+ UnicodeText t;
+ if (do_copy) {
+ t.CopyUTF8(utf8_buf, len);
+ } else {
+ t.PointToUTF8(utf8_buf, len);
+ }
+ return t;
+}
+
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy) {
+ return UTF8ToUnicodeText(utf8_buf, strlen(utf8_buf), do_copy);
+}
+
+UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy) {
+ return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
+}
+
+UnicodeText UTF8ToUnicodeText(StringPiece str, bool do_copy) {
+ return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
new file mode 100644
index 0000000..9810480
--- /dev/null
+++ b/native/utils/utf8/unicodetext.h
@@ -0,0 +1,242 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
+
+#include <iterator>
+#include <string>
+#include <utility>
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// ***************************** UnicodeText **************************
+//
+// A UnicodeText object is a wrapper around a sequence of Unicode
+// codepoint values that allows iteration over these values.
+//
+// The internal representation of the text is UTF-8. Since UTF-8 is a
+// variable-width format, UnicodeText does not provide random access
+// to the text, and changes to the text are permitted only at the end.
+//
+// The UnicodeText class defines a const_iterator. The dereferencing
+// operator (*) returns a codepoint (int32). The iterator is a
+// read-only iterator. It becomes invalid if the text is changed.
+//
+// Codepoints are integers in the range [0, 0xD7FF] or [0xE000,
+// 0x10FFFF], but UnicodeText has the additional restriction that it
+// can contain only those characters that are valid for interchange on
+// the Web. This excludes all of the control codes except for carriage
+// return, line feed, and horizontal tab. It also excludes
+// non-characters, but codepoints that are in the Private Use regions
+// are allowed, as are codepoints that are unassigned. (See the
+// Unicode reference for details.)
+//
+// MEMORY MANAGEMENT:
+//
+// PointToUTF8(buffer, size) creates an alias pointing to buffer.
+//
+// The purpose of an alias is to avoid making an unnecessary copy of a
+// UTF-8 buffer while still providing access to the Unicode values
+// within that text through iterators. The lifetime of an alias must not
+// exceed the lifetime of the buffer from which it was constructed.
+//
+// Aliases should be used with care. If the source from which an alias
+// was created is freed, or if the contents are changed, while the
+// alias is still in use, fatal errors could result. But it can be
+// quite useful to have a UnicodeText "window" through which to see a
+// UTF-8 buffer without having to pay the price of making a copy.
+
+class UnicodeText {
+ public:
+ class const_iterator;
+
+ UnicodeText(); // Create an empty text.
+ UnicodeText(const UnicodeText& src, bool do_copy = true);
+ UnicodeText& operator=(UnicodeText&& src);
+ ~UnicodeText();
+
+ class const_iterator {
+ typedef const_iterator CI;
+
+ public:
+ typedef std::bidirectional_iterator_tag iterator_category;
+ typedef char32 value_type;
+ typedef int difference_type;
+ typedef void pointer; // (Not needed.)
+ typedef const char32 reference; // (Needed for const_reverse_iterator)
+
+ // Iterators are default-constructible.
+ const_iterator();
+
+ // It's safe to make multiple passes over a UnicodeText.
+ const_iterator& operator=(const const_iterator& other);
+
+ char32 operator*() const; // Dereference
+
+ const_iterator& operator++(); // Advance (++iter)
+ const_iterator operator++(int) { // (iter++)
+ const_iterator result(*this);
+ ++*this;
+ return result;
+ }
+
+ const_iterator& operator--(); // Retreat (--iter)
+ const_iterator operator--(int) { // (iter--)
+ const_iterator result(*this);
+ --*this;
+ return result;
+ }
+
+ friend bool operator==(const CI& lhs, const CI& rhs) {
+ return lhs.it_ == rhs.it_;
+ }
+ friend bool operator!=(const CI& lhs, const CI& rhs) {
+ return !(lhs == rhs);
+ }
+ friend bool operator<(const CI& lhs, const CI& rhs);
+ friend bool operator>(const CI& lhs, const CI& rhs) { return rhs < lhs; }
+ friend bool operator<=(const CI& lhs, const CI& rhs) {
+ return !(rhs < lhs);
+ }
+ friend bool operator>=(const CI& lhs, const CI& rhs) {
+ return !(lhs < rhs);
+ }
+
+ int utf8_length() const {
+ const unsigned char byte = static_cast<unsigned char>(it_[0]);
+ if (byte < 0x80) {
+ return 1;
+ } else if (byte < 0xE0) {
+ return 2;
+ } else if (byte < 0xF0) {
+ return 3;
+ } else {
+ return 4;
+ }
+ }
+ const char* utf8_data() const { return it_; }
+
+ private:
+ friend class UnicodeText;
+ explicit const_iterator(const char* it) : it_(it) {}
+
+ const char* it_;
+ };
+
+ const_iterator begin() const;
+ const_iterator end() const;
+
+ // Gets pointer to the underlying utf8 data.
+ const char* data() const;
+
+ // Gets length (in bytes) of the underlying utf8 data.
+ int size_bytes() const;
+
+ // Computes length (in number of Unicode codepoints) of the underlying utf8
+ // data.
+ // NOTE: Complexity O(n).
+ int size_codepoints() const;
+
+ bool empty() const;
+
+ // Checks whether the underlying data is valid utf8 data.
+ bool is_valid() const;
+
+ bool operator==(const UnicodeText& other) const;
+
+ // x.PointToUTF8(buf,len) changes x so that it points to buf
+ // ("becomes an alias"). It does not take ownership or copy buf.
+ // This function assumes that the input is interchange valid UTF8.
+ UnicodeText& Copy(const UnicodeText& src);
+ UnicodeText& PointToUTF8(const char* utf8_buffer, int byte_length);
+ UnicodeText& CopyUTF8(const char* utf8_buffer, int byte_length);
+
+ // Calling this may invalidate pointers to underlying data.
+ UnicodeText& AppendUTF8(const char* utf8, int len);
+ UnicodeText& push_back(char32 ch);
+ void clear();
+
+ std::string ToUTF8String() const;
+ std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
+ static std::string UTF8Substring(const const_iterator& it_begin,
+ const const_iterator& it_end);
+ static UnicodeText Substring(const UnicodeText& text, int begin_codepoint,
+ int end_codepoint, bool do_copy = true);
+ static UnicodeText Substring(const const_iterator& it_begin,
+ const const_iterator& it_end,
+ bool do_copy = true);
+
+ private:
+ friend class const_iterator;
+
+ class Repr { // A byte-string.
+ public:
+ char* data_;
+ int size_;
+ int capacity_;
+ bool ours_; // Do we own data_?
+
+ Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {}
+ Repr& operator=(Repr&& src);
+ ~Repr() {
+ if (ours_) delete[] data_;
+ }
+
+ void clear();
+ void reserve(int capacity);
+ void resize(int size);
+
+ void append(const char* bytes, int byte_length);
+ void Copy(const char* data, int size);
+ void PointTo(const char* data, int size);
+
+ private:
+ Repr& operator=(const Repr&);
+ Repr(const Repr& other);
+ };
+
+ Repr repr_;
+};
+
+typedef std::pair<UnicodeText::const_iterator, UnicodeText::const_iterator>
+ UnicodeTextRange;
+
+// NOTE: The following are needed to avoid implicit conversion from char* to
+// std::string, or from ::string to std::string, because if this happens it
+// often results in invalid memory access to a temporary object created during
+// such conversion (if do_copy == false).
+// NOTE: These methods don't check if the input string is UTF8 well formed, for
+// efficiency reasons. Use UnicodeText::is_valid() when explicitly needed.
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len,
+ bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy = true);
+UnicodeText UTF8ToUnicodeText(StringPiece str, bool do_copy = true);
+
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const UnicodeText& message) {
+ stream.message.append(message.data(), message.size_bytes());
+ return stream;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
diff --git a/native/utils/utf8/unicodetext_test.cc b/native/utils/utf8/unicodetext_test.cc
new file mode 100644
index 0000000..4e8883b
--- /dev/null
+++ b/native/utils/utf8/unicodetext_test.cc
@@ -0,0 +1,228 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unicodetext.h"
+
+#include "utils/strings/stringpiece.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+class UnicodeTextTest : public testing::Test {
+ protected:
+ UnicodeTextTest() : empty_text_() {
+ text_.push_back(0x1C0);
+ text_.push_back(0x4E8C);
+ text_.push_back(0xD7DB);
+ text_.push_back(0x34);
+ text_.push_back(0x1D11E);
+ }
+
+ UnicodeText empty_text_;
+ UnicodeText text_;
+};
+
+TEST(UnicodeTextTest, ConstructionFromUnicodeText) {
+ UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
+ EXPECT_EQ(UnicodeText(text).ToUTF8String(), "1234😋hello");
+ EXPECT_EQ(UnicodeText(text, /*do_copy=*/false).ToUTF8String(), "1234😋hello");
+}
+
+// Tests for our modifications of UnicodeText.
+TEST(UnicodeTextTest, Custom) {
+ UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
+ EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
+ EXPECT_EQ(text.size_codepoints(), 10);
+ EXPECT_EQ(text.size_bytes(), 13);
+
+ auto it_begin = text.begin();
+ std::advance(it_begin, 4);
+ auto it_end = text.begin();
+ std::advance(it_end, 6);
+ EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
+}
+
+TEST(UnicodeTextTest, StringPieceView) {
+ std::string raw_text = "1234😋hello";
+ UnicodeText text =
+ UTF8ToUnicodeText(StringPiece(raw_text), /*do_copy=*/false);
+ EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
+ EXPECT_EQ(text.size_codepoints(), 10);
+ EXPECT_EQ(text.size_bytes(), 13);
+
+ auto it_begin = text.begin();
+ std::advance(it_begin, 4);
+ auto it_end = text.begin();
+ std::advance(it_end, 6);
+ EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
+}
+
+TEST(UnicodeTextTest, Substring) {
+ UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
+
+ EXPECT_EQ(
+ UnicodeText::Substring(std::next(text.begin(), 4),
+ std::next(text.begin(), 6), /*do_copy=*/true),
+ UTF8ToUnicodeText("😋h"));
+ EXPECT_EQ(
+ UnicodeText::Substring(std::next(text.begin(), 4),
+ std::next(text.begin(), 6), /*do_copy=*/false),
+ UTF8ToUnicodeText("😋h"));
+ EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/true),
+ UTF8ToUnicodeText("😋h"));
+ EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/false),
+ UTF8ToUnicodeText("😋h"));
+}
+
+TEST(UnicodeTextTest, Ownership) {
+ const std::string src = "\u304A\u00B0\u106B";
+
+ UnicodeText alias;
+ alias.PointToUTF8(src.data(), src.size());
+ EXPECT_EQ(alias.data(), src.data());
+ UnicodeText::const_iterator it = alias.begin();
+ EXPECT_EQ(*it++, 0x304A);
+ EXPECT_EQ(*it++, 0x00B0);
+ EXPECT_EQ(*it++, 0x106B);
+ EXPECT_EQ(it, alias.end());
+
+ UnicodeText t = alias; // Copy initialization copies the data.
+ EXPECT_NE(t.data(), alias.data());
+}
+
+TEST(UnicodeTextTest, Validation) {
+ EXPECT_TRUE(UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("\u304A\u00B0\u106B", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("this is a test😋😋😋", /*do_copy=*/false).is_valid());
+ EXPECT_TRUE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x8b", /*do_copy=*/false).is_valid());
+ // Too short (string is too short).
+ EXPECT_FALSE(UTF8ToUnicodeText("\xf0\x9f", /*do_copy=*/false).is_valid());
+ // Too long (too many trailing bytes).
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x8b\x8b", /*do_copy=*/false).is_valid());
+ // Too short (too few trailing bytes).
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("\xf0\x9f\x98\x61\x61", /*do_copy=*/false).is_valid());
+ // Invalid with context.
+ EXPECT_FALSE(
+ UTF8ToUnicodeText("hello \xf0\x9f\x98\x61\x61 world1", /*do_copy=*/false)
+ .is_valid());
+}
+
+class IteratorTest : public UnicodeTextTest {};
+
+TEST_F(IteratorTest, Iterates) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0x1C0, *iter);
+ EXPECT_EQ(&iter, &++iter); // operator++ returns *this.
+ EXPECT_EQ(0x4E8C, *iter++);
+ EXPECT_EQ(0xD7DB, *iter);
+ // Make sure you can dereference more than once.
+ EXPECT_EQ(0xD7DB, *iter);
+ EXPECT_EQ(0x34, *++iter);
+ EXPECT_EQ(0x1D11E, *++iter);
+ ASSERT_TRUE(iter != text_.end());
+ iter++;
+ EXPECT_TRUE(iter == text_.end());
+}
+
+TEST_F(IteratorTest, MultiPass) {
+ // Also tests Default Constructible and Assignable.
+ UnicodeText::const_iterator i1, i2;
+ i1 = text_.begin();
+ i2 = i1;
+ EXPECT_EQ(0x4E8C, *++i1);
+ EXPECT_TRUE(i1 != i2);
+ EXPECT_EQ(0x1C0, *i2);
+ ++i2;
+ EXPECT_TRUE(i1 == i2);
+ EXPECT_EQ(0x4E8C, *i2);
+}
+
+TEST_F(IteratorTest, ReverseIterates) {
+ UnicodeText::const_iterator iter = text_.end();
+ EXPECT_TRUE(iter == text_.end());
+ iter--;
+ ASSERT_TRUE(iter != text_.end());
+ EXPECT_EQ(0x1D11E, *iter--);
+ EXPECT_EQ(0x34, *iter);
+ EXPECT_EQ(0xD7DB, *--iter);
+ // Make sure you can dereference more than once.
+ EXPECT_EQ(0xD7DB, *iter);
+ --iter;
+ EXPECT_EQ(0x4E8C, *iter--);
+ EXPECT_EQ(0x1C0, *iter);
+ EXPECT_TRUE(iter == text_.begin());
+}
+
+TEST_F(IteratorTest, Comparable) {
+ UnicodeText::const_iterator i1, i2;
+ i1 = text_.begin();
+ i2 = i1;
+ ++i2;
+
+ EXPECT_TRUE(i1 < i2);
+ EXPECT_TRUE(text_.begin() <= i1);
+ EXPECT_FALSE(i1 >= i2);
+ EXPECT_FALSE(i1 > text_.end());
+}
+
+TEST_F(IteratorTest, Advance) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0x1C0, *iter);
+ std::advance(iter, 4);
+ EXPECT_EQ(0x1D11E, *iter);
+ ++iter;
+ EXPECT_TRUE(iter == text_.end());
+}
+
+TEST_F(IteratorTest, Distance) {
+ UnicodeText::const_iterator iter = text_.begin();
+ EXPECT_EQ(0, std::distance(text_.begin(), iter));
+ EXPECT_EQ(5, std::distance(iter, text_.end()));
+ ++iter;
+ ++iter;
+ EXPECT_EQ(2, std::distance(text_.begin(), iter));
+ EXPECT_EQ(3, std::distance(iter, text_.end()));
+ ++iter;
+ ++iter;
+ EXPECT_EQ(4, std::distance(text_.begin(), iter));
+ ++iter;
+ EXPECT_EQ(0, std::distance(iter, text_.end()));
+}
+
+class OperatorTest : public UnicodeTextTest {};
+
+TEST_F(OperatorTest, Clear) {
+ UnicodeText empty_text(UTF8ToUnicodeText("", /*do_copy=*/false));
+ EXPECT_FALSE(text_ == empty_text);
+ text_.clear();
+ EXPECT_TRUE(text_ == empty_text);
+}
+
+TEST_F(OperatorTest, Empty) {
+ EXPECT_TRUE(empty_text_.empty());
+ EXPECT_FALSE(text_.empty());
+ text_.clear();
+ EXPECT_TRUE(text_.empty());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
new file mode 100644
index 0000000..de52086
--- /dev/null
+++ b/native/utils/utf8/unilib-common.cc
@@ -0,0 +1,677 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unilib-common.h"
+
+#include <algorithm>
+
+namespace libtextclassifier3 {
+namespace {
+
+#define ARRAYSIZE(a) sizeof(a) / sizeof(*a)
+
+// Derived from http://www.unicode.org/Public/UNIDATA/UnicodeData.txt
+// grep -E "Ps" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
+// IMPORTANT: entries with the same offsets in kOpeningBrackets and
+// kClosingBrackets must be counterparts.
+constexpr char32 kOpeningBrackets[] = {
+ 0x0028, 0x005B, 0x007B, 0x0F3C, 0x2045, 0x207D, 0x208D, 0x2329, 0x2768,
+ 0x276A, 0x276C, 0x2770, 0x2772, 0x2774, 0x27E6, 0x27E8, 0x27EA, 0x27EC,
+ 0x27EE, 0x2983, 0x2985, 0x2987, 0x2989, 0x298B, 0x298D, 0x298F, 0x2991,
+ 0x2993, 0x2995, 0x2997, 0x29FC, 0x2E22, 0x2E24, 0x2E26, 0x2E28, 0x3008,
+ 0x300A, 0x300C, 0x300E, 0x3010, 0x3014, 0x3016, 0x3018, 0x301A, 0xFD3F,
+ 0xFE17, 0xFE35, 0xFE37, 0xFE39, 0xFE3B, 0xFE3D, 0xFE3F, 0xFE41, 0xFE43,
+ 0xFE47, 0xFE59, 0xFE5B, 0xFE5D, 0xFF08, 0xFF3B, 0xFF5B, 0xFF5F, 0xFF62};
+constexpr int kNumOpeningBrackets = ARRAYSIZE(kOpeningBrackets);
+
+// grep -E "Pe" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
+constexpr char32 kClosingBrackets[] = {
+ 0x0029, 0x005D, 0x007D, 0x0F3D, 0x2046, 0x207E, 0x208E, 0x232A, 0x2769,
+ 0x276B, 0x276D, 0x2771, 0x2773, 0x2775, 0x27E7, 0x27E9, 0x27EB, 0x27ED,
+ 0x27EF, 0x2984, 0x2986, 0x2988, 0x298A, 0x298C, 0x298E, 0x2990, 0x2992,
+ 0x2994, 0x2996, 0x2998, 0x29FD, 0x2E23, 0x2E25, 0x2E27, 0x2E29, 0x3009,
+ 0x300B, 0x300D, 0x300F, 0x3011, 0x3015, 0x3017, 0x3019, 0x301B, 0xFD3E,
+ 0xFE18, 0xFE36, 0xFE38, 0xFE3A, 0xFE3C, 0xFE3E, 0xFE40, 0xFE42, 0xFE44,
+ 0xFE48, 0xFE5A, 0xFE5C, 0xFE5E, 0xFF09, 0xFF3D, 0xFF5D, 0xFF60, 0xFF63};
+constexpr int kNumClosingBrackets = ARRAYSIZE(kClosingBrackets);
+
+// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kWhitespaces[] = {
+ 0x0009, 0x000A, 0x000B, 0x000C, 0x000D, 0x0020, 0x0085, 0x00A0,
+ 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006,
+ 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x2029, 0x202F, 0x205F,
+ 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85,
+ 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A,
+ 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500,
+ 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
+constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
+
+// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+// As the name suggests, these ranges are always 10 codepoints long, so we just
+// store the end of the range.
+constexpr char32 kDecimalDigitRangesEnd[] = {
+ 0x0039, 0x0669, 0x06f9, 0x07c9, 0x096f, 0x09ef, 0x0a6f, 0x0aef,
+ 0x0b6f, 0x0bef, 0x0c6f, 0x0cef, 0x0d6f, 0x0def, 0x0e59, 0x0ed9,
+ 0x0f29, 0x1049, 0x1099, 0x17e9, 0x1819, 0x194f, 0x19d9, 0x1a89,
+ 0x1a99, 0x1b59, 0x1bb9, 0x1c49, 0x1c59, 0xa629, 0xa8d9, 0xa909,
+ 0xa9d9, 0xa9f9, 0xaa59, 0xabf9, 0xff19, 0x104a9, 0x1106f, 0x110f9,
+ 0x1113f, 0x111d9, 0x112f9, 0x11459, 0x114d9, 0x11659, 0x116c9, 0x11739,
+ 0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff};
+constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd);
+
+// Visual source: https://en.wikipedia.org/wiki/Latin_script_in_Unicode
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+// clang-format off
+// grep "LATIN " latters.txt | grep -v "TAG LATIN" | grep -v "SQUARED LATIN" | grep -v "CIRCLED LATIN" | grep -v "PARENTHESIZED LATIN" | cut -d' ' -f1 | cut -d'+' -f2 | sed -re "s/([0-9A-Z]+).*/0x\1, /" | tr -d "\n" NOLINT
+// clang-format on
+constexpr char32 kLatinLettersRangesStart[] = {0x0041, 0x0061, 0x00C0, 0x00D8,
+ 0x00F8, 0x1D00, 0x2C60, 0xAB30,
+ 0xFF21, 0xFF41};
+constexpr int kNumLatinLettersRangesStart = ARRAYSIZE(kLatinLettersRangesStart);
+constexpr char32 kLatinLettersRangesEnd[] = {0x005A, 0x007A, 0x00D6, 0x00F7,
+ 0x02A8, 0x1EFF, 0xA7B7, 0xAB64,
+ 0xFF3A, 0xFF5A};
+constexpr int kNumLatinLettersRangesEnd = ARRAYSIZE(kLatinLettersRangesEnd);
+
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+constexpr char32 kArabicLettersRangesStart[] = {
+ 0x0620, 0x0641, 0x066E, 0x06EE, 0x0750, 0x08A0, 0xFB50, 0xFDFA, 0xFE80};
+constexpr int kNumArabicLettersRangesStart =
+ ARRAYSIZE(kArabicLettersRangesStart);
+constexpr char32 kArabicLettersRangesEnd[] = {
+ 0x063F, 0x064A, 0x06D5, 0x06FF, 0x077F, 0x08BD, 0xFBFF, 0xFDFB, 0xFEF4};
+constexpr int kNumArabicLettersRangesEnd = ARRAYSIZE(kArabicLettersRangesEnd);
+
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+constexpr char32 kCyrillicLettersRangesStart[] = {0x0400, 0x1C80, 0x2DE0,
+ 0xA640, 0xA674, 0xA680};
+constexpr int kNumCyrillicLettersRangesStart =
+ ARRAYSIZE(kCyrillicLettersRangesStart);
+constexpr char32 kCyrillicLettersRangesEnd[] = {0x052F, 0x1C88, 0x2DFF,
+ 0xA66E, 0xA67B, 0xA69F};
+constexpr int kNumCyrillicLettersRangesEnd =
+ ARRAYSIZE(kCyrillicLettersRangesEnd);
+
+constexpr char32 kChineseLettersRangesStart[] = {
+ 0x4E00, 0xF900, 0x2F800, 0xFE30, 0x3400,
+ 0x20000, 0x2A700, 0x2B740, 0x2B820, 0x2CEB0};
+constexpr int kNumChineseLettersRangesStart =
+ ARRAYSIZE(kChineseLettersRangesStart);
+constexpr char32 kChineseLettersRangesEnd[] = {
+ 0x9FFF, 0xFAFF, 0x2FA1F, 0xFE4F, 0x4DBF,
+ 0x2A6DF, 0x2B73F, 0x2B81F, 0x2CEAF, 0x2EBEF};
+constexpr int kNumChineseLettersRangesEnd = ARRAYSIZE(kChineseLettersRangesEnd);
+
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+// Hiragana and Katakana
+constexpr char32 kJapaneseLettersRangesStart[] = {0x3041, 0x30A1, 0x31F0,
+ 0xFF66};
+constexpr int kNumJapaneseLettersRangesStart =
+ ARRAYSIZE(kJapaneseLettersRangesStart);
+constexpr char32 kJapaneseLettersRangesEnd[] = {0x3096, 0x30FA, 0x31FF, 0xFF9D};
+constexpr int kNumJapaneseLettersRangesEnd =
+ ARRAYSIZE(kJapaneseLettersRangesEnd);
+
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+// Hangul
+constexpr char32 kKoreanLettersRangesStart[] = {0x3131, 0xFFA1};
+constexpr int kNumKoreanLettersRangesStart =
+ ARRAYSIZE(kKoreanLettersRangesStart);
+constexpr char32 kKoreanLettersRangesEnd[] = {0x318E, 0xFFDC};
+constexpr int kNumKoreanLettersRangesEnd = ARRAYSIZE(kKoreanLettersRangesEnd);
+
+// Source https://unicode-search.net/unicode-namesearch.pl?term=letter
+constexpr char32 kThaiLettersRangesStart[] = {0x0E01};
+constexpr int kNumThaiLettersRangesStart = ARRAYSIZE(kThaiLettersRangesStart);
+constexpr char32 kThaiLettersRangesEnd[] = {0x0E2E};
+constexpr int kNumThaiLettersRangesEnd = ARRAYSIZE(kThaiLettersRangesEnd);
+
+// grep -E ";P.;" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kPunctuationRangesStart[] = {
+ 0x0021, 0x0025, 0x002c, 0x003a, 0x003f, 0x005b, 0x005f, 0x007b,
+ 0x007d, 0x00a1, 0x00a7, 0x00ab, 0x00b6, 0x00bb, 0x00bf, 0x037e,
+ 0x0387, 0x055a, 0x0589, 0x05be, 0x05c0, 0x05c3, 0x05c6, 0x05f3,
+ 0x0609, 0x060c, 0x061b, 0x061e, 0x066a, 0x06d4, 0x0700, 0x07f7,
+ 0x0830, 0x085e, 0x0964, 0x0970, 0x09fd, 0x0a76, 0x0af0, 0x0c77,
+ 0x0c84, 0x0df4, 0x0e4f, 0x0e5a, 0x0f04, 0x0f14, 0x0f3a, 0x0f85,
+ 0x0fd0, 0x0fd9, 0x104a, 0x10fb, 0x1360, 0x1400, 0x166e, 0x169b,
+ 0x16eb, 0x1735, 0x17d4, 0x17d8, 0x1800, 0x1944, 0x1a1e, 0x1aa0,
+ 0x1aa8, 0x1b5a, 0x1bfc, 0x1c3b, 0x1c7e, 0x1cc0, 0x1cd3, 0x2010,
+ 0x2030, 0x2045, 0x2053, 0x207d, 0x208d, 0x2308, 0x2329, 0x2768,
+ 0x27c5, 0x27e6, 0x2983, 0x29d8, 0x29fc, 0x2cf9, 0x2cfe, 0x2d70,
+ 0x2e00, 0x2e30, 0x3001, 0x3008, 0x3014, 0x3030, 0x303d, 0x30a0,
+ 0x30fb, 0xa4fe, 0xa60d, 0xa673, 0xa67e, 0xa6f2, 0xa874, 0xa8ce,
+ 0xa8f8, 0xa8fc, 0xa92e, 0xa95f, 0xa9c1, 0xa9de, 0xaa5c, 0xaade,
+ 0xaaf0, 0xabeb, 0xfd3e, 0xfe10, 0xfe30, 0xfe54, 0xfe63, 0xfe68,
+ 0xfe6a, 0xff01, 0xff05, 0xff0c, 0xff1a, 0xff1f, 0xff3b, 0xff3f,
+ 0xff5b, 0xff5d, 0xff5f, 0x10100, 0x1039f, 0x103d0, 0x1056f, 0x10857,
+ 0x1091f, 0x1093f, 0x10a50, 0x10a7f, 0x10af0, 0x10b39, 0x10b99, 0x10f55,
+ 0x11047, 0x110bb, 0x110be, 0x11140, 0x11174, 0x111c5, 0x111cd, 0x111db,
+ 0x111dd, 0x11238, 0x112a9, 0x1144b, 0x1145b, 0x1145d, 0x114c6, 0x115c1,
+ 0x11641, 0x11660, 0x1173c, 0x1183b, 0x119e2, 0x11a3f, 0x11a9a, 0x11a9e,
+ 0x11c41, 0x11c70, 0x11ef7, 0x11fff, 0x12470, 0x16a6e, 0x16af5, 0x16b37,
+ 0x16b44, 0x16e97, 0x16fe2, 0x1bc9f, 0x1da87, 0x1e95e};
+constexpr int kNumPunctuationRangesStart = ARRAYSIZE(kPunctuationRangesStart);
+constexpr char32 kPunctuationRangesEnd[] = {
+ 0x0023, 0x002a, 0x002f, 0x003b, 0x0040, 0x005d, 0x005f, 0x007b,
+ 0x007d, 0x00a1, 0x00a7, 0x00ab, 0x00b7, 0x00bb, 0x00bf, 0x037e,
+ 0x0387, 0x055f, 0x058a, 0x05be, 0x05c0, 0x05c3, 0x05c6, 0x05f4,
+ 0x060a, 0x060d, 0x061b, 0x061f, 0x066d, 0x06d4, 0x070d, 0x07f9,
+ 0x083e, 0x085e, 0x0965, 0x0970, 0x09fd, 0x0a76, 0x0af0, 0x0c77,
+ 0x0c84, 0x0df4, 0x0e4f, 0x0e5b, 0x0f12, 0x0f14, 0x0f3d, 0x0f85,
+ 0x0fd4, 0x0fda, 0x104f, 0x10fb, 0x1368, 0x1400, 0x166e, 0x169c,
+ 0x16ed, 0x1736, 0x17d6, 0x17da, 0x180a, 0x1945, 0x1a1f, 0x1aa6,
+ 0x1aad, 0x1b60, 0x1bff, 0x1c3f, 0x1c7f, 0x1cc7, 0x1cd3, 0x2027,
+ 0x2043, 0x2051, 0x205e, 0x207e, 0x208e, 0x230b, 0x232a, 0x2775,
+ 0x27c6, 0x27ef, 0x2998, 0x29db, 0x29fd, 0x2cfc, 0x2cff, 0x2d70,
+ 0x2e2e, 0x2e4f, 0x3003, 0x3011, 0x301f, 0x3030, 0x303d, 0x30a0,
+ 0x30fb, 0xa4ff, 0xa60f, 0xa673, 0xa67e, 0xa6f7, 0xa877, 0xa8cf,
+ 0xa8fa, 0xa8fc, 0xa92f, 0xa95f, 0xa9cd, 0xa9df, 0xaa5f, 0xaadf,
+ 0xaaf1, 0xabeb, 0xfd3f, 0xfe19, 0xfe52, 0xfe61, 0xfe63, 0xfe68,
+ 0xfe6b, 0xff03, 0xff0a, 0xff0f, 0xff1b, 0xff20, 0xff3d, 0xff3f,
+ 0xff5b, 0xff5d, 0xff65, 0x10102, 0x1039f, 0x103d0, 0x1056f, 0x10857,
+ 0x1091f, 0x1093f, 0x10a58, 0x10a7f, 0x10af6, 0x10b3f, 0x10b9c, 0x10f59,
+ 0x1104d, 0x110bc, 0x110c1, 0x11143, 0x11175, 0x111c8, 0x111cd, 0x111db,
+ 0x111df, 0x1123d, 0x112a9, 0x1144f, 0x1145b, 0x1145d, 0x114c6, 0x115d7,
+ 0x11643, 0x1166c, 0x1173e, 0x1183b, 0x119e2, 0x11a46, 0x11a9c, 0x11aa2,
+ 0x11c45, 0x11c71, 0x11ef8, 0x11fff, 0x12474, 0x16a6f, 0x16af5, 0x16b3b,
+ 0x16b44, 0x16e9a, 0x16fe2, 0x1bc9f, 0x1da8b, 0x1e95f};
+constexpr int kNumPunctuationRangesEnd = ARRAYSIZE(kPunctuationRangesEnd);
+
+// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+// There are three common ways in which upper/lower case codepoint ranges
+// were introduced: one offs, dense ranges, and ranges that alternate between
+// lower and upper case. For the sake of keeping out binary size down, we
+// treat each independently.
+constexpr char32 kUpperSingles[] = {
+ 0x01b8, 0x01bc, 0x01c4, 0x01c7, 0x01ca, 0x01f1, 0x0376, 0x037f,
+ 0x03cf, 0x03f4, 0x03fa, 0x10c7, 0x10cd, 0x2102, 0x2107, 0x2115,
+ 0x2145, 0x2183, 0x2c72, 0x2c75, 0x2cf2, 0xa7b6};
+constexpr int kNumUpperSingles = ARRAYSIZE(kUpperSingles);
+constexpr char32 kUpperRanges1Start[] = {
+ 0x0041, 0x00c0, 0x00d8, 0x0181, 0x018a, 0x018e, 0x0193, 0x0196,
+ 0x019c, 0x019f, 0x01b2, 0x01f7, 0x023a, 0x023d, 0x0244, 0x0389,
+ 0x0392, 0x03a3, 0x03d2, 0x03fd, 0x0531, 0x10a0, 0x13a0, 0x1f08,
+ 0x1f18, 0x1f28, 0x1f38, 0x1f48, 0x1f68, 0x1fb8, 0x1fc8, 0x1fd8,
+ 0x1fe8, 0x1ff8, 0x210b, 0x2110, 0x2119, 0x212b, 0x2130, 0x213e,
+ 0x2c00, 0x2c63, 0x2c6e, 0x2c7e, 0xa7ab, 0xa7b0};
+constexpr int kNumUpperRanges1Start = ARRAYSIZE(kUpperRanges1Start);
+constexpr char32 kUpperRanges1End[] = {
+ 0x005a, 0x00d6, 0x00de, 0x0182, 0x018b, 0x0191, 0x0194, 0x0198,
+ 0x019d, 0x01a0, 0x01b3, 0x01f8, 0x023b, 0x023e, 0x0246, 0x038a,
+ 0x03a1, 0x03ab, 0x03d4, 0x042f, 0x0556, 0x10c5, 0x13f5, 0x1f0f,
+ 0x1f1d, 0x1f2f, 0x1f3f, 0x1f4d, 0x1f6f, 0x1fbb, 0x1fcb, 0x1fdb,
+ 0x1fec, 0x1ffb, 0x210d, 0x2112, 0x211d, 0x212d, 0x2133, 0x213f,
+ 0x2c2e, 0x2c64, 0x2c70, 0x2c80, 0xa7ae, 0xa7b4};
+constexpr int kNumUpperRanges1End = ARRAYSIZE(kUpperRanges1End);
+constexpr char32 kUpperRanges2Start[] = {
+ 0x0100, 0x0139, 0x014a, 0x0179, 0x0184, 0x0187, 0x01a2, 0x01a7, 0x01ac,
+ 0x01af, 0x01b5, 0x01cd, 0x01de, 0x01f4, 0x01fa, 0x0241, 0x0248, 0x0370,
+ 0x0386, 0x038c, 0x038f, 0x03d8, 0x03f7, 0x0460, 0x048a, 0x04c1, 0x04d0,
+ 0x1e00, 0x1e9e, 0x1f59, 0x2124, 0x2c60, 0x2c67, 0x2c82, 0x2ceb, 0xa640,
+ 0xa680, 0xa722, 0xa732, 0xa779, 0xa77e, 0xa78b, 0xa790, 0xa796};
+constexpr int kNumUpperRanges2Start = ARRAYSIZE(kUpperRanges2Start);
+constexpr char32 kUpperRanges2End[] = {
+ 0x0136, 0x0147, 0x0178, 0x017d, 0x0186, 0x0189, 0x01a6, 0x01a9, 0x01ae,
+ 0x01b1, 0x01b7, 0x01db, 0x01ee, 0x01f6, 0x0232, 0x0243, 0x024e, 0x0372,
+ 0x0388, 0x038e, 0x0391, 0x03ee, 0x03f9, 0x0480, 0x04c0, 0x04cd, 0x052e,
+ 0x1e94, 0x1efe, 0x1f5f, 0x212a, 0x2c62, 0x2c6d, 0x2ce2, 0x2ced, 0xa66c,
+ 0xa69a, 0xa72e, 0xa76e, 0xa77d, 0xa786, 0xa78d, 0xa792, 0xa7aa};
+constexpr int kNumUpperRanges2End = ARRAYSIZE(kUpperRanges2End);
+
+// grep -E "Ll" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kLowerSingles[] = {
+ 0x00b5, 0x0188, 0x0192, 0x0195, 0x019e, 0x01b0, 0x01c6, 0x01c9,
+ 0x01f0, 0x023c, 0x0242, 0x0377, 0x0390, 0x03f5, 0x03f8, 0x1fbe,
+ 0x210a, 0x2113, 0x212f, 0x2134, 0x2139, 0x214e, 0x2184, 0x2c61,
+ 0x2ce4, 0x2cf3, 0x2d27, 0x2d2d, 0xa7af, 0xa7c3, 0xa7fa, 0x1d7cb};
+constexpr int kNumLowerSingles = ARRAYSIZE(kLowerSingles);
+constexpr char32 kLowerRanges1Start[] = {
+ 0x0061, 0x00df, 0x00f8, 0x017f, 0x018c, 0x0199, 0x01b9, 0x01bd,
+ 0x0234, 0x023f, 0x0250, 0x0295, 0x037b, 0x03ac, 0x03d0, 0x03d5,
+ 0x03f0, 0x03fb, 0x0430, 0x0560, 0x10d0, 0x10fd, 0x13f8, 0x1c80,
+ 0x1d00, 0x1d6b, 0x1d79, 0x1e96, 0x1f00, 0x1f10, 0x1f20, 0x1f30,
+ 0x1f40, 0x1f50, 0x1f60, 0x1f70, 0x1f80, 0x1f90, 0x1fa0, 0x1fb0,
+ 0x1fb6, 0x1fc2, 0x1fc6, 0x1fd0, 0x1fd6, 0x1fe0, 0x1ff2, 0x1ff6,
+ 0x210e, 0x213c, 0x2146, 0x2c30, 0x2c65, 0x2c77, 0x2d00, 0xa730,
+ 0xa772, 0xa794, 0xab30, 0xab60, 0xab70, 0xfb00, 0xfb13, 0xff41,
+ 0x10428, 0x104d8, 0x10cc0, 0x118c0, 0x16e60, 0x1d41a, 0x1d44e, 0x1d456,
+ 0x1d482, 0x1d4b6, 0x1d4be, 0x1d4c5, 0x1d4ea, 0x1d51e, 0x1d552, 0x1d586,
+ 0x1d5ba, 0x1d5ee, 0x1d622, 0x1d656, 0x1d68a, 0x1d6c2, 0x1d6dc, 0x1d6fc,
+ 0x1d716, 0x1d736, 0x1d750, 0x1d770, 0x1d78a, 0x1d7aa, 0x1d7c4, 0x1e922};
+constexpr int kNumLowerRanges1Start = ARRAYSIZE(kLowerRanges1Start);
+constexpr char32 kLowerRanges1End[] = {
+ 0x007a, 0x00f6, 0x00ff, 0x0180, 0x018d, 0x019b, 0x01ba, 0x01bf,
+ 0x0239, 0x0240, 0x0293, 0x02af, 0x037d, 0x03ce, 0x03d1, 0x03d7,
+ 0x03f3, 0x03fc, 0x045f, 0x0588, 0x10fa, 0x10ff, 0x13fd, 0x1c88,
+ 0x1d2b, 0x1d77, 0x1d9a, 0x1e9d, 0x1f07, 0x1f15, 0x1f27, 0x1f37,
+ 0x1f45, 0x1f57, 0x1f67, 0x1f7d, 0x1f87, 0x1f97, 0x1fa7, 0x1fb4,
+ 0x1fb7, 0x1fc4, 0x1fc7, 0x1fd3, 0x1fd7, 0x1fe7, 0x1ff4, 0x1ff7,
+ 0x210f, 0x213d, 0x2149, 0x2c5e, 0x2c66, 0x2c7b, 0x2d25, 0xa731,
+ 0xa778, 0xa795, 0xab5a, 0xab67, 0xabbf, 0xfb06, 0xfb17, 0xff5a,
+ 0x1044f, 0x104fb, 0x10cf2, 0x118df, 0x16e7f, 0x1d433, 0x1d454, 0x1d467,
+ 0x1d49b, 0x1d4b9, 0x1d4c3, 0x1d4cf, 0x1d503, 0x1d537, 0x1d56b, 0x1d59f,
+ 0x1d5d3, 0x1d607, 0x1d63b, 0x1d66f, 0x1d6a5, 0x1d6da, 0x1d6e1, 0x1d714,
+ 0x1d71b, 0x1d74e, 0x1d755, 0x1d788, 0x1d78f, 0x1d7c2, 0x1d7c9, 0x1e943};
+constexpr int kNumLowerRanges1End = ARRAYSIZE(kLowerRanges1End);
+constexpr char32 kLowerRanges2Start[] = {
+ 0x0101, 0x0138, 0x0149, 0x017a, 0x0183, 0x01a1, 0x01a8, 0x01ab,
+ 0x01b4, 0x01cc, 0x01dd, 0x01f3, 0x01f9, 0x0247, 0x0371, 0x03d9,
+ 0x0461, 0x048b, 0x04c2, 0x04cf, 0x1e01, 0x1e9f, 0x2c68, 0x2c71,
+ 0x2c74, 0x2c81, 0x2cec, 0xa641, 0xa681, 0xa723, 0xa733, 0xa77a,
+ 0xa77f, 0xa78c, 0xa791, 0xa797, 0xa7b5, 0x1d4bb};
+constexpr int kNumLowerRanges2Start = ARRAYSIZE(kLowerRanges2Start);
+constexpr char32 kLowerRanges2End[] = {
+ 0x0137, 0x0148, 0x0177, 0x017e, 0x0185, 0x01a5, 0x01aa, 0x01ad,
+ 0x01b6, 0x01dc, 0x01ef, 0x01f5, 0x0233, 0x024f, 0x0373, 0x03ef,
+ 0x0481, 0x04bf, 0x04ce, 0x052f, 0x1e95, 0x1eff, 0x2c6c, 0x2c73,
+ 0x2c76, 0x2ce3, 0x2cee, 0xa66d, 0xa69b, 0xa72f, 0xa771, 0xa77c,
+ 0xa787, 0xa78e, 0xa793, 0xa7a9, 0xa7bf, 0x1d4bd};
+constexpr int kNumLowerRanges2End = ARRAYSIZE(kLowerRanges2End);
+
+// grep -E "Lu" UnicodeData.txt | \
+// sed -rne "s/^([0-9A-Z]+);.*;([0-9A-Z]+);$/(0x\1, 0x\2), /p"
+// We have two strategies for mapping from upper to lower case. We have single
+// character lookups that do not follow a pattern, and ranges for which there
+// is a constant codepoint shift.
+// Note that these ranges ignore anything that's not an upper case character,
+// so when applied to a non-uppercase character the result is incorrect.
+constexpr int kToLowerSingles[] = {
+ 0x0130, 0x0178, 0x0181, 0x0186, 0x018b, 0x018e, 0x018f, 0x0190, 0x0191,
+ 0x0194, 0x0196, 0x0197, 0x0198, 0x019c, 0x019d, 0x019f, 0x01a6, 0x01a9,
+ 0x01ae, 0x01b7, 0x01f6, 0x01f7, 0x0220, 0x023a, 0x023d, 0x023e, 0x0243,
+ 0x0244, 0x0245, 0x037f, 0x0386, 0x038c, 0x03cf, 0x03f4, 0x03f9, 0x04c0,
+ 0x1e9e, 0x1fec, 0x2126, 0x212a, 0x212b, 0x2132, 0x2183, 0x2c60, 0x2c62,
+ 0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa,
+ 0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3};
+constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles);
+constexpr int kToLowerSinglesOffsets[] = {
+ -199, -121, 210, 206, 1, 79, 202, 203, 1,
+ 207, 211, 209, 1, 211, 213, 214, 218, 218,
+ 218, 219, -97, -56, -130, 10795, -163, 10792, -195,
+ 69, 71, 116, 38, 64, 8, -60, -7, 15,
+ -7615, -7, -7517, -8383, -8262, 28, 1, 1, -10743,
+ -3814, -10727, -10780, -10749, -10783, -10782, -35332, -42280, -42308,
+ -42319, -42315, -42305, -42308, -42258, -42282, -42261, 928};
+constexpr int kNumToLowerSinglesOffsets = ARRAYSIZE(kToLowerSinglesOffsets);
+constexpr int kToUpperSingles[] = {
+ 0x00b5, 0x00ff, 0x0131, 0x017f, 0x0180, 0x0195, 0x0199, 0x019a, 0x019e,
+ 0x01bf, 0x01dd, 0x01f3, 0x0250, 0x0251, 0x0252, 0x0253, 0x0254, 0x0259,
+ 0x025b, 0x025c, 0x0260, 0x0261, 0x0263, 0x0265, 0x0266, 0x0268, 0x0269,
+ 0x026a, 0x026b, 0x026c, 0x026f, 0x0271, 0x0272, 0x0275, 0x027d, 0x0280,
+ 0x0282, 0x0283, 0x0287, 0x0288, 0x0289, 0x028c, 0x0292, 0x029d, 0x029e,
+ 0x03ac, 0x03c2, 0x03cc, 0x03d0, 0x03d1, 0x03d5, 0x03d6, 0x03d7, 0x03f0,
+ 0x03f1, 0x03f2, 0x03f3, 0x03f5, 0x04cf, 0x1c80, 0x1c81, 0x1c82, 0x1c85,
+ 0x1c86, 0x1c87, 0x1c88, 0x1d79, 0x1d7d, 0x1d8e, 0x1e9b, 0x1fb3, 0x1fbe,
+ 0x1fc3, 0x1fe5, 0x1ff3, 0x214e, 0x2184, 0x2c61, 0x2c65, 0x2c66, 0xa794,
+ 0xab53};
+constexpr int kNumToUpperSingles = ARRAYSIZE(kToUpperSingles);
+constexpr int kToUpperSinglesOffsets[] = {
+ 743, 121, -232, -300, 195, 97, -1, 163, 130, 56,
+ -79, -2, 10783, 10780, 10782, -210, -206, -202, -203, 42319,
+ -205, 42315, -207, 42280, 42308, -209, -211, 42308, 10743, 42305,
+ -211, 10749, -213, -214, 10727, -218, 42307, -218, 42282, -218,
+ -69, -71, -219, 42261, 42258, -38, -31, -64, -62, -57,
+ -47, -54, -8, -86, -80, 7, -116, -96, -15, -6254,
+ -6253, -6244, -6243, -6236, -6181, 35266, 35332, 3814, 35384, -59,
+ 9, -7205, 9, 7, 9, -28, -1, -1, -10795, -10792,
+ 48, -928};
+constexpr int kNumToUpperSinglesOffsets = ARRAYSIZE(kToUpperSinglesOffsets);
+constexpr int kToLowerRangesStart[] = {
+ 0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391,
+ 0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0,
+ 0x1e00, 0x1f08, 0x1fba, 0x1fc8, 0x1fd8, 0x1fda, 0x1fe8, 0x1fea, 0x1ff8,
+ 0x1ffa, 0x2c00, 0x2c67, 0x2c7e, 0x2c80, 0xff21, 0x10400, 0x10c80, 0x118a0};
+constexpr int kNumToLowerRangesStart = ARRAYSIZE(kToLowerRangesStart);
+constexpr int kToLowerRangesEnd[] = {
+ 0x00de, 0x0187, 0x019f, 0x01af, 0x01b2, 0x0386, 0x038c, 0x038f, 0x03cf,
+ 0x03fa, 0x03ff, 0x040f, 0x042f, 0x052e, 0x0556, 0x10cd, 0x13ef, 0x13f5,
+ 0x1efe, 0x1fb9, 0x1fbb, 0x1fcb, 0x1fd9, 0x1fdb, 0x1fe9, 0x1fec, 0x1ff9,
+ 0x2183, 0x2c64, 0x2c75, 0x2c7f, 0xa7b6, 0xff3a, 0x104d3, 0x10cb2, 0x118bf};
+constexpr int kNumToLowerRangesEnd = ARRAYSIZE(kToLowerRangesEnd);
+constexpr int kToLowerRangesOffsets[] = {
+ 32, 1, 205, 1, 217, 1, 37, 63, 32, 1, -130, 80,
+ 32, 1, 48, 7264, 38864, 8, 1, -8, -74, -86, -8, -100,
+ -8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32};
+constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets);
+constexpr int kToUpperRangesStart[] = {
+ 0x0061, 0x0101, 0x01c6, 0x01ce, 0x023f, 0x0242, 0x0256, 0x028a,
+ 0x0371, 0x037b, 0x03ad, 0x03b1, 0x03cd, 0x03d9, 0x0430, 0x0450,
+ 0x0461, 0x0561, 0x10d0, 0x13f8, 0x1c83, 0x1e01, 0x1f00, 0x1f70,
+ 0x1f72, 0x1f76, 0x1f78, 0x1f7a, 0x1f7c, 0x1f80, 0x2c30, 0x2c68,
+ 0x2d00, 0xa641, 0xab70, 0xff41, 0x10428, 0x10cc0, 0x118c0};
+constexpr int kNumToUpperRangesStart = ARRAYSIZE(kToUpperRangesStart);
+constexpr int kToUpperRangesEnd[] = {
+ 0x00fe, 0x01bd, 0x01cc, 0x023c, 0x0240, 0x024f, 0x0257, 0x028b,
+ 0x0377, 0x037d, 0x03af, 0x03cb, 0x03ce, 0x03fb, 0x044f, 0x045f,
+ 0x052f, 0x0586, 0x10ff, 0x13fd, 0x1c84, 0x1eff, 0x1f67, 0x1f71,
+ 0x1f75, 0x1f77, 0x1f79, 0x1f7b, 0x1f7d, 0x1fe1, 0x2c5e, 0x2cf3,
+ 0x2d2d, 0xa7c3, 0xabbf, 0xff5a, 0x104fb, 0x10cf2, 0x16e7f};
+constexpr int kNumToUpperRangesEnd = ARRAYSIZE(kToUpperRangesEnd);
+constexpr int kToUpperRangesOffsets[]{
+ -32, -1, -2, -1, 10815, -1, -205, -217, -1, 130, -37, -32, -63,
+ -1, -32, -80, -1, -48, 3008, -8, -6242, -1, 8, 74, 86, 100,
+ 128, 112, 126, 8, -48, -1, -7264, -1, -38864, -32, -40, -64, -32};
+constexpr int kNumToUpperRangesOffsets = ARRAYSIZE(kToUpperRangesOffsets);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=PERCENT
+constexpr char32 kPercentages[] = {0x0025, 0x066A, 0xFE6A, 0xFF05};
+constexpr int kNumPercentages = ARRAYSIZE(kPercentages);
+
+// Source from https://unicode-search.net/unicode-namesearch.pl?term=SLASH
+constexpr char32 kSlashes[] = {0x002f, 0x0337, 0x0338, 0x2044, 0x2215, 0xff0f};
+constexpr int kNumSlashes = ARRAYSIZE(kSlashes);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=minus
+constexpr char32 kMinuses[] = {0x002d, 0x02d7, 0x2212, 0xff0d};
+constexpr int kNumMinuses = ARRAYSIZE(kMinuses);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=NUMBER%20SIGN
+constexpr char32 kNumberSign[] = {0x0023, 0xfe5f, 0xff03};
+constexpr int kNumNumberSign = ARRAYSIZE(kNumberSign);
+
+// Source: https://unicode-search.net/unicode-namesearch.pl?term=period
+constexpr char32 kDots[] = {0x002e, 0xfe52, 0xff0e};
+constexpr int kNumDots = ARRAYSIZE(kDots);
+
+#undef ARRAYSIZE
+
+static_assert(kNumOpeningBrackets == kNumClosingBrackets,
+ "mismatching number of opening and closing brackets");
+static_assert(kNumLowerRanges1Start == kNumLowerRanges1End,
+ "number of uppercase stride 1 range starts/ends doesn't match");
+static_assert(kNumLowerRanges2Start == kNumLowerRanges2End,
+ "number of uppercase stride 2 range starts/ends doesn't match");
+static_assert(kNumUpperRanges1Start == kNumUpperRanges1End,
+ "number of uppercase stride 1 range starts/ends doesn't match");
+static_assert(kNumUpperRanges2Start == kNumUpperRanges2End,
+ "number of uppercase stride 2 range starts/ends doesn't match");
+static_assert(kNumToLowerSingles == kNumToLowerSinglesOffsets,
+ "number of to lower singles and offsets doesn't match");
+static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd,
+ "mismatching number of range starts/ends for to lower ranges");
+static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets,
+ "number of to lower ranges and offsets doesn't match");
+static_assert(kNumToUpperSingles == kNumToUpperSinglesOffsets,
+ "number of to upper singles and offsets doesn't match");
+static_assert(kNumToUpperRangesStart == kNumToUpperRangesEnd,
+ "mismatching number of range starts/ends for to upper ranges");
+static_assert(kNumToUpperRangesStart == kNumToUpperRangesOffsets,
+ "number of to upper ranges and offsets doesn't match");
+static_assert(kNumPunctuationRangesStart == kNumPunctuationRangesEnd,
+ "mismatch number of start/ends for punctuation ranges.");
+static_assert(kNumLatinLettersRangesStart == kNumLatinLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumArabicLettersRangesStart == kNumArabicLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumCyrillicLettersRangesStart == kNumCyrillicLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumChineseLettersRangesStart == kNumChineseLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumJapaneseLettersRangesStart == kNumJapaneseLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumKoreanLettersRangesStart == kNumKoreanLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+static_assert(kNumThaiLettersRangesStart == kNumThaiLettersRangesEnd,
+ "mismatch number of start/ends for letters ranges.");
+
+constexpr int kNoMatch = -1;
+
+// Returns the index of the element in the array that matched the given
+// codepoint, or kNoMatch if the element didn't exist.
+// The input array must be in sorted order.
+int GetMatchIndex(const char32* array, int array_length, char32 c) {
+ const char32* end = array + array_length;
+ const auto find_it = std::lower_bound(array, end, c);
+ if (find_it != end && *find_it == c) {
+ return find_it - array;
+ } else {
+ return kNoMatch;
+ }
+}
+
+// Returns the index of the range in the array that overlapped the given
+// codepoint, or kNoMatch if no such range existed.
+// The input array must be in sorted order.
+int GetOverlappingRangeIndex(const char32* arr, int arr_length,
+ int range_length, char32 c) {
+ const char32* end = arr + arr_length;
+ const auto find_it = std::lower_bound(arr, end, c);
+ if (find_it == end) {
+ return kNoMatch;
+ }
+ // The end is inclusive, we so subtract one less than the range length.
+ const char32 range_end = *find_it;
+ const char32 range_start = range_end - (range_length - 1);
+ if (c < range_start || range_end < c) {
+ return kNoMatch;
+ } else {
+ return find_it - arr;
+ }
+}
+
+// As above, but with explicit codepoint start and end indices for the range.
+// The input array must be in sorted order.
+int GetOverlappingRangeIndex(const char32* start_arr, const char32* end_arr,
+ int arr_length, int stride, char32 c) {
+ const char32* end_arr_end = end_arr + arr_length;
+ const auto find_it = std::lower_bound(end_arr, end_arr_end, c);
+ if (find_it == end_arr_end) {
+ return kNoMatch;
+ }
+ // Find the corresponding start.
+ const int range_index = find_it - end_arr;
+ const char32 range_start = start_arr[range_index];
+ const char32 range_end = *find_it;
+ if (c < range_start || range_end < c) {
+ return kNoMatch;
+ }
+ if ((c - range_start) % stride == 0) {
+ return range_index;
+ } else {
+ return kNoMatch;
+ }
+}
+
+} // anonymous namespace
+
+bool IsOpeningBracket(char32 codepoint) {
+ return GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint) >= 0;
+}
+
+bool IsClosingBracket(char32 codepoint) {
+ return GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint) >= 0;
+}
+
+bool IsWhitespace(char32 codepoint) {
+ return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
+}
+
+bool IsDigit(char32 codepoint) {
+ return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
+ kNumDecimalDigitRangesEnd,
+ /*range_length=*/10, codepoint) >= 0;
+}
+
+bool IsLower(char32 codepoint) {
+ if (GetMatchIndex(kLowerSingles, kNumLowerSingles, codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kLowerRanges1Start, kLowerRanges1End,
+ kNumLowerRanges1Start, /*stride=*/1,
+ codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kLowerRanges2Start, kLowerRanges2End,
+ kNumLowerRanges2Start, /*stride=*/2,
+ codepoint) >= 0) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool IsUpper(char32 codepoint) {
+ if (GetMatchIndex(kUpperSingles, kNumUpperSingles, codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kUpperRanges1Start, kUpperRanges1End,
+ kNumUpperRanges1Start, /*stride=*/1,
+ codepoint) >= 0) {
+ return true;
+ } else if (GetOverlappingRangeIndex(kUpperRanges2Start, kUpperRanges2End,
+ kNumUpperRanges2Start, /*stride=*/2,
+ codepoint) >= 0) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool IsPunctuation(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kPunctuationRangesStart, kPunctuationRangesEnd,
+ kNumPunctuationRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsPercentage(char32 codepoint) {
+ return GetMatchIndex(kPercentages, kNumPercentages, codepoint) >= 0;
+}
+
+bool IsSlash(char32 codepoint) {
+ return GetMatchIndex(kSlashes, kNumSlashes, codepoint) >= 0;
+}
+
+bool IsMinus(char32 codepoint) {
+ return GetMatchIndex(kMinuses, kNumMinuses, codepoint) >= 0;
+}
+
+bool IsNumberSign(char32 codepoint) {
+ return GetMatchIndex(kNumberSign, kNumNumberSign, codepoint) >= 0;
+}
+
+bool IsDot(char32 codepoint) {
+ return GetMatchIndex(kDots, kNumDots, codepoint) >= 0;
+}
+
+bool IsLatinLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kLatinLettersRangesStart, kLatinLettersRangesEnd,
+ kNumLatinLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsArabicLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kArabicLettersRangesStart, kArabicLettersRangesEnd,
+ kNumArabicLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsCyrillicLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kCyrillicLettersRangesStart, kCyrillicLettersRangesEnd,
+ kNumCyrillicLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsChineseLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kChineseLettersRangesStart, kChineseLettersRangesEnd,
+ kNumChineseLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsJapaneseLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kJapaneseLettersRangesStart, kJapaneseLettersRangesEnd,
+ kNumJapaneseLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsKoreanLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kKoreanLettersRangesStart, kKoreanLettersRangesEnd,
+ kNumKoreanLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsThaiLetter(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kThaiLettersRangesStart, kThaiLettersRangesEnd,
+ kNumThaiLettersRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
+bool IsCJTletter(char32 codepoint) {
+ return IsJapaneseLetter(codepoint) || IsChineseLetter(codepoint) ||
+ IsThaiLetter(codepoint);
+}
+
+bool IsLetter(char32 codepoint) {
+ return IsLatinLetter(codepoint) || IsArabicLetter(codepoint) ||
+ IsCyrillicLetter(codepoint) || IsJapaneseLetter(codepoint) ||
+ IsKoreanLetter(codepoint) || IsThaiLetter(codepoint) ||
+ IsChineseLetter(codepoint);
+}
+
+char32 ToLower(char32 codepoint) {
+ // Make sure we still produce output even if the method is called for a
+ // codepoint that's not an uppercase character.
+ if (!IsUpper(codepoint)) {
+ return codepoint;
+ }
+ const int singles_idx =
+ GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint);
+ if (singles_idx >= 0) {
+ return codepoint + kToLowerSinglesOffsets[singles_idx];
+ }
+ const int ranges_idx =
+ GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd,
+ kNumToLowerRangesStart, /*stride=*/1, codepoint);
+ if (ranges_idx >= 0) {
+ return codepoint + kToLowerRangesOffsets[ranges_idx];
+ }
+ return codepoint;
+}
+
+char32 ToUpper(char32 codepoint) {
+ // Make sure we still produce output even if the method is called for a
+ // codepoint that's not an uppercase character.
+ if (!IsLower(codepoint)) {
+ return codepoint;
+ }
+ const int singles_idx =
+ GetMatchIndex(kToUpperSingles, kNumToUpperSingles, codepoint);
+ if (singles_idx >= 0) {
+ return codepoint + kToUpperSinglesOffsets[singles_idx];
+ }
+ const int ranges_idx =
+ GetOverlappingRangeIndex(kToUpperRangesStart, kToUpperRangesEnd,
+ kNumToUpperRangesStart, /*stride=*/1, codepoint);
+ if (ranges_idx >= 0) {
+ return codepoint + kToUpperRangesOffsets[ranges_idx];
+ }
+ return codepoint;
+}
+
+char32 GetPairedBracket(char32 codepoint) {
+ const int open_offset =
+ GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint);
+ if (open_offset >= 0) {
+ return kClosingBrackets[open_offset];
+ }
+ const int close_offset =
+ GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint);
+ if (close_offset >= 0) {
+ return kOpeningBrackets[close_offset];
+ }
+ return codepoint;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
new file mode 100644
index 0000000..4f03de7
--- /dev/null
+++ b/native/utils/utf8/unilib-common.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+bool IsOpeningBracket(char32 codepoint);
+bool IsClosingBracket(char32 codepoint);
+bool IsWhitespace(char32 codepoint);
+bool IsDigit(char32 codepoint);
+bool IsLower(char32 codepoint);
+bool IsUpper(char32 codepoint);
+bool IsPunctuation(char32 codepoint);
+bool IsPercentage(char32 codepoint);
+bool IsSlash(char32 codepoint);
+bool IsMinus(char32 codepoint);
+bool IsNumberSign(char32 codepoint);
+bool IsDot(char32 codepoint);
+
+bool IsLatinLetter(char32 codepoint);
+bool IsArabicLetter(char32 codepoint);
+bool IsCyrillicLetter(char32 codepoint);
+bool IsChineseLetter(char32 codepoint);
+bool IsJapaneseLetter(char32 codepoint);
+bool IsKoreanLetter(char32 codepoint);
+bool IsThaiLetter(char32 codepoint);
+bool IsLetter(char32 codepoint);
+bool IsCJTletter(char32 codepoint);
+
+char32 ToLower(char32 codepoint);
+char32 ToUpper(char32 codepoint);
+char32 GetPairedBracket(char32 codepoint);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
new file mode 100644
index 0000000..de6b5ed
--- /dev/null
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -0,0 +1,540 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/utf8/unilib-javaicu.h"
+
+#include <math.h>
+
+#include <cassert>
+#include <cctype>
+#include <map>
+
+#include "utils/base/logging.h"
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/string_utils.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
+
+namespace libtextclassifier3 {
+
+UniLibBase::UniLibBase() {
+ TC3_LOG(FATAL) << "Java ICU UniLib must be initialized with a JniCache.";
+}
+
+UniLibBase::UniLibBase(const std::shared_ptr<JniCache>& jni_cache)
+ : jni_cache_(jni_cache) {}
+
+bool UniLibBase::IsOpeningBracket(char32 codepoint) const {
+ return libtextclassifier3::IsOpeningBracket(codepoint);
+}
+
+bool UniLibBase::IsClosingBracket(char32 codepoint) const {
+ return libtextclassifier3::IsClosingBracket(codepoint);
+}
+
+bool UniLibBase::IsWhitespace(char32 codepoint) const {
+ return libtextclassifier3::IsWhitespace(codepoint);
+}
+
+bool UniLibBase::IsDigit(char32 codepoint) const {
+ return libtextclassifier3::IsDigit(codepoint);
+}
+
+bool UniLibBase::IsLower(char32 codepoint) const {
+ return libtextclassifier3::IsLower(codepoint);
+}
+
+bool UniLibBase::IsUpper(char32 codepoint) const {
+ return libtextclassifier3::IsUpper(codepoint);
+}
+
+bool UniLibBase::IsPunctuation(char32 codepoint) const {
+ return libtextclassifier3::IsPunctuation(codepoint);
+}
+
+char32 UniLibBase::ToLower(char32 codepoint) const {
+ return libtextclassifier3::ToLower(codepoint);
+}
+
+char32 UniLibBase::ToUpper(char32 codepoint) const {
+ return libtextclassifier3::ToUpper(codepoint);
+}
+
+char32 UniLibBase::GetPairedBracket(char32 codepoint) const {
+ return libtextclassifier3::GetPairedBracket(codepoint);
+}
+
+// -----------------------------------------------------------------------------
+// Implementations that call out to JVM. Behold the beauty.
+// -----------------------------------------------------------------------------
+
+bool UniLibBase::ParseInt32(const UnicodeText& text, int32* result) const {
+ return ParseInt(text, result);
+}
+
+bool UniLibBase::ParseInt64(const UnicodeText& text, int64* result) const {
+ return ParseInt(text, result);
+}
+
+bool UniLibBase::ParseDouble(const UnicodeText& text, double* result) const {
+ if (!jni_cache_) {
+ return false;
+ }
+
+ JNIEnv* env = jni_cache_->GetEnv();
+ auto it_dot = text.begin();
+ for (; it_dot != text.end() && !IsDot(*it_dot); it_dot++) {
+ }
+
+ int64 integer_part;
+ if (!ParseInt(UnicodeText::Substring(text.begin(), it_dot, /*do_copy=*/false),
+ &integer_part)) {
+ return false;
+ }
+
+ int64 fractional_part = 0;
+ if (it_dot != text.end()) {
+ std::string fractional_part_str =
+ UnicodeText::UTF8Substring(++it_dot, text.end());
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const ScopedLocalRef<jstring> fractional_text_java,
+ jni_cache_->ConvertToJavaString(fractional_part_str));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ fractional_part,
+ JniHelper::CallStaticIntMethod<int64>(
+ env, jni_cache_->integer_class.get(), jni_cache_->integer_parse_int,
+ fractional_text_java.get()));
+ }
+
+ double factional_part_double = fractional_part;
+ while (factional_part_double >= 1) {
+ factional_part_double /= 10;
+ }
+ *result = integer_part + factional_part_double;
+
+ return true;
+}
+
+std::unique_ptr<UniLibBase::RegexPattern> UniLibBase::CreateRegexPattern(
+ const UnicodeText& regex) const {
+ return std::unique_ptr<UniLibBase::RegexPattern>(
+ new UniLibBase::RegexPattern(jni_cache_.get(), regex, /*lazy=*/false));
+}
+
+std::unique_ptr<UniLibBase::RegexPattern> UniLibBase::CreateLazyRegexPattern(
+ const UnicodeText& regex) const {
+ return std::unique_ptr<UniLibBase::RegexPattern>(
+ new UniLibBase::RegexPattern(jni_cache_.get(), regex, /*lazy=*/true));
+}
+
+UniLibBase::RegexPattern::RegexPattern(const JniCache* jni_cache,
+ const UnicodeText& pattern, bool lazy)
+ : jni_cache_(jni_cache),
+ pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
+ initialized_(false),
+ initialization_failure_(false),
+ pattern_text_(pattern) {
+ if (!lazy) {
+ LockedInitializeIfNotAlready();
+ }
+}
+
+Status UniLibBase::RegexPattern::LockedInitializeIfNotAlready() const {
+ std::lock_guard<std::mutex> guard(mutex_);
+ if (initialized_ || initialization_failure_) {
+ return Status::OK;
+ }
+
+ if (jni_cache_) {
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ initialization_failure_ = true;
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> regex_java,
+ jni_cache_->ConvertToJavaString(pattern_text_));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobject> pattern,
+ JniHelper::CallStaticObjectMethod(
+ jenv, jni_cache_->pattern_class.get(),
+ jni_cache_->pattern_compile, regex_java.get()));
+ pattern_ = MakeGlobalRef(pattern.get(), jenv, jni_cache_->jvm);
+ if (pattern_ == nullptr) {
+ return Status::UNKNOWN;
+ }
+
+ initialization_failure_ = false;
+ initialized_ = true;
+ pattern_text_.clear(); // We don't need this anymore.
+ }
+ return Status::OK;
+}
+
+constexpr int UniLibBase::RegexMatcher::kError;
+constexpr int UniLibBase::RegexMatcher::kNoError;
+
+std::unique_ptr<UniLibBase::RegexMatcher> UniLibBase::RegexPattern::Matcher(
+ const UnicodeText& context) const {
+ LockedInitializeIfNotAlready(); // Possibly lazy initialization.
+ if (initialization_failure_) {
+ return nullptr;
+ }
+
+ if (jni_cache_) {
+ JNIEnv* env = jni_cache_->GetEnv();
+ const StatusOr<ScopedLocalRef<jstring>> status_or_context_java =
+ jni_cache_->ConvertToJavaString(context);
+ if (!status_or_context_java.ok() || !status_or_context_java.ValueOrDie()) {
+ return nullptr;
+ }
+ const StatusOr<ScopedLocalRef<jobject>> status_or_matcher =
+ JniHelper::CallObjectMethod(env, pattern_.get(),
+ jni_cache_->pattern_matcher,
+ status_or_context_java.ValueOrDie().get());
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_matcher.ok() ||
+ !status_or_matcher.ValueOrDie()) {
+ return nullptr;
+ }
+ return std::unique_ptr<UniLibBase::RegexMatcher>(new RegexMatcher(
+ jni_cache_,
+ MakeGlobalRef(status_or_matcher.ValueOrDie().get(), env,
+ jni_cache_->jvm),
+ MakeGlobalRef(status_or_context_java.ValueOrDie().get(), env,
+ jni_cache_->jvm)));
+ } else {
+ // NOTE: A valid object needs to be created here to pass the interface
+ // tests.
+ return std::unique_ptr<UniLibBase::RegexMatcher>(
+ new RegexMatcher(jni_cache_, {}, {}));
+ }
+}
+
+UniLibBase::RegexMatcher::RegexMatcher(const JniCache* jni_cache,
+ ScopedGlobalRef<jobject> matcher,
+ ScopedGlobalRef<jstring> text)
+ : jni_cache_(jni_cache),
+ matcher_(std::move(matcher)),
+ text_(std::move(text)) {}
+
+bool UniLibBase::RegexMatcher::Matches(int* status) const {
+ if (jni_cache_) {
+ *status = kNoError;
+ const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
+ matcher_.get(), jni_cache_->matcher_matches);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return false;
+ }
+ return result;
+ } else {
+ *status = kError;
+ return false;
+ }
+}
+
+bool UniLibBase::RegexMatcher::ApproximatelyMatches(int* status) {
+ *status = kNoError;
+
+ jni_cache_->GetEnv()->CallObjectMethod(matcher_.get(),
+ jni_cache_->matcher_reset);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ if (!Find(status) || *status != kNoError) {
+ return false;
+ }
+
+ const int found_start = jni_cache_->GetEnv()->CallIntMethod(
+ matcher_.get(), jni_cache_->matcher_start_idx, 0);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ const int found_end = jni_cache_->GetEnv()->CallIntMethod(
+ matcher_.get(), jni_cache_->matcher_end_idx, 0);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ int context_length_bmp = jni_cache_->GetEnv()->CallIntMethod(
+ text_.get(), jni_cache_->string_length);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return false;
+ }
+
+ if (found_start != 0 || found_end != context_length_bmp) {
+ return false;
+ }
+
+ return true;
+}
+
+bool UniLibBase::RegexMatcher::UpdateLastFindOffset() const {
+ if (!last_find_offset_dirty_) {
+ return true;
+ }
+
+ const int find_offset = jni_cache_->GetEnv()->CallIntMethod(
+ matcher_.get(), jni_cache_->matcher_start_idx, 0);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ return false;
+ }
+
+ const int codepoint_count = jni_cache_->GetEnv()->CallIntMethod(
+ text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
+ find_offset);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ return false;
+ }
+
+ last_find_offset_codepoints_ += codepoint_count;
+ last_find_offset_ = find_offset;
+ last_find_offset_dirty_ = false;
+
+ return true;
+}
+
+bool UniLibBase::RegexMatcher::Find(int* status) {
+ if (jni_cache_) {
+ const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
+ matcher_.get(), jni_cache_->matcher_find);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return false;
+ }
+
+ last_find_offset_dirty_ = true;
+ *status = kNoError;
+ return result;
+ } else {
+ *status = kError;
+ return false;
+ }
+}
+
+int UniLibBase::RegexMatcher::Start(int* status) const {
+ return Start(/*group_idx=*/0, status);
+}
+
+int UniLibBase::RegexMatcher::Start(int group_idx, int* status) const {
+ if (jni_cache_) {
+ *status = kNoError;
+
+ if (!UpdateLastFindOffset()) {
+ *status = kError;
+ return kError;
+ }
+
+ const int java_index = jni_cache_->GetEnv()->CallIntMethod(
+ matcher_.get(), jni_cache_->matcher_start_idx, group_idx);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ // If the group didn't participate in the match the index is -1.
+ if (java_index == -1) {
+ return -1;
+ }
+
+ const int unicode_index = jni_cache_->GetEnv()->CallIntMethod(
+ text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
+ java_index);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ return unicode_index + last_find_offset_codepoints_;
+ } else {
+ *status = kError;
+ return kError;
+ }
+}
+
+int UniLibBase::RegexMatcher::End(int* status) const {
+ return End(/*group_idx=*/0, status);
+}
+
+int UniLibBase::RegexMatcher::End(int group_idx, int* status) const {
+ if (jni_cache_) {
+ *status = kNoError;
+
+ if (!UpdateLastFindOffset()) {
+ *status = kError;
+ return kError;
+ }
+
+ const int java_index = jni_cache_->GetEnv()->CallIntMethod(
+ matcher_.get(), jni_cache_->matcher_end_idx, group_idx);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ // If the group didn't participate in the match the index is -1.
+ if (java_index == -1) {
+ return -1;
+ }
+
+ const int unicode_index = jni_cache_->GetEnv()->CallIntMethod(
+ text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
+ java_index);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ *status = kError;
+ return kError;
+ }
+
+ return unicode_index + last_find_offset_codepoints_;
+ } else {
+ *status = kError;
+ return kError;
+ }
+}
+
+UnicodeText UniLibBase::RegexMatcher::Group(int* status) const {
+ if (jni_cache_) {
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ StatusOr<ScopedLocalRef<jstring>> status_or_java_result =
+ JniHelper::CallObjectMethod<jstring>(jenv, matcher_.get(),
+ jni_cache_->matcher_group);
+
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_java_result.ok() ||
+ !status_or_java_result.ValueOrDie()) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+
+ std::string result;
+ if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
+ &result)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ *status = kNoError;
+ return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ } else {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+}
+
+UnicodeText UniLibBase::RegexMatcher::Group(int group_idx, int* status) const {
+ if (jni_cache_) {
+ JNIEnv* jenv = jni_cache_->GetEnv();
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_java_result =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv, matcher_.get(), jni_cache_->matcher_group_idx, group_idx);
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_java_result.ok()) {
+ *status = kError;
+ TC3_LOG(ERROR) << "Exception occurred";
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+
+ // java_result is nullptr when the group did not participate in the match.
+ // For these cases other UniLib implementations return empty string, and
+ // the participation can be checked by checking if Start() == -1.
+ if (!status_or_java_result.ValueOrDie()) {
+ *status = kNoError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+
+ std::string result;
+ if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
+ &result)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ *status = kNoError;
+ return UTF8ToUnicodeText(result, /*do_copy=*/true);
+ } else {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+}
+
+constexpr int UniLibBase::BreakIterator::kDone;
+
+UniLibBase::BreakIterator::BreakIterator(const JniCache* jni_cache,
+ const UnicodeText& text)
+ : jni_cache_(jni_cache),
+ text_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
+ iterator_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
+ last_break_index_(0),
+ last_unicode_index_(0) {
+ if (jni_cache_) {
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ StatusOr<ScopedLocalRef<jstring>> status_or_text =
+ jni_cache_->ConvertToJavaString(text);
+ if (!status_or_text.ok()) {
+ return;
+ }
+ text_ =
+ MakeGlobalRef(status_or_text.ValueOrDie().get(), jenv, jni_cache->jvm);
+ if (!text_) {
+ return;
+ }
+
+ StatusOr<ScopedLocalRef<jobject>> status_or_iterator =
+ JniHelper::CallStaticObjectMethod(
+ jenv, jni_cache->breakiterator_class.get(),
+ jni_cache->breakiterator_getwordinstance,
+ jni_cache->locale_us.get());
+ if (!status_or_iterator.ok()) {
+ return;
+ }
+ iterator_ = MakeGlobalRef(status_or_iterator.ValueOrDie().get(), jenv,
+ jni_cache->jvm);
+ if (!iterator_) {
+ return;
+ }
+ JniHelper::CallVoidMethod(jenv, iterator_.get(),
+ jni_cache->breakiterator_settext, text_.get());
+ }
+}
+
+int UniLibBase::BreakIterator::Next() {
+ if (jni_cache_) {
+ const int break_index = jni_cache_->GetEnv()->CallIntMethod(
+ iterator_.get(), jni_cache_->breakiterator_next);
+ if (jni_cache_->ExceptionCheckAndClear() ||
+ break_index == BreakIterator::kDone) {
+ return BreakIterator::kDone;
+ }
+
+ const int token_unicode_length = jni_cache_->GetEnv()->CallIntMethod(
+ text_.get(), jni_cache_->string_code_point_count, last_break_index_,
+ break_index);
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ return BreakIterator::kDone;
+ }
+
+ last_break_index_ = break_index;
+ return last_unicode_index_ += token_unicode_length;
+ }
+ return BreakIterator::kDone;
+}
+
+std::unique_ptr<UniLibBase::BreakIterator> UniLibBase::CreateBreakIterator(
+ const UnicodeText& text) const {
+ return std::unique_ptr<UniLibBase::BreakIterator>(
+ new UniLibBase::BreakIterator(jni_cache_.get(), text));
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
new file mode 100644
index 0000000..d208730
--- /dev/null
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -0,0 +1,209 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An implementation of Unilib that uses Android Java interfaces via JNI. The
+// performance critical ops have been re-implemented in C++.
+// Specifically, this class must be compatible with API level 14 (ICS).
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
+
+#include <jni.h>
+
+#include <memory>
+#include <mutex> // NOLINT
+#include <string>
+
+#include "utils/base/integral_types.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/jni-helper.h"
+#include "utils/java/string_utils.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+class UniLibBase {
+ public:
+ UniLibBase();
+ explicit UniLibBase(const std::shared_ptr<JniCache>& jni_cache);
+
+ bool ParseInt32(const UnicodeText& text, int32* result) const;
+ bool ParseInt64(const UnicodeText& text, int64* result) const;
+ bool ParseDouble(const UnicodeText& text, double* result) const;
+
+ bool IsOpeningBracket(char32 codepoint) const;
+ bool IsClosingBracket(char32 codepoint) const;
+ bool IsWhitespace(char32 codepoint) const;
+ bool IsDigit(char32 codepoint) const;
+ bool IsLower(char32 codepoint) const;
+ bool IsUpper(char32 codepoint) const;
+ bool IsPunctuation(char32 codepoint) const;
+
+ char32 ToLower(char32 codepoint) const;
+ char32 ToUpper(char32 codepoint) const;
+ char32 GetPairedBracket(char32 codepoint) const;
+
+ // Forward declaration for friend.
+ class RegexPattern;
+
+ class RegexMatcher {
+ public:
+ static constexpr int kError = -1;
+ static constexpr int kNoError = 0;
+
+ // Checks whether the input text matches the pattern exactly.
+ bool Matches(int* status) const;
+
+ // Approximate Matches() implementation implemented using Find(). It uses
+ // the first Find() result and then checks that it spans the whole input.
+ // NOTE: Unlike Matches() it can result in false negatives.
+ // NOTE: Resets the matcher, so the current Find() state will be lost.
+ bool ApproximatelyMatches(int* status);
+
+ // Finds occurrences of the pattern in the input text.
+ // Can be called repeatedly to find all occurrences. A call will update
+ // internal state, so that 'Start', 'End' and 'Group' can be called to get
+ // information about the match.
+ // NOTE: Any call to ApproximatelyMatches() in between Find() calls will
+ // modify the state.
+ bool Find(int* status);
+
+ // Gets the start offset of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find'
+ // was not called previously.
+ int Start(int* status) const;
+
+ // Gets the start offset of the specified group of the last match.
+ // (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ int Start(int group_idx, int* status) const;
+
+ // Gets the end offset of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find'
+ // was not called previously.
+ int End(int* status) const;
+
+ // Gets the end offset of the specified group of the last match.
+ // (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ int End(int group_idx, int* status) const;
+
+ // Gets the text of the last match (from 'Find').
+ // Sets status to 'kError' if 'Find' was not called previously.
+ UnicodeText Group(int* status) const;
+
+ // Gets the text of the specified group of the last match (from 'Find').
+ // Sets status to 'kError' if an invalid group was specified or if 'Find'
+ // was not called previously.
+ UnicodeText Group(int group_idx, int* status) const;
+
+ // Returns the matched text (the 0th capturing group).
+ std::string Text() const {
+ ScopedStringChars text_str =
+ GetScopedStringChars(jni_cache_->GetEnv(), text_.get());
+ return text_str.get();
+ }
+
+ private:
+ friend class RegexPattern;
+ RegexMatcher(const JniCache* jni_cache, ScopedGlobalRef<jobject> matcher,
+ ScopedGlobalRef<jstring> text);
+ bool UpdateLastFindOffset() const;
+
+ const JniCache* jni_cache_;
+ ScopedGlobalRef<jobject> matcher_;
+ ScopedGlobalRef<jstring> text_;
+ mutable int last_find_offset_ = 0;
+ mutable int last_find_offset_codepoints_ = 0;
+ mutable bool last_find_offset_dirty_ = true;
+ };
+
+ class RegexPattern {
+ public:
+ std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& context) const;
+
+ private:
+ friend class UniLibBase;
+ RegexPattern(const JniCache* jni_cache, const UnicodeText& pattern,
+ bool lazy);
+ Status LockedInitializeIfNotAlready() const;
+
+ const JniCache* jni_cache_;
+
+ // These members need to be mutable because of the lazy initialization.
+ // NOTE: The Matcher method first ensures (using a lock) that the
+ // initialization was attempted (by using LockedInitializeIfNotAlready) and
+ // then can access them without locking.
+ mutable std::mutex mutex_;
+ mutable ScopedGlobalRef<jobject> pattern_;
+ mutable bool initialized_;
+ mutable bool initialization_failure_;
+ mutable UnicodeText pattern_text_;
+ };
+
+ class BreakIterator {
+ public:
+ int Next();
+
+ static constexpr int kDone = -1;
+
+ private:
+ friend class UniLibBase;
+ BreakIterator(const JniCache* jni_cache, const UnicodeText& text);
+
+ const JniCache* jni_cache_;
+ ScopedGlobalRef<jstring> text_;
+ ScopedGlobalRef<jobject> iterator_;
+ int last_break_index_;
+ int last_unicode_index_;
+ };
+
+ std::unique_ptr<RegexPattern> CreateRegexPattern(
+ const UnicodeText& regex) const;
+ std::unique_ptr<RegexPattern> CreateLazyRegexPattern(
+ const UnicodeText& regex) const;
+ std::unique_ptr<BreakIterator> CreateBreakIterator(
+ const UnicodeText& text) const;
+
+ private:
+ template <class T>
+ bool ParseInt(const UnicodeText& text, T* result) const;
+
+ std::shared_ptr<JniCache> jni_cache_;
+};
+
+template <class T>
+bool UniLibBase::ParseInt(const UnicodeText& text, T* result) const {
+ if (!jni_cache_) {
+ return false;
+ }
+
+ JNIEnv* env = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> text_java,
+ jni_cache_->ConvertToJavaString(text));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *result, JniHelper::CallStaticIntMethod<T>(
+ env, jni_cache_->integer_class.get(),
+ jni_cache_->integer_parse_int, text_java.get()));
+ return true;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
new file mode 100644
index 0000000..d0e6164
--- /dev/null
+++ b/native/utils/utf8/unilib.h
@@ -0,0 +1,149 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
+#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib-common.h"
+
+#if defined TC3_UNILIB_ICU
+#include "utils/utf8/unilib-icu.h"
+#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
+#elif defined TC3_UNILIB_JAVAICU
+#include "utils/utf8/unilib-javaicu.h"
+#define INIT_UNILIB_FOR_TESTING(VAR) VAR(nullptr)
+#elif defined TC3_UNILIB_APPLE
+#include "utils/utf8/unilib-apple.h"
+#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
+#elif defined TC3_UNILIB_DUMMY
+#include "utils/utf8/unilib-dummy.h"
+#define INIT_UNILIB_FOR_TESTING(VAR) VAR()
+#else
+#error No TC3_UNILIB implementation specified.
+#endif
+
+namespace libtextclassifier3 {
+
+class UniLib : public UniLibBase {
+ public:
+ using UniLibBase::UniLibBase;
+
+ // Lowercase a unicode string.
+ UnicodeText ToLowerText(const UnicodeText& text) const {
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ result.push_back(ToLower(codepoint));
+ }
+ return result;
+ }
+
+ // Uppercase a unicode string.
+ UnicodeText ToUpperText(const UnicodeText& text) const {
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ result.push_back(UniLibBase::ToUpper(codepoint));
+ }
+ return result;
+ }
+
+ bool IsLowerText(const UnicodeText& text) const {
+ for (const char32 codepoint : text) {
+ if (!IsLower(codepoint)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool IsUpperText(const UnicodeText& text) const {
+ for (const char32 codepoint : text) {
+ if (!IsUpper(codepoint)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool IsDigits(const UnicodeText& text) const {
+ for (const char32 codepoint : text) {
+ if (!IsDigit(codepoint)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool IsPercentage(char32 codepoint) const {
+ return libtextclassifier3::IsPercentage(codepoint);
+ }
+
+ bool IsSlash(char32 codepoint) const {
+ return libtextclassifier3::IsSlash(codepoint);
+ }
+
+ bool IsMinus(char32 codepoint) const {
+ return libtextclassifier3::IsMinus(codepoint);
+ }
+
+ bool IsNumberSign(char32 codepoint) const {
+ return libtextclassifier3::IsNumberSign(codepoint);
+ }
+
+ bool IsDot(char32 codepoint) const {
+ return libtextclassifier3::IsDot(codepoint);
+ }
+
+ bool IsLatinLetter(char32 codepoint) const {
+ return libtextclassifier3::IsLatinLetter(codepoint);
+ }
+
+ bool IsArabicLetter(char32 codepoint) const {
+ return libtextclassifier3::IsArabicLetter(codepoint);
+ }
+
+ bool IsCyrillicLetter(char32 codepoint) const {
+ return libtextclassifier3::IsCyrillicLetter(codepoint);
+ }
+
+ bool IsChineseLetter(char32 codepoint) const {
+ return libtextclassifier3::IsChineseLetter(codepoint);
+ }
+
+ bool IsJapaneseLetter(char32 codepoint) const {
+ return libtextclassifier3::IsJapaneseLetter(codepoint);
+ }
+
+ bool IsKoreanLetter(char32 codepoint) const {
+ return libtextclassifier3::IsKoreanLetter(codepoint);
+ }
+
+ bool IsThaiLetter(char32 codepoint) const {
+ return libtextclassifier3::IsThaiLetter(codepoint);
+ }
+
+ bool IsCJTletter(char32 codepoint) const {
+ return libtextclassifier3::IsCJTletter(codepoint);
+ }
+
+ bool IsLetter(char32 codepoint) const {
+ return libtextclassifier3::IsLetter(codepoint);
+ }
+};
+
+} // namespace libtextclassifier3
+#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
diff --git a/native/utils/variant.cc b/native/utils/variant.cc
new file mode 100644
index 0000000..0513440
--- /dev/null
+++ b/native/utils/variant.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+std::string Variant::ToString() const {
+ switch (GetType()) {
+ case Variant::TYPE_BOOL_VALUE:
+ if (Value<bool>()) {
+ return "true";
+ } else {
+ return "false";
+ }
+ break;
+ case Variant::TYPE_INT_VALUE:
+ return std::to_string(Value<int>());
+ break;
+ case Variant::TYPE_INT64_VALUE:
+ return std::to_string(Value<int64>());
+ break;
+ case Variant::TYPE_FLOAT_VALUE:
+ return std::to_string(Value<float>());
+ break;
+ case Variant::TYPE_DOUBLE_VALUE:
+ return std::to_string(Value<double>());
+ break;
+ case Variant::TYPE_STRING_VALUE:
+ return ConstRefValue<std::string>();
+ break;
+ default:
+ TC3_LOG(FATAL) << "Unsupported variant type: " << GetType();
+ return "";
+ break;
+ }
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Variant& value) {
+ return stream << "Variant(" << value.GetType() << ", " << value.ToString()
+ << ")";
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/variant.h b/native/utils/variant.h
new file mode 100644
index 0000000..551a822
--- /dev/null
+++ b/native/utils/variant.h
@@ -0,0 +1,290 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
+#define LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Represents a type-tagged union of different basic types.
+class Variant {
+ public:
+ enum Type {
+ TYPE_EMPTY = 0,
+ TYPE_INT8_VALUE = 1,
+ TYPE_UINT8_VALUE = 2,
+ TYPE_INT_VALUE = 3,
+ TYPE_UINT_VALUE = 4,
+ TYPE_INT64_VALUE = 5,
+ TYPE_UINT64_VALUE = 6,
+ TYPE_FLOAT_VALUE = 7,
+ TYPE_DOUBLE_VALUE = 8,
+ TYPE_BOOL_VALUE = 9,
+ TYPE_STRING_VALUE = 10,
+ TYPE_STRING_VECTOR_VALUE = 11,
+ TYPE_FLOAT_VECTOR_VALUE = 12,
+ TYPE_INT_VECTOR_VALUE = 13,
+ TYPE_STRING_VARIANT_MAP_VALUE = 14,
+ };
+
+ Variant() : type_(TYPE_EMPTY) {}
+ explicit Variant(const int8_t value)
+ : type_(TYPE_INT8_VALUE), int8_value_(value) {}
+ explicit Variant(const uint8_t value)
+ : type_(TYPE_UINT8_VALUE), uint8_value_(value) {}
+ explicit Variant(const int value)
+ : type_(TYPE_INT_VALUE), int_value_(value) {}
+ explicit Variant(const uint value)
+ : type_(TYPE_UINT_VALUE), uint_value_(value) {}
+ explicit Variant(const int64 value)
+ : type_(TYPE_INT64_VALUE), long_value_(value) {}
+ explicit Variant(const uint64 value)
+ : type_(TYPE_UINT64_VALUE), ulong_value_(value) {}
+ explicit Variant(const float value)
+ : type_(TYPE_FLOAT_VALUE), float_value_(value) {}
+ explicit Variant(const double value)
+ : type_(TYPE_DOUBLE_VALUE), double_value_(value) {}
+ explicit Variant(const StringPiece value)
+ : type_(TYPE_STRING_VALUE), string_value_(value.ToString()) {}
+ explicit Variant(const std::string value)
+ : type_(TYPE_STRING_VALUE), string_value_(value) {}
+ explicit Variant(const char* value)
+ : type_(TYPE_STRING_VALUE), string_value_(value) {}
+ explicit Variant(const bool value)
+ : type_(TYPE_BOOL_VALUE), bool_value_(value) {}
+ explicit Variant(const std::vector<std::string>& value)
+ : type_(TYPE_STRING_VECTOR_VALUE), string_vector_value_(value) {}
+ explicit Variant(const std::vector<float>& value)
+ : type_(TYPE_FLOAT_VECTOR_VALUE), float_vector_value_(value) {}
+ explicit Variant(const std::vector<int>& value)
+ : type_(TYPE_INT_VECTOR_VALUE), int_vector_value_(value) {}
+ explicit Variant(const std::map<std::string, Variant>& value)
+ : type_(TYPE_STRING_VARIANT_MAP_VALUE),
+ string_variant_map_value_(value) {}
+
+ Variant& operator=(const Variant&) = default;
+
+ template <class T>
+ struct dependent_false : std::false_type {};
+
+ template <typename T>
+ T Value() const {
+ static_assert(dependent_false<T>::value, "Not supported.");
+ }
+
+ template <>
+ int8 Value() const {
+ TC3_CHECK(Has<int8>());
+ return int8_value_;
+ }
+
+ template <>
+ uint8 Value() const {
+ TC3_CHECK(Has<uint8>());
+ return uint8_value_;
+ }
+
+ template <>
+ int Value() const {
+ TC3_CHECK(Has<int>());
+ return int_value_;
+ }
+
+ template <>
+ uint Value() const {
+ TC3_CHECK(Has<uint>());
+ return uint_value_;
+ }
+
+ template <>
+ int64 Value() const {
+ TC3_CHECK(Has<int64>());
+ return long_value_;
+ }
+
+ template <>
+ uint64 Value() const {
+ TC3_CHECK(Has<uint64>());
+ return ulong_value_;
+ }
+
+ template <>
+ float Value() const {
+ TC3_CHECK(Has<float>());
+ return float_value_;
+ }
+
+ template <>
+ double Value() const {
+ TC3_CHECK(Has<double>());
+ return double_value_;
+ }
+
+ template <>
+ bool Value() const {
+ TC3_CHECK(Has<bool>());
+ return bool_value_;
+ }
+
+ template <typename T>
+ const T& ConstRefValue() const;
+
+ template <>
+ const std::string& ConstRefValue() const {
+ TC3_CHECK(Has<std::string>());
+ return string_value_;
+ }
+
+ template <>
+ const std::vector<std::string>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<std::string>>());
+ return string_vector_value_;
+ }
+
+ template <>
+ const std::vector<float>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<float>>());
+ return float_vector_value_;
+ }
+
+ template <>
+ const std::vector<int>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<int>>());
+ return int_vector_value_;
+ }
+
+ template <>
+ const std::map<std::string, Variant>& ConstRefValue() const {
+ TC3_CHECK((Has<std::map<std::string, Variant>>()));
+ return string_variant_map_value_;
+ }
+
+ template <typename T>
+ bool Has() const;
+
+ template <>
+ bool Has<int8>() const {
+ return type_ == TYPE_INT8_VALUE;
+ }
+
+ template <>
+ bool Has<uint8>() const {
+ return type_ == TYPE_UINT8_VALUE;
+ }
+
+ template <>
+ bool Has<int>() const {
+ return type_ == TYPE_INT_VALUE;
+ }
+
+ template <>
+ bool Has<uint>() const {
+ return type_ == TYPE_UINT_VALUE;
+ }
+
+ template <>
+ bool Has<int64>() const {
+ return type_ == TYPE_INT64_VALUE;
+ }
+
+ template <>
+ bool Has<uint64>() const {
+ return type_ == TYPE_UINT64_VALUE;
+ }
+
+ template <>
+ bool Has<float>() const {
+ return type_ == TYPE_FLOAT_VALUE;
+ }
+
+ template <>
+ bool Has<double>() const {
+ return type_ == TYPE_DOUBLE_VALUE;
+ }
+
+ template <>
+ bool Has<bool>() const {
+ return type_ == TYPE_BOOL_VALUE;
+ }
+
+ template <>
+ bool Has<std::string>() const {
+ return type_ == TYPE_STRING_VALUE;
+ }
+
+ template <>
+ bool Has<std::vector<std::string>>() const {
+ return type_ == TYPE_STRING_VECTOR_VALUE;
+ }
+
+ template <>
+ bool Has<std::vector<float>>() const {
+ return type_ == TYPE_FLOAT_VECTOR_VALUE;
+ }
+
+ template <>
+ bool Has<std::vector<int>>() const {
+ return type_ == TYPE_INT_VECTOR_VALUE;
+ }
+
+ template <>
+ bool Has<std::map<std::string, Variant>>() const {
+ return type_ == TYPE_STRING_VARIANT_MAP_VALUE;
+ }
+
+ // Converts the value of this variant to its string representation, regardless
+ // of the type of the actual value.
+ std::string ToString() const;
+
+ Type GetType() const { return type_; }
+
+ bool HasValue() const { return type_ != TYPE_EMPTY; }
+
+ private:
+ Type type_;
+ union {
+ int8_t int8_value_;
+ uint8_t uint8_value_;
+ int int_value_;
+ uint uint_value_;
+ int64 long_value_;
+ uint64 ulong_value_;
+ float float_value_;
+ double double_value_;
+ bool bool_value_;
+ };
+ std::string string_value_;
+ std::vector<std::string> string_vector_value_;
+ std::vector<float> float_vector_value_;
+ std::vector<int> int_vector_value_;
+ std::map<std::string, Variant> string_variant_map_value_;
+};
+
+// Pretty-printing function for Variant.
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Variant& value);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
diff --git a/native/utils/variant_test.cc b/native/utils/variant_test.cc
new file mode 100644
index 0000000..cf0acfb
--- /dev/null
+++ b/native/utils/variant_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/variant.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(VariantTest, GetType) {
+ EXPECT_EQ(Variant().GetType(), Variant::TYPE_EMPTY);
+ EXPECT_EQ(Variant(static_cast<int8_t>(9)).GetType(),
+ Variant::TYPE_INT8_VALUE);
+ EXPECT_EQ(Variant(static_cast<uint8_t>(9)).GetType(),
+ Variant::TYPE_UINT8_VALUE);
+ EXPECT_EQ(Variant(static_cast<int>(9)).GetType(), Variant::TYPE_INT_VALUE);
+ EXPECT_EQ(Variant(static_cast<uint>(9)).GetType(), Variant::TYPE_UINT_VALUE);
+ EXPECT_EQ(Variant(static_cast<int64>(9)).GetType(),
+ Variant::TYPE_INT64_VALUE);
+ EXPECT_EQ(Variant(static_cast<uint64>(9)).GetType(),
+ Variant::TYPE_UINT64_VALUE);
+ EXPECT_EQ(Variant(static_cast<float>(9)).GetType(),
+ Variant::TYPE_FLOAT_VALUE);
+ EXPECT_EQ(Variant(static_cast<double>(9)).GetType(),
+ Variant::TYPE_DOUBLE_VALUE);
+ EXPECT_EQ(Variant(true).GetType(), Variant::TYPE_BOOL_VALUE);
+ EXPECT_EQ(Variant("hello").GetType(), Variant::TYPE_STRING_VALUE);
+}
+
+TEST(VariantTest, HasValue) {
+ EXPECT_FALSE(Variant().HasValue());
+ EXPECT_TRUE(Variant(static_cast<int8_t>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<uint8_t>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<int>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<uint>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<int64>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<uint64>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<float>(9)).HasValue());
+ EXPECT_TRUE(Variant(static_cast<double>(9)).HasValue());
+ EXPECT_TRUE(Variant(true).HasValue());
+ EXPECT_TRUE(Variant("hello").HasValue());
+}
+
+TEST(VariantTest, Value) {
+ EXPECT_EQ(Variant(static_cast<int8_t>(9)).Value<int8>(), 9);
+ EXPECT_EQ(Variant(static_cast<uint8_t>(9)).Value<uint8>(), 9);
+ EXPECT_EQ(Variant(static_cast<int>(9)).Value<int>(), 9);
+ EXPECT_EQ(Variant(static_cast<uint>(9)).Value<uint>(), 9);
+ EXPECT_EQ(Variant(static_cast<int64>(9)).Value<int64>(), 9);
+ EXPECT_EQ(Variant(static_cast<uint64>(9)).Value<uint64>(), 9);
+ EXPECT_EQ(Variant(static_cast<float>(9)).Value<float>(), 9);
+ EXPECT_EQ(Variant(static_cast<double>(9)).Value<double>(), 9);
+ EXPECT_EQ(Variant(true).Value<bool>(), true);
+ EXPECT_EQ(Variant("hello").ConstRefValue<std::string>(), "hello");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/utils/zlib/buffer.fbs b/native/utils/zlib/buffer.fbs
similarity index 100%
rename from utils/zlib/buffer.fbs
rename to native/utils/zlib/buffer.fbs
diff --git a/utils/zlib/zlib.cc b/native/utils/zlib/zlib.cc
similarity index 100%
rename from utils/zlib/zlib.cc
rename to native/utils/zlib/zlib.cc
diff --git a/utils/zlib/zlib.h b/native/utils/zlib/zlib.h
similarity index 100%
rename from utils/zlib/zlib.h
rename to native/utils/zlib/zlib.h
diff --git a/native/utils/zlib/zlib_regex.cc b/native/utils/zlib/zlib_regex.cc
new file mode 100644
index 0000000..73b6d30
--- /dev/null
+++ b/native/utils/zlib/zlib_regex.cc
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/zlib/zlib_regex.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
+ const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
+ const CompressedBuffer* compressed_pattern, bool lazy_compile_regex,
+ ZlibDecompressor* decompressor, std::string* result_pattern_text) {
+ UnicodeText unicode_regex_pattern;
+ std::string decompressed_pattern;
+ if (compressed_pattern != nullptr &&
+ compressed_pattern->buffer() != nullptr) {
+ if (decompressor == nullptr ||
+ !decompressor->MaybeDecompress(compressed_pattern,
+ &decompressed_pattern)) {
+ TC3_LOG(ERROR) << "Cannot decompress pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(decompressed_pattern.data(),
+ decompressed_pattern.size(), /*do_copy=*/false);
+ } else {
+ if (uncompressed_pattern == nullptr) {
+ TC3_LOG(ERROR) << "Cannot load uncompressed pattern.";
+ return nullptr;
+ }
+ unicode_regex_pattern =
+ UTF8ToUnicodeText(uncompressed_pattern->c_str(),
+ uncompressed_pattern->Length(), /*do_copy=*/false);
+ }
+
+ if (result_pattern_text != nullptr) {
+ *result_pattern_text = unicode_regex_pattern.ToUTF8String();
+ }
+
+ std::unique_ptr<UniLib::RegexPattern> regex_pattern;
+ if (lazy_compile_regex) {
+ regex_pattern = unilib.CreateLazyRegexPattern(unicode_regex_pattern);
+ } else {
+ regex_pattern = unilib.CreateRegexPattern(unicode_regex_pattern);
+ }
+
+ if (!regex_pattern) {
+ TC3_LOG(ERROR) << "Could not create pattern: "
+ << unicode_regex_pattern.ToUTF8String();
+ }
+ return regex_pattern;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/zlib/zlib_regex.h b/native/utils/zlib/zlib_regex.h
similarity index 100%
rename from utils/zlib/zlib_regex.h
rename to native/utils/zlib/zlib_regex.h
diff --git a/notification/Android.bp b/notification/Android.bp
new file mode 100644
index 0000000..966941d
--- /dev/null
+++ b/notification/Android.bp
@@ -0,0 +1,37 @@
+//
+// Copyright (C) 2019 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// A library that contains all java classes with the AndroidManifest.
+android_library {
+ name: "TextClassifierNotificationLib",
+ static_libs: ["TextClassifierNotificationLibNoManifest"],
+ sdk_version: "system_current",
+ min_sdk_version: "29",
+ manifest: "AndroidManifest.xml",
+}
+
+// Similar to TextClassifierNotificationLib, but without the AndroidManifest.
+android_library {
+ name: "TextClassifierNotificationLibNoManifest",
+ srcs: ["src/**/*.java"],
+ static_libs: [
+ "androidx.annotation_annotation",
+ "guava",
+ ],
+ sdk_version: "system_current",
+ min_sdk_version: "29",
+ manifest: "LibNoManifest_AndroidManifest.xml",
+}
diff --git a/notification/AndroidManifest.xml b/notification/AndroidManifest.xml
new file mode 100644
index 0000000..3153d1d
--- /dev/null
+++ b/notification/AndroidManifest.xml
@@ -0,0 +1,13 @@
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.notification">
+
+ <uses-sdk android:minSdkVersion="29" />
+
+ <application>
+ <activity
+ android:name=".CopyCodeActivity"
+ android:exported="false"
+ android:theme="@android:style/Theme.NoDisplay" />
+ </application>
+
+</manifest>
\ No newline at end of file
diff --git a/notification/LibNoManifest_AndroidManifest.xml b/notification/LibNoManifest_AndroidManifest.xml
new file mode 100644
index 0000000..b9ebf7d
--- /dev/null
+++ b/notification/LibNoManifest_AndroidManifest.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-->
+
+<!--
+ This is for the AndroidManifest.xml for the TextClassifierNotificationLibNoManifest library.
+ The user of this library should explicitly put the necssary components in their own
+ AndroidManifest.xml, see AndroidManifest.xml under the same folder.
+-->
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.notification">
+
+ <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
+
+</manifest>
diff --git a/notification/res/drawable/tc_notif_ic_action_open.xml b/notification/res/drawable/tc_notif_ic_action_open.xml
new file mode 100644
index 0000000..4ca7dc4
--- /dev/null
+++ b/notification/res/drawable/tc_notif_ic_action_open.xml
@@ -0,0 +1,10 @@
+<vector xmlns:android="http://schemas.android.com/apk/res/android"
+ android:width="24dp"
+ android:height="24dp"
+ android:tint="?android:colorControlNormal"
+ android:viewportHeight="24.0"
+ android:viewportWidth="24.0">
+ <path
+ android:fillColor="#FFFFFFFF"
+ android:pathData="M19 19H5V5h7V3H5c-1.11 0-2 .9-2 2v14c0 1.1.89 2 2 2h14c1.1 0 2-.9 2-2v-7h-2v7zM14 3v2h3.59l-9.83 9.83 1.41 1.41L19 6.41V10h2V3h-7z" />
+</vector>
diff --git a/notification/res/drawable/tc_notif_ic_menu_copy_material.xml b/notification/res/drawable/tc_notif_ic_menu_copy_material.xml
new file mode 100644
index 0000000..8d35374
--- /dev/null
+++ b/notification/res/drawable/tc_notif_ic_menu_copy_material.xml
@@ -0,0 +1,12 @@
+<vector xmlns:android="http://schemas.android.com/apk/res/android"
+ android:width="24dp"
+ android:height="24dp"
+ android:viewportWidth="24.0"
+ android:viewportHeight="24.0"
+ android:autoMirrored="true"
+ android:tint="?android:colorControlNormal">
+ <path
+ android:pathData="M16,1L4,1C2.9,1 2,1.9 2,3l0,14l2,0L4,3l12,0L16,1zM19,5L8,5C6.9,5 6,5.9 6,7l0,14c0,1.1 0.9,2 2,2l11,0c1.1,0 2,-0.9 2,-2L21,7C21,5.9 20.1,5 19,5zM19,21L8,21L8,7l11,0L19,21z"
+ android:fillColor="#FFFFFFFF"/>
+</vector>
+
diff --git a/notification/res/values-af/strings.xml b/notification/res/values-af/strings.xml
new file mode 100755
index 0000000..fb86971
--- /dev/null
+++ b/notification/res/values-af/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopieer \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kode is gekopieer</string>
+</resources>
diff --git a/notification/res/values-am/strings.xml b/notification/res/values-am/strings.xml
new file mode 100755
index 0000000..62ad42d
--- /dev/null
+++ b/notification/res/values-am/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">ቅዳ \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">ኮድ ተቀድቷል</string>
+</resources>
diff --git a/notification/res/values-ar/strings.xml b/notification/res/values-ar/strings.xml
new file mode 100755
index 0000000..ab32f2d
--- /dev/null
+++ b/notification/res/values-ar/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">نسخ \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">تم نسخ الرمز.</string>
+</resources>
diff --git a/notification/res/values-as/strings.xml b/notification/res/values-as/strings.xml
new file mode 100755
index 0000000..f52e2f5
--- /dev/null
+++ b/notification/res/values-as/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c প্ৰতিলিপি কৰক</string>
+ <string name="tc_notif_code_copied_to_clipboard">ক’ড প্ৰতিলিপি কৰা হ’ল</string>
+</resources>
diff --git a/notification/res/values-az/strings.xml b/notification/res/values-az/strings.xml
new file mode 100755
index 0000000..6eec653
--- /dev/null
+++ b/notification/res/values-az/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopyalayın: \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kod kopyalandı</string>
+</resources>
diff --git a/notification/res/values-b+es+419/strings.xml b/notification/res/values-b+es+419/strings.xml
new file mode 100755
index 0000000..c64eaae
--- /dev/null
+++ b/notification/res/values-b+es+419/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiar \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Se copió el código</string>
+</resources>
diff --git a/notification/res/values-b+sr+Latn/strings.xml b/notification/res/values-b+sr+Latn/strings.xml
new file mode 100755
index 0000000..480ef86
--- /dev/null
+++ b/notification/res/values-b+sr+Latn/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiraj „%1$s“</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kôd je kopiran</string>
+</resources>
diff --git a/notification/res/values-be/strings.xml b/notification/res/values-be/strings.xml
new file mode 100755
index 0000000..2505167
--- /dev/null
+++ b/notification/res/values-be/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Капіраваць \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Код скапіраваны</string>
+</resources>
diff --git a/notification/res/values-bg/strings.xml b/notification/res/values-bg/strings.xml
new file mode 100755
index 0000000..f882a3b
--- /dev/null
+++ b/notification/res/values-bg/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Копиране на %1$s</string>
+ <string name="tc_notif_code_copied_to_clipboard">Кодът е копиран</string>
+</resources>
diff --git a/notification/res/values-bn/strings.xml b/notification/res/values-bn/strings.xml
new file mode 100755
index 0000000..e6fd335
--- /dev/null
+++ b/notification/res/values-bn/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c কপি করুন</string>
+ <string name="tc_notif_code_copied_to_clipboard">কোড কপি করা হয়েছে</string>
+</resources>
diff --git a/notification/res/values-bs/strings.xml b/notification/res/values-bs/strings.xml
new file mode 100755
index 0000000..79c3db7
--- /dev/null
+++ b/notification/res/values-bs/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiraj \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kôd je kopiran</string>
+</resources>
diff --git a/notification/res/values-ca/strings.xml b/notification/res/values-ca/strings.xml
new file mode 100755
index 0000000..0a8ab7a
--- /dev/null
+++ b/notification/res/values-ca/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copia \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">S\'ha copiat el codi</string>
+</resources>
diff --git a/notification/res/values-cs/strings.xml b/notification/res/values-cs/strings.xml
new file mode 100755
index 0000000..4d3f887
--- /dev/null
+++ b/notification/res/values-cs/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopírovat kód %1$s</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kód zkopírován</string>
+</resources>
diff --git a/notification/res/values-da/strings.xml b/notification/res/values-da/strings.xml
new file mode 100755
index 0000000..f4c7eed
--- /dev/null
+++ b/notification/res/values-da/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiér \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Koden blev kopieret</string>
+</resources>
diff --git a/notification/res/values-de/strings.xml b/notification/res/values-de/strings.xml
new file mode 100755
index 0000000..8047fbb
--- /dev/null
+++ b/notification/res/values-de/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c kopieren</string>
+ <string name="tc_notif_code_copied_to_clipboard">Code wurde kopiert</string>
+</resources>
diff --git a/notification/res/values-el/strings.xml b/notification/res/values-el/strings.xml
new file mode 100755
index 0000000..abe8fee
--- /dev/null
+++ b/notification/res/values-el/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Αντιγραφή \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Αντιγράφηκε ο κωδικός</string>
+</resources>
diff --git a/notification/res/values-en-rGB/strings.xml b/notification/res/values-en-rGB/strings.xml
new file mode 100755
index 0000000..eae16e7
--- /dev/null
+++ b/notification/res/values-en-rGB/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copy \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Code copied</string>
+</resources>
diff --git a/notification/res/values-es/strings.xml b/notification/res/values-es/strings.xml
new file mode 100755
index 0000000..c847d86
--- /dev/null
+++ b/notification/res/values-es/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiar \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Se ha copiado el código</string>
+</resources>
diff --git a/notification/res/values-et/strings.xml b/notification/res/values-et/strings.xml
new file mode 100755
index 0000000..593ba77
--- /dev/null
+++ b/notification/res/values-et/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopeeri \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kood kopeeriti</string>
+</resources>
diff --git a/notification/res/values-eu/strings.xml b/notification/res/values-eu/strings.xml
new file mode 100755
index 0000000..57b0c1c
--- /dev/null
+++ b/notification/res/values-eu/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiatu \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kopiatu da kodea</string>
+</resources>
diff --git a/notification/res/values-fa/strings.xml b/notification/res/values-fa/strings.xml
new file mode 100755
index 0000000..8417a1e
--- /dev/null
+++ b/notification/res/values-fa/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">کپی کردن \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">کد کپی شد</string>
+</resources>
diff --git a/notification/res/values-fi/strings.xml b/notification/res/values-fi/strings.xml
new file mode 100755
index 0000000..564ebba
--- /dev/null
+++ b/notification/res/values-fi/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopioi \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Koodi kopioitu</string>
+</resources>
diff --git a/notification/res/values-fr-rCA/strings.xml b/notification/res/values-fr-rCA/strings.xml
new file mode 100755
index 0000000..e230b05
--- /dev/null
+++ b/notification/res/values-fr-rCA/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copier « %1$s »</string>
+ <string name="tc_notif_code_copied_to_clipboard">Code copié</string>
+</resources>
diff --git a/notification/res/values-fr/strings.xml b/notification/res/values-fr/strings.xml
new file mode 100755
index 0000000..0b03c7e
--- /dev/null
+++ b/notification/res/values-fr/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copier \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Code copié</string>
+</resources>
diff --git a/notification/res/values-gl/strings.xml b/notification/res/values-gl/strings.xml
new file mode 100755
index 0000000..2beaa34
--- /dev/null
+++ b/notification/res/values-gl/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiar \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Copiouse o código</string>
+</resources>
diff --git a/notification/res/values-gu/strings.xml b/notification/res/values-gu/strings.xml
new file mode 100755
index 0000000..4d898dd
--- /dev/null
+++ b/notification/res/values-gu/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c કૉપિ કરો</string>
+ <string name="tc_notif_code_copied_to_clipboard">કોડ કૉપિ કર્યો</string>
+</resources>
diff --git a/notification/res/values-hi/strings.xml b/notification/res/values-hi/strings.xml
new file mode 100755
index 0000000..bfcc7ed
--- /dev/null
+++ b/notification/res/values-hi/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c कॉपी करें</string>
+ <string name="tc_notif_code_copied_to_clipboard">कोड कॉपी किया गया</string>
+</resources>
diff --git a/notification/res/values-hr/strings.xml b/notification/res/values-hr/strings.xml
new file mode 100755
index 0000000..79c3db7
--- /dev/null
+++ b/notification/res/values-hr/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiraj \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kôd je kopiran</string>
+</resources>
diff --git a/notification/res/values-hu/strings.xml b/notification/res/values-hu/strings.xml
new file mode 100755
index 0000000..d38ef72
--- /dev/null
+++ b/notification/res/values-hu/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201E%1$s\u201D másolása</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kód másolva</string>
+</resources>
diff --git a/notification/res/values-hy/strings.xml b/notification/res/values-hy/strings.xml
new file mode 100755
index 0000000..bf30692
--- /dev/null
+++ b/notification/res/values-hy/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Պատճենել «%1$s» կոդը</string>
+ <string name="tc_notif_code_copied_to_clipboard">Կոդը պատճենվեց</string>
+</resources>
diff --git a/notification/res/values-id/strings.xml b/notification/res/values-id/strings.xml
new file mode 100755
index 0000000..12c518d
--- /dev/null
+++ b/notification/res/values-id/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Salin \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kode disalin</string>
+</resources>
diff --git a/notification/res/values-is/strings.xml b/notification/res/values-is/strings.xml
new file mode 100755
index 0000000..c951d74
--- /dev/null
+++ b/notification/res/values-is/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Afrita \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kóði afritaður</string>
+</resources>
diff --git a/notification/res/values-it/strings.xml b/notification/res/values-it/strings.xml
new file mode 100755
index 0000000..ad214d4
--- /dev/null
+++ b/notification/res/values-it/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copia \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Codice copiato</string>
+</resources>
diff --git a/notification/res/values-iw/strings.xml b/notification/res/values-iw/strings.xml
new file mode 100755
index 0000000..36bb51d
--- /dev/null
+++ b/notification/res/values-iw/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">העתקה של \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">הקוד הועתק</string>
+</resources>
diff --git a/notification/res/values-ja/strings.xml b/notification/res/values-ja/strings.xml
new file mode 100755
index 0000000..feeae62
--- /dev/null
+++ b/notification/res/values-ja/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">「%1$s」をコピー</string>
+ <string name="tc_notif_code_copied_to_clipboard">コードをコピーしました</string>
+</resources>
diff --git a/notification/res/values-ka/strings.xml b/notification/res/values-ka/strings.xml
new file mode 100755
index 0000000..6681fd5
--- /dev/null
+++ b/notification/res/values-ka/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">კოპირება \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">კოდი კოპირებულია</string>
+</resources>
diff --git a/notification/res/values-kk/strings.xml b/notification/res/values-kk/strings.xml
new file mode 100755
index 0000000..eccceb2
--- /dev/null
+++ b/notification/res/values-kk/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c көшіру</string>
+ <string name="tc_notif_code_copied_to_clipboard">Код көшірілді.</string>
+</resources>
diff --git a/notification/res/values-km/strings.xml b/notification/res/values-km/strings.xml
new file mode 100755
index 0000000..bf66909
--- /dev/null
+++ b/notification/res/values-km/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">ចម្លង \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">បានចម្លងកូដ</string>
+</resources>
diff --git a/notification/res/values-kn/strings.xml b/notification/res/values-kn/strings.xml
new file mode 100755
index 0000000..80ff7ed
--- /dev/null
+++ b/notification/res/values-kn/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c ಅನ್ನು ನಕಲಿಸಿ</string>
+ <string name="tc_notif_code_copied_to_clipboard">ಕೋಡ್ ನಕಲಿಸಲಾಗಿದೆ</string>
+</resources>
diff --git a/notification/res/values-ko/strings.xml b/notification/res/values-ko/strings.xml
new file mode 100755
index 0000000..48c3614
--- /dev/null
+++ b/notification/res/values-ko/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c 복사</string>
+ <string name="tc_notif_code_copied_to_clipboard">코드 복사됨</string>
+</resources>
diff --git a/notification/res/values-ky/strings.xml b/notification/res/values-ky/strings.xml
new file mode 100755
index 0000000..f9043f8
--- /dev/null
+++ b/notification/res/values-ky/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c кодун көчүрүү</string>
+ <string name="tc_notif_code_copied_to_clipboard">Код көчүрүлдү</string>
+</resources>
diff --git a/notification/res/values-lo/strings.xml b/notification/res/values-lo/strings.xml
new file mode 100755
index 0000000..06b91db
--- /dev/null
+++ b/notification/res/values-lo/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">ສຳເນົາ \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">ສຳເນົາແລ້ວ</string>
+</resources>
diff --git a/notification/res/values-lt/strings.xml b/notification/res/values-lt/strings.xml
new file mode 100755
index 0000000..f864d8a
--- /dev/null
+++ b/notification/res/values-lt/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopijuoti „%1$s“</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kodas nukopijuotas</string>
+</resources>
diff --git a/notification/res/values-lv/strings.xml b/notification/res/values-lv/strings.xml
new file mode 100755
index 0000000..e2fe3bc
--- /dev/null
+++ b/notification/res/values-lv/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopēt kodu \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kods ir nokopēts</string>
+</resources>
diff --git a/notification/res/values-mk/strings.xml b/notification/res/values-mk/strings.xml
new file mode 100755
index 0000000..841df09
--- /dev/null
+++ b/notification/res/values-mk/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Копирај „%1$s“</string>
+ <string name="tc_notif_code_copied_to_clipboard">Кодот е копиран</string>
+</resources>
diff --git a/notification/res/values-ml/strings.xml b/notification/res/values-ml/strings.xml
new file mode 100755
index 0000000..e64cf36
--- /dev/null
+++ b/notification/res/values-ml/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c പകർത്തുക</string>
+ <string name="tc_notif_code_copied_to_clipboard">കോഡ് പകർത്തി</string>
+</resources>
diff --git a/notification/res/values-mn/strings.xml b/notification/res/values-mn/strings.xml
new file mode 100755
index 0000000..9c04d40
--- /dev/null
+++ b/notification/res/values-mn/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c-г хуулах</string>
+ <string name="tc_notif_code_copied_to_clipboard">Кодыг хуулсан</string>
+</resources>
diff --git a/notification/res/values-mr/strings.xml b/notification/res/values-mr/strings.xml
new file mode 100755
index 0000000..ca058c5
--- /dev/null
+++ b/notification/res/values-mr/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c कॉपी करा</string>
+ <string name="tc_notif_code_copied_to_clipboard">कोड कॉपी केला</string>
+</resources>
diff --git a/notification/res/values-ms/strings.xml b/notification/res/values-ms/strings.xml
new file mode 100755
index 0000000..06415a6
--- /dev/null
+++ b/notification/res/values-ms/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Salin \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kod disalin</string>
+</resources>
diff --git a/notification/res/values-my/strings.xml b/notification/res/values-my/strings.xml
new file mode 100755
index 0000000..144c6b9
--- /dev/null
+++ b/notification/res/values-my/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c ကို ကူးယူပါ</string>
+ <string name="tc_notif_code_copied_to_clipboard">ကုဒ်ကို ကူးယူလိုက်ပါပြီ</string>
+</resources>
diff --git a/notification/res/values-ne/strings.xml b/notification/res/values-ne/strings.xml
new file mode 100755
index 0000000..6940c77
--- /dev/null
+++ b/notification/res/values-ne/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c प्रतिलिपि गर्नु…</string>
+ <string name="tc_notif_code_copied_to_clipboard">कोड प्रतिलिपि गरियो</string>
+</resources>
diff --git a/notification/res/values-nl/strings.xml b/notification/res/values-nl/strings.xml
new file mode 100755
index 0000000..642def9
--- /dev/null
+++ b/notification/res/values-nl/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c kopiëren</string>
+ <string name="tc_notif_code_copied_to_clipboard">Code gekopieerd</string>
+</resources>
diff --git a/notification/res/values-no/strings.xml b/notification/res/values-no/strings.xml
new file mode 100755
index 0000000..e22bdae
--- /dev/null
+++ b/notification/res/values-no/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiér «%1$s»</string>
+ <string name="tc_notif_code_copied_to_clipboard">Koden er kopiert</string>
+</resources>
diff --git a/notification/res/values-or/strings.xml b/notification/res/values-or/strings.xml
new file mode 100755
index 0000000..175fa11
--- /dev/null
+++ b/notification/res/values-or/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c କପି କରନ୍ତୁ</string>
+ <string name="tc_notif_code_copied_to_clipboard">କୋଡ୍ କପି କରାଯାଇଛି</string>
+</resources>
diff --git a/notification/res/values-pa/strings.xml b/notification/res/values-pa/strings.xml
new file mode 100755
index 0000000..8bdbae8
--- /dev/null
+++ b/notification/res/values-pa/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c ਕਾਪੀ ਕਰੋ</string>
+ <string name="tc_notif_code_copied_to_clipboard">ਕੋਡ ਕਾਪੀ ਕੀਤਾ ਗਿਆ</string>
+</resources>
diff --git a/notification/res/values-pl/strings.xml b/notification/res/values-pl/strings.xml
new file mode 100755
index 0000000..916ba1e
--- /dev/null
+++ b/notification/res/values-pl/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiuj kod „%1$s”</string>
+ <string name="tc_notif_code_copied_to_clipboard">Skopiowano kod</string>
+</resources>
diff --git a/notification/res/values-pt-rBR/strings.xml b/notification/res/values-pt-rBR/strings.xml
new file mode 100755
index 0000000..c925814
--- /dev/null
+++ b/notification/res/values-pt-rBR/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiar \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Código copiado</string>
+</resources>
diff --git a/notification/res/values-pt-rPT/strings.xml b/notification/res/values-pt-rPT/strings.xml
new file mode 100755
index 0000000..c925814
--- /dev/null
+++ b/notification/res/values-pt-rPT/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiar \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Código copiado</string>
+</resources>
diff --git a/notification/res/values-ro/strings.xml b/notification/res/values-ro/strings.xml
new file mode 100755
index 0000000..4a62325
--- /dev/null
+++ b/notification/res/values-ro/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Copiați \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Codul a fost copiat</string>
+</resources>
diff --git a/notification/res/values-ru/strings.xml b/notification/res/values-ru/strings.xml
new file mode 100755
index 0000000..4f26348
--- /dev/null
+++ b/notification/res/values-ru/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Копировать \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Код скопирован.</string>
+</resources>
diff --git a/notification/res/values-si/strings.xml b/notification/res/values-si/strings.xml
new file mode 100755
index 0000000..493b85d
--- /dev/null
+++ b/notification/res/values-si/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c පිටපත් කරන්න</string>
+ <string name="tc_notif_code_copied_to_clipboard">කේතය පිටපත් කරන ලදී</string>
+</resources>
diff --git a/notification/res/values-sk/strings.xml b/notification/res/values-sk/strings.xml
new file mode 100755
index 0000000..0a455f4
--- /dev/null
+++ b/notification/res/values-sk/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopírovať \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kód bol skopírovaný</string>
+</resources>
diff --git a/notification/res/values-sl/strings.xml b/notification/res/values-sl/strings.xml
new file mode 100755
index 0000000..273e7b7
--- /dev/null
+++ b/notification/res/values-sl/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiraj \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Koda je kopirana</string>
+</resources>
diff --git a/notification/res/values-sq/strings.xml b/notification/res/values-sq/strings.xml
new file mode 100755
index 0000000..8caadbd
--- /dev/null
+++ b/notification/res/values-sq/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopjo \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kodi u kopjua</string>
+</resources>
diff --git a/notification/res/values-sr/strings.xml b/notification/res/values-sr/strings.xml
new file mode 100755
index 0000000..9df7992
--- /dev/null
+++ b/notification/res/values-sr/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Копирај „%1$s“</string>
+ <string name="tc_notif_code_copied_to_clipboard">Кôд је копиран</string>
+</resources>
diff --git a/notification/res/values-sv/strings.xml b/notification/res/values-sv/strings.xml
new file mode 100755
index 0000000..fe603ba
--- /dev/null
+++ b/notification/res/values-sv/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopiera \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Koden har kopierats</string>
+</resources>
diff --git a/notification/res/values-sw/strings.xml b/notification/res/values-sw/strings.xml
new file mode 100755
index 0000000..6ea6774
--- /dev/null
+++ b/notification/res/values-sw/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Nakili \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Imenakili msimbo</string>
+</resources>
diff --git a/notification/res/values-ta/strings.xml b/notification/res/values-ta/strings.xml
new file mode 100755
index 0000000..ec0d6fc
--- /dev/null
+++ b/notification/res/values-ta/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\u201c%1$s\u201c என்பதை நகலெடு</string>
+ <string name="tc_notif_code_copied_to_clipboard">குறியீடு நகலெடுக்கப்பட்டது</string>
+</resources>
diff --git a/notification/res/values-te/strings.xml b/notification/res/values-te/strings.xml
new file mode 100755
index 0000000..7240035
--- /dev/null
+++ b/notification/res/values-te/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">\'\u201c%1$s\u201c\'ను కాపీ చేయి</string>
+ <string name="tc_notif_code_copied_to_clipboard">కోడ్ కాపీ చేయబడింది</string>
+</resources>
diff --git a/notification/res/values-th/strings.xml b/notification/res/values-th/strings.xml
new file mode 100755
index 0000000..e5373e2
--- /dev/null
+++ b/notification/res/values-th/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">คัดลอก \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">คัดลอกรหัสแล้ว</string>
+</resources>
diff --git a/notification/res/values-tl/strings.xml b/notification/res/values-tl/strings.xml
new file mode 100755
index 0000000..b761fa3
--- /dev/null
+++ b/notification/res/values-tl/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopyahin ang \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Nakopya ang code</string>
+</resources>
diff --git a/notification/res/values-tr/strings.xml b/notification/res/values-tr/strings.xml
new file mode 100755
index 0000000..3c15f94
--- /dev/null
+++ b/notification/res/values-tr/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopyala: \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kod kopyalandı</string>
+</resources>
diff --git a/notification/res/values-uk/strings.xml b/notification/res/values-uk/strings.xml
new file mode 100755
index 0000000..c53d94e
--- /dev/null
+++ b/notification/res/values-uk/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Скопіювати \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Код скопійовано</string>
+</resources>
diff --git a/notification/res/values-ur/strings.xml b/notification/res/values-ur/strings.xml
new file mode 100755
index 0000000..bcf1a8d
--- /dev/null
+++ b/notification/res/values-ur/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">کاپی کریں \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">کوڈ کاپی ہو گیا</string>
+</resources>
diff --git a/notification/res/values-uz/strings.xml b/notification/res/values-uz/strings.xml
new file mode 100755
index 0000000..c1a8319
--- /dev/null
+++ b/notification/res/values-uz/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Nusxa olish: \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Kod nusxalandi</string>
+</resources>
diff --git a/notification/res/values-vi/strings.xml b/notification/res/values-vi/strings.xml
new file mode 100755
index 0000000..8b13c43
--- /dev/null
+++ b/notification/res/values-vi/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Sao chép \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Đã sao chép mã</string>
+</resources>
diff --git a/notification/res/values-zh-rHK/strings.xml b/notification/res/values-zh-rHK/strings.xml
new file mode 100755
index 0000000..fa9f576
--- /dev/null
+++ b/notification/res/values-zh-rHK/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">複製「%1$s」</string>
+ <string name="tc_notif_code_copied_to_clipboard">已複製代碼</string>
+</resources>
diff --git a/notification/res/values-zh-rTW/strings.xml b/notification/res/values-zh-rTW/strings.xml
new file mode 100755
index 0000000..c5df879
--- /dev/null
+++ b/notification/res/values-zh-rTW/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">複製「%1$s」</string>
+ <string name="tc_notif_code_copied_to_clipboard">已複製密碼</string>
+</resources>
diff --git a/notification/res/values-zh/strings.xml b/notification/res/values-zh/strings.xml
new file mode 100755
index 0000000..036c45a
--- /dev/null
+++ b/notification/res/values-zh/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">复制 \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">已复制代码</string>
+</resources>
diff --git a/notification/res/values-zu/strings.xml b/notification/res/values-zu/strings.xml
new file mode 100755
index 0000000..20cc778
--- /dev/null
+++ b/notification/res/values-zu/strings.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <string name="tc_notif_copy_code_desc">Kopisha \u201c%1$s\u201c</string>
+ <string name="tc_notif_code_copied_to_clipboard">Ikhodi ikopishiwe</string>
+</resources>
diff --git a/notification/res/values/strings.xml b/notification/res/values/strings.xml
new file mode 100644
index 0000000..4a23013
--- /dev/null
+++ b/notification/res/values/strings.xml
@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <!-- Action chip to copy a one time code to the user's clipboard [CHAR LIMIT=30] -->
+ <string name="tc_notif_copy_code_desc">Copy \u201c<xliff:g id="code">%1$s</xliff:g>\u201c</string>
+ <!-- Toast to display when text is copied to the device clipboard [CHAR LIMIT=NONE] -->
+ <string name="tc_notif_code_copied_to_clipboard">Code copied</string>
+</resources>
diff --git a/notification/src/com/android/textclassifier/notification/CopyCodeActivity.java b/notification/src/com/android/textclassifier/notification/CopyCodeActivity.java
new file mode 100644
index 0000000..4cfb299
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/CopyCodeActivity.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import android.app.Activity;
+import android.content.ClipData;
+import android.content.ClipboardManager;
+import android.content.Intent;
+import android.os.Bundle;
+import android.text.TextUtils;
+import android.util.Log;
+import android.widget.Toast;
+import javax.annotation.Nullable;
+
+/** Handles the copy code action. */
+public class CopyCodeActivity extends Activity {
+ private static final String TAG = "CopyCodeActivity";
+
+ @Override
+ protected void onCreate(@Nullable Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ handleIntent();
+ finish();
+ }
+
+ private void handleIntent() {
+ String code = getIntent().getStringExtra(Intent.EXTRA_TEXT);
+ if (TextUtils.isEmpty(code)) {
+ Log.w(TAG, "handleIntent: empty code");
+ return;
+ }
+ ClipboardManager clipboardManager = getSystemService(ClipboardManager.class);
+ ClipData clipData = ClipData.newPlainText(null, code);
+ clipboardManager.setPrimaryClip(clipData);
+ Toast.makeText(
+ getApplicationContext(), R.string.tc_notif_code_copied_to_clipboard, Toast.LENGTH_SHORT)
+ .show();
+ }
+}
diff --git a/notification/src/com/android/textclassifier/notification/NotificationUtils.java b/notification/src/com/android/textclassifier/notification/NotificationUtils.java
new file mode 100644
index 0000000..3f80fa1
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/NotificationUtils.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import android.app.Notification;
+import android.app.Notification.Style;
+import android.app.RemoteInput;
+import android.service.notification.StatusBarNotification;
+import java.util.Objects;
+
+final class NotificationUtils {
+
+ /**
+ * Returns whether the given status bar notification is showing some incoming messages.
+ *
+ * @see Notification#CATEGORY_MESSAGE
+ * @see Notification.MessagingStyle
+ */
+ static boolean isMessaging(StatusBarNotification statusBarNotification) {
+ return isCategory(statusBarNotification, Notification.CATEGORY_MESSAGE)
+ || isPublicVersionCategory(statusBarNotification, Notification.CATEGORY_MESSAGE)
+ || hasStyle(statusBarNotification, Notification.MessagingStyle.class);
+ }
+
+ /**
+ * Returns whether the given status bar notification has an reply button that allows user to do
+ * inline reply.
+ */
+ static boolean hasInlineReply(StatusBarNotification statusBarNotification) {
+ Notification.Action[] actions = statusBarNotification.getNotification().actions;
+ if (actions == null) {
+ return false;
+ }
+ for (Notification.Action action : actions) {
+ RemoteInput[] remoteInputs = action.getRemoteInputs();
+ if (remoteInputs == null) {
+ continue;
+ }
+ for (RemoteInput remoteInput : remoteInputs) {
+ if (remoteInput.getAllowFreeFormInput()) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean hasStyle(
+ StatusBarNotification statusBarNotification, Class<? extends Style> targetStyle) {
+ String templateClass =
+ statusBarNotification.getNotification().extras.getString(Notification.EXTRA_TEMPLATE);
+ return targetStyle.getName().equals(templateClass);
+ }
+
+ private static boolean isCategory(StatusBarNotification statusBarNotification, String category) {
+ return Objects.equals(statusBarNotification.getNotification().category, category);
+ }
+
+ private static boolean isPublicVersionCategory(
+ StatusBarNotification statusBarNotification, String category) {
+ Notification publicVersion = statusBarNotification.getNotification().publicVersion;
+ return publicVersion != null && isCategoryInternal(publicVersion, category);
+ }
+
+ private static boolean isCategoryInternal(Notification notification, String category) {
+ return Objects.equals(notification.category, category);
+ }
+
+ private NotificationUtils() {}
+}
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestions.java b/notification/src/com/android/textclassifier/notification/SmartSuggestions.java
new file mode 100644
index 0000000..4c6b5e1
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestions.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import android.app.Notification;
+import android.app.Notification.Action;
+import com.google.common.collect.ImmutableList;
+import java.util.List;
+
+/** Suggestions on a given conversation. */
+public final class SmartSuggestions {
+ private final ImmutableList<CharSequence> replies;
+ private final ImmutableList<Notification.Action> actions;
+
+ public SmartSuggestions(List<CharSequence> replies, List<Notification.Action> actions) {
+ this.replies = ImmutableList.copyOf(replies);
+ this.actions = ImmutableList.copyOf(actions);
+ }
+
+ /** Gets a list of suggested replies. */
+ public ImmutableList<CharSequence> getReplies() {
+ return replies;
+ }
+
+ /** Gets a list of suggested actions. */
+ public ImmutableList<Action> getActions() {
+ return actions;
+ }
+}
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsConfig.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsConfig.java
new file mode 100644
index 0000000..841759d
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsConfig.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+/** Configurations for the smart actions feature. */
+public interface SmartSuggestionsConfig {
+ /** To generate contextual replies for notifications or not. */
+ boolean shouldGenerateReplies();
+
+ /** To generate contextual actions for notifications or not. */
+ boolean shouldGenerateActions();
+
+ /** The maximum number of suggestions to generate for a conversation. */
+ int getMaxSuggestions();
+
+ /**
+ * The maximum number of messages to should be extracted from a conversation when constructing
+ * suggestions for that conversation.
+ */
+ int getMaxMessagesToExtract();
+}
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
new file mode 100644
index 0000000..a6aa9ae
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -0,0 +1,472 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import android.app.Notification;
+import android.app.Notification.MessagingStyle.Message;
+import android.app.PendingIntent;
+import android.app.Person;
+import android.app.RemoteAction;
+import android.app.RemoteInput;
+import android.content.Context;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.os.Bundle;
+import android.os.Process;
+import android.service.notification.StatusBarNotification;
+import android.text.TextUtils;
+import android.util.ArrayMap;
+import android.util.Log;
+import android.util.LruCache;
+import android.util.Pair;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.time.Instant;
+import java.time.ZoneOffset;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nullable;
+
+/**
+ * Generates suggestions from incoming notifications and handles related logging.
+ *
+ * <p>This class is not thread-safe. Either call methods in this class in a single worker thread or
+ * guard all the calls with the same lock.
+ */
+public class SmartSuggestionsHelper {
+ private static final String TAG = "SmartSuggestionsHelper";
+
+ static final String ENTITIES_EXTRAS = "entities-extras";
+ static final String KEY_ACTION_TYPE = "action_type";
+ static final String KEY_ACTION_SCORE = "action_score";
+ static final String KEY_TEXT = "text";
+ static final String NOTIFICATION_KEY = "notificationKey";
+ // Copied from ConversationAction.java
+ static final String TYPE_COPY = "copy";
+
+ // If a notification has any of these flags set, it's inelgibile for actions being added.
+ private static final int FLAG_MASK_INELGIBILE_FOR_ACTIONS =
+ Notification.FLAG_ONGOING_EVENT
+ | Notification.FLAG_FOREGROUND_SERVICE
+ | Notification.FLAG_GROUP_SUMMARY
+ | Notification.FLAG_NO_CLEAR;
+ private static final int MAX_RESULT_ID_TO_CACHE = 20;
+ private static final ImmutableList<String> HINTS =
+ ImmutableList.of(ConversationActions.Request.HINT_FOR_NOTIFICATION);
+ private static final ConversationActions EMPTY_CONVERSATION_ACTIONS =
+ new ConversationActions(ImmutableList.of(), null);
+
+ private final Context context;
+ private final TextClassificationManager textClassificationManager;
+ private final SmartSuggestionsConfig config;
+ private final LruCache<String, SmartSuggestionsLogSession> sessionCache =
+ new LruCache<String, SmartSuggestionsLogSession>(MAX_RESULT_ID_TO_CACHE) {
+ @Override
+ protected void entryRemoved(
+ boolean evicted,
+ String key,
+ SmartSuggestionsLogSession oldSession,
+ SmartSuggestionsLogSession newSession) {
+ oldSession.destroy();
+ }
+ };
+ private final TextClassificationContext textClassificationContext;
+
+ public SmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) {
+ this.context = context;
+ textClassificationManager = this.context.getSystemService(TextClassificationManager.class);
+ this.config = config;
+ this.textClassificationContext =
+ new TextClassificationContext.Builder(
+ context.getPackageName(), TextClassifier.WIDGET_TYPE_NOTIFICATION)
+ .build();
+ }
+
+ /**
+ * Notifies a notification is enqueued and returns some suggestions based on the conversation in
+ * the given status bar notification.
+ */
+ public SmartSuggestions onNotificationEnqueued(StatusBarNotification statusBarNotification) {
+ // Whenever onNotificationEnqueued() is called again on the same notification key, its
+ // previous session is ended.
+ sessionCache.remove(statusBarNotification.getKey());
+
+ boolean eligibleForReplyAdjustment =
+ config.shouldGenerateReplies() && isEligibleForReplyAdjustment(statusBarNotification);
+ boolean eligibleForActionAdjustment =
+ config.shouldGenerateActions() && isEligibleForActionAdjustment(statusBarNotification);
+
+ TextClassifier textClassifier =
+ textClassificationManager.createTextClassificationSession(textClassificationContext);
+
+ ConversationActions conversationActionsResult =
+ suggestConversationActions(
+ textClassifier,
+ statusBarNotification,
+ eligibleForReplyAdjustment,
+ eligibleForActionAdjustment);
+
+ String resultId = conversationActionsResult.getId();
+ List<ConversationAction> conversationActions =
+ conversationActionsResult.getConversationActions();
+
+ ArrayList<CharSequence> replies = new ArrayList<>();
+ Map<CharSequence, Float> repliesScore = new ArrayMap<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ CharSequence textReply = conversationAction.getTextReply();
+ if (TextUtils.isEmpty(textReply)) {
+ continue;
+ }
+ replies.add(textReply);
+ repliesScore.put(textReply, conversationAction.getConfidenceScore());
+ }
+
+ ArrayList<Notification.Action> actions = new ArrayList<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ if (!TextUtils.isEmpty(conversationAction.getTextReply())) {
+ continue;
+ }
+ Notification.Action notificationAction;
+ RemoteAction remoteAction = conversationAction.getAction();
+ if (remoteAction == null) {
+ notificationAction = createNotificationActionWithoutRemoteAction(conversationAction);
+ } else {
+ notificationAction =
+ createNotificationActionFromRemoteAction(
+ remoteAction,
+ conversationAction.getType(),
+ conversationAction.getConfidenceScore());
+ }
+ if (notificationAction != null) {
+ actions.add(notificationAction);
+ }
+ }
+ if (TextUtils.isEmpty(resultId)) {
+ textClassifier.destroy();
+ } else {
+ SmartSuggestionsLogSession session =
+ new SmartSuggestionsLogSession(
+ resultId, repliesScore, textClassifier, textClassificationContext);
+ session.onSuggestionsGenerated(conversationActions);
+
+ // Store the session if we expect more logging from it, destroy it otherwise.
+ if (!conversationActions.isEmpty()
+ && suggestionsMightBeUsedInNotification(
+ statusBarNotification, !actions.isEmpty(), !replies.isEmpty())) {
+ sessionCache.put(statusBarNotification.getKey(), session);
+ } else {
+ session.destroy();
+ }
+ }
+
+ return new SmartSuggestions(replies, actions);
+ }
+
+ /**
+ * Creates notification action from ConversationAction that does not come up a RemoteAction. It
+ * could happen because we don't have common intents for some actions, like copying text.
+ */
+ @Nullable
+ private Notification.Action createNotificationActionWithoutRemoteAction(
+ ConversationAction conversationAction) {
+ if (TYPE_COPY.equals(conversationAction.getType())) {
+ return createCopyCodeAction(conversationAction);
+ }
+ return null;
+ }
+
+ @Nullable
+ private Notification.Action createCopyCodeAction(ConversationAction conversationAction) {
+ Bundle extras = conversationAction.getExtras();
+ Bundle entitiesExtas = extras.getParcelable(ENTITIES_EXTRAS);
+ if (entitiesExtas == null) {
+ return null;
+ }
+ String code = entitiesExtas.getString(KEY_TEXT);
+ if (TextUtils.isEmpty(code)) {
+ return null;
+ }
+ String contentDescription = context.getString(R.string.tc_notif_copy_code_desc, code);
+ Intent intent = new Intent(context, CopyCodeActivity.class);
+ intent.putExtra(Intent.EXTRA_TEXT, code);
+
+ RemoteAction remoteAction =
+ new RemoteAction(
+ Icon.createWithResource(context, R.drawable.tc_notif_ic_menu_copy_material),
+ code,
+ contentDescription,
+ PendingIntent.getActivity(
+ context, code.hashCode(), intent, PendingIntent.FLAG_UPDATE_CURRENT));
+
+ return createNotificationActionFromRemoteAction(
+ remoteAction, TYPE_COPY, conversationAction.getConfidenceScore());
+ }
+
+ /**
+ * Returns whether the suggestion might be used in the notifications in SysUI.
+ *
+ * <p>Currently, NAS has no idea if suggestions will actually be used in the notification, and
+ * thus this function tries to make a heuristic. This function tries to optimize the precision,
+ * that means when it is unsure, it will return false. The objective is to avoid false positive,
+ * which could pollute the log and CTR as we are logging click rate of suggestions that could be
+ * never visible to users. On the other hand, it is fine to have false negative because it would
+ * be just like sampling.
+ */
+ private static boolean suggestionsMightBeUsedInNotification(
+ StatusBarNotification statusBarNotification, boolean hasSmartAction, boolean hasSmartReply) {
+ Notification notification = statusBarNotification.getNotification();
+ boolean hasAppGeneratedContextualActions = !notification.getContextualActions().isEmpty();
+
+ Pair<RemoteInput, Notification.Action> freeformRemoteInputAndAction =
+ notification.findRemoteInputActionPair(/* requiresFreeform */ true);
+ boolean hasAppGeneratedReplies = false;
+ boolean allowGeneratedReplies = false;
+ if (freeformRemoteInputAndAction != null) {
+ RemoteInput freeformRemoteInput = freeformRemoteInputAndAction.first;
+ Notification.Action actionWithFreeformRemoteInput = freeformRemoteInputAndAction.second;
+ CharSequence[] choices = freeformRemoteInput.getChoices();
+ hasAppGeneratedReplies = (choices != null && choices.length > 0);
+ allowGeneratedReplies = actionWithFreeformRemoteInput.getAllowGeneratedReplies();
+ }
+
+ if (hasAppGeneratedReplies || hasAppGeneratedContextualActions) {
+ return false;
+ }
+ return (hasSmartAction && notification.getAllowSystemGeneratedContextualActions())
+ || (hasSmartReply && allowGeneratedReplies);
+ }
+
+ /** Adds action adjustments based on the notification contents. */
+ private ConversationActions suggestConversationActions(
+ TextClassifier textClassifier,
+ StatusBarNotification statusBarNotification,
+ boolean includeReplies,
+ boolean includeActions) {
+ if (!includeReplies && !includeActions) {
+ return EMPTY_CONVERSATION_ACTIONS;
+ }
+ ImmutableList<ConversationActions.Message> messages =
+ extractMessages(statusBarNotification.getNotification());
+ if (messages.isEmpty()) {
+ return EMPTY_CONVERSATION_ACTIONS;
+ }
+ // Do not generate smart actions if the last message is from the local user.
+ ConversationActions.Message lastMessage = Iterables.getLast(messages);
+ if (arePersonsEqual(ConversationActions.Message.PERSON_USER_SELF, lastMessage.getAuthor())) {
+ return EMPTY_CONVERSATION_ACTIONS;
+ }
+
+ TextClassifier.EntityConfig.Builder typeConfigBuilder =
+ new TextClassifier.EntityConfig.Builder();
+ if (!includeReplies) {
+ typeConfigBuilder.setExcludedTypes(ImmutableList.of(ConversationAction.TYPE_TEXT_REPLY));
+ } else if (!includeActions) {
+ typeConfigBuilder
+ .setIncludedTypes(ImmutableList.of(ConversationAction.TYPE_TEXT_REPLY))
+ .includeTypesFromTextClassifier(false);
+ }
+
+ // Put the notification key into the request extras
+ Bundle extra = new Bundle();
+ extra.putString(NOTIFICATION_KEY, statusBarNotification.getKey());
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(messages)
+ .setMaxSuggestions(config.getMaxSuggestions())
+ .setHints(HINTS)
+ .setExtras(extra)
+ .setTypeConfig(typeConfigBuilder.build())
+ .build();
+
+ return textClassifier.suggestConversationActions(request);
+ }
+
+ /**
+ * Notifies that a notification has been expanded or collapsed.
+ *
+ * @param statusBarNotification status bar notification
+ * @param isExpanded true for when a notification is expanded, false for when it is collapsed
+ */
+ public void onNotificationExpansionChanged(
+ StatusBarNotification statusBarNotification, boolean isExpanded) {
+ SmartSuggestionsLogSession session = sessionCache.get(statusBarNotification.getKey());
+ if (session == null) {
+ return;
+ }
+ session.onNotificationExpansionChanged(isExpanded);
+ }
+
+ /** Notifies that a direct reply has been sent from a notification. */
+ public void onNotificationDirectReplied(String key) {
+ SmartSuggestionsLogSession session = sessionCache.get(key);
+ if (session == null) {
+ return;
+ }
+ session.onDirectReplied();
+ }
+
+ /**
+ * Notifies that a suggested reply has been sent.
+ *
+ * @param key the notification key
+ * @param reply the reply that is just sent
+ * @param source the source that provided the reply, e.g. SOURCE_FROM_ASSISTANT
+ */
+ public void onSuggestedReplySent(String key, CharSequence reply, int source) {
+ SmartSuggestionsLogSession session = sessionCache.get(key);
+ if (session == null) {
+ return;
+ }
+ session.onSuggestedReplySent(reply, source);
+ }
+
+ /**
+ * Notifies an action is clicked.
+ *
+ * @param key the notification key
+ * @param action the action that is just clicked
+ * @param source the source that provided the reply, e.g. SOURCE_FROM_ASSISTANT
+ */
+ public void onActionClicked(String key, Notification.Action action, int source) {
+ SmartSuggestionsLogSession session = sessionCache.get(key);
+ if (session == null) {
+ return;
+ }
+ session.onActionClicked(action, source);
+ }
+
+ /** Clears the internal cache. */
+ public void clearCache() {
+ sessionCache.evictAll();
+ }
+
+ private Notification.Action createNotificationActionFromRemoteAction(
+ RemoteAction remoteAction, String actionType, float score) {
+ Icon icon =
+ remoteAction.shouldShowIcon()
+ ? remoteAction.getIcon()
+ : Icon.createWithResource(context, R.drawable.tc_notif_ic_action_open);
+ Bundle extras = new Bundle();
+ extras.putString(KEY_ACTION_TYPE, actionType);
+ extras.putFloat(KEY_ACTION_SCORE, score);
+ return new Notification.Action.Builder(
+ icon, remoteAction.getTitle(), remoteAction.getActionIntent())
+ .setContextual(true)
+ .addExtras(extras)
+ .build();
+ }
+
+ /**
+ * Returns whether a notification is eligible for action adjustments.
+ *
+ * <p>We exclude system notifications, those that get refreshed frequently, or ones that relate to
+ * fundamental phone functionality where any error would result in a very negative user
+ * experience.
+ */
+ private static boolean isEligibleForActionAdjustment(
+ StatusBarNotification statusBarNotification) {
+ String pkg = statusBarNotification.getPackageName();
+ if (!Process.myUserHandle().equals(statusBarNotification.getUser())) {
+ return false;
+ }
+ Notification notification = statusBarNotification.getNotification();
+ if ((notification.flags & FLAG_MASK_INELGIBILE_FOR_ACTIONS) != 0) {
+ return false;
+ }
+ if (TextUtils.isEmpty(pkg) || pkg.equals("android")) {
+ return false;
+ }
+ // For now, we are only interested in messages.
+ return NotificationUtils.isMessaging(statusBarNotification);
+ }
+
+ private static boolean isEligibleForReplyAdjustment(StatusBarNotification statusBarNotification) {
+ if (!Process.myUserHandle().equals(statusBarNotification.getUser())) {
+ return false;
+ }
+ String pkg = statusBarNotification.getPackageName();
+ if (TextUtils.isEmpty(pkg) || pkg.equals("android")) {
+ return false;
+ }
+ // For now, we are only interested in messages.
+ if (!NotificationUtils.isMessaging(statusBarNotification)) {
+ return false;
+ }
+ // Does not make sense to provide suggested replies if it is not something that can be
+ // replied.
+ if (!NotificationUtils.hasInlineReply(statusBarNotification)) {
+ return false;
+ }
+ return true;
+ }
+
+ /** Returns the text most salient for action extraction in a notification. */
+ private ImmutableList<ConversationActions.Message> extractMessages(Notification notification) {
+ List<Message> messages =
+ Message.getMessagesFromBundleArray(
+ notification.extras.getParcelableArray(Notification.EXTRA_MESSAGES));
+ if (messages == null || messages.isEmpty()) {
+ return ImmutableList.of(
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText(notification.extras.getCharSequence(Notification.EXTRA_TEXT))
+ .build());
+ }
+ Person localUser = notification.extras.getParcelable(Notification.EXTRA_MESSAGING_PERSON);
+ if (localUser == null) {
+ Log.w(TAG, "EXTRA_MESSAGING_PERSON is missing, failed to extract messages.");
+ return ImmutableList.of();
+ }
+ Deque<ConversationActions.Message> extractMessages = new ArrayDeque<>();
+ for (int i = messages.size() - 1; i >= 0; i--) {
+ Message message = messages.get(i);
+ if (message == null) {
+ continue;
+ }
+ Person senderPerson = message.getSenderPerson();
+ // As per the javadoc of Notification.addMessage(), a null sender refers to the user
+ // themselves.
+ Person author =
+ senderPerson == null || arePersonsEqual(localUser, senderPerson)
+ ? ConversationActions.Message.PERSON_USER_SELF
+ : senderPerson;
+ extractMessages.push(
+ new ConversationActions.Message.Builder(author)
+ .setText(message.getText())
+ .setReferenceTime(
+ Instant.ofEpochMilli(message.getTimestamp()).atZone(ZoneOffset.systemDefault()))
+ .build());
+ if (extractMessages.size() >= config.getMaxMessagesToExtract()) {
+ break;
+ }
+ }
+ return ImmutableList.copyOf(new ArrayList<>(extractMessages));
+ }
+
+ private static boolean arePersonsEqual(Person left, Person right) {
+ return Objects.equals(left.getKey(), right.getKey())
+ && TextUtils.equals(left.getName(), right.getName())
+ && Objects.equals(left.getUri(), right.getUri());
+ }
+}
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsLogSession.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsLogSession.java
new file mode 100644
index 0000000..4ac82c7
--- /dev/null
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsLogSession.java
@@ -0,0 +1,145 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import android.app.Notification;
+import android.service.notification.NotificationAssistantService;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import com.google.common.collect.ImmutableMap;
+import java.util.List;
+import java.util.Map;
+
+/** Logs events related to a {@link TextClassifier} result. */
+final class SmartSuggestionsLogSession {
+
+ private final TextClassificationContext textClassificationContext;
+ private final String resultId;
+ private final ImmutableMap<CharSequence, Float> repliesScores;
+ private final TextClassifier textClassifier;
+ private boolean isSeenEventLogged;
+
+ /**
+ * Creates a session for logging purpose.
+ *
+ * @param resultId the result id from a {@link TextClassifier} result object.
+ * @param repliesScores a map contains suggested replies and their scores.
+ * @param textClassifier a text classifier that the log will be sent to. This instance will take
+ * the ownership of this text classifier, do not further interact with it directly.
+ * @param textClassificationContext a context where the {@link TextClassifier} API is performed.
+ */
+ SmartSuggestionsLogSession(
+ String resultId,
+ Map<CharSequence, Float> repliesScores,
+ TextClassifier textClassifier,
+ TextClassificationContext textClassificationContext) {
+ this.resultId = resultId;
+ this.repliesScores = ImmutableMap.copyOf(repliesScores);
+ this.textClassifier = textClassifier;
+ this.textClassificationContext = textClassificationContext;
+ }
+
+ /**
+ * Notifies that a notification has been expanded or collapsed.
+ *
+ * @param isExpanded true for when a notification is expanded, false for when it is collapsed.
+ */
+ void onNotificationExpansionChanged(boolean isExpanded) {
+ if (!isExpanded) {
+ return;
+ }
+ // Only report if this is the first time the user sees these suggestions.
+ if (isSeenEventLogged) {
+ return;
+ }
+ isSeenEventLogged = true;
+ TextClassifierEvent textClassifierEvent =
+ createTextClassifierEventBuilder(TextClassifierEvent.TYPE_ACTIONS_SHOWN, resultId).build();
+ // TODO(tonymak): If possible, report which replies / actions are actually seen by user.
+ textClassifier.onTextClassifierEvent(textClassifierEvent);
+ }
+
+ /**
+ * Notifies that a suggested text reply has been sent.
+ *
+ * @param action the action that was just clicked
+ * @param source the source that provided the reply, e.g. SOURCE_FROM_ASSISTANT
+ */
+ void onActionClicked(Notification.Action action, int source) {
+ if (source != NotificationAssistantService.SOURCE_FROM_ASSISTANT) {
+ return;
+ }
+ String actionType = action.getExtras().getString(SmartSuggestionsHelper.KEY_ACTION_TYPE);
+ if (actionType == null) {
+ return;
+ }
+ TextClassifierEvent textClassifierEvent =
+ createTextClassifierEventBuilder(TextClassifierEvent.TYPE_SMART_ACTION, resultId)
+ .setEntityTypes(actionType)
+ .build();
+ textClassifier.onTextClassifierEvent(textClassifierEvent);
+ }
+
+ /**
+ * Notifies that a suggested reply has been sent.
+ *
+ * @param reply the reply that is just sent
+ * @param source the source that provided the reply, e.g. SOURCE_FROM_ASSISTANT
+ */
+ void onSuggestedReplySent(CharSequence reply, int source) {
+ if (source != NotificationAssistantService.SOURCE_FROM_ASSISTANT) {
+ return;
+ }
+ TextClassifierEvent textClassifierEvent =
+ createTextClassifierEventBuilder(TextClassifierEvent.TYPE_SMART_ACTION, resultId)
+ .setEntityTypes(ConversationAction.TYPE_TEXT_REPLY)
+ .setScores(repliesScores.getOrDefault(reply, 0f))
+ .build();
+ textClassifier.onTextClassifierEvent(textClassifierEvent);
+ }
+
+ /** Notifies that a direct reply has been sent from a notification. */
+ void onDirectReplied() {
+ TextClassifierEvent textClassifierEvent =
+ createTextClassifierEventBuilder(TextClassifierEvent.TYPE_MANUAL_REPLY, resultId).build();
+ textClassifier.onTextClassifierEvent(textClassifierEvent);
+ }
+
+ /** Notifies that some suggestions have been generated. */
+ void onSuggestionsGenerated(List<ConversationAction> generatedActions) {
+ TextClassifierEvent textClassifierEvent =
+ createTextClassifierEventBuilder(TextClassifierEvent.TYPE_ACTIONS_GENERATED, resultId)
+ .setEntityTypes(
+ generatedActions.stream().map(ConversationAction::getType).toArray(String[]::new))
+ .build();
+ textClassifier.onTextClassifierEvent(textClassifierEvent);
+ }
+
+ /** Destroys this session. Do not call any method in this class after this is called. */
+ void destroy() {
+ textClassifier.destroy();
+ }
+
+ private TextClassifierEvent.ConversationActionsEvent.Builder createTextClassifierEventBuilder(
+ int eventType, String resultId) {
+ return new TextClassifierEvent.ConversationActionsEvent.Builder(eventType)
+ .setEventContext(this.textClassificationContext)
+ .setResultId(resultId);
+ }
+}
diff --git a/notification/tests/Android.bp b/notification/tests/Android.bp
new file mode 100644
index 0000000..a7703cd
--- /dev/null
+++ b/notification/tests/Android.bp
@@ -0,0 +1,40 @@
+//
+// Copyright (C) 2019 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+android_test {
+ name: "TextClassifierNotificationTests",
+
+ manifest: "AndroidManifest.xml",
+
+ srcs: [
+ "src/**/*.java",
+ ],
+
+ static_libs: [
+ "androidx.test.ext.junit",
+ "androidx.test.rules",
+ "androidx.test.ext.truth",
+ "mockito-target-minus-junit4",
+ "compatibility-device-util-axt",
+ "TextClassifierNotificationLib"
+ ],
+
+ test_suites: [
+ "device-tests", "mts"
+ ],
+
+ instrumentation_for: "TextClassifierNotificationLib",
+}
\ No newline at end of file
diff --git a/notification/tests/AndroidManifest.xml b/notification/tests/AndroidManifest.xml
new file mode 100644
index 0000000..81308e3
--- /dev/null
+++ b/notification/tests/AndroidManifest.xml
@@ -0,0 +1,16 @@
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.notification">
+
+ <uses-sdk
+ android:minSdkVersion="29"
+ android:targetSdkVersion="29" />
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.textclassifier.notification"/>
+
+</manifest>
diff --git a/notification/tests/AndroidTest.xml b/notification/tests/AndroidTest.xml
new file mode 100644
index 0000000..1890e75
--- /dev/null
+++ b/notification/tests/AndroidTest.xml
@@ -0,0 +1,33 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<!-- This test config file is auto-generated. -->
+<configuration description="Runs TextClassifierNotificationTests.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TextClassifierNotificationTests.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.notification" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+
+ <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+ <option name="mainline-module-package-name" value="com.google.android.extservices" />
+ </object>
+</configuration>
diff --git a/notification/tests/src/com/android/textclassifier/notification/CopyCodeActivityTest.java b/notification/tests/src/com/android/textclassifier/notification/CopyCodeActivityTest.java
new file mode 100644
index 0000000..966fbe0
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/CopyCodeActivityTest.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.ClipData;
+import android.content.ClipboardManager;
+import android.content.Intent;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import androidx.test.rule.ActivityTestRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class CopyCodeActivityTest {
+
+ private static final String CODE_TO_COPY = "code";
+ private static final Intent EMPTY_INTENT =
+ new Intent(Intent.ACTION_VIEW).putExtra(Intent.EXTRA_TEXT, "");
+ private static final Intent CODE_INTENT =
+ new Intent(Intent.ACTION_VIEW).putExtra(Intent.EXTRA_TEXT, CODE_TO_COPY);
+
+ @Rule
+ public ActivityTestRule<CopyCodeActivity> activityRule =
+ new ActivityTestRule<>(
+ CopyCodeActivity.class, /* initialTouchMode= */ false, /* launchActivity= */ false);
+
+ @Test
+ public void onCreate_emptyCode() throws Exception {
+ ClipboardManager clipboardManager =
+ ApplicationProvider.getApplicationContext().getSystemService(ClipboardManager.class);
+ // Use shell's permissions to ensure we can access the clipboard
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().adoptShellPermissionIdentity();
+ clipboardManager.clearPrimaryClip();
+
+ activityRule.launchActivity(EMPTY_INTENT);
+
+ try {
+ assertThat(clipboardManager.hasPrimaryClip()).isFalse();
+ } finally {
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+ }
+ }
+
+ @Test
+ public void onCreate_codeCopied() throws Exception {
+ ClipboardManager clipboardManager =
+ ApplicationProvider.getApplicationContext().getSystemService(ClipboardManager.class);
+ // Use shell's permissions to ensure we can access the clipboard
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().adoptShellPermissionIdentity();
+ clipboardManager.clearPrimaryClip();
+
+ activityRule.launchActivity(CODE_INTENT);
+
+ ClipData clipFromClipboard;
+ try {
+ assertThat(clipboardManager.hasPrimaryClip()).isTrue();
+ clipFromClipboard = clipboardManager.getPrimaryClip();
+ } finally {
+ clipboardManager.clearPrimaryClip();
+ InstrumentationRegistry.getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+ }
+
+ assertThat(clipFromClipboard).isNotNull();
+ assertThat(clipFromClipboard.getItemCount()).isEqualTo(1);
+ assertThat(clipFromClipboard.getItemAt(0).getText().toString()).isEqualTo(CODE_TO_COPY);
+ }
+}
diff --git a/notification/tests/src/com/android/textclassifier/notification/NotificationUtilsTest.java b/notification/tests/src/com/android/textclassifier/notification/NotificationUtilsTest.java
new file mode 100644
index 0000000..f279830
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/NotificationUtilsTest.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Notification;
+import android.app.Notification.MessagingStyle;
+import android.app.PendingIntent;
+import android.app.Person;
+import android.app.RemoteInput;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.os.Process;
+import android.service.notification.StatusBarNotification;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@LargeTest
+@RunWith(AndroidJUnit4.class)
+public class NotificationUtilsTest {
+
+ @Test
+ public void isMessaging_categoryMessage() {
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setCategory(Notification.CATEGORY_MESSAGE)
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.isMessaging(statusBarNotification)).isTrue();
+ }
+
+ @Test
+ public void isMessaging_messagingStyle() {
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setStyle(new MessagingStyle(new Person.Builder().setName("name").build()))
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.isMessaging(statusBarNotification)).isTrue();
+ }
+
+ @Test
+ public void isMessaging_publicVersionCategoryMessage() {
+ Notification publicVersion =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setCategory(Notification.CATEGORY_MESSAGE)
+ .build();
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setPublicVersion(publicVersion)
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.isMessaging(statusBarNotification)).isTrue();
+ }
+
+ @Test
+ public void isMessaging_negative() {
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setContentText("Hello")
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.isMessaging(statusBarNotification)).isFalse();
+ }
+
+ @Test
+ public void hasInlineReply_positive() {
+ Notification.Action archiveAction =
+ new Notification.Action.Builder(
+ Icon.createWithData(new byte[0], 0, 0),
+ "archive",
+ PendingIntent.getActivity(
+ ApplicationProvider.getApplicationContext(), 0, new Intent(), 0))
+ .build();
+ Notification.Action replyAction =
+ new Notification.Action.Builder(
+ Icon.createWithData(new byte[0], 0, 0),
+ "reply",
+ PendingIntent.getActivity(
+ ApplicationProvider.getApplicationContext(), 0, new Intent(), 0))
+ .addRemoteInput(
+ new RemoteInput.Builder("resultKey").setAllowFreeFormInput(true).build())
+ .build();
+
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setActions(archiveAction, replyAction)
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.hasInlineReply(statusBarNotification)).isTrue();
+ }
+
+ @Test
+ public void hasInlineReply_negative() {
+ Notification.Action archiveAction =
+ new Notification.Action.Builder(
+ Icon.createWithData(new byte[0], 0, 0),
+ "archive",
+ PendingIntent.getActivity(
+ ApplicationProvider.getApplicationContext(), 0, new Intent(), 0))
+ .build();
+
+ Notification notification =
+ new Notification.Builder(ApplicationProvider.getApplicationContext(), "channel")
+ .setActions(archiveAction)
+ .build();
+ StatusBarNotification statusBarNotification = createStatusBarNotification(notification);
+
+ assertThat(NotificationUtils.hasInlineReply(statusBarNotification)).isFalse();
+ }
+
+ private static StatusBarNotification createStatusBarNotification(Notification notification) {
+ return new StatusBarNotification(
+ "pkg.name",
+ "pkg.name",
+ /* id= */ 2,
+ "tag",
+ /* uid= */ 1,
+ /* initialPid= */ 1,
+ /* score= */ 1,
+ notification,
+ Process.myUserHandle(),
+ System.currentTimeMillis());
+ }
+}
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
new file mode 100644
index 0000000..1cbfbf2
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -0,0 +1,506 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_SELF;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Notification;
+import android.app.PendingIntent;
+import android.app.Person;
+import android.app.RemoteAction;
+import android.app.RemoteInput;
+import android.content.Context;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.net.Uri;
+import android.os.Bundle;
+import android.os.Process;
+import android.service.notification.NotificationAssistantService;
+import android.service.notification.StatusBarNotification;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.ConversationActions.Message;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@LargeTest
+@RunWith(AndroidJUnit4.class)
+public class SmartSuggestionsHelperTest {
+ private static final String PACKAGE_NAME = "random.app";
+ private static final String MESSAGE = "Where are you?";
+ private static final CharSequence SMART_REPLY = "Home";
+ private static final CharSequence ACTION_TITLE = "Open";
+
+ private static final String RESULT_ID = "id";
+ private static final float REPLY_SCORE = 0.7f;
+ private static final float ACTION_SCORE = 1.0f;
+ private final Context context = ApplicationProvider.getApplicationContext();
+ private final FakeTextClassifier fakeTextClassifier = new FakeTextClassifier();
+ private final TestConfig config = new TestConfig();
+ private SmartSuggestionsHelper smartActions;
+ private Notification.Builder notificationBuilder;
+
+ @Before
+ public void setup() {
+ TextClassificationManager textClassificationManager =
+ context.getSystemService(TextClassificationManager.class);
+ textClassificationManager.setTextClassifier(fakeTextClassifier);
+ smartActions = new SmartSuggestionsHelper(context, config);
+ notificationBuilder = new Notification.Builder(context, "id");
+ }
+
+ @Test
+ public void onNotificationEnqueued_notMessageCategory() {
+ Notification notification = notificationBuilder.setContentText(MESSAGE).build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getReplies()).isEmpty();
+ assertThat(smartSuggestions.getActions()).isEmpty();
+ }
+
+ @Test
+ public void onNotificationEnqueued_fromSystem() {
+ Notification notification =
+ notificationBuilder
+ .setContentText(MESSAGE)
+ .setCategory(Notification.CATEGORY_MESSAGE)
+ .setActions(createReplyAction())
+ .build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, "android");
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getReplies()).isEmpty();
+ assertThat(smartSuggestions.getActions()).isEmpty();
+ }
+
+ @Test
+ public void onNotificationEnqueued_noInlineReply() {
+ Notification notification =
+ notificationBuilder
+ .setContentText(MESSAGE)
+ .setCategory(Notification.CATEGORY_MESSAGE)
+ .build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getReplies()).isEmpty();
+ assertAdjustmentWithSmartAction(smartSuggestions);
+ }
+
+ @Test
+ public void onNotificationEnqueued_messageCategoryNotification() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertAdjustmentWithSmartReply(smartSuggestions);
+ assertAdjustmentWithSmartAction(smartSuggestions);
+ ConversationActions.Request request = fakeTextClassifier.getLastRequest();
+ List<Message> messages = request.getConversation();
+ assertThat(messages).hasSize(1);
+ assertThat(messages.get(0).getText().toString()).isEqualTo(MESSAGE);
+ }
+
+ @Test
+ public void onNotificationEnqueued_messageStyleNotification() {
+ Person me = new Person.Builder().setName("Me").build();
+ Person userA = new Person.Builder().setName("A").build();
+ Person userB = new Person.Builder().setName("B").build();
+ Notification.MessagingStyle style =
+ new Notification.MessagingStyle(me)
+ .addMessage("firstMessage", 1000, (Person) null)
+ .addMessage("secondMessage", 2000, me)
+ .addMessage("thirdMessage", 3000, userA)
+ .addMessage("fourthMessage", 4000, userB);
+ Notification notification =
+ notificationBuilder
+ .setContentText("You have three new messages")
+ .setStyle(style)
+ .setActions(createReplyAction())
+ .build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertAdjustmentWithSmartReply(smartSuggestions);
+ assertAdjustmentWithSmartAction(smartSuggestions);
+ ConversationActions.Request request = fakeTextClassifier.getLastRequest();
+ List<ConversationActions.Message> messages = request.getConversation();
+ assertThat(messages).hasSize(4);
+
+ assertMessage(messages.get(0), "firstMessage", PERSON_USER_SELF, 1000);
+ assertMessage(messages.get(1), "secondMessage", PERSON_USER_SELF, 2000);
+ assertMessage(messages.get(2), "thirdMessage", userA, 3000);
+ assertMessage(messages.get(3), "fourthMessage", userB, 4000);
+ }
+
+ @Test
+ public void onNotificationEnqueued_lastMessageFromLocalUser() {
+ Person me = new Person.Builder().setName("Me").build();
+ Person userA = new Person.Builder().setName("A").build();
+ Notification.MessagingStyle style =
+ new Notification.MessagingStyle(me)
+ .addMessage("firstMessage", 1000, userA)
+ .addMessage("secondMessage", 2000, me);
+ Notification notification =
+ notificationBuilder
+ .setContentText("You have two new messages")
+ .setStyle(style)
+ .setActions(createReplyAction())
+ .build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getReplies()).isEmpty();
+ assertThat(smartSuggestions.getActions()).isEmpty();
+ }
+
+ @Test
+ public void onNotificationEnqueued_messageStyleNotification_missingPerson() {
+ Person me = new Person.Builder().setName("Me").build();
+ Notification.MessagingStyle style =
+ new Notification.MessagingStyle(me).addMessage("message", 1000, (Person) null);
+ Notification notification =
+ notificationBuilder
+ .setContentText("You have one new message")
+ .setStyle(style)
+ .setActions(createReplyAction())
+ .build();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getReplies()).isEmpty();
+ assertThat(smartSuggestions.getActions()).isEmpty();
+ }
+
+ @Test
+ public void onSuggestedReplySent() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onSuggestedReplySent(
+ statusBarNotification.getKey(),
+ SMART_REPLY,
+ NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(2);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ assertThat(firstEvent.getEntityTypes())
+ .asList()
+ .containsExactly(ConversationAction.TYPE_TEXT_REPLY, ConversationAction.TYPE_OPEN_URL);
+ TextClassifierEvent secondEvent = textClassifierEvents.get(1);
+ assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
+ assertThat(secondEvent.getEntityTypes()[0]).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+ }
+
+ @Test
+ public void onSuggestedReplySent_noMatchingSession() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onSuggestedReplySent(
+ "something_else", MESSAGE, NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+
+ // No matching session, so TYPE_SMART_ACTION should not be logged.
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(1);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ }
+
+ @Test
+ public void onNotificationDirectReply() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onNotificationDirectReplied(statusBarNotification.getKey());
+
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(2);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ TextClassifierEvent secondEvent = textClassifierEvents.get(1);
+ assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_MANUAL_REPLY);
+ }
+
+ @Test
+ public void oNotificationExpansionChanged_expanded() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onNotificationExpansionChanged(statusBarNotification, /* isExpanded= */ true);
+
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(2);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ TextClassifierEvent secondEvent = textClassifierEvents.get(1);
+ assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ }
+
+ @Test
+ public void oNotificationExpansionChanged_notExpanded() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onNotificationExpansionChanged(statusBarNotification, /* isExpanded= */ false);
+
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(1);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ }
+
+ @Test
+ public void oNotificationExpansionChanged_expanded_logShownEventOnce() {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ smartActions.onNotificationExpansionChanged(statusBarNotification, /* isExpanded= */ true);
+ smartActions.onNotificationExpansionChanged(statusBarNotification, /* isExpanded= */ true);
+
+ List<TextClassifierEvent> textClassifierEvents = fakeTextClassifier.getTextClassifierEvents();
+ assertThat(textClassifierEvents).hasSize(2);
+ TextClassifierEvent firstEvent = textClassifierEvents.get(0);
+ assertThat(firstEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ TextClassifierEvent secondEvent = textClassifierEvents.get(1);
+ assertThat(secondEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ }
+
+ @Test
+ public void copyCodeAction() {
+ Bundle extras = new Bundle();
+ Bundle entitiesExtras = new Bundle();
+ entitiesExtras.putString(SmartSuggestionsHelper.KEY_TEXT, "12345");
+ extras.putParcelable(SmartSuggestionsHelper.ENTITIES_EXTRAS, entitiesExtras);
+ ConversationAction conversationAction =
+ new ConversationAction.Builder(SmartSuggestionsHelper.TYPE_COPY).setExtras(extras).build();
+ fakeTextClassifier.setSuggestConversationActionsResponse(
+ new ConversationActions(ImmutableList.of(conversationAction), RESULT_ID));
+
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+ SmartSuggestions smartSuggestions = smartActions.onNotificationEnqueued(statusBarNotification);
+
+ assertThat(smartSuggestions.getActions()).hasSize(1);
+ assertThat(smartSuggestions.getActions().get(0).title.toString()).isEqualTo("12345");
+ }
+
+ @Ignore // Disabled because it is way too slow to run on an emulator.
+ @Test
+ public void noBinderLeakage() {
+ // Use the real text classifier from system.
+ TextClassificationManager textClassificationManager =
+ context.getSystemService(TextClassificationManager.class);
+ textClassificationManager.setTextClassifier(null);
+
+ // System server crashes when there are more than 20,000 leaked binder proxy.
+ // See
+ // http://cs/android/frameworks/base/core/java/android/os/BinderProxy.java?l=73&rcl=ae52315c8c7d0391bd3c7bca0525a98eeb4cd840.
+ for (int i = 0; i < 20000; i++) {
+ Notification notification = createMessageCategoryNotification();
+ StatusBarNotification statusBarNotification =
+ createStatusBarNotification(notification, PACKAGE_NAME);
+ smartActions.onNotificationEnqueued(statusBarNotification);
+ }
+ }
+
+ private Notification createMessageCategoryNotification() {
+ return notificationBuilder
+ .setContentText(MESSAGE)
+ .setCategory(Notification.CATEGORY_MESSAGE)
+ .setActions(createReplyAction())
+ .build();
+ }
+
+ private static StatusBarNotification createStatusBarNotification(
+ Notification notification, String packageName) {
+ return new StatusBarNotification(
+ packageName,
+ packageName,
+ /* id= */ 2,
+ "tag",
+ /* uid= */ 1,
+ /* initialPid= */ 1,
+ /* score= */ 1,
+ notification,
+ Process.myUserHandle(),
+ System.currentTimeMillis());
+ }
+
+ private Notification.Action createReplyAction() {
+ return new Notification.Action.Builder(
+ Icon.createWithResource(context, android.R.drawable.stat_sys_warning),
+ "Reply",
+ PendingIntent.getActivity(context, 0, new Intent(context, this.getClass()), 0))
+ .addRemoteInput(new RemoteInput.Builder("result").setAllowFreeFormInput(true).build())
+ .build();
+ }
+
+ private static void assertMessage(
+ ConversationActions.Message subject,
+ String expectedMessage,
+ Person expectedAuthor,
+ long expectedReferenceTime) {
+ assertThat(subject.getText().toString()).isEqualTo(expectedMessage);
+ assertThat(subject.getAuthor()).isEqualTo(expectedAuthor);
+ assertThat(subject.getReferenceTime().toInstant().toEpochMilli())
+ .isEqualTo(expectedReferenceTime);
+ }
+
+ private static void assertAdjustmentWithSmartReply(SmartSuggestions smartSuggestions) {
+ assertThat(smartSuggestions.getReplies()).containsExactly(SMART_REPLY);
+ }
+
+ private static void assertAdjustmentWithSmartAction(SmartSuggestions smartSuggestions) {
+ assertThat(smartSuggestions.getActions().get(0).title.toString())
+ .isEqualTo(ACTION_TITLE.toString());
+ }
+
+ private static class FakeTextClassifier implements TextClassifier {
+
+ private ConversationActions.Request lastRequest;
+ private final List<TextClassifierEvent> textClassifierEvents = new ArrayList<>();
+ private ConversationActions conversationActions;
+
+ @Override
+ public ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ lastRequest = request;
+ if (conversationActions != null) {
+ return conversationActions;
+ }
+ Collection<String> types =
+ request
+ .getTypeConfig()
+ .resolveEntityListModifications(
+ Arrays.asList(
+ ConversationAction.TYPE_OPEN_URL, ConversationAction.TYPE_TEXT_REPLY));
+ List<ConversationAction> result = new ArrayList<>();
+
+ if (types.contains(ConversationAction.TYPE_TEXT_REPLY)) {
+ ConversationAction smartReply =
+ new ConversationAction.Builder(ConversationAction.TYPE_TEXT_REPLY)
+ .setTextReply(SMART_REPLY)
+ .setConfidenceScore(REPLY_SCORE)
+ .build();
+ result.add(smartReply);
+ }
+ if (types.contains(ConversationAction.TYPE_OPEN_URL)) {
+ Intent webIntent = new Intent(Intent.ACTION_VIEW).setData(Uri.parse("www.android.com"));
+ ConversationAction smartAction =
+ new ConversationAction.Builder(ConversationAction.TYPE_OPEN_URL)
+ .setConfidenceScore(ACTION_SCORE)
+ .setAction(
+ new RemoteAction(
+ Icon.createWithData(new byte[0], 0, 0),
+ ACTION_TITLE,
+ ACTION_TITLE,
+ PendingIntent.getActivity(
+ ApplicationProvider.getApplicationContext(), 0, webIntent, 0)))
+ .build();
+ result.add(smartAction);
+ }
+ return new ConversationActions(result, RESULT_ID);
+ }
+
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {
+ textClassifierEvents.add(event);
+ }
+
+ private void setSuggestConversationActionsResponse(ConversationActions conversationActions) {
+ this.conversationActions = conversationActions;
+ }
+
+ private ConversationActions.Request getLastRequest() {
+ return lastRequest;
+ }
+
+ private List<TextClassifierEvent> getTextClassifierEvents() {
+ return ImmutableList.copyOf(textClassifierEvents);
+ }
+ }
+
+ private static class TestConfig implements SmartSuggestionsConfig {
+ private final boolean shouldGenerateReplies = true;
+ private final boolean shouldGenerateActions = true;
+ private final int maxSuggestions = 3;
+ private final int maxMessagesToExtract = 5;
+
+ @Override
+ public boolean shouldGenerateReplies() {
+ return shouldGenerateReplies;
+ }
+
+ @Override
+ public boolean shouldGenerateActions() {
+ return shouldGenerateActions;
+ }
+
+ @Override
+ public int getMaxSuggestions() {
+ return maxSuggestions;
+ }
+
+ @Override
+ public int getMaxMessagesToExtract() {
+ return maxMessagesToExtract;
+ }
+ }
+}
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
new file mode 100644
index 0000000..bc30fcf
--- /dev/null
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
@@ -0,0 +1,181 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.notification;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+import android.app.Notification;
+import android.app.PendingIntent;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.os.Bundle;
+import android.service.notification.NotificationAssistantService;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.LargeTest;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@LargeTest
+@RunWith(AndroidJUnit4.class)
+public class SmartSuggestionsLogSessionTest {
+ private static final String RESULT_ID = "resultId";
+ private static final String REPLY = "reply";
+ private static final float SCORE = 0.5f;
+
+ @Mock private TextClassifier textClassifier;
+ private SmartSuggestionsLogSession session;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ session =
+ new SmartSuggestionsLogSession(
+ RESULT_ID,
+ ImmutableMap.of(REPLY, SCORE),
+ textClassifier,
+ new TextClassificationContext.Builder(
+ "pkg.name", TextClassifier.WIDGET_TYPE_NOTIFICATION)
+ .build());
+ }
+
+ @Test
+ public void onActionClicked() {
+ session.onActionClicked(
+ createNotificationAction(), NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+
+ TextClassifierEvent textClassifierEvent = getLoggedEvent();
+ assertThat(textClassifierEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
+ assertThat(textClassifierEvent.getResultId()).isEqualTo(RESULT_ID);
+ assertThat(textClassifierEvent.getEntityTypes())
+ .asList()
+ .containsExactly(ConversationAction.TYPE_CALL_PHONE);
+ }
+
+ @Test
+ public void onActionClicked_sourceFromApp() {
+ session.onActionClicked(
+ createNotificationAction(), NotificationAssistantService.SOURCE_FROM_APP);
+
+ verify(textClassifier, never()).onTextClassifierEvent(any());
+ }
+
+ private static Notification.Action createNotificationAction() {
+ Bundle actionExtras = new Bundle();
+ actionExtras.putString(
+ SmartSuggestionsHelper.KEY_ACTION_TYPE, ConversationAction.TYPE_CALL_PHONE);
+ return new Notification.Action.Builder(
+ Icon.createWithData(new byte[0], 0, 0),
+ "Label",
+ PendingIntent.getActivity(
+ ApplicationProvider.getApplicationContext(), 0, new Intent(), 0))
+ .addExtras(actionExtras)
+ .build();
+ }
+
+ @Test
+ public void onDirectReplied() {
+ session.onDirectReplied();
+
+ TextClassifierEvent textClassifierEvent = getLoggedEvent();
+ assertThat(textClassifierEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_MANUAL_REPLY);
+ assertThat(textClassifierEvent.getResultId()).isEqualTo(RESULT_ID);
+ }
+
+ @Test
+ public void onNotificationExpansionChanged() {
+ session.onNotificationExpansionChanged(/* isExpanded= */ true);
+
+ TextClassifierEvent textClassifierEvent = getLoggedEvent();
+ assertThat(textClassifierEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_ACTIONS_SHOWN);
+ assertThat(textClassifierEvent.getResultId()).isEqualTo(RESULT_ID);
+ }
+
+ @Test
+ public void onNotificationExpansionChanged_loggedOnce() {
+ session.onNotificationExpansionChanged(/* isExpanded= */ true);
+ session.onNotificationExpansionChanged(/* isExpanded= */ true);
+
+ ArgumentCaptor<TextClassifierEvent> argumentCaptor =
+ ArgumentCaptor.forClass(TextClassifierEvent.class);
+ verify(textClassifier).onTextClassifierEvent(argumentCaptor.capture());
+ assertThat(argumentCaptor.getAllValues()).hasSize(1);
+ }
+
+ @Test
+ public void onNotificationExpansionChanged_collapsed() {
+ session.onNotificationExpansionChanged(/* isExpanded= */ false);
+
+ verify(textClassifier, never()).onTextClassifierEvent(any());
+ }
+
+ @Test
+ public void onSuggestedReplySent() {
+ session.onSuggestedReplySent(REPLY, NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+
+ TextClassifierEvent textClassifierEvent = getLoggedEvent();
+ assertThat(textClassifierEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
+ assertThat(textClassifierEvent.getResultId()).isEqualTo(RESULT_ID);
+ assertThat(textClassifierEvent.getScores()).usingExactEquality().containsExactly(SCORE);
+ }
+
+ @Test
+ public void onSuggestedReplySent_sourceFromApp() {
+ session.onSuggestedReplySent(REPLY, NotificationAssistantService.SOURCE_FROM_APP);
+
+ verify(textClassifier, never()).onTextClassifierEvent(any());
+ }
+
+ @Test
+ public void onSuggestionsGenerated() {
+ ConversationAction callPhoneAction =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE).build();
+ ConversationAction openUrlAction =
+ new ConversationAction.Builder(ConversationAction.TYPE_OPEN_URL).build();
+
+ session.onSuggestionsGenerated(ImmutableList.of(callPhoneAction, openUrlAction));
+
+ TextClassifierEvent textClassifierEvent = getLoggedEvent();
+ assertThat(textClassifierEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_ACTIONS_GENERATED);
+ assertThat(textClassifierEvent.getEntityTypes())
+ .asList()
+ .containsExactly(ConversationAction.TYPE_CALL_PHONE, ConversationAction.TYPE_OPEN_URL);
+ }
+
+ private TextClassifierEvent getLoggedEvent() {
+ ArgumentCaptor<TextClassifierEvent> argumentCaptor =
+ ArgumentCaptor.forClass(TextClassifierEvent.class);
+ verify(textClassifier).onTextClassifierEvent(argumentCaptor.capture());
+ return argumentCaptor.getValue();
+ }
+}
diff --git a/utils/base/integral_types.h b/utils/base/integral_types.h
deleted file mode 100644
index e3253de..0000000
--- a/utils/base/integral_types.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Basic integer type definitions.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
-#define LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
-
-#include "utils/base/config.h"
-
-namespace libtextclassifier3 {
-
-typedef unsigned int uint32;
-typedef unsigned long long uint64;
-
-#ifndef SWIG
-typedef int int32;
-typedef unsigned char uint8; // NOLINT
-typedef unsigned short uint16; // NOLINT
-
-// A type to represent a Unicode code-point value. As of Unicode 4.0,
-// such values require up to 21 bits.
-// (For type-checking on pointers, make this explicitly signed,
-// and it should always be the signed version of whatever int32 is.)
-typedef signed int char32;
-#endif // SWIG
-
-#ifdef COMPILER_MSVC
-typedef __int64 int64;
-#else
-typedef long long int64; // NOLINT
-#endif // COMPILER_MSVC
-
-// Some compile-time assertions that our new types have the intended size.
-// static_assert exists only since C++11, so we need an ifdef.
-#ifdef LANG_CXX11
-static_assert(sizeof(int) == 4, "Our typedefs depend on int being 32 bits");
-static_assert(sizeof(uint32) == 4, "wrong size");
-static_assert(sizeof(int32) == 4, "wrong size");
-static_assert(sizeof(uint8) == 1, "wrong size");
-static_assert(sizeof(uint16) == 2, "wrong size");
-static_assert(sizeof(char32) == 4, "wrong size");
-static_assert(sizeof(int64) == 8, "wrong size");
-#endif // LANG_CXX11
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_BASE_INTEGRAL_TYPES_H_
diff --git a/utils/base/logging.cc b/utils/base/logging.cc
deleted file mode 100644
index d7ddeb8..0000000
--- a/utils/base/logging.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/base/logging.h"
-
-#include <stdlib.h>
-#include <exception>
-#include <iostream>
-
-#include "utils/base/logging_raw.h"
-
-namespace libtextclassifier3 {
-namespace logging {
-
-namespace {
-// Returns pointer to beginning of last /-separated token from file_name.
-// file_name should be a pointer to a zero-terminated array of chars.
-// E.g., "foo/bar.cc" -> "bar.cc", "foo/" -> "", "foo" -> "foo".
-const char *JumpToBasename(const char *file_name) {
- if (file_name == nullptr) {
- return nullptr;
- }
-
- // Points to the beginning of the last encountered token.
- const char *last_token_start = file_name;
- while (*file_name != '\0') {
- if (*file_name == '/') {
- // Found token separator. A new (potentially empty) token starts after
- // this position. Notice that if file_name is a valid zero-terminated
- // string, file_name + 1 is a valid pointer (there is at least one char
- // after address file_name, the zero terminator).
- last_token_start = file_name + 1;
- }
- file_name++;
- }
- return last_token_start;
-}
-} // namespace
-
-LogMessage::LogMessage(LogSeverity severity, const char *file_name,
- int line_number)
- : severity_(severity) {
- stream_ << JumpToBasename(file_name) << ":" << line_number << ": ";
-}
-
-LogMessage::~LogMessage() {
- LowLevelLogging(severity_, /* tag = */ "txtClsf", stream_.message);
- if (severity_ == FATAL) {
- std::terminate(); // Will print a stacktrace (stdout or logcat).
- }
-}
-
-} // namespace logging
-} // namespace libtextclassifier3
diff --git a/utils/base/logging.h b/utils/base/logging.h
deleted file mode 100644
index 1267f5e..0000000
--- a/utils/base/logging.h
+++ /dev/null
@@ -1,167 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
-#define LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
-
-#include <cassert>
-#include <string>
-
-#include "utils/base/logging_levels.h"
-#include "utils/base/port.h"
-
-
-namespace libtextclassifier3 {
-namespace logging {
-
-// A tiny code footprint string stream for assembling log messages.
-struct LoggingStringStream {
- LoggingStringStream() {}
- LoggingStringStream &stream() { return *this; }
- // Needed for invocation in TC3_CHECK macro.
- explicit operator bool() const { return true; }
-
- std::string message;
-};
-
-template <typename T>
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const T &entry) {
- stream.message.append(std::to_string(entry));
- return stream;
-}
-
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const char *message) {
- stream.message.append(message);
- return stream;
-}
-
-#if defined(HAS_GLOBAL_STRING)
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const ::string &message) {
- stream.message.append(message);
- return stream;
-}
-#endif
-
-inline LoggingStringStream &operator<<(LoggingStringStream &stream,
- const std::string &message) {
- stream.message.append(message);
- return stream;
-}
-
-// The class that does all the work behind our TC3_LOG(severity) macros. Each
-// TC3_LOG(severity) << obj1 << obj2 << ...; logging statement creates a
-// LogMessage temporary object containing a stringstream. Each operator<< adds
-// info to that stringstream and the LogMessage destructor performs the actual
-// logging. The reason this works is that in C++, "all temporary objects are
-// destroyed as the last step in evaluating the full-expression that (lexically)
-// contains the point where they were created." For more info, see
-// http://en.cppreference.com/w/cpp/language/lifetime. Hence, the destructor is
-// invoked after the last << from that logging statement.
-class LogMessage {
- public:
- LogMessage(LogSeverity severity, const char *file_name,
- int line_number) TC3_ATTRIBUTE_NOINLINE;
-
- ~LogMessage() TC3_ATTRIBUTE_NOINLINE;
-
- // Returns the stream associated with the logger object.
- LoggingStringStream &stream() { return stream_; }
-
- private:
- const LogSeverity severity_;
-
- // Stream that "prints" all info into a string (not to a file). We construct
- // here the entire logging message and next print it in one operation.
- LoggingStringStream stream_;
-};
-
-// Pseudo-stream that "eats" the tokens <<-pumped into it, without printing
-// anything.
-class NullStream {
- public:
- NullStream() {}
- NullStream &stream() { return *this; }
-};
-template <typename T>
-inline NullStream &operator<<(NullStream &str, const T &) {
- return str;
-}
-
-} // namespace logging
-} // namespace libtextclassifier3
-
-#define TC3_LOG(severity) \
- ::libtextclassifier3::logging::LogMessage( \
- ::libtextclassifier3::logging::severity, __FILE__, __LINE__) \
- .stream()
-
-// If condition x is true, does nothing. Otherwise, crashes the program (liek
-// LOG(FATAL)) with an informative message. Can be continued with extra
-// messages, via <<, like any logging macro, e.g.,
-//
-// TC3_CHECK(my_cond) << "I think we hit a problem";
-#define TC3_CHECK(x) \
- (x) || TC3_LOG(FATAL) << __FILE__ << ":" << __LINE__ << ": check failed: \"" \
- << #x << "\" "
-
-#define TC3_CHECK_EQ(x, y) TC3_CHECK((x) == (y))
-#define TC3_CHECK_LT(x, y) TC3_CHECK((x) < (y))
-#define TC3_CHECK_GT(x, y) TC3_CHECK((x) > (y))
-#define TC3_CHECK_LE(x, y) TC3_CHECK((x) <= (y))
-#define TC3_CHECK_GE(x, y) TC3_CHECK((x) >= (y))
-#define TC3_CHECK_NE(x, y) TC3_CHECK((x) != (y))
-
-#define TC3_NULLSTREAM ::libtextclassifier3::logging::NullStream().stream()
-
-// Debug checks: a TC3_DCHECK<suffix> macro should behave like TC3_CHECK<suffix>
-// in debug mode an don't check / don't print anything in non-debug mode.
-#ifdef NDEBUG
-
-#define TC3_DCHECK(x) TC3_NULLSTREAM
-#define TC3_DCHECK_EQ(x, y) TC3_NULLSTREAM
-#define TC3_DCHECK_LT(x, y) TC3_NULLSTREAM
-#define TC3_DCHECK_GT(x, y) TC3_NULLSTREAM
-#define TC3_DCHECK_LE(x, y) TC3_NULLSTREAM
-#define TC3_DCHECK_GE(x, y) TC3_NULLSTREAM
-#define TC3_DCHECK_NE(x, y) TC3_NULLSTREAM
-
-#else // NDEBUG
-
-// In debug mode, each TC3_DCHECK<suffix> is equivalent to TC3_CHECK<suffix>,
-// i.e., a real check that crashes when the condition is not true.
-#define TC3_DCHECK(x) TC3_CHECK(x)
-#define TC3_DCHECK_EQ(x, y) TC3_CHECK_EQ(x, y)
-#define TC3_DCHECK_LT(x, y) TC3_CHECK_LT(x, y)
-#define TC3_DCHECK_GT(x, y) TC3_CHECK_GT(x, y)
-#define TC3_DCHECK_LE(x, y) TC3_CHECK_LE(x, y)
-#define TC3_DCHECK_GE(x, y) TC3_CHECK_GE(x, y)
-#define TC3_DCHECK_NE(x, y) TC3_CHECK_NE(x, y)
-
-#endif // NDEBUG
-
-#ifdef TC3_VLOG
-#define TC3_VLOG(severity) \
- ::libtextclassifier3::logging::LogMessage( \
- ::libtextclassifier3::logging::INFO, __FILE__, __LINE__) \
- .stream()
-#else
-#define TC3_VLOG(severity) TC3_NULLSTREAM
-#endif
-
-#endif // LIBTEXTCLASSIFIER_UTILS_BASE_LOGGING_H_
diff --git a/utils/base/logging_raw.cc b/utils/base/logging_raw.cc
deleted file mode 100644
index ccaef22..0000000
--- a/utils/base/logging_raw.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/base/logging_raw.h"
-
-#include <stdio.h>
-#include <string>
-
-// NOTE: this file contains two implementations: one for Android, one for all
-// other cases. We always build exactly one implementation.
-#if defined(__ANDROID__)
-
-// Compiled as part of Android.
-#include <android/log.h>
-
-namespace libtextclassifier3 {
-namespace logging {
-
-namespace {
-// Converts LogSeverity to level for __android_log_write.
-int GetAndroidLogLevel(LogSeverity severity) {
- switch (severity) {
- case FATAL:
- return ANDROID_LOG_FATAL;
- case ERROR:
- return ANDROID_LOG_ERROR;
- case WARNING:
- return ANDROID_LOG_WARN;
- case INFO:
- return ANDROID_LOG_INFO;
- default:
- return ANDROID_LOG_DEBUG;
- }
-}
-} // namespace
-
-void LowLevelLogging(LogSeverity severity, const std::string& tag,
- const std::string& message) {
- const int android_log_level = GetAndroidLogLevel(severity);
-#if !defined(TC3_DEBUG_LOGGING)
- if (android_log_level != ANDROID_LOG_ERROR &&
- android_log_level != ANDROID_LOG_FATAL) {
- return;
- }
-#endif
- __android_log_write(android_log_level, tag.c_str(), message.c_str());
-}
-
-} // namespace logging
-} // namespace libtextclassifier3
-
-#else // if defined(__ANDROID__)
-
-// Not on Android: implement LowLevelLogging to print to stderr (see below).
-namespace libtextclassifier3 {
-namespace logging {
-
-namespace {
-// Converts LogSeverity to human-readable text.
-const char *LogSeverityToString(LogSeverity severity) {
- switch (severity) {
- case INFO:
- return "INFO";
- case WARNING:
- return "WARNING";
- case ERROR:
- return "ERROR";
- case FATAL:
- return "FATAL";
- default:
- return "UNKNOWN";
- }
-}
-} // namespace
-
-void LowLevelLogging(LogSeverity severity, const std::string &tag,
- const std::string &message) {
- fprintf(stderr, "[%s] %s : %s\n", LogSeverityToString(severity), tag.c_str(),
- message.c_str());
- fflush(stderr);
-}
-
-} // namespace logging
-} // namespace libtextclassifier3
-
-#endif // if defined(__ANDROID__)
diff --git a/utils/base/macros.h b/utils/base/macros.h
deleted file mode 100644
index 6739c0b..0000000
--- a/utils/base/macros.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
-#define LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
-
-#include "utils/base/config.h"
-
-namespace libtextclassifier3 {
-
-#if LANG_CXX11
-#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
- TypeName(const TypeName &) = delete; \
- TypeName &operator=(const TypeName &) = delete
-#else // C++98 case follows
-
-// Note that these C++98 implementations cannot completely disallow copying,
-// as members and friends can still accidentally make elided copies without
-// triggering a linker error.
-#define TC3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
- TypeName(const TypeName &); \
- TypeName &operator=(const TypeName &)
-#endif // LANG_CXX11
-
-// The TC3_FALLTHROUGH_INTENDED macro can be used to annotate implicit
-// fall-through between switch labels:
-//
-// switch (x) {
-// case 40:
-// case 41:
-// if (truth_is_out_there) {
-// ++x;
-// TC3_FALLTHROUGH_INTENDED; // Use instead of/along with annotations in
-// // comments.
-// } else {
-// return x;
-// }
-// case 42:
-// ...
-//
-// As shown in the example above, the TC3_FALLTHROUGH_INTENDED macro should be
-// followed by a semicolon. It is designed to mimic control-flow statements
-// like 'break;', so it can be placed in most places where 'break;' can, but
-// only if there are no statements on the execution path between it and the
-// next switch label.
-//
-// When compiled with clang in C++11 mode, the TC3_FALLTHROUGH_INTENDED macro
-// is expanded to [[clang::fallthrough]] attribute, which is analysed when
-// performing switch labels fall-through diagnostic ('-Wimplicit-fallthrough').
-// See clang documentation on language extensions for details:
-// http://clang.llvm.org/docs/AttributeReference.html#fallthrough-clang-fallthrough
-//
-// When used with unsupported compilers, the TC3_FALLTHROUGH_INTENDED macro has
-// no effect on diagnostics.
-//
-// In either case this macro has no effect on runtime behavior and performance
-// of code.
-#if defined(__clang__) && defined(__has_warning)
-#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
-#define TC3_FALLTHROUGH_INTENDED [[clang::fallthrough]]
-#endif
-#elif defined(__GNUC__) && __GNUC__ >= 7
-#define TC3_FALLTHROUGH_INTENDED [[gnu::fallthrough]]
-#endif
-
-#ifndef TC3_FALLTHROUGH_INTENDED
-#define TC3_FALLTHROUGH_INTENDED \
- do { \
- } while (0)
-#endif
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
diff --git a/utils/calendar/CalendarJavaIcuLocalTest.java b/utils/calendar/CalendarJavaIcuLocalTest.java
deleted file mode 100644
index 9beb36e..0000000
--- a/utils/calendar/CalendarJavaIcuLocalTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.calendar;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import com.google.thirdparty.robolectric.GoogleRobolectricTestRunner;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(GoogleRobolectricTestRunner.class)
-public class CalendarJavaIcuLocalTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("calendar-javaicu_test-lib");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/utils/calendar/CalendarJavaIcuTest.java b/utils/calendar/CalendarJavaIcuTest.java
deleted file mode 100644
index ab1f00a..0000000
--- a/utils/calendar/CalendarJavaIcuTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.calendar;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public class CalendarJavaIcuTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("calendar-javaicu_test-lib");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/utils/calendar/calendar-common.h b/utils/calendar/calendar-common.h
deleted file mode 100644
index 5c91e22..0000000
--- a/utils/calendar/calendar-common.h
+++ /dev/null
@@ -1,362 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
-
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/base/macros.h"
-
-namespace libtextclassifier3 {
-namespace calendar {
-
-// Macro to reduce the amount of boilerplate needed for propagating errors.
-#define TC3_CALENDAR_CHECK(EXPR) \
- if (!(EXPR)) { \
- return false; \
- }
-
-// An implementation of CalendarLib that is independent of the particular
-// calendar implementation used (implementation type is passed as template
-// argument).
-template <class TCalendar>
-class CalendarLibTempl {
- public:
- bool InterpretParseData(const DateParseData& parse_data,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- TCalendar* calendar,
- DatetimeGranularity* granularity) const;
-
- DatetimeGranularity GetGranularity(const DateParseData& data) const;
-
- private:
- // Adjusts the calendar's time instant according to a relative date reference
- // in the parsed data.
- bool ApplyRelationField(const DateParseData& parse_data,
- TCalendar* calendar) const;
-
- // Round the time instant's precision down to the given granularity.
- bool RoundToGranularity(DatetimeGranularity granularity,
- TCalendar* calendar) const;
-
- // Adjusts time in steps of relation_type, by distance steps.
- // For example:
- // - Adjusting by -2 MONTHS will return the beginning of the 1st
- // two weeks ago.
- // - Adjusting by +4 Wednesdays will return the beginning of the next
- // Wednesday at least 4 weeks from now.
- // If allow_today is true, the same day of the week may be kept
- // if it already matches the relation type.
- bool AdjustByRelation(DateParseData::RelationType relation_type, int distance,
- bool allow_today, TCalendar* calendar) const;
-};
-
-template <class TCalendar>
-bool CalendarLibTempl<TCalendar>::InterpretParseData(
- const DateParseData& parse_data, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& reference_locale,
- TCalendar* calendar, DatetimeGranularity* granularity) const {
- TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale,
- reference_time_ms_utc))
-
- bool should_round_to_granularity = true;
- *granularity = GetGranularity(parse_data);
-
- // Apply each of the parsed fields in order of increasing granularity.
- static const int64 kMillisInHour = 1000 * 60 * 60;
- if (parse_data.field_set_mask & DateParseData::Fields::ZONE_OFFSET_FIELD) {
- TC3_CALENDAR_CHECK(
- calendar->SetZoneOffset(parse_data.zone_offset * kMillisInHour))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::DST_OFFSET_FIELD) {
- TC3_CALENDAR_CHECK(
- calendar->SetDstOffset(parse_data.dst_offset * kMillisInHour))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::RELATION_FIELD) {
- TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar));
- // Don't round to the granularity for relative expressions that specify the
- // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
- // 10:35:03.
- if (parse_data.field_set_mask &
- DateParseData::Fields::RELATION_DISTANCE_FIELD) {
- should_round_to_granularity = false;
- }
- } else {
- // By default, the parsed time is interpreted to be on the reference day.
- // But a parsed date should have time 0:00:00 unless specified.
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0))
- TC3_CALENDAR_CHECK(calendar->SetMinute(0))
- TC3_CALENDAR_CHECK(calendar->SetSecond(0))
- TC3_CALENDAR_CHECK(calendar->SetMillisecond(0))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::YEAR_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetYear(parse_data.year))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::MONTH_FIELD) {
- // ICU has months starting at 0, Java and Datetime parser at 1, so we
- // need to subtract 1.
- TC3_CALENDAR_CHECK(calendar->SetMonth(parse_data.month - 1))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::DAY_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(parse_data.day_of_month))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::HOUR_FIELD) {
- if (parse_data.field_set_mask & DateParseData::Fields::AMPM_FIELD &&
- parse_data.ampm == DateParseData::AMPM::PM && parse_data.hour < 12) {
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour + 12))
- } else if (parse_data.ampm == DateParseData::AMPM::AM &&
- parse_data.hour == 12) {
- // Do nothing. 12am == 0.
- } else {
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(parse_data.hour))
- }
- }
- if (parse_data.field_set_mask & DateParseData::Fields::MINUTE_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetMinute(parse_data.minute))
- }
- if (parse_data.field_set_mask & DateParseData::Fields::SECOND_FIELD) {
- TC3_CALENDAR_CHECK(calendar->SetSecond(parse_data.second))
- }
-
- if (should_round_to_granularity) {
- TC3_CALENDAR_CHECK(RoundToGranularity(*granularity, calendar))
- }
- return true;
-}
-
-template <class TCalendar>
-bool CalendarLibTempl<TCalendar>::ApplyRelationField(
- const DateParseData& parse_data, TCalendar* calendar) const {
- constexpr int relation_type_mask = DateParseData::Fields::RELATION_TYPE_FIELD;
- constexpr int relation_distance_mask =
- DateParseData::Fields::RELATION_DISTANCE_FIELD;
- switch (parse_data.relation) {
- case DateParseData::Relation::UNSPECIFIED:
- TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
- return false;
- case DateParseData::Relation::NEXT:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/1,
- /*allow_today=*/false, calendar));
- }
- return true;
- case DateParseData::Relation::NEXT_OR_SAME:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/1,
- /*allow_today=*/true, calendar))
- }
- return true;
- case DateParseData::Relation::LAST:
- if (parse_data.field_set_mask & relation_type_mask) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- /*distance=*/-1,
- /*allow_today=*/false, calendar))
- }
- return true;
- case DateParseData::Relation::NOW:
- return true; // NOOP
- case DateParseData::Relation::TOMORROW:
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(1));
- return true;
- case DateParseData::Relation::YESTERDAY:
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(-1));
- return true;
- case DateParseData::Relation::PAST:
- if ((parse_data.field_set_mask & relation_type_mask) &&
- (parse_data.field_set_mask & relation_distance_mask)) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- -parse_data.relation_distance,
- /*allow_today=*/false, calendar))
- }
- return true;
- case DateParseData::Relation::FUTURE:
- if ((parse_data.field_set_mask & relation_type_mask) &&
- (parse_data.field_set_mask & relation_distance_mask)) {
- TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
- parse_data.relation_distance,
- /*allow_today=*/false, calendar))
- }
- return true;
- }
- return false;
-}
-
-template <class TCalendar>
-bool CalendarLibTempl<TCalendar>::RoundToGranularity(
- DatetimeGranularity granularity, TCalendar* calendar) const {
- // Force recomputation before doing the rounding.
- int unused;
- TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&unused));
-
- switch (granularity) {
- case GRANULARITY_YEAR:
- TC3_CALENDAR_CHECK(calendar->SetMonth(0));
- TC3_FALLTHROUGH_INTENDED;
- case GRANULARITY_MONTH:
- TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1));
- TC3_FALLTHROUGH_INTENDED;
- case GRANULARITY_DAY:
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
- TC3_FALLTHROUGH_INTENDED;
- case GRANULARITY_HOUR:
- TC3_CALENDAR_CHECK(calendar->SetMinute(0));
- TC3_FALLTHROUGH_INTENDED;
- case GRANULARITY_MINUTE:
- TC3_CALENDAR_CHECK(calendar->SetSecond(0));
- break;
-
- case GRANULARITY_WEEK:
- int first_day_of_week;
- TC3_CALENDAR_CHECK(calendar->GetFirstDayOfWeek(&first_day_of_week));
- TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(first_day_of_week));
- TC3_CALENDAR_CHECK(calendar->SetHourOfDay(0));
- TC3_CALENDAR_CHECK(calendar->SetMinute(0));
- TC3_CALENDAR_CHECK(calendar->SetSecond(0));
- break;
-
- case GRANULARITY_UNKNOWN:
- case GRANULARITY_SECOND:
- break;
- }
- return true;
-}
-
-template <class TCalendar>
-bool CalendarLibTempl<TCalendar>::AdjustByRelation(
- DateParseData::RelationType relation_type, int distance, bool allow_today,
- TCalendar* calendar) const {
- const int distance_sign = distance < 0 ? -1 : 1;
- switch (relation_type) {
- case DateParseData::RelationType::MONDAY:
- case DateParseData::RelationType::TUESDAY:
- case DateParseData::RelationType::WEDNESDAY:
- case DateParseData::RelationType::THURSDAY:
- case DateParseData::RelationType::FRIDAY:
- case DateParseData::RelationType::SATURDAY:
- case DateParseData::RelationType::SUNDAY:
- if (!allow_today) {
- // If we're not including the same day as the reference, skip it.
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
- }
- // Keep walking back until we hit the desired day of the week.
- while (distance != 0) {
- int day_of_week;
- TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week))
- if (day_of_week == static_cast<int>(relation_type)) {
- distance += -distance_sign;
- if (distance == 0) break;
- }
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
- }
- return true;
- case DateParseData::RelationType::SECOND:
- TC3_CALENDAR_CHECK(calendar->AddSecond(distance));
- return true;
- case DateParseData::RelationType::MINUTE:
- TC3_CALENDAR_CHECK(calendar->AddMinute(distance));
- return true;
- case DateParseData::RelationType::HOUR:
- TC3_CALENDAR_CHECK(calendar->AddHourOfDay(distance));
- return true;
- case DateParseData::RelationType::DAY:
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance));
- return true;
- case DateParseData::RelationType::WEEK:
- TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance))
- TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1))
- return true;
- case DateParseData::RelationType::MONTH:
- TC3_CALENDAR_CHECK(calendar->AddMonth(distance))
- TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1))
- return true;
- case DateParseData::RelationType::YEAR:
- TC3_CALENDAR_CHECK(calendar->AddYear(distance))
- TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1))
- return true;
- default:
- TC3_LOG(ERROR) << "Unknown relation type: "
- << static_cast<int>(relation_type);
- return false;
- }
- return false;
-}
-
-template <class TCalendar>
-DatetimeGranularity CalendarLibTempl<TCalendar>::GetGranularity(
- const DateParseData& data) const {
- DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_YEAR;
- if ((data.field_set_mask & DateParseData::YEAR_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::YEAR))) {
- granularity = DatetimeGranularity::GRANULARITY_YEAR;
- }
- if ((data.field_set_mask & DateParseData::MONTH_FIELD) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONTH))) {
- granularity = DatetimeGranularity::GRANULARITY_MONTH;
- }
- if (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::WEEK)) {
- granularity = DatetimeGranularity::GRANULARITY_WEEK;
- }
- if (data.field_set_mask & DateParseData::DAY_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_FIELD &&
- (data.relation == DateParseData::Relation::NOW ||
- data.relation == DateParseData::Relation::TOMORROW ||
- data.relation == DateParseData::Relation::YESTERDAY)) ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::MONDAY ||
- data.relation_type == DateParseData::RelationType::TUESDAY ||
- data.relation_type == DateParseData::RelationType::WEDNESDAY ||
- data.relation_type == DateParseData::RelationType::THURSDAY ||
- data.relation_type == DateParseData::RelationType::FRIDAY ||
- data.relation_type == DateParseData::RelationType::SATURDAY ||
- data.relation_type == DateParseData::RelationType::SUNDAY ||
- data.relation_type == DateParseData::RelationType::DAY))) {
- granularity = DatetimeGranularity::GRANULARITY_DAY;
- }
- if (data.field_set_mask & DateParseData::HOUR_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::HOUR))) {
- granularity = DatetimeGranularity::GRANULARITY_HOUR;
- }
- if (data.field_set_mask & DateParseData::MINUTE_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- data.relation_type == DateParseData::RelationType::MINUTE)) {
- granularity = DatetimeGranularity::GRANULARITY_MINUTE;
- }
- if (data.field_set_mask & DateParseData::SECOND_FIELD ||
- (data.field_set_mask & DateParseData::RELATION_TYPE_FIELD &&
- (data.relation_type == DateParseData::RelationType::SECOND))) {
- granularity = DatetimeGranularity::GRANULARITY_SECOND;
- }
-
- return granularity;
-}
-
-}; // namespace calendar
-
-#undef TC3_CALENDAR_CHECK
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_COMMON_H_
diff --git a/utils/calendar/calendar-javaicu.cc b/utils/calendar/calendar-javaicu.cc
deleted file mode 100644
index ac09979..0000000
--- a/utils/calendar/calendar-javaicu.cc
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/calendar/calendar-javaicu.h"
-
-#include "annotator/types.h"
-#include "utils/java/scoped_local_ref.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-// Generic version of icu::Calendar::add with error checking.
-bool CalendarAdd(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
- jint field, jint value) {
- jenv->CallVoidMethod(calendar, jni_cache->calendar_add, field, value);
- return !jni_cache->ExceptionCheckAndClear();
-}
-
-// Generic version of icu::Calendar::get with error checking.
-bool CalendarGet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
- jint field, jint* value) {
- *value = jenv->CallIntMethod(calendar, jni_cache->calendar_get, field);
- return !jni_cache->ExceptionCheckAndClear();
-}
-
-// Generic version of icu::Calendar::set with error checking.
-bool CalendarSet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
- jint field, jint value) {
- jenv->CallVoidMethod(calendar, jni_cache->calendar_set, field, value);
- return !jni_cache->ExceptionCheckAndClear();
-}
-
-// Extracts the first tag from a BCP47 tag (e.g. "en" for "en-US").
-std::string GetFirstBcp47Tag(const std::string& tag) {
- for (size_t i = 0; i < tag.size(); ++i) {
- if (tag[i] == '_' || tag[i] == '-') {
- return std::string(tag, 0, i);
- }
- }
- return tag;
-}
-
-} // anonymous namespace
-
-Calendar::Calendar(JniCache* jni_cache)
- : jni_cache_(jni_cache),
- jenv_(jni_cache_ ? jni_cache->GetEnv() : nullptr) {}
-
-bool Calendar::Initialize(const std::string& time_zone,
- const std::string& locale, int64 time_ms_utc) {
- if (!jni_cache_ || !jenv_) {
- TC3_LOG(ERROR) << "Initialize without env";
- return false;
- }
-
- // We'll assume the day indices match later on, so verify it here.
- if (jni_cache_->calendar_sunday !=
- static_cast<int>(DateParseData::RelationType::SUNDAY) ||
- jni_cache_->calendar_monday !=
- static_cast<int>(DateParseData::RelationType::MONDAY) ||
- jni_cache_->calendar_tuesday !=
- static_cast<int>(DateParseData::RelationType::TUESDAY) ||
- jni_cache_->calendar_wednesday !=
- static_cast<int>(DateParseData::RelationType::WEDNESDAY) ||
- jni_cache_->calendar_thursday !=
- static_cast<int>(DateParseData::RelationType::THURSDAY) ||
- jni_cache_->calendar_friday !=
- static_cast<int>(DateParseData::RelationType::FRIDAY) ||
- jni_cache_->calendar_saturday !=
- static_cast<int>(DateParseData::RelationType::SATURDAY)) {
- TC3_LOG(ERROR) << "day of the week indices mismatch";
- return false;
- }
-
- // Get the time zone.
- ScopedLocalRef<jstring> java_time_zone_str(
- jenv_->NewStringUTF(time_zone.c_str()));
- ScopedLocalRef<jobject> java_time_zone(jenv_->CallStaticObjectMethod(
- jni_cache_->timezone_class.get(), jni_cache_->timezone_get_timezone,
- java_time_zone_str.get()));
- if (jni_cache_->ExceptionCheckAndClear() || !java_time_zone) {
- TC3_LOG(ERROR) << "failed to get timezone";
- return false;
- }
-
- // Get the locale.
- ScopedLocalRef<jobject> java_locale;
- if (jni_cache_->locale_for_language_tag) {
- // API level 21+, we can actually parse language tags.
- ScopedLocalRef<jstring> java_locale_str(
- jenv_->NewStringUTF(locale.c_str()));
- java_locale.reset(jenv_->CallStaticObjectMethod(
- jni_cache_->locale_class.get(), jni_cache_->locale_for_language_tag,
- java_locale_str.get()));
- } else {
- // API level <21. We can't parse tags, so we just use the language.
- ScopedLocalRef<jstring> java_language_str(
- jenv_->NewStringUTF(GetFirstBcp47Tag(locale).c_str()));
- java_locale.reset(jenv_->NewObject(jni_cache_->locale_class.get(),
- jni_cache_->locale_init_string,
- java_language_str.get()));
- }
- if (jni_cache_->ExceptionCheckAndClear() || !java_locale) {
- TC3_LOG(ERROR) << "failed to get locale";
- return false;
- }
-
- // Get the calendar.
- calendar_.reset(jenv_->CallStaticObjectMethod(
- jni_cache_->calendar_class.get(), jni_cache_->calendar_get_instance,
- java_time_zone.get(), java_locale.get()));
- if (jni_cache_->ExceptionCheckAndClear() || !calendar_) {
- TC3_LOG(ERROR) << "failed to get calendar";
- return false;
- }
-
- // Set the time.
- jenv_->CallVoidMethod(calendar_.get(),
- jni_cache_->calendar_set_time_in_millis, time_ms_utc);
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "failed to set time";
- return false;
- }
- return true;
-}
-
-bool Calendar::GetFirstDayOfWeek(int* value) const {
- if (!jni_cache_ || !jenv_ || !calendar_) return false;
- *value = jenv_->CallIntMethod(calendar_.get(),
- jni_cache_->calendar_get_first_day_of_week);
- return !jni_cache_->ExceptionCheckAndClear();
-}
-
-bool Calendar::GetTimeInMillis(int64* value) const {
- if (!jni_cache_ || !jenv_ || !calendar_) return false;
- *value = jenv_->CallLongMethod(calendar_.get(),
- jni_cache_->calendar_get_time_in_millis);
- return !jni_cache_->ExceptionCheckAndClear();
-}
-
-CalendarLib::CalendarLib() {
- TC3_LOG(FATAL) << "Java ICU CalendarLib must be initialized with a JniCache.";
-}
-
-CalendarLib::CalendarLib(const std::shared_ptr<JniCache>& jni_cache)
- : jni_cache_(jni_cache) {}
-
-// Below is the boilerplate code for implementing the specialisations of
-// get/set/add for the various field types.
-#define TC3_DEFINE_FIELD_ACCESSOR(NAME, FIELD, KIND, TYPE) \
- bool Calendar::KIND##NAME(TYPE value) const { \
- if (!jni_cache_ || !jenv_ || !calendar_) return false; \
- return Calendar##KIND(jni_cache_, jenv_, calendar_.get(), \
- jni_cache_->calendar_##FIELD, value); \
- }
-#define TC3_DEFINE_ADD(NAME, CONST) \
- TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Add, int)
-#define TC3_DEFINE_SET(NAME, CONST) \
- TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Set, int)
-#define TC3_DEFINE_GET(NAME, CONST) \
- TC3_DEFINE_FIELD_ACCESSOR(NAME, CONST, Get, int*)
-
-TC3_DEFINE_ADD(Second, second)
-TC3_DEFINE_ADD(Minute, minute)
-TC3_DEFINE_ADD(HourOfDay, hour_of_day)
-TC3_DEFINE_ADD(DayOfMonth, day_of_month)
-TC3_DEFINE_ADD(Year, year)
-TC3_DEFINE_ADD(Month, month)
-TC3_DEFINE_GET(DayOfWeek, day_of_week)
-TC3_DEFINE_SET(ZoneOffset, zone_offset)
-TC3_DEFINE_SET(DstOffset, dst_offset)
-TC3_DEFINE_SET(Year, year)
-TC3_DEFINE_SET(Month, month)
-TC3_DEFINE_SET(DayOfYear, day_of_year)
-TC3_DEFINE_SET(DayOfMonth, day_of_month)
-TC3_DEFINE_SET(DayOfWeek, day_of_week)
-TC3_DEFINE_SET(HourOfDay, hour_of_day)
-TC3_DEFINE_SET(Minute, minute)
-TC3_DEFINE_SET(Second, second)
-TC3_DEFINE_SET(Millisecond, millisecond)
-
-#undef TC3_DEFINE_FIELD_ACCESSOR
-#undef TC3_DEFINE_ADD
-#undef TC3_DEFINE_SET
-#undef TC3_DEFINE_GET
-
-} // namespace libtextclassifier3
diff --git a/utils/calendar/calendar-javaicu.h b/utils/calendar/calendar-javaicu.h
deleted file mode 100644
index 02673cc..0000000
--- a/utils/calendar/calendar-javaicu.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
-
-#include <jni.h>
-#include <memory>
-#include <string>
-
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar-common.h"
-#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
-
-namespace libtextclassifier3 {
-
-class Calendar {
- public:
- explicit Calendar(JniCache* jni_cache);
- bool Initialize(const std::string& time_zone, const std::string& locale,
- int64 time_ms_utc);
- bool AddSecond(int value) const;
- bool AddMinute(int value) const;
- bool AddHourOfDay(int value) const;
- bool AddDayOfMonth(int value) const;
- bool AddYear(int value) const;
- bool AddMonth(int value) const;
- bool GetDayOfWeek(int* value) const;
- bool GetFirstDayOfWeek(int* value) const;
- bool GetTimeInMillis(int64* value) const;
- bool SetZoneOffset(int value) const;
- bool SetDstOffset(int value) const;
- bool SetYear(int value) const;
- bool SetMonth(int value) const;
- bool SetDayOfYear(int value) const;
- bool SetDayOfMonth(int value) const;
- bool SetDayOfWeek(int value) const;
- bool SetHourOfDay(int value) const;
- bool SetMinute(int value) const;
- bool SetSecond(int value) const;
- bool SetMillisecond(int value) const;
-
- private:
- JniCache* jni_cache_;
- JNIEnv* jenv_;
- ScopedLocalRef<jobject> calendar_;
-};
-
-class CalendarLib {
- public:
- CalendarLib();
- explicit CalendarLib(const std::shared_ptr<JniCache>& jni_cache);
-
- // Returns false (dummy version).
- bool InterpretParseData(const DateParseData& parse_data,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- int64* interpreted_time_ms_utc,
- DatetimeGranularity* granularity) const {
- Calendar calendar(jni_cache_.get());
- if (!impl_.InterpretParseData(parse_data, reference_time_ms_utc,
- reference_timezone, reference_locale,
- &calendar, granularity)) {
- return false;
- }
- return calendar.GetTimeInMillis(interpreted_time_ms_utc);
- }
-
- DatetimeGranularity GetGranularity(const DateParseData& data) const {
- return impl_.GetGranularity(data);
- }
-
- private:
- std::shared_ptr<JniCache> jni_cache_;
- calendar::CalendarLibTempl<Calendar> impl_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
diff --git a/utils/calendar/calendar.h b/utils/calendar/calendar.h
deleted file mode 100644
index 99b137f..0000000
--- a/utils/calendar/calendar.h
+++ /dev/null
@@ -1,23 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
-
-#include "utils/calendar/calendar-javaicu.h"
-#define INIT_CALENDARLIB_FOR_TESTING(VAR) VAR(nullptr)
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_H_
diff --git a/utils/calendar/calendar_test-include.cc b/utils/calendar/calendar_test-include.cc
deleted file mode 100644
index 70520a2..0000000
--- a/utils/calendar/calendar_test-include.cc
+++ /dev/null
@@ -1,309 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/calendar/calendar_test-include.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-TEST_F(CalendarTest, Interface) {
- int64 time;
- DatetimeGranularity granularity;
- std::string timezone;
- bool result = calendarlib_.InterpretParseData(
- DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0,
- /*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0, /*dst_offset=*/0,
- static_cast<DateParseData::Relation>(0),
- static_cast<DateParseData::RelationType>(0),
- /*relation_distance=*/0},
- 0L, "Zurich", "en-CH", &time, &granularity);
- TC3_LOG(INFO) << result;
-}
-
-TEST_F(CalendarTest, SetsZeroTimeWhenNotRelative) {
- int64 time;
- DatetimeGranularity granularity;
- DateParseData data;
-
- data.year = 2018;
- data.field_set_mask = DateParseData::YEAR_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-}
-
-TEST_F(CalendarTest, RoundingToGranularityBasic) {
- int64 time;
- DatetimeGranularity granularity;
- DateParseData data;
-
- data.year = 2018;
- data.field_set_mask = DateParseData::YEAR_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
-
- data.month = 4;
- data.field_set_mask |= DateParseData::MONTH_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
-
- data.day_of_month = 25;
- data.field_set_mask |= DateParseData::DAY_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
-
- data.hour = 9;
- data.field_set_mask |= DateParseData::HOUR_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
-
- data.minute = 33;
- data.field_set_mask |= DateParseData::MINUTE_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
-
- data.second = 59;
- data.field_set_mask |= DateParseData::SECOND_FIELD;
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
- EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
-}
-
-TEST_F(CalendarTest, RoundingToGranularityWeek) {
- int64 time;
- DatetimeGranularity granularity;
- // Prepare data structure that means: "next week"
- DateParseData data;
- data.field_set_mask =
- DateParseData::RELATION_FIELD | DateParseData::RELATION_TYPE_FIELD;
- data.relation = DateParseData::Relation::NEXT;
- data.relation_type = DateParseData::RelationType::WEEK;
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"de-CH", &time, &granularity));
- EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
-
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- data,
- /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 255600000L /* Sun Jan 04 1970 00:00:00 */);
-}
-
-TEST_F(CalendarTest, RelativeTime) {
- const int field_mask = DateParseData::RELATION_FIELD |
- DateParseData::RELATION_TYPE_FIELD |
- DateParseData::RELATION_DISTANCE_FIELD;
- const int64 ref_time = 1524648839000L; /* 25 April 2018 09:33:59 */
- int64 time;
- DatetimeGranularity granularity;
-
- // Two Weds from now.
- const DateParseData future_wed_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1525858439000L /* Wed May 09 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Next Wed.
- const DateParseData next_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1525253639000L /* Wed May 02 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Same Wed.
- const DateParseData same_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::NEXT_OR_SAME,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524648839000L /* Wed Apr 25 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Previous Wed.
- const DateParseData last_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::LAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/0};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524044039000L /* Wed Apr 18 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // Two Weds ago.
- const DateParseData past_wed_parse = {field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::PAST,
- DateParseData::RelationType::WEDNESDAY,
- /*relation_distance=*/2};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1523439239000L /* Wed Apr 11 2018 11:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_DAY);
-
- // In 3 hours.
- const DateParseData in_3_hours_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::HOUR,
- /*relation_distance=*/3};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524659639000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_HOUR);
-
- // In 5 minutes.
- const DateParseData in_5_minutes_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::MINUTE,
- /*relation_distance=*/5};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524649139000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_MINUTE);
-
- // In 10 seconds.
- const DateParseData in_10_seconds_parse = {
- field_mask,
- /*year=*/0,
- /*month=*/0,
- /*day_of_month=*/0,
- /*hour=*/0,
- /*minute=*/0,
- /*second=*/0,
- /*ampm=*/static_cast<DateParseData::AMPM>(0),
- /*zone_offset=*/0,
- /*dst_offset=*/0,
- DateParseData::Relation::FUTURE,
- DateParseData::RelationType::SECOND,
- /*relation_distance=*/10};
- ASSERT_TRUE(calendarlib_.InterpretParseData(
- in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
- EXPECT_EQ(time, 1524648849000L /* Wed Apr 25 2018 14:33:59 */);
- EXPECT_EQ(granularity, GRANULARITY_SECOND);
-}
-
-} // namespace test_internal
-} // namespace libtextclassifier3
diff --git a/utils/calendar/calendar_test-include.h b/utils/calendar/calendar_test-include.h
deleted file mode 100644
index 169a4ed..0000000
--- a/utils/calendar/calendar_test-include.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// This is a shared test between icu and javaicu calendar implementations.
-// It is meant to be #include'd.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
-
-#if defined TC3_CALENDAR_ICU
-#include "utils/calendar/calendar-icu.h"
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) VAR()
-#elif defined TC3_CALENDAR_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(VAR) \
- VAR(JniCache::Create(g_jenv))
-#include "utils/calendar/calendar-javaicu.h"
-#else
-#error Unsupported calendar implementation.
-#endif
-#include "utils/base/logging.h"
-
-#include "gtest/gtest.h"
-
-// This can get overridden in the javaicu version which needs to pass an JNIEnv*
-// argument to the constructor.
-#ifndef TC3_TESTING_CREATE_CALENDARLIB_INSTANCE
-
-#endif
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-class CalendarTest : public ::testing::Test {
- protected:
- CalendarTest() : TC3_TESTING_CREATE_CALENDARLIB_INSTANCE(calendarlib_) {}
- CalendarLib calendarlib_;
-};
-
-} // namespace test_internal
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_TEST_INCLUDE_H_
diff --git a/utils/flatbuffers.cc b/utils/flatbuffers.cc
deleted file mode 100644
index a4dbabd..0000000
--- a/utils/flatbuffers.cc
+++ /dev/null
@@ -1,421 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/flatbuffers.h"
-
-#include <vector>
-#include "utils/strings/numbers.h"
-#include "utils/variant.h"
-
-namespace libtextclassifier3 {
-namespace {
-bool CreateRepeatedField(
- const reflection::Schema* schema, const reflection::Type* type,
- std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
- switch (type->element()) {
- case reflection::Bool:
- repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
- return true;
- case reflection::Int:
- repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
- return true;
- case reflection::Long:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
- return true;
- case reflection::Float:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<float>);
- return true;
- case reflection::Double:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<double>);
- return true;
- case reflection::String:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
- return true;
- case reflection::Obj:
- repeated_field->reset(
- new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
- schema, type));
- return true;
- default:
- TC3_LOG(ERROR) << "Unsupported type: " << type->element();
- return false;
- }
-}
-} // namespace
-
-template <>
-const char* FlatbufferFileIdentifier<Model>() {
- return ModelIdentifier();
-}
-
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
- const {
- if (!schema_->root_table()) {
- TC3_LOG(ERROR) << "No root table specified.";
- return nullptr;
- }
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, schema_->root_table()));
-}
-
-std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
- StringPiece table_name) const {
- for (const reflection::Object* object : *schema_->objects()) {
- if (table_name.Equals(object->name()->str())) {
- return std::unique_ptr<ReflectiveFlatbuffer>(
- new ReflectiveFlatbuffer(schema_, object));
- }
- }
- return nullptr;
-}
-
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
- const StringPiece field_name) const {
- return type_->fields()->LookupByKey(field_name.data());
-}
-
-const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
- const FlatbufferField* field) const {
- // Lookup by name might be faster as the fields are sorted by name in the
- // schema data, so try that first.
- if (field->field_name() != nullptr) {
- return GetFieldOrNull(field->field_name()->str());
- }
- return GetFieldByOffsetOrNull(field->field_offset());
-}
-
-bool ReflectiveFlatbuffer::GetFieldWithParent(
- const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
- reflection::Field const** field) {
- const auto* path = field_path->field();
- if (path == nullptr || path->size() == 0) {
- return false;
- }
-
- for (int i = 0; i < path->size(); i++) {
- *parent = (i == 0 ? this : (*parent)->Mutable(*field));
- if (*parent == nullptr) {
- return false;
- }
- *field = (*parent)->GetFieldOrNull(path->Get(i));
- if (*field == nullptr) {
- return false;
- }
- }
-
- return true;
-}
-
-const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
- const int field_offset) const {
- if (type_->fields() == nullptr) {
- return nullptr;
- }
- for (const reflection::Field* field : *type_->fields()) {
- if (field->offset() == field_offset) {
- return field;
- }
- }
- return nullptr;
-}
-
-bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
- const Variant& value) const {
- switch (field->type()->base_type()) {
- case reflection::Bool:
- return value.HasBool();
- case reflection::Int:
- return value.HasInt();
- case reflection::Long:
- return value.HasInt64();
- case reflection::Float:
- return value.HasFloat();
- case reflection::Double:
- return value.HasDouble();
- case reflection::String:
- return value.HasString();
- default:
- return false;
- }
-}
-
-bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
- const std::string& value) {
- switch (field->type()->base_type()) {
- case reflection::String:
- return Set(field, value);
- case reflection::Int: {
- int32 int_value;
- if (!ParseInt32(value.data(), &int_value)) {
- TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32.";
- return false;
- }
- return Set(field, int_value);
- }
- case reflection::Long: {
- int64 int_value;
- if (!ParseInt64(value.data(), &int_value)) {
- TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64.";
- return false;
- }
- return Set(field, int_value);
- }
- case reflection::Float: {
- double double_value;
- if (!ParseDouble(value.data(), &double_value)) {
- TC3_LOG(ERROR) << "Could not parse '" << value << "' as float.";
- return false;
- }
- return Set(field, static_cast<float>(double_value));
- }
- case reflection::Double: {
- double double_value;
- if (!ParseDouble(value.data(), &double_value)) {
- TC3_LOG(ERROR) << "Could not parse '" << value << "' as double.";
- return false;
- }
- return Set(field, double_value);
- }
- default:
- TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
- return false;
- }
-}
-
-bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
- const std::string& value) {
- ReflectiveFlatbuffer* parent;
- const reflection::Field* field;
- if (!GetFieldWithParent(path, &parent, &field)) {
- return false;
- }
- return parent->ParseAndSet(field, value);
-}
-
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const StringPiece field_name) {
- if (const reflection::Field* field = GetFieldOrNull(field_name)) {
- return Mutable(field);
- }
- TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
- return nullptr;
-}
-
-ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
- const reflection::Field* field) {
- if (field->type()->base_type() != reflection::Obj) {
- TC3_LOG(ERROR) << "Field is not of type Object.";
- return nullptr;
- }
- const auto entry = children_.find(field);
- if (entry != children_.end()) {
- return entry->second.get();
- }
- const auto it = children_.insert(
- /*hint=*/entry,
- std::make_pair(
- field,
- std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
- schema_, schema_->objects()->Get(field->type()->index())))));
- return it->second.get();
-}
-
-ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
- StringPiece field_name) {
- if (const reflection::Field* field = GetFieldOrNull(field_name)) {
- return Repeated(field);
- }
- TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
- return nullptr;
-}
-
-ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
- const reflection::Field* field) {
- if (field->type()->base_type() != reflection::Vector) {
- TC3_LOG(ERROR) << "Field is not of type Vector.";
- return nullptr;
- }
-
- // If the repeated field was already set, return its instance.
- const auto entry = repeated_fields_.find(field);
- if (entry != repeated_fields_.end()) {
- return entry->second.get();
- }
-
- // Otherwise, create a new instance and store it.
- std::unique_ptr<RepeatedField> repeated_field;
- if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
- TC3_LOG(ERROR) << "Could not create repeated field.";
- return nullptr;
- }
- const auto it = repeated_fields_.insert(
- /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
- return it->second.get();
-}
-
-flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
- flatbuffers::FlatBufferBuilder* builder) const {
- // Build all children before we can start with this table.
- std::vector<
- std::pair</* field vtable offset */ int,
- /* field data offset in buffer */ flatbuffers::uoffset_t>>
- offsets;
- offsets.reserve(children_.size() + repeated_fields_.size());
- for (const auto& it : children_) {
- offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
- }
-
- // Create strings.
- for (const auto& it : fields_) {
- if (it.second.HasString()) {
- offsets.push_back({it.first->offset(),
- builder->CreateString(it.second.StringValue()).o});
- }
- }
-
- // Build the repeated fields.
- for (const auto& it : repeated_fields_) {
- offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
- }
-
- // Build the table now.
- const flatbuffers::uoffset_t table_start = builder->StartTable();
-
- // Add scalar fields.
- for (const auto& it : fields_) {
- switch (it.second.GetType()) {
- case Variant::TYPE_BOOL_VALUE:
- builder->AddElement<uint8_t>(
- it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
- static_cast<uint8_t>(it.first->default_integer()));
- continue;
- case Variant::TYPE_INT_VALUE:
- builder->AddElement<int32>(
- it.first->offset(), it.second.IntValue(),
- static_cast<int32>(it.first->default_integer()));
- continue;
- case Variant::TYPE_INT64_VALUE:
- builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
- it.first->default_integer());
- continue;
- case Variant::TYPE_FLOAT_VALUE:
- builder->AddElement<float>(
- it.first->offset(), it.second.FloatValue(),
- static_cast<float>(it.first->default_real()));
- continue;
- case Variant::TYPE_DOUBLE_VALUE:
- builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
- it.first->default_real());
- continue;
- default:
- continue;
- }
- }
-
- // Add strings, subtables and repeated fields.
- for (const auto& it : offsets) {
- builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
- }
-
- return builder->EndTable(table_start);
-}
-
-std::string ReflectiveFlatbuffer::Serialize() const {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
- // No fields to set.
- if (type_->fields() == nullptr) {
- return true;
- }
-
- for (const reflection::Field* field : *type_->fields()) {
- // Skip fields that are not explicitly set.
- if (!from->CheckField(field->offset())) {
- continue;
- }
- const reflection::BaseType type = field->type()->base_type();
- switch (type) {
- case reflection::Bool:
- Set<bool>(field, from->GetField<uint8_t>(field->offset(),
- field->default_integer()));
- break;
- case reflection::Int:
- Set<int32>(field, from->GetField<int32>(field->offset(),
- field->default_integer()));
- break;
- case reflection::Long:
- Set<int64>(field, from->GetField<int64>(field->offset(),
- field->default_integer()));
- break;
- case reflection::Float:
- Set<float>(field, from->GetField<float>(field->offset(),
- field->default_real()));
- break;
- case reflection::Double:
- Set<double>(field, from->GetField<double>(field->offset(),
- field->default_real()));
- break;
- case reflection::String:
- Set<std::string>(
- field, from->GetPointer<const flatbuffers::String*>(field->offset())
- ->str());
- break;
- case reflection::Obj:
- if (!Mutable(field)->MergeFrom(
- from->GetPointer<const flatbuffers::Table* const>(
- field->offset()))) {
- return false;
- }
- break;
- default:
- TC3_LOG(ERROR) << "Unsupported type: " << type;
- return false;
- }
- }
- return true;
-}
-
-bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
- return MergeFrom(flatbuffers::GetAnyRoot(
- reinterpret_cast<const unsigned char*>(from.data())));
-}
-
-void ReflectiveFlatbuffer::AsFlatMap(
- const std::string& key_separator, const std::string& key_prefix,
- std::map<std::string, Variant>* result) const {
- // Add direct fields.
- for (auto it : fields_) {
- (*result)[key_prefix + it.first->name()->str()] = it.second;
- }
-
- // Add nested messages.
- for (auto& it : children_) {
- it.second->AsFlatMap(key_separator,
- key_prefix + it.first->name()->str() + key_separator,
- result);
- }
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/flatbuffers.fbs b/utils/flatbuffers.fbs
deleted file mode 100755
index 584b885..0000000
--- a/utils/flatbuffers.fbs
+++ /dev/null
@@ -1,32 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-// Specifies a field in a flatbuffer message.
-namespace libtextclassifier3;
-table FlatbufferField {
- // Name of the field.
- field_name:string;
-
- // Offset of the field
- field_offset:int;
-}
-
-// Specifies a (nested) field in a flatbuffer message.
-namespace libtextclassifier3;
-table FlatbufferFieldPath {
- field:[FlatbufferField];
-}
-
diff --git a/utils/flatbuffers.h b/utils/flatbuffers.h
deleted file mode 100644
index 76b095f..0000000
--- a/utils/flatbuffers.h
+++ /dev/null
@@ -1,337 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Utility functions for working with FlatBuffers.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-
-#include <map>
-#include <memory>
-#include <string>
-
-#include "annotator/model_generated.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/variant.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-
-namespace libtextclassifier3 {
-
-// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
-// integrity.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- flatbuffers::GetRoot<FlatbufferMessage>(buffer);
- if (message == nullptr) {
- return nullptr;
- }
- flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
- size);
- if (message->Verify(verifier)) {
- return message;
- } else {
- return nullptr;
- }
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-// Loads and interprets the buffer as 'FlatbufferMessage', verifies its
-// integrity and returns its mutable version.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
- const FlatbufferMessage* message =
- LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
- if (message == nullptr) {
- return nullptr;
- }
- return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
- message->UnPack());
-}
-
-// Same as above but takes string.
-template <typename FlatbufferMessage>
-std::unique_ptr<typename FlatbufferMessage::NativeTableType>
-LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
- return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
- buffer.size());
-}
-
-template <typename FlatbufferMessage>
-const char* FlatbufferFileIdentifier() {
- return nullptr;
-}
-
-template <>
-const char* FlatbufferFileIdentifier<Model>();
-
-// Packs the mutable flatbuffer message to string.
-template <typename FlatbufferMessage>
-std::string PackFlatbuffer(
- const typename FlatbufferMessage::NativeTableType* mutable_message) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
- FlatbufferFileIdentifier<FlatbufferMessage>());
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-// A flatbuffer that can be built using flatbuffer reflection data of the
-// schema.
-// Normally, field information is hard-coded in code generated from a flatbuffer
-// schema. Here we lookup the necessary information for building a flatbuffer
-// from the provided reflection meta data.
-// When serializing a flatbuffer, the library requires that the sub messages
-// are already serialized, therefore we explicitly keep the field values and
-// serialize the message in (reverse) topological dependency order.
-class ReflectiveFlatbuffer {
- public:
- ReflectiveFlatbuffer(const reflection::Schema* schema,
- const reflection::Object* type)
- : schema_(schema), type_(type) {}
-
- // Encapsulates a repeated field.
- // Serves as a common base class for repeated fields.
- class RepeatedField {
- public:
- virtual ~RepeatedField() {}
-
- virtual flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const = 0;
- };
-
- // Represents a repeated field of particular type.
- template <typename T>
- class TypedRepeatedField : public RepeatedField {
- public:
- void Add(const T value) { items_.push_back(value); }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- return builder->CreateVector(items_).o;
- }
-
- private:
- std::vector<T> items_;
- };
-
- // Specialization for strings.
- template <>
- class TypedRepeatedField<std::string> : public RepeatedField {
- public:
- void Add(const std::string& value) { items_.push_back(value); }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
- items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = builder->CreateString(items_[i]);
- }
- return builder->CreateVector(offsets).o;
- }
-
- private:
- std::vector<std::string> items_;
- };
-
- // Specialization for repeated sub-messages.
- template <>
- class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
- public:
- TypedRepeatedField<ReflectiveFlatbuffer>(
- const reflection::Schema* const schema,
- const reflection::Type* const type)
- : schema_(schema), type_(type) {}
-
- ReflectiveFlatbuffer* Add() {
- items_.emplace_back(new ReflectiveFlatbuffer(
- schema_, schema_->objects()->Get(type_->index())));
- return items_.back().get();
- }
-
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<void>> offsets(items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = items_[i]->Serialize(builder);
- }
- return builder->CreateVector(offsets).o;
- }
-
- private:
- const reflection::Schema* const schema_;
- const reflection::Type* const type_;
- std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
- };
-
- // Gets the field information for a field name, returns nullptr if the
- // field was not defined.
- const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
- const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
- const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
-
- // Gets a nested field and the message it is defined on.
- bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
- ReflectiveFlatbuffer** parent,
- reflection::Field const** field);
-
- // Checks whether a variant value type agrees with a field type.
- bool IsMatchingType(const reflection::Field* field,
- const Variant& value) const;
-
- // Sets a (primitive) field to a specific value.
- // Returns true if successful, and false if the field was not found or the
- // expected type doesn't match.
- template <typename T>
- bool Set(StringPiece field_name, T value) {
- if (const reflection::Field* field = GetFieldOrNull(field_name)) {
- return Set<T>(field, value);
- }
- return false;
- }
-
- // Sets a (primitive) field to a specific value.
- // Returns true if successful, and false if the expected type doesn't match.
- // Expects `field` to be non-null.
- template <typename T>
- bool Set(const reflection::Field* field, T value) {
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Expected non-null field.";
- return false;
- }
- Variant variant_value(value);
- if (!IsMatchingType(field, variant_value)) {
- TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`, expected: " << field->type()->base_type()
- << ", got: " << variant_value.GetType();
- return false;
- }
- fields_[field] = variant_value;
- return true;
- }
-
- template <typename T>
- bool Set(const FlatbufferFieldPath* path, T value) {
- ReflectiveFlatbuffer* parent;
- const reflection::Field* field;
- if (!GetFieldWithParent(path, &parent, &field)) {
- return false;
- }
- return parent->Set<T>(field, value);
- }
-
- // Sets a (primitive) field to a specific value.
- // Parses the string value according to the field type.
- bool ParseAndSet(const reflection::Field* field, const std::string& value);
- bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
-
- // Gets the reflective flatbuffer for a table field.
- // Returns nullptr if the field was not found, or the field type was not a
- // table.
- ReflectiveFlatbuffer* Mutable(StringPiece field_name);
- ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
-
- // Gets the reflective flatbuffer for a repeated field.
- // Returns nullptr if the field was not found, or the field type was not a
- // vector.
- RepeatedField* Repeated(StringPiece field_name);
- RepeatedField* Repeated(const reflection::Field* field);
-
- template <typename T>
- TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
- return static_cast<TypedRepeatedField<T>*>(Repeated(field));
- }
-
- template <typename T>
- TypedRepeatedField<T>* Repeated(StringPiece field_name) {
- return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
- }
-
- // Serializes the flatbuffer.
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const;
- std::string Serialize() const;
-
- // Merges the fields from the given flatbuffer table into this flatbuffer.
- // Scalar fields will be overwritten, if present in `from`.
- // Embedded messages will be merged.
- bool MergeFrom(const flatbuffers::Table* from);
- bool MergeFromSerializedFlatbuffer(StringPiece from);
-
- // Flattens the flatbuffer as a flat map.
- // (Nested) fields names are joined by `key_separator`.
- std::map<std::string, Variant> AsFlatMap(
- const std::string& key_separator = ".") const {
- std::map<std::string, Variant> result;
- AsFlatMap(key_separator, /*key_prefix=*/"", &result);
- return result;
- }
-
- private:
- const reflection::Schema* const schema_;
- const reflection::Object* const type_;
-
- // Cached primitive fields (scalars and strings).
- std::map<const reflection::Field*, Variant> fields_;
-
- // Cached sub-messages.
- std::map<const reflection::Field*, std::unique_ptr<ReflectiveFlatbuffer>>
- children_;
-
- // Cached repeated fields.
- std::map<const reflection::Field*, std::unique_ptr<RepeatedField>>
- repeated_fields_;
-
- // Flattens the flatbuffer as a flat map.
- // (Nested) fields names are joined by `key_separator` and prefixed by
- // `key_prefix`.
- void AsFlatMap(const std::string& key_separator,
- const std::string& key_prefix,
- std::map<std::string, Variant>* result) const;
-};
-
-// A helper class to build flatbuffers based on schema reflection data.
-// Can be used to a `ReflectiveFlatbuffer` for the root message of the
-// schema, or any defined table via name.
-class ReflectiveFlatbufferBuilder {
- public:
- explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
- : schema_(schema) {}
-
- // Starts a new root table message.
- std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
-
- // Starts a new table message. Returns nullptr if no table with given name is
- // found in the schema.
- std::unique_ptr<ReflectiveFlatbuffer> NewTable(
- const StringPiece table_name) const;
-
- private:
- const reflection::Schema* const schema_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/utils/flatbuffers_test.cc b/utils/flatbuffers_test.cc
deleted file mode 100644
index 348ca73..0000000
--- a/utils/flatbuffers_test.cc
+++ /dev/null
@@ -1,311 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <fstream>
-#include <map>
-#include <memory>
-#include <string>
-
-#include "utils/flatbuffers.h"
-#include "utils/flatbuffers_generated.h"
-#include "utils/flatbuffers_test_generated.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-#include "flatbuffers/reflection_generated.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestMetadataPath() {
- return "flatbuffers_test.bfbs";
-}
-
-std::string LoadTestMetadata() {
- std::ifstream test_config_stream(GetTestMetadataPath());
- return std::string((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
-}
-
-TEST(FlatbuffersTest, PrimitiveFieldsAreCorrectlySet) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- EXPECT_TRUE(buffer != nullptr);
- EXPECT_TRUE(buffer->Set("an_int_field", 42));
- EXPECT_TRUE(buffer->Set("a_long_field", 84ll));
- EXPECT_TRUE(buffer->Set("a_bool_field", true));
- EXPECT_TRUE(buffer->Set("a_float_field", 1.f));
- EXPECT_TRUE(buffer->Set("a_double_field", 1.0));
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->an_int_field, 42);
- EXPECT_EQ(entity_data->a_long_field, 84);
- EXPECT_EQ(entity_data->a_bool_field, true);
- EXPECT_NEAR(entity_data->a_float_field, 1.f, 1e-4);
- EXPECT_NEAR(entity_data->a_double_field, 1.f, 1e-4);
-}
-
-TEST(FlatbuffersTest, HandlesUnknownFields) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- EXPECT_TRUE(buffer != nullptr);
-
- // Add a field that is not known to the (statically generated) code.
- EXPECT_TRUE(buffer->Set("mystic", "this is an unknown field."));
-
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(flatbuffers::Offset<void>(buffer->Serialize(&builder)));
-
- // Try to read the field again.
- const flatbuffers::Table* extra =
- flatbuffers::GetAnyRoot(builder.GetBufferPointer());
- EXPECT_EQ(extra
- ->GetPointer<const flatbuffers::String*>(
- buffer->GetFieldOrNull("mystic")->offset())
- ->str(),
- "this is an unknown field.");
-}
-
-TEST(FlatbuffersTest, HandlesNestedFields) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
-
- ReflectiveFlatbuffer* parent = nullptr;
- reflection::Field const* field = nullptr;
- EXPECT_TRUE(
- buffer->GetFieldWithParent(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- &parent, &field));
- EXPECT_EQ(parent, buffer->Mutable("flight_number"));
- EXPECT_EQ(field,
- buffer->Mutable("flight_number")->GetFieldOrNull("carrier_code"));
-}
-
-TEST(FlatbuffersTest, HandlesMultipleNestedFields) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- ReflectiveFlatbuffer* contact_info = buffer->Mutable("contact_info");
- EXPECT_TRUE(contact_info->Set("first_name", "Barack"));
- EXPECT_TRUE(contact_info->Set("last_name", "Obama"));
- EXPECT_TRUE(contact_info->Set("phone_number", "1-800-TEST"));
- EXPECT_TRUE(contact_info->Set("score", 1.f));
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
- EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
- EXPECT_EQ(entity_data->contact_info->last_name, "Obama");
- EXPECT_EQ(entity_data->contact_info->phone_number, "1-800-TEST");
- EXPECT_NEAR(entity_data->contact_info->score, 1.f, 1e-4);
-}
-
-TEST(FlatbuffersTest, HandlesFieldsSetWithNamePath) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "flight_number";
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_name = "carrier_code";
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- // Test setting value using Set function.
- buffer->Mutable("flight_number")->Set("flight_code", 38);
- // Test setting value using FlatbufferFieldPath.
- buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- "LX");
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
-}
-
-TEST(FlatbuffersTest, HandlesFieldsSetWithOffsetPath) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
-
- FlatbufferFieldPathT path;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_offset = 14;
- path.field.emplace_back(new FlatbufferFieldT);
- path.field.back()->field_offset = 4;
- flatbuffers::FlatBufferBuilder path_builder;
- path_builder.Finish(FlatbufferFieldPath::Pack(path_builder, &path));
-
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- // Test setting value using Set function.
- buffer->Mutable("flight_number")->Set("flight_code", 38);
- // Test setting value using FlatbufferFieldPath.
- buffer->Set(flatbuffers::GetRoot<FlatbufferFieldPath>(
- path_builder.GetBufferPointer()),
- "LX");
-
- // Try to parse with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 38);
-}
-
-TEST(FlatbuffersTest, PartialBuffersAreCorrectlyMerged) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- buffer->Set("an_int_field", 42);
- buffer->Set("a_long_field", 84ll);
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- // Create message to merge.
- test::EntityDataT additional_entity_data;
- additional_entity_data.an_int_field = 43;
- additional_entity_data.flight_number.reset(new test::FlightNumberInfoT);
- additional_entity_data.flight_number->flight_code = 39;
- additional_entity_data.contact_info.reset(new test::ContactInfoT);
- additional_entity_data.contact_info->first_name = "Barack";
- flatbuffers::FlatBufferBuilder to_merge_builder;
- to_merge_builder.Finish(
- test::EntityData::Pack(to_merge_builder, &additional_entity_data));
-
- // Merge it.
- EXPECT_TRUE(buffer->MergeFrom(
- flatbuffers::GetAnyRoot(to_merge_builder.GetBufferPointer())));
-
- // Try to parse it with the generated code.
- std::string serialized_entity_data = buffer->Serialize();
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(entity_data->an_int_field, 43);
- EXPECT_EQ(entity_data->a_long_field, 84);
- EXPECT_EQ(entity_data->flight_number->carrier_code, "LX");
- EXPECT_EQ(entity_data->flight_number->flight_code, 39);
- EXPECT_EQ(entity_data->contact_info->first_name, "Barack");
-}
-
-TEST(FlatbuffersTest, PrimitiveAndNestedFieldsAreCorrectlyFlattened) {
- std::string metadata_buffer = LoadTestMetadata();
- ReflectiveFlatbufferBuilder reflective_builder(
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data()));
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
- buffer->Set("an_int_field", 42);
- buffer->Set("a_long_field", 84ll);
- ReflectiveFlatbuffer* flight_info = buffer->Mutable("flight_number");
- flight_info->Set("carrier_code", "LX");
- flight_info->Set("flight_code", 38);
-
- std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
- EXPECT_EQ(4, entity_data_map.size());
- EXPECT_EQ(42, entity_data_map["an_int_field"].IntValue());
- EXPECT_EQ(84, entity_data_map["a_long_field"].Int64Value());
- EXPECT_EQ("LX", entity_data_map["flight_number.carrier_code"].StringValue());
- EXPECT_EQ(38, entity_data_map["flight_number.flight_code"].IntValue());
-}
-
-TEST(FlatbuffersTest, RepeatedFieldSetThroughReflectionCanBeRead) {
- std::string metadata_buffer = LoadTestMetadata();
- const reflection::Schema* schema =
- flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
- ReflectiveFlatbufferBuilder reflective_builder(schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = reflective_builder.NewRoot();
-
- auto reminders = buffer->Repeated<ReflectiveFlatbuffer>("reminders");
- {
- auto reminder = reminders->Add();
- reminder->Set("title", "test reminder");
- auto notes = reminder->Repeated<std::string>("notes");
- notes->Add("note A");
- notes->Add("note B");
- }
- {
- auto reminder = reminders->Add();
- reminder->Set("title", "test reminder 2");
- auto notes = reminder->Repeated<std::string>("notes");
- notes->Add("note i");
- notes->Add("note ii");
- notes->Add("note iii");
- }
- const std::string serialized_entity_data = buffer->Serialize();
-
- std::unique_ptr<test::EntityDataT> entity_data =
- LoadAndVerifyMutableFlatbuffer<test::EntityData>(
- serialized_entity_data.data(), serialized_entity_data.size());
- EXPECT_TRUE(entity_data != nullptr);
- EXPECT_EQ(2, entity_data->reminders.size());
- EXPECT_EQ("test reminder", entity_data->reminders[0]->title);
- EXPECT_THAT(entity_data->reminders[0]->notes,
- testing::ElementsAreArray({"note A", "note B"}));
- EXPECT_EQ("test reminder 2", entity_data->reminders[1]->title);
- EXPECT_THAT(entity_data->reminders[1]->notes,
- testing::ElementsAreArray({"note i", "note ii", "note iii"}));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/flatbuffers_test.fbs b/utils/flatbuffers_test.fbs
deleted file mode 100644
index 0d5b09b..0000000
--- a/utils/flatbuffers_test.fbs
+++ /dev/null
@@ -1,47 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-namespace libtextclassifier3.test;
-
-table FlightNumberInfo {
- carrier_code: string;
- flight_code: int;
-}
-
-table ContactInfo {
- first_name: string;
- last_name: string;
- phone_number: string;
- score: float;
-}
-
-table Reminder {
- title: string;
- notes: [string];
-}
-
-table EntityData {
- an_int_field: int;
- a_long_field: int64;
- a_bool_field: bool;
- a_float_field: float;
- a_double_field: double;
- flight_number: FlightNumberInfo;
- contact_info: ContactInfo;
- reminders: [Reminder];
-}
-
-root_type libtextclassifier3.test.EntityData;
diff --git a/utils/i18n/locale.cc b/utils/i18n/locale.cc
deleted file mode 100644
index 6349d63..0000000
--- a/utils/i18n/locale.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/i18n/locale.h"
-
-#include "utils/strings/split.h"
-
-namespace libtextclassifier3 {
-
-namespace {
-constexpr const char* kAnyMatch = "*";
-
-// BCP 47 code for "Undetermined Language".
-constexpr const char* kUnknownLanguageCode = "und";
-
-bool CheckLanguage(StringPiece language) {
- if (language.size() == 1 && language.data()[0] == '*') {
- return true;
- }
-
- if (language.size() != 2 && language.size() != 3) {
- return false;
- }
-
- // Needs to be all lowercase.
- for (int i = 0; i < language.size(); ++i) {
- if (!std::islower(language[i])) {
- return false;
- }
- }
-
- return true;
-}
-
-bool CheckScript(StringPiece script) {
- if (script.size() != 4) {
- return false;
- }
-
- if (!std::isupper(script[0])) {
- return false;
- }
-
- // Needs to be all lowercase.
- for (int i = 1; i < script.size(); ++i) {
- if (!std::islower(script[i])) {
- return false;
- }
- }
-
- return true;
-}
-
-bool CheckRegion(StringPiece region) {
- if (region.size() == 2) {
- return std::isupper(region[0]) && std::isupper(region[1]);
- } else if (region.size() == 3) {
- return std::isdigit(region[0]) && std::isdigit(region[1]) &&
- std::isdigit(region[2]);
- } else {
- return false;
- }
-}
-
-} // namespace
-
-Locale Locale::FromBCP47(const std::string& locale_tag) {
- std::vector<StringPiece> parts = strings::Split(locale_tag, '-');
- if (parts.empty()) {
- return Locale::Invalid();
- }
-
- auto parts_it = parts.begin();
- StringPiece language = *parts_it;
- if (!CheckLanguage(language)) {
- return Locale::Invalid();
- }
- ++parts_it;
-
- StringPiece script;
- if (parts_it != parts.end()) {
- script = *parts_it;
- if (!CheckScript(script)) {
- script = "";
- } else {
- ++parts_it;
- }
- }
-
- StringPiece region;
- if (parts_it != parts.end()) {
- region = *parts_it;
- if (!CheckRegion(region)) {
- region = "";
- } else {
- ++parts_it;
- }
- }
-
- // NOTE: We don't parse the rest of the BCP47 tag here even if specified.
-
- return Locale(language.ToString(), script.ToString(), region.ToString());
-}
-
-bool Locale::IsUnknown() const {
- return is_valid_ && language_ == kUnknownLanguageCode;
-}
-
-bool Locale::IsLocaleSupported(const Locale& locale,
- const std::vector<Locale>& supported_locales,
- bool default_value) {
- if (!locale.IsValid()) {
- return false;
- }
- if (locale.IsUnknown()) {
- return default_value;
- }
- for (const Locale& supported_locale : supported_locales) {
- if (!supported_locale.IsValid()) {
- continue;
- }
- const bool language_matches =
- supported_locale.Language().empty() ||
- supported_locale.Language() == kAnyMatch ||
- supported_locale.Language() == locale.Language();
- const bool script_matches = supported_locale.Script().empty() ||
- supported_locale.Script() == kAnyMatch ||
- locale.Script().empty() ||
- supported_locale.Script() == locale.Script();
- const bool region_matches = supported_locale.Region().empty() ||
- supported_locale.Region() == kAnyMatch ||
- locale.Region().empty() ||
- supported_locale.Region() == locale.Region();
- if (language_matches && script_matches && region_matches) {
- return true;
- }
- }
- return false;
-}
-
-bool Locale::IsAnyLocaleSupported(const std::vector<Locale>& locales,
- const std::vector<Locale>& supported_locales,
- bool default_value) {
- if (locales.empty()) {
- return default_value;
- }
- if (supported_locales.empty()) {
- return default_value;
- }
- for (const Locale& locale : locales) {
- if (IsLocaleSupported(locale, supported_locales, default_value)) {
- return true;
- }
- }
- return false;
-}
-
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const Locale& locale) {
- return stream << "Locale(language=" << locale.Language()
- << ", script=" << locale.Script()
- << ", region=" << locale.Region()
- << ", is_valid=" << locale.IsValid()
- << ", is_unknown=" << locale.IsUnknown() << ")";
-}
-
-bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales) {
- for (const auto& locale_str : strings::Split(locales_list, ',')) {
- const Locale locale = Locale::FromBCP47(locale_str.ToString());
- if (!locale.IsValid()) {
- TC3_LOG(ERROR) << "Invalid locale " << locale_str.ToString();
- return false;
- }
- locales->push_back(locale);
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/i18n/locale.h b/utils/i18n/locale.h
deleted file mode 100644
index 4420b56..0000000
--- a/utils/i18n/locale.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
-#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-class Locale {
- public:
- // Constructs the object from a valid BCP47 tag. If the tag is invalid,
- // an object is created that gives false when IsInvalid() is called.
- static Locale FromBCP47(const std::string& locale_tag);
-
- // Creates a prototypical invalid locale object.
- static Locale Invalid() {
- Locale locale(/*language=*/"", /*script=*/"", /*region=*/"");
- locale.is_valid_ = false;
- return locale;
- }
-
- std::string Language() const { return language_; }
-
- std::string Script() const { return script_; }
-
- std::string Region() const { return region_; }
-
- bool IsValid() const { return is_valid_; }
- bool IsUnknown() const;
-
- // Returns whether any of the given locales is supported by any of the
- // supported locales. Returns default value if the given 'locales' list, or
- // 'supported_locales' list is empty or an unknown locale is found.
- // Locale::FromBCP47("*") means any locale.
- static bool IsAnyLocaleSupported(const std::vector<Locale>& locales,
- const std::vector<Locale>& supported_locales,
- bool default_value);
-
- private:
- Locale(const std::string& language, const std::string& script,
- const std::string& region)
- : language_(language),
- script_(script),
- region_(region),
- is_valid_(true) {}
-
- static bool IsLocaleSupported(const Locale& locale,
- const std::vector<Locale>& supported_locales,
- bool default_value);
-
- std::string language_;
- std::string script_;
- std::string region_;
- bool is_valid_;
-};
-
-// Pretty-printing function for Locale.
-logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
- const Locale& locale);
-
-// Parses a comma-separated list of BCP47 tags.
-bool ParseLocales(StringPiece locales_list, std::vector<Locale>* locales);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_H_
diff --git a/utils/intents/IntentGeneratorTest.java b/utils/intents/IntentGeneratorTest.java
deleted file mode 100644
index f43ecc0..0000000
--- a/utils/intents/IntentGeneratorTest.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.intents;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.content.Context;
-import androidx.test.InstrumentationRegistry;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public final class IntentGeneratorTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("intent-generator-test-lib");
- }
-
- private native boolean testsMain(Context context);
-
- @Test
- public void testNative() {
- assertThat(testsMain(InstrumentationRegistry.getContext())).isTrue();
- }
-}
diff --git a/utils/intents/intent-config.fbs b/utils/intents/intent-config.fbs
deleted file mode 100755
index 09ebbb4..0000000
--- a/utils/intents/intent-config.fbs
+++ /dev/null
@@ -1,199 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "utils/zlib/buffer.fbs";
-
-// The type of variable to fetch.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorVariableType : int {
- INVALID_VARIABLE = 0,
-
- // The raw text that was classified.
- RAW_TEXT = 1,
-
- // Text as a URL with explicit protocol. If no protocol was specified, http
- // is prepended.
- URL_TEXT = 2,
-
- // The raw text, but URL encoded.
- URL_ENCODED_TEXT = 3,
-
- // For dates/times: the instant of the event in UTC millis.
- EVENT_TIME_MS_UTC = 4,
-
- // For dates/times: the start of the event in UTC millis.
- EVENT_START_MS_UTC = 5,
-
- // For dates/times: the end of the event in UTC millis.
- EVENT_END_MS_UTC = 6,
-
- // Name of the package that's running the classifier.
- PACKAGE_NAME = 7,
-}
-
-// Enumerates the possible extra types for the simple intent generator.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorExtraType : int {
- INVALID_EXTRA_TYPE = 0,
- STRING = 1,
- BOOL = 2,
- VARIABLE_AS_LONG = 3,
-}
-
-// Enumerates the possible condition types for the simple intent generator.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorConditionType : int {
- INVALID_CONDITION_TYPE = 0,
-
- // Queries the UserManager for the given boolean restriction. The condition
- // passes if the result is of getBoolean is false. The name of the
- // restriction to check is in the string_ field.
- USER_RESTRICTION_NOT_SET = 1,
-
- // Checks that the parsed event start time is at least a give number of
- // milliseconds in the future. (Only valid if there is a parsed event
- // time) The offset is stored in the int64_ field.
- EVENT_START_IN_FUTURE_MS = 2,
-}
-
-// Describes how intents for the various entity types should be generated on
-// Android. This is distributed through the model, but not used by
-// libtextclassifier yet - rather, it's passed to the calling Java code, which
-// implements the Intent generation logic.
-namespace libtextclassifier3;
-table AndroidIntentFactoryOptions {
- entity:[AndroidIntentFactoryEntityOptions];
-}
-
-// Describes how intents should be generated for a particular entity type.
-namespace libtextclassifier3;
-table AndroidIntentFactoryEntityOptions {
- // The entity type as defined by one of the TextClassifier ENTITY_TYPE
- // constants. (e.g. "address", "phone", etc.)
- entity_type:string;
-
- // List of generators for all the different types of intents that should
- // be made available for the entity type.
- generator:[AndroidIntentGeneratorOptions];
-}
-
-// Configures a single Android Intent generator.
-namespace libtextclassifier3;
-table AndroidIntentGeneratorOptions {
- // Strings for UI elements.
- strings:[AndroidIntentGeneratorStrings];
-
- // Generator specific configuration.
- simple:AndroidSimpleIntentGeneratorOptions;
-}
-
-// Language dependent configuration for an Android Intent generator.
-namespace libtextclassifier3;
-table AndroidIntentGeneratorStrings {
- // BCP 47 tag for the supported locale. Note that because of API level
- // restrictions, this must /not/ use wildcards. To e.g. match all English
- // locales, use only "en" and not "en_*". Reference the java.util.Locale
- // constructor for details.
- language_tag:string;
-
- // Title shown for the action (see RemoteAction.getTitle).
- title:string;
-
- // Description shown for the action (see
- // RemoteAction.getContentDescription).
- description:string;
-}
-
-// An extra to set on a simple intent generator Intent.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorExtra {
- // The name of the extra to set.
- name:string;
-
- // The type of the extra to set.
- type:AndroidSimpleIntentGeneratorExtraType;
-
- string_:string;
-
- bool_:bool;
- int32_:int;
-}
-
-// A condition that needs to be fulfilled for an Intent to get generated.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorCondition {
- type:AndroidSimpleIntentGeneratorConditionType;
-
- string_:string;
-
- int32_:int;
- int64_:long;
-}
-
-// Configures an intent generator where the logic is simple to be expressed with
-// basic rules - which covers the vast majority of use cases and is analogous
-// to Android Actions.
-// Most strings (action, data, type, ...) may contain variable references. To
-// use them, the generator must first declare all the variables it wishes to use
-// in the variables field. The values then become available as numbered
-// arguments (using the normal java.util.Formatter syntax) in the order they
-// were specified.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorOptions {
- // The action to set on the Intent (see Intent.setAction). Supports variables.
- action:string;
-
- // The data to set on the Intent (see Intent.setData). Supports variables.
- data:string;
-
- // The type to set on the Intent (see Intent.setType). Supports variables.
- type:string;
-
- // The list of all the extras to add to the Intent.
- extra:[AndroidSimpleIntentGeneratorExtra];
-
- // The list of all the variables that become available for substitution in
- // the action, data, type and extra strings. To e.g. set a field to the value
- // of the first variable, use "%0$s".
- variable:[AndroidSimpleIntentGeneratorVariableType];
-
- // The list of all conditions that need to be fulfilled for Intent generation.
- condition:[AndroidSimpleIntentGeneratorCondition];
-}
-
-// Describes how intents should be generated for a particular entity type.
-namespace libtextclassifier3.IntentFactoryModel_;
-table IntentGenerator {
- // The type of the intent generator, e.g. the entity type as defined by
- // on the TextClassifier ENTITY_TYPE constants e.g. "address", "phone", etc.
- type:string;
-
- // The template generator lua code, either as text source or precompiled
- // bytecode.
- lua_template_generator:[ubyte];
-
- compressed_lua_template_generator:CompressedBuffer;
-}
-
-// Describes how intents for the various entity types should be generated.
-namespace libtextclassifier3;
-table IntentFactoryModel {
- generator:[IntentFactoryModel_.IntentGenerator];
-
- // Whether to precompile the generators when loading.
- precompile_generators:bool = false;
-}
-
diff --git a/utils/intents/intent-generator.cc b/utils/intents/intent-generator.cc
deleted file mode 100644
index f882515..0000000
--- a/utils/intents/intent-generator.cc
+++ /dev/null
@@ -1,899 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/intents/intent-generator.h"
-
-#include <vector>
-
-#include "actions/lua-utils.h"
-#include "actions/types.h"
-#include "annotator/types.h"
-#include "utils/base/logging.h"
-#include "utils/hash/farmhash.h"
-#include "utils/java/jni-base.h"
-#include "utils/java/string_utils.h"
-#include "utils/lua-utils.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/strings/substitute.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/variant.h"
-#include "utils/zlib/zlib.h"
-#include "flatbuffers/reflection_generated.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lua.h"
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-namespace {
-
-static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
-static constexpr const char* kHashKey = "hash";
-static constexpr const char* kUrlSchemaKey = "url_schema";
-static constexpr const char* kUrlHostKey = "url_host";
-static constexpr const char* kUrlEncodeKey = "urlencode";
-static constexpr const char* kPackageNameKey = "package_name";
-static constexpr const char* kDeviceLocaleKey = "device_locales";
-static constexpr const char* kFormatKey = "format";
-
-// An Android specific Lua environment with JNI backed callbacks.
-class JniLuaEnvironment : public LuaEnvironment {
- public:
- JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales);
- // Environment setup.
- bool Initialize();
-
- // Runs an intent generator snippet.
- bool RunIntentGenerator(const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions);
-
- protected:
- virtual void SetupExternalHook();
-
- int HandleExternalCallback();
- int HandleAndroidCallback();
- int HandleUserRestrictionsCallback();
- int HandleUrlEncode();
- int HandleUrlSchema();
- int HandleHash();
- int HandleFormat();
- int HandleAndroidStringResources();
- int HandleUrlHost();
-
- // Checks and retrieves string resources from the model.
- bool LookupModelStringResource();
-
- // Reads and create a RemoteAction result from Lua.
- RemoteActionTemplate ReadRemoteActionTemplateResult();
-
- // Reads the extras from the Lua result.
- void ReadExtras(std::map<std::string, Variant>* extra);
-
- // Reads the intent categories array from a Lua result.
- void ReadCategories(std::vector<std::string>* category);
-
- // Retrieves user manager if not previously done.
- bool RetrieveUserManager();
-
- // Retrieves system resources if not previously done.
- bool RetrieveSystemResources();
-
- // Parse the url string by using Uri.parse from Java.
- ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
-
- // Read remote action templates from lua generator.
- int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
-
- const Resources& resources_;
- JNIEnv* jenv_;
- const JniCache* jni_cache_;
- const jobject context_;
- std::vector<Locale> device_locales_;
-
- ScopedGlobalRef<jobject> usermanager_;
- // Whether we previously attempted to retrieve the UserManager before.
- bool usermanager_retrieved_;
-
- ScopedGlobalRef<jobject> system_resources_;
- // Whether we previously attempted to retrieve the system resources.
- bool system_resources_resources_retrieved_;
-
- // Cached JNI references for Java strings `string` and `android`.
- ScopedGlobalRef<jstring> string_;
- ScopedGlobalRef<jstring> android_;
-};
-
-JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
- const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales)
- : resources_(resources),
- jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
- jni_cache_(jni_cache),
- context_(context),
- device_locales_(device_locales),
- usermanager_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- usermanager_retrieved_(false),
- system_resources_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- system_resources_resources_retrieved_(false),
- string_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
- android_(/*object=*/nullptr,
- /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
-
-bool JniLuaEnvironment::Initialize() {
- string_ =
- MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
- android_ =
- MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
- if (string_ == nullptr || android_ == nullptr) {
- TC3_LOG(ERROR) << "Could not allocate constant strings references.";
- return false;
- }
- return (RunProtected([this] {
- LoadDefaultLibraries();
- SetupExternalHook();
- lua_setglobal(state_, "external");
- return LUA_OK;
- }) == LUA_OK);
-}
-
-void JniLuaEnvironment::SetupExternalHook() {
- // This exposes an `external` object with the following fields:
- // * entity: the bundle with all information about a classification.
- // * android: callbacks into specific android provided methods.
- // * android.user_restrictions: callbacks to check user permissions.
- // * android.R: callbacks to retrieve string resources.
- BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
- "external");
-
- // android
- BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
- "android");
- {
- // android.user_restrictions
- BindTable<JniLuaEnvironment,
- &JniLuaEnvironment::HandleUserRestrictionsCallback>(
- "user_restrictions");
- lua_setfield(state_, /*idx=*/-2, "user_restrictions");
-
- // android.R
- // Callback to access android string resources.
- BindTable<JniLuaEnvironment,
- &JniLuaEnvironment::HandleAndroidStringResources>("R");
- lua_setfield(state_, /*idx=*/-2, "R");
- }
- lua_setfield(state_, /*idx=*/-2, "android");
-}
-
-int JniLuaEnvironment::HandleExternalCallback() {
- const StringPiece key = ReadString(/*index=*/-1);
- if (key.Equals(kHashKey)) {
- Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
- return 1;
- } else if (key.Equals(kFormatKey)) {
- Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleAndroidCallback() {
- const StringPiece key = ReadString(/*index=*/-1);
- if (key.Equals(kDeviceLocaleKey)) {
- // Provide the locale as table with the individual fields set.
- lua_newtable(state_);
- for (int i = 0; i < device_locales_.size(); i++) {
- // Adjust index to 1-based indexing for Lua.
- lua_pushinteger(state_, i + 1);
- lua_newtable(state_);
- PushString(device_locales_[i].Language());
- lua_setfield(state_, -2, "language");
- PushString(device_locales_[i].Region());
- lua_setfield(state_, -2, "region");
- PushString(device_locales_[i].Script());
- lua_setfield(state_, -2, "script");
- lua_settable(state_, /*idx=*/-3);
- }
- return 1;
- } else if (key.Equals(kPackageNameKey)) {
- if (context_ == nullptr) {
- TC3_LOG(ERROR) << "Context invalid.";
- lua_error(state_);
- return 0;
- }
- ScopedLocalRef<jstring> package_name_str(
- static_cast<jstring>(jenv_->CallObjectMethod(
- context_, jni_cache_->context_get_package_name)));
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling Context.getPackageName";
- lua_error(state_);
- return 0;
- }
- PushString(ToStlString(jenv_, package_name_str.get()));
- return 1;
- } else if (key.Equals(kUrlEncodeKey)) {
- Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
- return 1;
- } else if (key.Equals(kUrlHostKey)) {
- Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
- return 1;
- } else if (key.Equals(kUrlSchemaKey)) {
- Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
- return 1;
- } else {
- TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
- lua_error(state_);
- return 0;
- }
-}
-
-int JniLuaEnvironment::HandleUserRestrictionsCallback() {
- if (jni_cache_->usermanager_class == nullptr ||
- jni_cache_->usermanager_get_user_restrictions == nullptr) {
- // UserManager is only available for API level >= 17 and
- // getUserRestrictions only for API level >= 18, so we just return false
- // normally here.
- lua_pushboolean(state_, false);
- return 1;
- }
-
- // Get user manager if not previously retrieved.
- if (!RetrieveUserManager()) {
- TC3_LOG(ERROR) << "Error retrieving user manager.";
- lua_error(state_);
- return 0;
- }
-
- ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
- usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
- if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
- TC3_LOG(ERROR) << "Error calling getUserRestrictions";
- lua_error(state_);
- return 0;
- }
-
- const StringPiece key_str = ReadString(/*index=*/-1);
- if (key_str.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
- if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
- const bool permission = jenv_->CallBooleanMethod(
- bundle.get(), jni_cache_->bundle_get_boolean, key.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error getting bundle value";
- lua_pushboolean(state_, false);
- } else {
- lua_pushboolean(state_, permission);
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlEncode() {
- const StringPiece input = ReadString(/*index=*/1);
- if (input.empty()) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
-
- // Call Java URL encoder.
- ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
- if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null.";
- lua_error(state_);
- return 0;
- }
- ScopedLocalRef<jstring> encoded_str(
- static_cast<jstring>(jenv_->CallStaticObjectMethod(
- jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
- input_str.get(), jni_cache_->string_utf8.get())));
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
- lua_error(state_);
- return 0;
- }
- PushString(ToStlString(jenv_, encoded_str.get()));
- return 1;
-}
-
-ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
- if (url.empty()) {
- return nullptr;
- }
-
- // Call to Java URI parser.
- ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
- if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null";
- return nullptr;
- }
-
- // Try to parse uri and get scheme.
- ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
- jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
- if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
- TC3_LOG(ERROR) << "Error calling Uri.parse";
- }
- return uri;
-}
-
-int JniLuaEnvironment::HandleUrlSchema() {
- StringPiece url = ReadString(/*index=*/1);
-
- ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
- if (parsed_uri == nullptr) {
- lua_error(state_);
- return 0;
- }
-
- ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
- jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling Uri.getScheme";
- lua_error(state_);
- return 0;
- }
- if (scheme_str == nullptr) {
- lua_pushnil(state_);
- } else {
- PushString(ToStlString(jenv_, scheme_str.get()));
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleUrlHost() {
- StringPiece url = ReadString(/*index=*/-1);
-
- ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
- if (parsed_uri == nullptr) {
- lua_error(state_);
- return 0;
- }
-
- ScopedLocalRef<jstring> host_str(static_cast<jstring>(
- jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling Uri.getHost";
- lua_error(state_);
- return 0;
- }
- if (host_str == nullptr) {
- lua_pushnil(state_);
- } else {
- PushString(ToStlString(jenv_, host_str.get()));
- }
- return 1;
-}
-
-int JniLuaEnvironment::HandleHash() {
- const StringPiece input = ReadString(/*index=*/-1);
- lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
- return 1;
-}
-
-int JniLuaEnvironment::HandleFormat() {
- const int num_args = lua_gettop(state_);
- std::vector<StringPiece> args(num_args - 1);
- for (int i = 0; i < num_args - 1; i++) {
- args[i] = ReadString(/*index=*/i + 2);
- }
- PushString(strings::Substitute(ReadString(/*index=*/1), args));
- return 1;
-}
-
-bool JniLuaEnvironment::LookupModelStringResource() {
- // Handle only lookup by name.
- if (lua_type(state_, 2) != LUA_TSTRING) {
- return false;
- }
-
- const StringPiece resource_name = ReadString(/*index=*/-1);
- std::string resource_content;
- if (!resources_.GetResourceContent(device_locales_, resource_name,
- &resource_content)) {
- // Resource cannot be provided by the model.
- return false;
- }
-
- PushString(resource_content);
- return true;
-}
-
-int JniLuaEnvironment::HandleAndroidStringResources() {
- // Check whether the requested resource can be served from the model data.
- if (LookupModelStringResource()) {
- return 1;
- }
-
- // Get system resources if not previously retrieved.
- if (!RetrieveSystemResources()) {
- TC3_LOG(ERROR) << "Error retrieving system resources.";
- lua_error(state_);
- return 0;
- }
-
- int resource_id;
- switch (lua_type(state_, -1)) {
- case LUA_TNUMBER:
- resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
- break;
- case LUA_TSTRING: {
- const StringPiece resource_name_str = ReadString(/*index=*/-1);
- if (resource_name_str.empty()) {
- TC3_LOG(ERROR) << "No resource name provided.";
- lua_error(state_);
- return 0;
- }
- ScopedLocalRef<jstring> resource_name =
- jni_cache_->ConvertToJavaString(resource_name_str);
- if (resource_name == nullptr) {
- TC3_LOG(ERROR) << "Invalid resource name.";
- lua_error(state_);
- return 0;
- }
- resource_id = jenv_->CallIntMethod(
- system_resources_.get(), jni_cache_->resources_get_identifier,
- resource_name.get(), string_.get(), android_.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getIdentifier.";
- lua_error(state_);
- return 0;
- }
- break;
- }
- default:
- TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
- lua_error(state_);
- return 0;
- }
- if (resource_id == 0) {
- TC3_LOG(ERROR) << "Resource not found.";
- lua_pushnil(state_);
- return 1;
- }
- ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
- jenv_->CallObjectMethod(system_resources_.get(),
- jni_cache_->resources_get_string, resource_id)));
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getString.";
- lua_error(state_);
- return 0;
- }
- if (resource_str == nullptr) {
- lua_pushnil(state_);
- } else {
- PushString(ToStlString(jenv_, resource_str.get()));
- }
- return 1;
-}
-
-bool JniLuaEnvironment::RetrieveSystemResources() {
- if (system_resources_resources_retrieved_) {
- return (system_resources_ != nullptr);
- }
- system_resources_resources_retrieved_ = true;
- jobject system_resources_ref = jenv_->CallStaticObjectMethod(
- jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getSystem.";
- return false;
- }
- system_resources_ =
- MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
- return (system_resources_ != nullptr);
-}
-
-bool JniLuaEnvironment::RetrieveUserManager() {
- if (context_ == nullptr) {
- return false;
- }
- if (usermanager_retrieved_) {
- return (usermanager_ != nullptr);
- }
- usermanager_retrieved_ = true;
- ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
- jobject usermanager_ref = jenv_->CallObjectMethod(
- context_, jni_cache_->context_get_system_service, service.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getSystemService.";
- return false;
- }
- usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
- return (usermanager_ != nullptr);
-}
-
-RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
- RemoteActionTemplate result;
- // Read intent template.
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("title_without_entity")) {
- result.title_without_entity = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("title_with_entity")) {
- result.title_with_entity = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("description")) {
- result.description = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("description_with_app_name")) {
- result.description_with_app_name = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("action")) {
- result.action = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("data")) {
- result.data = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("type")) {
- result.type = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("flags")) {
- result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
- } else if (key.Equals("package_name")) {
- result.package_name = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("request_code")) {
- result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
- } else if (key.Equals("category")) {
- ReadCategories(&result.category);
- } else if (key.Equals("extra")) {
- ReadExtras(&result.extra);
- } else {
- TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
- }
- lua_pop(state_, 1);
- }
- lua_pop(state_, 1);
- return result;
-}
-
-void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
- // Read category array.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected categories table, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_pop(state_, 1);
- return;
- }
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- category->push_back(ReadString(/*index=*/-1).ToString());
- lua_pop(state_, 1);
- }
-}
-
-void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected extras table, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_pop(state_, 1);
- return;
- }
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- // Each entry is a table specifying name and value.
- // The value is specified via a type specific field as Lua doesn't allow
- // to easily distinguish between different number types.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected a table for an extra, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_pop(state_, 1);
- return;
- }
- std::string name;
- Variant value;
-
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- if (key.Equals("name")) {
- name = ReadString(/*index=*/-1).ToString();
- } else if (key.Equals("int_value")) {
- value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
- } else if (key.Equals("long_value")) {
- value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
- } else if (key.Equals("float_value")) {
- value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
- } else if (key.Equals("bool_value")) {
- value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
- } else if (key.Equals("string_value")) {
- value = Variant(ReadString(/*index=*/-1).ToString());
- } else {
- TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
- }
- lua_pop(state_, 1);
- }
- if (!name.empty()) {
- (*extra)[name] = value;
- } else {
- TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
- }
- lua_pop(state_, 1);
- }
-}
-
-int JniLuaEnvironment::ReadRemoteActionTemplates(
- std::vector<RemoteActionTemplate>* result) {
- // Read result.
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
- lua_error(state_);
- return LUA_ERRRUN;
- }
-
- // Read remote action templates array.
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected intent table, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_pop(state_, 1);
- continue;
- }
- result->push_back(ReadRemoteActionTemplateResult());
- }
- lua_pop(state_, /*n=*/1);
- return LUA_OK;
-}
-
-bool JniLuaEnvironment::RunIntentGenerator(
- const std::string& generator_snippet,
- std::vector<RemoteActionTemplate>* remote_actions) {
- int status;
- status = luaL_loadbuffer(state_, generator_snippet.data(),
- generator_snippet.size(),
- /*name=*/nullptr);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
- return false;
- }
- status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
- if (status != LUA_OK) {
- TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
- return false;
- }
- if (RunProtected(
- [this, remote_actions] {
- return ReadRemoteActionTemplates(remote_actions);
- },
- /*num_args=*/1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not read results.";
- return false;
- }
- // Check that we correctly cleaned-up the state.
- const int stack_size = lua_gettop(state_);
- if (stack_size > 0) {
- TC3_LOG(ERROR) << "Unexpected stack size.";
- lua_settop(state_, 0);
- return false;
- }
- return true;
-}
-
-// Lua environment for classfication result intent generation.
-class AnnotatorJniEnvironment : public JniLuaEnvironment {
- public:
- AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
- const jobject context,
- const std::vector<Locale>& device_locales,
- const std::string& entity_text,
- const ClassificationResult& classification,
- const int64 reference_time_ms_utc,
- const reflection::Schema* entity_data_schema)
- : JniLuaEnvironment(resources, jni_cache, context, device_locales),
- entity_text_(entity_text),
- classification_(classification),
- reference_time_ms_utc_(reference_time_ms_utc),
- entity_data_schema_(entity_data_schema) {}
-
- protected:
- void SetupExternalHook() override {
- JniLuaEnvironment::SetupExternalHook();
- lua_pushinteger(state_, reference_time_ms_utc_);
- lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
-
- PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
- lua_setfield(state_, /*idx=*/-2, "entity");
- }
-
- const std::string& entity_text_;
- const ClassificationResult& classification_;
- const int64 reference_time_ms_utc_;
-
- // Reflection schema data.
- const reflection::Schema* const entity_data_schema_;
-};
-
-// Lua environment for actions intent generation.
-class ActionsJniLuaEnvironment : public JniLuaEnvironment {
- public:
- ActionsJniLuaEnvironment(
- const Resources& resources, const JniCache* jni_cache,
- const jobject context, const std::vector<Locale>& device_locales,
- const Conversation& conversation, const ActionSuggestion& action,
- const reflection::Schema* actions_entity_data_schema,
- const reflection::Schema* annotations_entity_data_schema)
- : JniLuaEnvironment(resources, jni_cache, context, device_locales),
- conversation_(conversation),
- action_(action),
- annotation_iterator_(annotations_entity_data_schema, this),
- conversation_iterator_(annotations_entity_data_schema, this),
- entity_data_schema_(actions_entity_data_schema) {}
-
- protected:
- void SetupExternalHook() override {
- JniLuaEnvironment::SetupExternalHook();
- conversation_iterator_.NewIterator("conversation", &conversation_.messages,
- state_);
- lua_setfield(state_, /*idx=*/-2, "conversation");
-
- PushAction(action_, entity_data_schema_, annotation_iterator_, this);
- lua_setfield(state_, /*idx=*/-2, "entity");
- }
-
- const Conversation& conversation_;
- const ActionSuggestion& action_;
- const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
- const ConversationIterator conversation_iterator_;
- const reflection::Schema* entity_data_schema_;
-};
-
-} // namespace
-
-std::unique_ptr<IntentGenerator> IntentGenerator::Create(
- const IntentFactoryModel* options, const ResourcePool* resources,
- const std::shared_ptr<JniCache>& jni_cache) {
- std::unique_ptr<IntentGenerator> intent_generator(
- new IntentGenerator(options, resources, jni_cache));
-
- if (options == nullptr || options->generator() == nullptr) {
- TC3_LOG(ERROR) << "No intent generator options.";
- return nullptr;
- }
-
- std::unique_ptr<ZlibDecompressor> zlib_decompressor =
- ZlibDecompressor::Instance();
- if (!zlib_decompressor) {
- TC3_LOG(ERROR) << "Cannot initialize decompressor.";
- return nullptr;
- }
-
- for (const IntentFactoryModel_::IntentGenerator* generator :
- *options->generator()) {
- std::string lua_template_generator;
- if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
- generator->lua_template_generator(),
- generator->compressed_lua_template_generator(),
- &lua_template_generator)) {
- TC3_LOG(ERROR) << "Could not decompress generator template.";
- return nullptr;
- }
-
- std::string lua_code = lua_template_generator;
- if (options->precompile_generators()) {
- if (!Compile(lua_template_generator, &lua_code)) {
- TC3_LOG(ERROR) << "Could not precompile generator template.";
- return nullptr;
- }
- }
-
- intent_generator->generators_[generator->type()->str()] = lua_code;
- }
-
- return intent_generator;
-}
-
-std::vector<Locale> IntentGenerator::ParseDeviceLocales(
- const jstring device_locales) const {
- if (device_locales == nullptr) {
- TC3_LOG(ERROR) << "No locales provided.";
- return {};
- }
- ScopedStringChars locales_str =
- GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
- if (locales_str == nullptr) {
- TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
- return {};
- }
- std::vector<Locale> locales;
- if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
- &locales)) {
- TC3_LOG(ERROR) << "Cannot parse locales.";
- return {};
- }
- return locales;
-}
-
-bool IntentGenerator::GenerateIntents(
- const jstring device_locales, const ClassificationResult& classification,
- const int64 reference_time_ms_utc, const std::string& text,
- const CodepointSpan selection_indices, const jobject context,
- const reflection::Schema* annotations_entity_data_schema,
- std::vector<RemoteActionTemplate>* remote_actions) const {
- if (options_ == nullptr) {
- return false;
- }
-
- // Retrieve generator for specified entity.
- auto it = generators_.find(classification.collection);
- if (it == generators_.end()) {
- return true;
- }
-
- const std::string entity_text =
- UTF8ToUnicodeText(text, /*do_copy=*/false)
- .UTF8Substring(selection_indices.first, selection_indices.second);
-
- std::unique_ptr<AnnotatorJniEnvironment> interpreter(
- new AnnotatorJniEnvironment(
- resources_, jni_cache_.get(), context,
- ParseDeviceLocales(device_locales), entity_text, classification,
- reference_time_ms_utc, annotations_entity_data_schema));
-
- if (!interpreter->Initialize()) {
- TC3_LOG(ERROR) << "Could not create Lua interpreter.";
- return false;
- }
-
- return interpreter->RunIntentGenerator(it->second, remote_actions);
-}
-
-bool IntentGenerator::GenerateIntents(
- const jstring device_locales, const ActionSuggestion& action,
- const Conversation& conversation, const jobject context,
- const reflection::Schema* annotations_entity_data_schema,
- const reflection::Schema* actions_entity_data_schema,
- std::vector<RemoteActionTemplate>* remote_actions) const {
- if (options_ == nullptr) {
- return false;
- }
-
- // Retrieve generator for specified action.
- auto it = generators_.find(action.type);
- if (it == generators_.end()) {
- return true;
- }
-
- std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
- new ActionsJniLuaEnvironment(
- resources_, jni_cache_.get(), context,
- ParseDeviceLocales(device_locales), conversation, action,
- actions_entity_data_schema, annotations_entity_data_schema));
-
- if (!interpreter->Initialize()) {
- TC3_LOG(ERROR) << "Could not create Lua interpreter.";
- return false;
- }
-
- return interpreter->RunIntentGenerator(it->second, remote_actions);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/intents/intent-generator.h b/utils/intents/intent-generator.h
deleted file mode 100644
index 9177adb..0000000
--- a/utils/intents/intent-generator.h
+++ /dev/null
@@ -1,113 +0,0 @@
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
-#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
-
-#include <jni.h>
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "actions/types.h"
-#include "annotator/types.h"
-#include "utils/i18n/locale.h"
-#include "utils/intents/intent-config_generated.h"
-#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/optional.h"
-#include "utils/resources.h"
-#include "utils/resources_generated.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// A template with parameters for an Android remote action.
-struct RemoteActionTemplate {
- // Title shown for the action (see: RemoteAction.getTitle).
- Optional<std::string> title_without_entity;
-
- // Title with entity for the action. It is not guaranteed that the client
- // will use this, so title should be always given and general enough.
- Optional<std::string> title_with_entity;
-
- // Description shown for the action (see: RemoteAction.getContentDescription).
- Optional<std::string> description;
-
- // Description shown for the action (see: RemoteAction.getContentDescription)
- // when app name is available. Caller is expected to replace the placeholder
- // by the name of the app that is going to handle the action.
- Optional<std::string> description_with_app_name;
-
- // The action to set on the Intent (see: Intent.setAction).
- Optional<std::string> action;
-
- // The data to set on the Intent (see: Intent.setData).
- Optional<std::string> data;
-
- // The type to set on the Intent (see: Intent.setType).
- Optional<std::string> type;
-
- // Flags for launching the Intent (see: Intent.setFlags).
- Optional<int> flags;
-
- // Categories to set on the Intent (see: Intent.addCategory).
- std::vector<std::string> category;
-
- // Explicit application package to set on the Intent (see: Intent.setPackage).
- Optional<std::string> package_name;
-
- // The list of all the extras to add to the Intent.
- std::map<std::string, Variant> extra;
-
- // Private request code ot use for the Intent.
- Optional<int> request_code;
-};
-
-// Helper class to generate Android intents for text classifier results.
-class IntentGenerator {
- public:
- static std::unique_ptr<IntentGenerator> Create(
- const IntentFactoryModel* options, const ResourcePool* resources,
- const std::shared_ptr<JniCache>& jni_cache);
-
- // Generates intents for a classification result.
- // Returns true, if the intent generator snippets could be successfully run,
- // returns false otherwise.
- bool GenerateIntents(const jstring device_locales,
- const ClassificationResult& classification,
- const int64 reference_time_ms_utc,
- const std::string& text,
- const CodepointSpan selection_indices,
- const jobject context,
- const reflection::Schema* annotations_entity_data_schema,
- std::vector<RemoteActionTemplate>* remote_actions) const;
-
- // Generates intents for an action suggestion.
- // Returns true, if the intent generator snippets could be successfully run,
- // returns false otherwise.
- bool GenerateIntents(const jstring device_locales,
- const ActionSuggestion& action,
- const Conversation& conversation, const jobject context,
- const reflection::Schema* annotations_entity_data_schema,
- const reflection::Schema* actions_entity_data_schema,
- std::vector<RemoteActionTemplate>* remote_actions) const;
-
- private:
- IntentGenerator(const IntentFactoryModel* options,
- const ResourcePool* resources,
- const std::shared_ptr<JniCache>& jni_cache)
- : options_(options),
- resources_(Resources(resources)),
- jni_cache_(jni_cache) {}
-
- std::vector<Locale> ParseDeviceLocales(const jstring device_locales) const;
-
- const IntentFactoryModel* options_;
- const Resources resources_;
- std::shared_ptr<JniCache> jni_cache_;
- std::map<std::string, std::string> generators_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
diff --git a/utils/intents/jni.cc b/utils/intents/jni.cc
deleted file mode 100644
index d6274b1..0000000
--- a/utils/intents/jni.cc
+++ /dev/null
@@ -1,227 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/intents/jni.h"
-#include <memory>
-#include "utils/intents/intent-generator.h"
-#include "utils/java/scoped_local_ref.h"
-
-namespace libtextclassifier3 {
-
-// The macros below are intended to reduce the boilerplate and avoid
-// easily introduced copy/paste errors.
-#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
-#define TC3_GET_CLASS(FIELD, NAME) \
- handler->FIELD = MakeGlobalRef(env->FindClass(NAME), env, jni_cache->jvm); \
- TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME;
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
- TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
-
-std::unique_ptr<RemoteActionTemplatesHandler>
-RemoteActionTemplatesHandler::Create(
- const std::shared_ptr<JniCache>& jni_cache) {
- JNIEnv* env = jni_cache->GetEnv();
- if (env == nullptr) {
- return nullptr;
- }
-
- std::unique_ptr<RemoteActionTemplatesHandler> handler(
- new RemoteActionTemplatesHandler(jni_cache));
-
- TC3_GET_CLASS(integer_class_, "java/lang/Integer");
- TC3_GET_METHOD(integer_class_, integer_init_, "<init>", "(I)V");
-
- TC3_GET_CLASS(remote_action_template_class_,
- TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR);
- TC3_GET_METHOD(
- remote_action_template_class_, remote_action_template_init_, "<init>",
- "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
- "Integer;[Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- TC3_NAMED_VARIANT_CLASS_NAME_STR ";Ljava/lang/Integer;)V");
-
- TC3_GET_CLASS(named_variant_class_,
- TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR);
-
- TC3_GET_METHOD(named_variant_class_, named_variant_from_int_, "<init>",
- "(Ljava/lang/String;I)V");
- TC3_GET_METHOD(named_variant_class_, named_variant_from_long_, "<init>",
- "(Ljava/lang/String;J)V");
- TC3_GET_METHOD(named_variant_class_, named_variant_from_float_, "<init>",
- "(Ljava/lang/String;F)V");
- TC3_GET_METHOD(named_variant_class_, named_variant_from_double_, "<init>",
- "(Ljava/lang/String;D)V");
- TC3_GET_METHOD(named_variant_class_, named_variant_from_bool_, "<init>",
- "(Ljava/lang/String;Z)V");
- TC3_GET_METHOD(named_variant_class_, named_variant_from_string_, "<init>",
- "(Ljava/lang/String;Ljava/lang/String;)V");
-
- return handler;
-}
-
-jstring RemoteActionTemplatesHandler::AsUTF8String(
- const Optional<std::string>& optional) const {
- if (!optional.has_value()) {
- return nullptr;
- }
- return jni_cache_->ConvertToJavaString(optional.value()).release();
-}
-
-jobject RemoteActionTemplatesHandler::AsInteger(
- const Optional<int>& optional) const {
- return (optional.has_value()
- ? jni_cache_->GetEnv()->NewObject(integer_class_.get(),
- integer_init_, optional.value())
- : nullptr);
-}
-
-jobjectArray RemoteActionTemplatesHandler::AsStringArray(
- const std::vector<std::string>& values) const {
- if (values.empty()) {
- return nullptr;
- }
- jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
- values.size(), jni_cache_->string_class.get(), nullptr);
- if (result == nullptr) {
- return nullptr;
- }
- for (int k = 0; k < values.size(); k++) {
- ScopedLocalRef<jstring> value_str =
- jni_cache_->ConvertToJavaString(values[k]);
- jni_cache_->GetEnv()->SetObjectArrayElement(result, k, value_str.get());
- }
- return result;
-}
-
-jobject RemoteActionTemplatesHandler::AsNamedVariant(
- const std::string& name_str, const Variant& value) const {
- ScopedLocalRef<jstring> name = jni_cache_->ConvertToJavaString(name_str);
- if (name == nullptr) {
- return nullptr;
- }
- switch (value.GetType()) {
- case Variant::TYPE_INT_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_int_,
- name.get(), value.IntValue());
- case Variant::TYPE_INT64_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_long_,
- name.get(), value.Int64Value());
- case Variant::TYPE_FLOAT_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_float_,
- name.get(), value.FloatValue());
- case Variant::TYPE_DOUBLE_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_double_,
- name.get(), value.DoubleValue());
- case Variant::TYPE_BOOL_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_bool_,
- name.get(), value.BoolValue());
- case Variant::TYPE_STRING_VALUE: {
- ScopedLocalRef<jstring> value_jstring =
- jni_cache_->ConvertToJavaString(value.StringValue());
- if (value_jstring == nullptr) {
- return nullptr;
- }
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_string_,
- name.get(), value_jstring.get());
- }
- default:
- return nullptr;
- }
-}
-
-jobjectArray RemoteActionTemplatesHandler::AsNamedVariantArray(
- const std::map<std::string, Variant>& values) const {
- if (values.empty()) {
- return nullptr;
- }
- jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
- values.size(), named_variant_class_.get(), nullptr);
- int element_index = 0;
- for (auto key_value_pair : values) {
- if (!key_value_pair.second.HasValue()) {
- element_index++;
- continue;
- }
- ScopedLocalRef<jobject> named_extra(
- AsNamedVariant(key_value_pair.first, key_value_pair.second),
- jni_cache_->GetEnv());
- if (named_extra == nullptr) {
- return nullptr;
- }
- jni_cache_->GetEnv()->SetObjectArrayElement(result, element_index,
- named_extra.get());
- element_index++;
- }
- return result;
-}
-
-jobjectArray RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
- const std::vector<RemoteActionTemplate>& remote_actions) const {
- const jobjectArray results = jni_cache_->GetEnv()->NewObjectArray(
- remote_actions.size(), remote_action_template_class_.get(), nullptr);
- if (results == nullptr) {
- return nullptr;
- }
- for (int i = 0; i < remote_actions.size(); i++) {
- const RemoteActionTemplate& remote_action = remote_actions[i];
- const jstring title_without_entity =
- AsUTF8String(remote_action.title_without_entity);
- const jstring title_with_entity =
- AsUTF8String(remote_action.title_with_entity);
- const jstring description = AsUTF8String(remote_action.description);
- const jstring description_with_app_name =
- AsUTF8String(remote_action.description_with_app_name);
- const jstring action = AsUTF8String(remote_action.action);
- const jstring data = AsUTF8String(remote_action.data);
- const jstring type = AsUTF8String(remote_action.type);
- const jobject flags = AsInteger(remote_action.flags);
- const jobjectArray category = AsStringArray(remote_action.category);
- const jstring package = AsUTF8String(remote_action.package_name);
- const jobjectArray extra = AsNamedVariantArray(remote_action.extra);
- const jobject request_code = AsInteger(remote_action.request_code);
- ScopedLocalRef<jobject> result(
- jni_cache_->GetEnv()->NewObject(
- remote_action_template_class_.get(), remote_action_template_init_,
- title_without_entity, title_with_entity, description,
- description_with_app_name, action, data, type, flags, category,
- package, extra, request_code),
- jni_cache_->GetEnv());
- if (result == nullptr) {
- return nullptr;
- }
- jni_cache_->GetEnv()->SetObjectArrayElement(results, i, result.get());
- }
- return results;
-}
-
-jobject RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
- const reflection::Schema* entity_data_schema,
- const std::string& serialized_entity_data) const {
- ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
- std::unique_ptr<ReflectiveFlatbuffer> buffer = entity_data_builder.NewRoot();
- buffer->MergeFromSerializedFlatbuffer(serialized_entity_data);
- std::map<std::string, Variant> entity_data_map = buffer->AsFlatMap();
- return AsNamedVariantArray(entity_data_map);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/intents/jni.h b/utils/intents/jni.h
deleted file mode 100644
index 37952a2..0000000
--- a/utils/intents/jni.h
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
-#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
-
-#include <jni.h>
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "utils/flatbuffers.h"
-#include "utils/intents/intent-generator.h"
-#include "utils/java/jni-base.h"
-#include "utils/java/jni-cache.h"
-#include "utils/optional.h"
-#include "utils/variant.h"
-
-#ifndef TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME
-#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME RemoteActionTemplate
-#endif
-
-#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR \
- TC3_ADD_QUOTES(TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME)
-
-#ifndef TC3_NAMED_VARIANT_CLASS_NAME
-#define TC3_NAMED_VARIANT_CLASS_NAME NamedVariant
-#endif
-
-#define TC3_NAMED_VARIANT_CLASS_NAME_STR \
- TC3_ADD_QUOTES(TC3_NAMED_VARIANT_CLASS_NAME)
-
-namespace libtextclassifier3 {
-
-// A helper class to create RemoteActionTemplate object from model results.
-class RemoteActionTemplatesHandler {
- public:
- static std::unique_ptr<RemoteActionTemplatesHandler> Create(
- const std::shared_ptr<JniCache>& jni_cache);
-
- jstring AsUTF8String(const Optional<std::string>& optional) const;
- jobject AsInteger(const Optional<int>& optional) const;
- jobjectArray AsStringArray(const std::vector<std::string>& values) const;
- jobject AsNamedVariant(const std::string& name, const Variant& value) const;
- jobjectArray AsNamedVariantArray(
- const std::map<std::string, Variant>& values) const;
-
- jobjectArray RemoteActionTemplatesToJObjectArray(
- const std::vector<RemoteActionTemplate>& remote_actions) const;
-
- jobject EntityDataAsNamedVariantArray(
- const reflection::Schema* entity_data_schema,
- const std::string& serialized_entity_data) const;
-
- private:
- explicit RemoteActionTemplatesHandler(
- const std::shared_ptr<JniCache>& jni_cache)
- : jni_cache_(jni_cache),
- integer_class_(nullptr, jni_cache->jvm),
- remote_action_template_class_(nullptr, jni_cache->jvm),
- named_variant_class_(nullptr, jni_cache->jvm) {}
-
- std::shared_ptr<JniCache> jni_cache_;
-
- // java.lang.Integer
- ScopedGlobalRef<jclass> integer_class_;
- jmethodID integer_init_ = nullptr;
-
- // RemoteActionTemplate
- ScopedGlobalRef<jclass> remote_action_template_class_;
- jmethodID remote_action_template_init_ = nullptr;
-
- // NamedVariant
- ScopedGlobalRef<jclass> named_variant_class_;
- jmethodID named_variant_from_int_ = nullptr;
- jmethodID named_variant_from_long_ = nullptr;
- jmethodID named_variant_from_float_ = nullptr;
- jmethodID named_variant_from_double_ = nullptr;
- jmethodID named_variant_from_bool_ = nullptr;
- jmethodID named_variant_from_string_ = nullptr;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
diff --git a/utils/intents/zlib-utils.cc b/utils/intents/zlib-utils.cc
deleted file mode 100644
index 9f29b46..0000000
--- a/utils/intents/zlib-utils.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/intents/zlib-utils.h"
-
-#include <memory>
-
-#include "utils/zlib/buffer_generated.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-
-bool CompressIntentModel(IntentFactoryModelT* intent_model) {
- std::unique_ptr<ZlibCompressor> intent_zlib_compressor =
- ZlibCompressor::Instance();
- for (auto& generator : intent_model->generator) {
- generator->compressed_lua_template_generator.reset(new CompressedBufferT);
- intent_zlib_compressor->Compress(
- std::string(reinterpret_cast<const char*>(
- generator->lua_template_generator.data()),
- generator->lua_template_generator.size()),
- generator->compressed_lua_template_generator.get());
- generator->lua_template_generator.clear();
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/intents/zlib-utils.h b/utils/intents/zlib-utils.h
deleted file mode 100644
index afefa3d..0000000
--- a/utils/intents/zlib-utils.h
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
-
-#include "utils/intents/intent-config_generated.h"
-
-namespace libtextclassifier3 {
-
-bool CompressIntentModel(IntentFactoryModelT* intent_model);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_ZLIB_UTILS_H_
diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc
deleted file mode 100644
index 4483b79..0000000
--- a/utils/java/jni-base.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/java/jni-base.h"
-
-#include <jni.h>
-#include <type_traits>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/java/string_utils.h"
-#include "utils/memory/mmap.h"
-
-using libtextclassifier3::JStringToUtf8String;
-using libtextclassifier3::ScopedLocalRef;
-
-namespace libtextclassifier3 {
-
-std::string ToStlString(JNIEnv* env, const jstring& str) {
- std::string result;
- JStringToUtf8String(env, str, &result);
- return result;
-}
-
-jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) {
- ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
- env);
- if (fd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find FileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jfieldID fd_class_descriptor =
- env->GetFieldID(fd_class.get(), "descriptor", "I");
- if (fd_class_descriptor == nullptr) {
- env->ExceptionClear();
- fd_class_descriptor = env->GetFieldID(fd_class.get(), "fd", "I");
- }
- if (fd_class_descriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find descriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- return env->GetIntField(fd, fd_class_descriptor);
-}
-
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
- ScopedLocalRef<jclass> afd_class(
- env->FindClass("android/content/res/AssetFileDescriptor"), env);
- if (afd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jmethodID afd_class_getFileDescriptor = env->GetMethodID(
- afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
- if (afd_class_getFileDescriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
- return GetFdFromFileDescriptor(env, bundle_jfd);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/java/jni-base.h b/utils/java/jni-base.h
deleted file mode 100644
index 23658a3..0000000
--- a/utils/java/jni-base.h
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
-
-#include <jni.h>
-#include <string>
-
-// When we use a macro as an argument for a macro, an additional level of
-// indirection is needed, if the macro argument is used with # or ##.
-#define TC3_ADD_QUOTES_HELPER(TOKEN) #TOKEN
-#define TC3_ADD_QUOTES(TOKEN) TC3_ADD_QUOTES_HELPER(TOKEN)
-
-#ifndef TC3_PACKAGE_NAME
-#define TC3_PACKAGE_NAME com_google_android_textclassifier
-#endif
-
-#ifndef TC3_PACKAGE_PATH
-#define TC3_PACKAGE_PATH \
- "com/google/android/textclassifier/"
-#endif
-
-#define TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name) \
- Java_##package_name##_##class_name##_##method_name
-
-#define TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, \
- method_name) \
- JNIEXPORT return_type JNICALL TC3_JNI_METHOD_NAME_INTERNAL( \
- package_name, class_name, method_name)
-
-// The indirection is needed to correctly expand the TC3_PACKAGE_NAME macro.
-// See the explanation near TC3_ADD_QUOTES macro.
-#define TC3_JNI_METHOD2(return_type, package_name, class_name, method_name) \
- TC3_JNI_METHOD_PRIMITIVE(return_type, package_name, class_name, method_name)
-
-#define TC3_JNI_METHOD(return_type, class_name, method_name) \
- TC3_JNI_METHOD2(return_type, TC3_PACKAGE_NAME, class_name, method_name)
-
-#define TC3_JNI_METHOD_NAME2(package_name, class_name, method_name) \
- TC3_JNI_METHOD_NAME_INTERNAL(package_name, class_name, method_name)
-
-#define TC3_JNI_METHOD_NAME(class_name, method_name) \
- TC3_JNI_METHOD_NAME2(TC3_PACKAGE_NAME, class_name, method_name)
-
-namespace libtextclassifier3 {
-
-template <typename T, typename F>
-std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
- jclass class_object, F function,
- const std::string& method_name,
- const std::string& return_java_type) {
- const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
- ("()" + return_java_type).c_str());
- if (!method) {
- return std::make_pair(false, T());
- }
- return std::make_pair(true, (env->*function)(object, method));
-}
-
-std::string ToStlString(JNIEnv* env, const jstring& str);
-
-// Get system-level file descriptor from AssetFileDescriptor.
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd);
-
-// Get system-level file descriptor from FileDescriptor.
-jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd);
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/utils/java/jni-cache.cc b/utils/java/jni-cache.cc
deleted file mode 100644
index 8c2f00a..0000000
--- a/utils/java/jni-cache.cc
+++ /dev/null
@@ -1,304 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/java/jni-cache.h"
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-JniCache::JniCache(JavaVM* jvm)
- : jvm(jvm),
- string_class(nullptr, jvm),
- string_utf8(nullptr, jvm),
- pattern_class(nullptr, jvm),
- matcher_class(nullptr, jvm),
- locale_class(nullptr, jvm),
- locale_us(nullptr, jvm),
- breakiterator_class(nullptr, jvm),
- integer_class(nullptr, jvm),
- calendar_class(nullptr, jvm),
- timezone_class(nullptr, jvm),
- urlencoder_class(nullptr, jvm)
-#ifdef __ANDROID__
- ,
- context_class(nullptr, jvm),
- uri_class(nullptr, jvm),
- usermanager_class(nullptr, jvm),
- bundle_class(nullptr, jvm),
- resources_class(nullptr, jvm)
-#endif
-{
-}
-
-// The macros below are intended to reduce the boilerplate in Create and avoid
-// easily introduced copy/paste errors.
-#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
-#define TC3_CHECK_JNI_RESULT(RESULT) TC3_CHECK(RESULT)
-
-#define TC3_GET_CLASS(FIELD, NAME) \
- result->FIELD##_class = MakeGlobalRef(env->FindClass(NAME), env, jvm); \
- TC3_CHECK_JNI_PTR(result->FIELD##_class) << "Error finding class: " << NAME;
-
-#define TC3_GET_OPTIONAL_CLASS(FIELD, NAME) \
- { \
- jclass clazz = env->FindClass(NAME); \
- if (clazz != nullptr) { \
- result->FIELD##_class = MakeGlobalRef(clazz, env, jvm); \
- } \
- env->ExceptionClear(); \
- }
-
-#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
-
-#define TC3_GET_OPTIONAL_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
- }
-
-#define TC3_GET_OPTIONAL_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- if (result->CLASS##_class != nullptr) { \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- env->ExceptionClear(); \
- }
-
-#define TC3_GET_STATIC_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
- result->CLASS##_##FIELD = \
- env->GetStaticMethodID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding method: " << NAME;
-
-#define TC3_GET_STATIC_OBJECT_FIELD(CLASS, FIELD, NAME, SIGNATURE) \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, SIGNATURE); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- result->CLASS##_##FIELD = \
- MakeGlobalRef(env->GetStaticObjectField(result->CLASS##_class.get(), \
- CLASS##_##FIELD##_field), \
- env, jvm); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding field: " << NAME;
-
-#define TC3_GET_STATIC_INT_FIELD(CLASS, FIELD, NAME) \
- const jfieldID CLASS##_##FIELD##_field = \
- env->GetStaticFieldID(result->CLASS##_class.get(), NAME, "I"); \
- TC3_CHECK_JNI_RESULT(CLASS##_##FIELD##_field) \
- << "Error finding field id: " << NAME; \
- result->CLASS##_##FIELD = env->GetStaticIntField( \
- result->CLASS##_class.get(), CLASS##_##FIELD##_field); \
- TC3_CHECK_JNI_RESULT(result->CLASS##_##FIELD) \
- << "Error finding field: " << NAME;
-
-std::unique_ptr<JniCache> JniCache::Create(JNIEnv* env) {
- if (env == nullptr) {
- return nullptr;
- }
- JavaVM* jvm = nullptr;
- if (JNI_OK != env->GetJavaVM(&jvm) || jvm == nullptr) {
- return nullptr;
- }
- std::unique_ptr<JniCache> result(new JniCache(jvm));
-
- // String
- TC3_GET_CLASS(string, "java/lang/String");
- TC3_GET_METHOD(string, init_bytes_charset, "<init>",
- "([BLjava/lang/String;)V");
- TC3_GET_METHOD(string, code_point_count, "codePointCount", "(II)I");
- TC3_GET_METHOD(string, length, "length", "()I");
- result->string_utf8 = MakeGlobalRef(env->NewStringUTF("UTF-8"), env, jvm);
- TC3_CHECK_JNI_PTR(result->string_utf8);
-
- // Pattern
- TC3_GET_CLASS(pattern, "java/util/regex/Pattern");
- TC3_GET_STATIC_METHOD(pattern, compile, "compile",
- "(Ljava/lang/String;)Ljava/util/regex/Pattern;");
- TC3_GET_METHOD(pattern, matcher, "matcher",
- "(Ljava/lang/CharSequence;)Ljava/util/regex/Matcher;");
-
- // Matcher
- TC3_GET_CLASS(matcher, "java/util/regex/Matcher");
- TC3_GET_METHOD(matcher, matches, "matches", "()Z");
- TC3_GET_METHOD(matcher, find, "find", "()Z");
- TC3_GET_METHOD(matcher, reset, "reset", "()Ljava/util/regex/Matcher;");
- TC3_GET_METHOD(matcher, start_idx, "start", "(I)I");
- TC3_GET_METHOD(matcher, end_idx, "end", "(I)I");
- TC3_GET_METHOD(matcher, group, "group", "()Ljava/lang/String;");
- TC3_GET_METHOD(matcher, group_idx, "group", "(I)Ljava/lang/String;");
-
- // Locale
- TC3_GET_CLASS(locale, "java/util/Locale");
- TC3_GET_STATIC_OBJECT_FIELD(locale, us, "US", "Ljava/util/Locale;");
- TC3_GET_METHOD(locale, init_string, "<init>", "(Ljava/lang/String;)V");
- TC3_GET_OPTIONAL_STATIC_METHOD(locale, for_language_tag, "forLanguageTag",
- "(Ljava/lang/String;)Ljava/util/Locale;");
-
- // BreakIterator
- TC3_GET_CLASS(breakiterator, "java/text/BreakIterator");
- TC3_GET_STATIC_METHOD(breakiterator, getwordinstance, "getWordInstance",
- "(Ljava/util/Locale;)Ljava/text/BreakIterator;");
- TC3_GET_METHOD(breakiterator, settext, "setText", "(Ljava/lang/String;)V");
- TC3_GET_METHOD(breakiterator, next, "next", "()I");
-
- // Integer
- TC3_GET_CLASS(integer, "java/lang/Integer");
- TC3_GET_STATIC_METHOD(integer, parse_int, "parseInt",
- "(Ljava/lang/String;)I");
-
- // Calendar.
- TC3_GET_CLASS(calendar, "java/util/Calendar");
- TC3_GET_STATIC_METHOD(
- calendar, get_instance, "getInstance",
- "(Ljava/util/TimeZone;Ljava/util/Locale;)Ljava/util/Calendar;");
- TC3_GET_METHOD(calendar, get_first_day_of_week, "getFirstDayOfWeek", "()I");
- TC3_GET_METHOD(calendar, get_time_in_millis, "getTimeInMillis", "()J");
- TC3_GET_METHOD(calendar, set_time_in_millis, "setTimeInMillis", "(J)V");
- TC3_GET_METHOD(calendar, add, "add", "(II)V");
- TC3_GET_METHOD(calendar, get, "get", "(I)I");
- TC3_GET_METHOD(calendar, set, "set", "(II)V");
- TC3_GET_STATIC_INT_FIELD(calendar, zone_offset, "ZONE_OFFSET");
- TC3_GET_STATIC_INT_FIELD(calendar, dst_offset, "DST_OFFSET");
- TC3_GET_STATIC_INT_FIELD(calendar, year, "YEAR");
- TC3_GET_STATIC_INT_FIELD(calendar, month, "MONTH");
- TC3_GET_STATIC_INT_FIELD(calendar, day_of_year, "DAY_OF_YEAR");
- TC3_GET_STATIC_INT_FIELD(calendar, day_of_month, "DAY_OF_MONTH");
- TC3_GET_STATIC_INT_FIELD(calendar, day_of_week, "DAY_OF_WEEK");
- TC3_GET_STATIC_INT_FIELD(calendar, hour_of_day, "HOUR_OF_DAY");
- TC3_GET_STATIC_INT_FIELD(calendar, minute, "MINUTE");
- TC3_GET_STATIC_INT_FIELD(calendar, second, "SECOND");
- TC3_GET_STATIC_INT_FIELD(calendar, millisecond, "MILLISECOND");
- TC3_GET_STATIC_INT_FIELD(calendar, sunday, "SUNDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, monday, "MONDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, tuesday, "TUESDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, wednesday, "WEDNESDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, thursday, "THURSDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, friday, "FRIDAY");
- TC3_GET_STATIC_INT_FIELD(calendar, saturday, "SATURDAY");
-
- // TimeZone.
- TC3_GET_CLASS(timezone, "java/util/TimeZone");
- TC3_GET_STATIC_METHOD(timezone, get_timezone, "getTimeZone",
- "(Ljava/lang/String;)Ljava/util/TimeZone;");
-
- // URLEncoder.
- TC3_GET_CLASS(urlencoder, "java/net/URLEncoder");
- TC3_GET_STATIC_METHOD(
- urlencoder, encode, "encode",
- "(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");
-
-#ifdef __ANDROID__
- // Context.
- TC3_GET_CLASS(context, "android/content/Context");
- TC3_GET_METHOD(context, get_package_name, "getPackageName",
- "()Ljava/lang/String;");
- TC3_GET_METHOD(context, get_system_service, "getSystemService",
- "(Ljava/lang/String;)Ljava/lang/Object;");
-
- // Uri.
- TC3_GET_CLASS(uri, "android/net/Uri");
- TC3_GET_STATIC_METHOD(uri, parse, "parse",
- "(Ljava/lang/String;)Landroid/net/Uri;");
- TC3_GET_METHOD(uri, get_scheme, "getScheme", "()Ljava/lang/String;");
- TC3_GET_METHOD(uri, get_host, "getHost", "()Ljava/lang/String;");
-
- // UserManager.
- TC3_GET_OPTIONAL_CLASS(usermanager, "android/os/UserManager");
- TC3_GET_OPTIONAL_METHOD(usermanager, get_user_restrictions,
- "getUserRestrictions", "()Landroid/os/Bundle;");
-
- // Bundle.
- TC3_GET_CLASS(bundle, "android/os/Bundle");
- TC3_GET_METHOD(bundle, get_boolean, "getBoolean", "(Ljava/lang/String;)Z");
-
- // String resources.
- TC3_GET_CLASS(resources, "android/content/res/Resources");
- TC3_GET_STATIC_METHOD(resources, get_system, "getSystem",
- "()Landroid/content/res/Resources;");
- TC3_GET_METHOD(resources, get_identifier, "getIdentifier",
- "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I");
- TC3_GET_METHOD(resources, get_string, "getString", "(I)Ljava/lang/String;");
-#endif
-
- return result;
-}
-
-#undef TC3_GET_STATIC_INT_FIELD
-#undef TC3_GET_STATIC_OBJECT_FIELD
-#undef TC3_GET_STATIC_METHOD
-#undef TC3_GET_METHOD
-#undef TC3_GET_CLASS
-#undef TC3_CHECK_JNI_PTR
-
-JNIEnv* JniCache::GetEnv() const {
- void* env;
- if (JNI_OK == jvm->GetEnv(&env, JNI_VERSION_1_4)) {
- return reinterpret_cast<JNIEnv*>(env);
- } else {
- TC3_LOG(ERROR) << "JavaICU UniLib used on unattached thread";
- return nullptr;
- }
-}
-
-bool JniCache::ExceptionCheckAndClear() const {
- JNIEnv* env = GetEnv();
- TC3_CHECK(env != nullptr);
- const bool result = env->ExceptionCheck();
- if (result) {
- env->ExceptionDescribe();
- env->ExceptionClear();
- }
- return result;
-}
-
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
- const char* utf8_text, const int utf8_text_size_bytes) const {
- // Create java byte array.
- JNIEnv* jenv = GetEnv();
- const ScopedLocalRef<jbyteArray> text_java_utf8(
- jenv->NewByteArray(utf8_text_size_bytes), jenv);
- if (!text_java_utf8) {
- return nullptr;
- }
-
- jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
- reinterpret_cast<const jbyte*>(utf8_text));
-
- // Create the string with a UTF-8 charset.
- return ScopedLocalRef<jstring>(
- reinterpret_cast<jstring>(
- jenv->NewObject(string_class.get(), string_init_bytes_charset,
- text_java_utf8.get(), string_utf8.get())),
- jenv);
-}
-
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
- StringPiece utf8_text) const {
- return ConvertToJavaString(utf8_text.data(), utf8_text.size());
-}
-
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
- const UnicodeText& text) const {
- return ConvertToJavaString(text.data(), text.size_bytes());
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/java/jni-cache.h b/utils/java/jni-cache.h
deleted file mode 100644
index 609ddb1..0000000
--- a/utils/java/jni-cache.h
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
-
-#include <jni.h>
-#include "utils/java/scoped_global_ref.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-// A helper class to cache class and method pointers for calls from JNI to Java.
-// (for implementations such as Java ICU that need to make calls from C++ to
-// Java)
-struct JniCache {
- static std::unique_ptr<JniCache> Create(JNIEnv* env);
-
- JNIEnv* GetEnv() const;
- bool ExceptionCheckAndClear() const;
-
- JavaVM* jvm = nullptr;
-
- // java.lang.String
- ScopedGlobalRef<jclass> string_class;
- jmethodID string_init_bytes_charset = nullptr;
- jmethodID string_code_point_count = nullptr;
- jmethodID string_length = nullptr;
- ScopedGlobalRef<jstring> string_utf8;
-
- // java.util.regex.Pattern
- ScopedGlobalRef<jclass> pattern_class;
- jmethodID pattern_compile = nullptr;
- jmethodID pattern_matcher = nullptr;
-
- // java.util.regex.Matcher
- ScopedGlobalRef<jclass> matcher_class;
- jmethodID matcher_matches = nullptr;
- jmethodID matcher_find = nullptr;
- jmethodID matcher_reset = nullptr;
- jmethodID matcher_start_idx = nullptr;
- jmethodID matcher_end_idx = nullptr;
- jmethodID matcher_group = nullptr;
- jmethodID matcher_group_idx = nullptr;
-
- // java.util.Locale
- ScopedGlobalRef<jclass> locale_class;
- ScopedGlobalRef<jobject> locale_us;
- jmethodID locale_init_string = nullptr;
- jmethodID locale_for_language_tag = nullptr;
-
- // java.text.BreakIterator
- ScopedGlobalRef<jclass> breakiterator_class;
- jmethodID breakiterator_getwordinstance = nullptr;
- jmethodID breakiterator_settext = nullptr;
- jmethodID breakiterator_next = nullptr;
-
- // java.lang.Integer
- ScopedGlobalRef<jclass> integer_class;
- jmethodID integer_parse_int = nullptr;
-
- // java.util.Calendar
- ScopedGlobalRef<jclass> calendar_class;
- jmethodID calendar_get_instance = nullptr;
- jmethodID calendar_get_first_day_of_week = nullptr;
- jmethodID calendar_get_time_in_millis = nullptr;
- jmethodID calendar_set_time_in_millis = nullptr;
- jmethodID calendar_add = nullptr;
- jmethodID calendar_get = nullptr;
- jmethodID calendar_set = nullptr;
- jint calendar_zone_offset;
- jint calendar_dst_offset;
- jint calendar_year;
- jint calendar_month;
- jint calendar_day_of_year;
- jint calendar_day_of_month;
- jint calendar_day_of_week;
- jint calendar_hour_of_day;
- jint calendar_minute;
- jint calendar_second;
- jint calendar_millisecond;
- jint calendar_sunday;
- jint calendar_monday;
- jint calendar_tuesday;
- jint calendar_wednesday;
- jint calendar_thursday;
- jint calendar_friday;
- jint calendar_saturday;
-
- // java.util.TimeZone
- ScopedGlobalRef<jclass> timezone_class;
- jmethodID timezone_get_timezone = nullptr;
-
- // java.net.URLEncoder
- ScopedGlobalRef<jclass> urlencoder_class;
- jmethodID urlencoder_encode = nullptr;
-
- // android.content.Context
- ScopedGlobalRef<jclass> context_class;
- jmethodID context_get_package_name = nullptr;
- jmethodID context_get_system_service = nullptr;
-
- // android.net.Uri
- ScopedGlobalRef<jclass> uri_class;
- jmethodID uri_parse = nullptr;
- jmethodID uri_get_scheme = nullptr;
- jmethodID uri_get_host = nullptr;
-
- // android.os.UserManager
- ScopedGlobalRef<jclass> usermanager_class;
- jmethodID usermanager_get_user_restrictions = nullptr;
-
- // android.os.Bundle
- ScopedGlobalRef<jclass> bundle_class;
- jmethodID bundle_get_boolean = nullptr;
-
- // android.content.res.Resources
- ScopedGlobalRef<jclass> resources_class;
- jmethodID resources_get_system = nullptr;
- jmethodID resources_get_identifier = nullptr;
- jmethodID resources_get_string = nullptr;
-
- // Helper to convert lib3 UnicodeText to Java strings.
- ScopedLocalRef<jstring> ConvertToJavaString(
- const char* utf8_text, const int utf8_text_size_bytes) const;
- ScopedLocalRef<jstring> ConvertToJavaString(StringPiece utf8_text) const;
- ScopedLocalRef<jstring> ConvertToJavaString(const UnicodeText& text) const;
-
- private:
- explicit JniCache(JavaVM* jvm);
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
diff --git a/utils/java/scoped_global_ref.h b/utils/java/scoped_global_ref.h
deleted file mode 100644
index de0608e..0000000
--- a/utils/java/scoped_global_ref.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
-
-#include <jni.h>
-#include <memory>
-#include <type_traits>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-// A deleter to be used with std::unique_ptr to delete JNI global references.
-class GlobalRefDeleter {
- public:
- GlobalRefDeleter() : jvm_(nullptr) {}
-
- // Style guide violating implicit constructor so that the GlobalRefDeleter
- // is implicitly constructed from the second argument to ScopedGlobalRef.
- GlobalRefDeleter(JavaVM* jvm) : jvm_(jvm) {} // NOLINT(runtime/explicit)
-
- GlobalRefDeleter(const GlobalRefDeleter& orig) = default;
-
- // Copy assignment to allow move semantics in ScopedGlobalRef.
- GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) {
- TC3_CHECK_EQ(jvm_, rhs.jvm_);
- return *this;
- }
-
- // The delete operator.
- void operator()(jobject object) const {
- JNIEnv* env;
- if (object != nullptr && jvm_ != nullptr &&
- JNI_OK ==
- jvm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_4)) {
- env->DeleteGlobalRef(object);
- }
- }
-
- private:
- // The jvm_ stashed to use for deletion.
- JavaVM* const jvm_;
-};
-
-// A smart pointer that deletes a JNI global reference when it goes out
-// of scope. Usage is:
-// ScopedGlobalRef<jobject> scoped_global(env->JniFunction(), jvm);
-template <typename T>
-using ScopedGlobalRef =
- std::unique_ptr<typename std::remove_pointer<T>::type, GlobalRefDeleter>;
-
-// A helper to create global references. Assumes the object has a local
-// reference, which it deletes.
-template <typename T>
-ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) {
- const jobject global_object = env->NewGlobalRef(object);
- env->DeleteLocalRef(object);
- return ScopedGlobalRef<T>(reinterpret_cast<T>(global_object), jvm);
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
diff --git a/utils/java/scoped_local_ref.h b/utils/java/scoped_local_ref.h
deleted file mode 100644
index f439c45..0000000
--- a/utils/java/scoped_local_ref.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
-
-#include <jni.h>
-#include <memory>
-#include <type_traits>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-// A deleter to be used with std::unique_ptr to delete JNI local references.
-class LocalRefDeleter {
- public:
- LocalRefDeleter() : env_(nullptr) {}
-
- // Style guide violating implicit constructor so that the LocalRefDeleter
- // is implicitly constructed from the second argument to ScopedLocalRef.
- LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit)
-
- LocalRefDeleter(const LocalRefDeleter& orig) = default;
-
- // Copy assignment to allow move semantics in ScopedLocalRef.
- LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
- // As the deleter and its state are thread-local, ensure the envs
- // are consistent but do nothing.
- TC3_CHECK_EQ(env_, rhs.env_);
- return *this;
- }
-
- // The delete operator.
- void operator()(jobject object) const {
- if (env_) {
- env_->DeleteLocalRef(object);
- }
- }
-
- private:
- // The env_ stashed to use for deletion. Thread-local, don't share!
- JNIEnv* const env_;
-};
-
-// A smart pointer that deletes a JNI local reference when it goes out
-// of scope. Usage is:
-// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
-//
-// Note that this class is not thread-safe since it caches JNIEnv in
-// the deleter. Do not use the same jobject across different threads.
-template <typename T>
-using ScopedLocalRef =
- std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>;
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/utils/java/string_utils.cc b/utils/java/string_utils.cc
deleted file mode 100644
index 457a667..0000000
--- a/utils/java/string_utils.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/java/string_utils.h"
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
- std::string* result) {
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- if (array_bytes == nullptr) {
- return false;
- }
-
- const int array_length = env->GetArrayLength(array);
- *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
-
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
-
- return true;
-}
-
-bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
- std::string* result) {
- if (jstr == nullptr) {
- *result = std::string();
- return false;
- }
-
- jclass string_class = env->FindClass("java/lang/String");
- if (!string_class) {
- TC3_LOG(ERROR) << "Can't find String class";
- return false;
- }
-
- jmethodID get_bytes_id =
- env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
-
- jstring encoding = env->NewStringUTF("UTF-8");
-
- jbyteArray array = reinterpret_cast<jbyteArray>(
- env->CallObjectMethod(jstr, get_bytes_id, encoding));
-
- JByteArrayToString(env, array, result);
-
- // Release the array.
- env->DeleteLocalRef(array);
- env->DeleteLocalRef(string_class);
- env->DeleteLocalRef(encoding);
-
- return true;
-}
-
-ScopedStringChars GetScopedStringChars(JNIEnv* env, jstring string,
- jboolean* is_copy) {
- return ScopedStringChars(env->GetStringUTFChars(string, is_copy),
- StringCharsReleaser(env, string));
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/lua-utils.cc b/utils/lua-utils.cc
deleted file mode 100644
index 64071ca..0000000
--- a/utils/lua-utils.cc
+++ /dev/null
@@ -1,303 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/lua-utils.h"
-
-// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
-#ifndef TC3_AOSP
-#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
-#endif
-
-namespace libtextclassifier3 {
-namespace {
-// Upvalue indices for the flatbuffer callback.
-static constexpr int kSchemaArgId = 1;
-static constexpr int kTypeArgId = 2;
-static constexpr int kTableArgId = 3;
-
-static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
- {LUA_TABLIBNAME, luaopen_table},
- {LUA_STRLIBNAME, luaopen_string},
- {LUA_BITLIBNAME, luaopen_bit32},
- {LUA_MATHLIBNAME, luaopen_math},
- {nullptr, nullptr}};
-
-// Implementation of a lua_Writer that appends the data to a string.
-int LuaStringWriter(lua_State *state, const void *data, size_t size,
- void *result) {
- std::string *const result_string = static_cast<std::string *>(result);
- result_string->insert(result_string->size(), static_cast<const char *>(data),
- size);
- return LUA_OK;
-}
-
-} // namespace
-
-LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
-
-LuaEnvironment::~LuaEnvironment() {
- if (state_ != nullptr) {
- lua_close(state_);
- }
-}
-
-int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
- return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
-}
-
-int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
- return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
-}
-
-int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
- return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
-}
-
-int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
- return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
-}
-
-void LuaEnvironment::PushFlatbuffer(const char *name,
- const reflection::Schema *schema,
- const reflection::Object *type,
- const flatbuffers::Table *table,
- lua_State *state) {
- lua_newtable(state);
- luaL_newmetatable(state, name);
- lua_pushlightuserdata(state, AsUserData(schema));
- lua_pushlightuserdata(state, AsUserData(type));
- lua_pushlightuserdata(state, AsUserData(table));
- lua_pushcclosure(state, &GetFieldCallback, 3);
- lua_setfield(state, -2, kIndexKey);
- lua_setmetatable(state, -2);
-}
-
-int LuaEnvironment::GetFieldCallback(lua_State *state) {
- // Fetch the arguments.
- const reflection::Schema *schema =
- FromUpValue<reflection::Schema *>(kSchemaArgId, state);
- const reflection::Object *type =
- FromUpValue<reflection::Object *>(kTypeArgId, state);
- const flatbuffers::Table *table =
- FromUpValue<flatbuffers::Table *>(kTableArgId, state);
- return GetField(schema, type, table, state);
-}
-
-int LuaEnvironment::GetField(const reflection::Schema *schema,
- const reflection::Object *type,
- const flatbuffers::Table *table,
- lua_State *state) {
- const char *field_name = lua_tostring(state, -1);
- const reflection::Field *field = type->fields()->LookupByKey(field_name);
- if (field == nullptr) {
- lua_error(state);
- return 0;
- }
- // Provide primitive fields directly.
- const reflection::BaseType field_type = field->type()->base_type();
- switch (field_type) {
- case reflection::Bool:
- lua_pushboolean(state, table->GetField<uint8_t>(
- field->offset(), field->default_integer()));
- break;
- case reflection::Int:
- lua_pushinteger(state, table->GetField<int32>(field->offset(),
- field->default_integer()));
- break;
- case reflection::Long:
- lua_pushinteger(state, table->GetField<int64>(field->offset(),
- field->default_integer()));
- break;
- case reflection::Float:
- lua_pushnumber(state, table->GetField<float>(field->offset(),
- field->default_real()));
- break;
- case reflection::Double:
- lua_pushnumber(state, table->GetField<double>(field->offset(),
- field->default_real()));
- break;
- case reflection::String: {
- const flatbuffers::String *string_value =
- table->GetPointer<const flatbuffers::String *>(field->offset());
- if (string_value != nullptr) {
- lua_pushlstring(state, string_value->data(), string_value->Length());
- } else {
- lua_pushlstring(state, "", 0);
- }
- break;
- }
- case reflection::Obj: {
- const flatbuffers::Table *field_table =
- table->GetPointer<const flatbuffers::Table *>(field->offset());
- if (field_table == nullptr) {
- TC3_LOG(ERROR) << "Field was not set in entity data.";
- lua_error(state);
- return 0;
- }
- const reflection::Object *field_type =
- schema->objects()->Get(field->type()->index());
- PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
- state);
- break;
- }
- default:
- TC3_LOG(ERROR) << "Unsupported type: " << field_type;
- lua_error(state);
- return 0;
- }
- return 1;
-}
-
-int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
- if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
- TC3_LOG(ERROR) << "Expected actions table, got: "
- << lua_type(state_, /*idx=*/-1);
- lua_error(state_);
- return LUA_ERRRUN;
- }
-
- lua_pushnil(state_);
- while (lua_next(state_, /*idx=*/-2)) {
- const StringPiece key = ReadString(/*index=*/-2);
- const reflection::Field *field = buffer->GetFieldOrNull(key);
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
- lua_error(state_);
- return LUA_ERRRUN;
- }
- switch (field->type()->base_type()) {
- case reflection::Obj:
- return ReadFlatbuffer(buffer->Mutable(field));
- case reflection::Bool:
- buffer->Set(field,
- static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
- break;
- case reflection::Int:
- buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
- break;
- case reflection::Long:
- buffer->Set(field,
- static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
- break;
- case reflection::Float:
- buffer->Set(field,
- static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
- break;
- case reflection::Double:
- buffer->Set(field,
- static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
- break;
- case reflection::String: {
- buffer->Set(field, ReadString(/*index=*/-1));
- break;
- }
- default:
- TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
- lua_error(state_);
- return LUA_ERRRUN;
- }
- lua_pop(state_, 1);
- }
- // lua_pop(state_, /*n=*/1);
- return LUA_OK;
-}
-
-void LuaEnvironment::LoadDefaultLibraries() {
- for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
- luaL_requiref(state_, lib->name, lib->func, 1);
- lua_pop(state_, 1); /* remove lib */
- }
-}
-
-void LuaEnvironment::PushValue(const Variant &value) {
- if (value.HasInt()) {
- lua_pushnumber(state_, value.IntValue());
- } else if (value.HasInt64()) {
- lua_pushnumber(state_, value.Int64Value());
- } else if (value.HasBool()) {
- lua_pushboolean(state_, value.BoolValue());
- } else if (value.HasFloat()) {
- lua_pushnumber(state_, value.FloatValue());
- } else if (value.HasDouble()) {
- lua_pushnumber(state_, value.DoubleValue());
- } else if (value.HasString()) {
- lua_pushlstring(state_, value.StringValue().data(),
- value.StringValue().size());
- } else {
- TC3_LOG(FATAL) << "Unknown value type.";
- }
-}
-
-StringPiece LuaEnvironment::ReadString(const int index) const {
- size_t length = 0;
- const char *data = lua_tolstring(state_, index, &length);
- return StringPiece(data, length);
-}
-
-void LuaEnvironment::PushString(const StringPiece str) {
- lua_pushlstring(state_, str.data(), str.size());
-}
-
-void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
- const flatbuffers::Table *table) {
- PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
- schema->root_table(), table, state_);
-}
-
-int LuaEnvironment::RunProtected(const std::function<int()> &func,
- const int num_args, const int num_results) {
- struct ProtectedCall {
- std::function<int()> func;
-
- static int run(lua_State *state) {
- // Read the pointer to the ProtectedCall struct.
- ProtectedCall *p = static_cast<ProtectedCall *>(
- lua_touserdata(state, lua_upvalueindex(1)));
- return p->func();
- }
- };
- ProtectedCall protected_call = {func};
- lua_pushlightuserdata(state_, &protected_call);
- lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
- // Put the closure before the arguments on the stack.
- if (num_args > 0) {
- lua_insert(state_, -(1 + num_args));
- }
- return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
-}
-
-bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
- if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
- /*name=*/nullptr) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not compile lua snippet: "
- << ReadString(/*index=*/-1).ToString();
- lua_pop(state_, 1);
- return false;
- }
- if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
- lua_pop(state_, 1);
- return false;
- }
- lua_pop(state_, 1);
- return true;
-}
-
-bool Compile(StringPiece snippet, std::string *bytecode) {
- return LuaEnvironment().Compile(snippet, bytecode);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/lua-utils.h b/utils/lua-utils.h
deleted file mode 100644
index d825cb9..0000000
--- a/utils/lua-utils.h
+++ /dev/null
@@ -1,264 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
-
-#include <functional>
-#include <vector>
-
-#include "utils/flatbuffers.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/variant.h"
-#include "flatbuffers/reflection_generated.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lua.h"
-#include "lualib.h"
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-
-static constexpr const char *kLengthKey = "__len";
-static constexpr const char *kPairsKey = "__pairs";
-static constexpr const char *kIndexKey = "__index";
-
-// Casts to the lua user data type.
-template <typename T>
-void *AsUserData(const T *value) {
- return static_cast<void *>(const_cast<T *>(value));
-}
-template <typename T>
-void *AsUserData(const T value) {
- return reinterpret_cast<void *>(value);
-}
-
-// Retrieves up-values.
-template <typename T>
-T FromUpValue(const int index, lua_State *state) {
- return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
-}
-
-class LuaEnvironment {
- public:
- // Wrapper for handling an iterator.
- class Iterator {
- public:
- virtual ~Iterator() {}
- static int NextCallback(lua_State *state);
- static int LengthCallback(lua_State *state);
- static int ItemCallback(lua_State *state);
- static int IteritemsCallback(lua_State *state);
-
- // Called when the next element of an iterator is fetched.
- virtual int Next(lua_State *state) const = 0;
-
- // Called when the length of the iterator is queried.
- virtual int Length(lua_State *state) const = 0;
-
- // Called when an item is queried.
- virtual int Item(lua_State *state) const = 0;
-
- // Called when a new iterator is started.
- virtual int Iteritems(lua_State *state) const = 0;
-
- protected:
- static constexpr int kIteratorArgId = 1;
- };
-
- template <typename T>
- class ItemIterator : public Iterator {
- public:
- void NewIterator(StringPiece name, const T *items, lua_State *state) const {
- lua_newtable(state);
- luaL_newmetatable(state, name.data());
- lua_pushlightuserdata(state, AsUserData(this));
- lua_pushlightuserdata(state, AsUserData(items));
- lua_pushcclosure(state, &Iterator::ItemCallback, 2);
- lua_setfield(state, -2, kIndexKey);
- lua_pushlightuserdata(state, AsUserData(this));
- lua_pushlightuserdata(state, AsUserData(items));
- lua_pushcclosure(state, &Iterator::LengthCallback, 2);
- lua_setfield(state, -2, kLengthKey);
- lua_pushlightuserdata(state, AsUserData(this));
- lua_pushlightuserdata(state, AsUserData(items));
- lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
- lua_setfield(state, -2, kPairsKey);
- lua_setmetatable(state, -2);
- }
-
- int Iteritems(lua_State *state) const override {
- lua_pushlightuserdata(state, AsUserData(this));
- lua_pushlightuserdata(
- state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
- lua_pushnumber(state, 0);
- lua_pushcclosure(state, &Iterator::NextCallback, 3);
- return /*num results=*/1;
- }
-
- int Length(lua_State *state) const override {
- lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
- return /*num results=*/1;
- }
-
- int Next(lua_State *state) const override {
- return Next(FromUpValue<T *>(kItemsArgId, state),
- lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
- state);
- }
-
- int Next(const T *items, const int64 pos, lua_State *state) const {
- if (pos >= items->size()) {
- return 0;
- }
-
- // Update iterator value.
- lua_pushnumber(state, pos + 1);
- lua_replace(state, lua_upvalueindex(3));
-
- // Push key.
- lua_pushinteger(state, pos + 1);
-
- // Push item.
- return 1 + Item(items, pos, state);
- }
-
- int Item(lua_State *state) const override {
- const T *items = FromUpValue<T *>(kItemsArgId, state);
- switch (lua_type(state, -1)) {
- case LUA_TNUMBER: {
- // Lua is one based, so adjust the index here.
- const int64 index =
- static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
- if (index < 0 || index >= items->size()) {
- TC3_LOG(ERROR) << "Invalid index: " << index;
- lua_error(state);
- return 0;
- }
- return Item(items, index, state);
- }
- case LUA_TSTRING: {
- size_t key_length = 0;
- const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
- return Item(items, StringPiece(key, key_length), state);
- }
- default:
- TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
- lua_error(state);
- return 0;
- }
- }
-
- virtual int Item(const T *items, const int64 pos,
- lua_State *state) const = 0;
-
- virtual int Item(const T *items, StringPiece key, lua_State *state) const {
- TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
- lua_error(state);
- return 0;
- }
-
- protected:
- static constexpr int kItemsArgId = 2;
- static constexpr int kIterValueArgId = 3;
- };
-
- virtual ~LuaEnvironment();
- LuaEnvironment();
-
- // Compile a lua snippet into binary bytecode.
- // NOTE: The compiled bytecode might not be compatible across Lua versions
- // and platforms.
- bool Compile(StringPiece snippet, std::string *bytecode);
-
- typedef int (*CallbackHandler)(lua_State *);
-
- // Loads default libraries.
- void LoadDefaultLibraries();
-
- // Provides a callback to Lua.
- template <typename T, int (T::*handler)()>
- void Bind() {
- lua_pushlightuserdata(state_, static_cast<void *>(this));
- lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
- }
-
- // Setup a named table that callsback whenever a member is accessed.
- // This allows to lazily provide required information to the script.
- template <typename T, int (T::*handler)()>
- void BindTable(const char *name) {
- lua_newtable(state_);
- luaL_newmetatable(state_, name);
- lua_pushlightuserdata(state_, static_cast<void *>(this));
- lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
- lua_setfield(state_, -2, kIndexKey);
- lua_setmetatable(state_, -2);
- }
-
- void PushValue(const Variant &value);
-
- // Reads a string from the stack.
- StringPiece ReadString(const int index) const;
-
- // Pushes a string to the stack.
- void PushString(const StringPiece str);
-
- // Pushes a flatbuffer to the stack.
- void PushFlatbuffer(const reflection::Schema *schema,
- const flatbuffers::Table *table);
-
- // Reads a flatbuffer from the stack.
- int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
-
- // Runs a closure in protected mode.
- // `func`: closure to run in protected mode.
- // `num_lua_args`: number of arguments from the lua stack to process.
- // `num_results`: number of result values pushed on the stack.
- int RunProtected(const std::function<int()> &func, const int num_args = 0,
- const int num_results = 0);
-
- lua_State *state() const { return state_; }
-
- protected:
- lua_State *state_;
-
- private:
- // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
- static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
- const reflection::Object *type,
- const flatbuffers::Table *table, lua_State *state);
- static int GetFieldCallback(lua_State *state);
- static int GetField(const reflection::Schema *schema,
- const reflection::Object *type,
- const flatbuffers::Table *table, lua_State *state);
-
- template <typename T, int (T::*handler)()>
- static int Dispatch(lua_State *state) {
- T *env = FromUpValue<T *>(1, state);
- return ((*env).*handler)();
- }
-};
-
-bool Compile(StringPiece snippet, std::string *bytecode);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
diff --git a/utils/math/fastexp.h b/utils/math/fastexp.h
deleted file mode 100644
index f690c73..0000000
--- a/utils/math/fastexp.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Fast approximation for exp.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
-#define LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
-
-#include <cassert>
-#include <cmath>
-#include <limits>
-
-#include "utils/base/casts.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-class FastMathClass {
- private:
- static const int kBits = 7;
- static const int kMask1 = (1 << kBits) - 1;
- static const int kMask2 = 0xFF << kBits;
- static constexpr float kLogBase2OfE = 1.44269504088896340736f;
-
- struct Table {
- int32 exp1[1 << kBits];
- };
-
- public:
- float VeryFastExp2(float f) const {
- TC3_DCHECK_LE(fabs(f), 126);
- const float g = f + (127 + (1 << (23 - kBits)));
- const int32 x = bit_cast<int32>(g);
- int32 ret = ((x & kMask2) << (23 - kBits))
- | cache_.exp1[x & kMask1];
- return bit_cast<float>(ret);
- }
-
- float VeryFastExp(float f) const {
- return VeryFastExp2(f * kLogBase2OfE);
- }
-
- private:
- static const Table cache_;
-};
-
-extern FastMathClass FastMathInstance;
-
-inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_MATH_FASTEXP_H_
diff --git a/utils/memory/mmap.h b/utils/memory/mmap.h
deleted file mode 100644
index acce7db..0000000
--- a/utils/memory/mmap.h
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
-#define LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
-
-#include <stddef.h>
-
-#include <string>
-
-#include "utils/base/integral_types.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Handle for a memory area where a file has been mmapped.
-//
-// Similar to a pointer: you "allocate" it using MmapFile(filename) and "delete"
-// it using Unmap(). Just like a pointer, it is passed around by value (see
-// signature of MmapFile and Unmap; fortunately, it's a small class, so there
-// shouldn't be any significant performance penalty) and its usage is not
-// necessarily scoped (that's why the destructor is not performing the unmap).
-//
-// Note: on program termination, each still unmapped file is automatically
-// unmapped. Hence, it is not an error if you don't call Unmap() (provided you
-// are ok keeping that file in memory the whole time).
-class MmapHandle {
- public:
- MmapHandle(void *start, size_t num_bytes, void *unmap_addr = nullptr)
- : start_(start), num_bytes_(num_bytes), unmap_addr_(unmap_addr) {}
-
- // Returns start address for the memory area where a file has been mmapped.
- void *start() const { return start_; }
-
- // Returns address to use for munmap call. If unmap_addr was not specified
- // the start address is used.
- void *unmap_addr() const {
- if (unmap_addr_ != nullptr) {
- return unmap_addr_;
- } else {
- return start_;
- }
- }
-
- // Returns number of bytes of the memory area from start().
- size_t num_bytes() const { return num_bytes_; }
-
- // Shortcut to simplify checking success of MmapFile(). See usage example
- // from the doc of that function.
- bool ok() const { return start() != nullptr; }
-
- // Returns a StringPiece pointing to the same underlying bytes.
- StringPiece to_stringpiece() const {
- return StringPiece(reinterpret_cast<char *>(start_), num_bytes_);
- }
-
- private:
- // See doc for start(). Not owned.
- void *const start_;
-
- // See doc for num_bytes().
- const size_t num_bytes_;
-
- // Address to use for unmapping.
- void *const unmap_addr_;
-};
-
-// Maps the full content of a file in memory (using mmap).
-//
-// When done using the file content, one can unmap using Unmap(). Otherwise,
-// all mapped files are unmapped when the program terminates.
-//
-// Sample usage:
-//
-// MmapHandle mmap_handle = MmapFile(filename);
-// TC3_DCHECK(mmap_handle.ok()) << "Unable to mmap " << filename;
-//
-// ... use data from addresses
-// ... [mmap_handle.start, mmap_handle.start + mmap_handle.num_bytes)
-//
-// Unmap(mmap_handle); // Unmap logs errors internally.
-//
-// Note: one can read *and* write the num_bytes bytes from start, but those
-// writes are not propagated to the underlying file, nor to other processes that
-// may have mmapped that file (all changes are local to current process).
-MmapHandle MmapFile(const std::string &filename);
-
-// Like MmapFile(const std::string &filename), but uses a file descriptor.
-MmapHandle MmapFile(int fd);
-
-// Maps a segment of a file to memory. File is given by a file descriptor, and
-// offset (relative to the beginning of the file) and size specify the segment
-// to be mapped. NOTE: Internally, we align the offset for the call to mmap
-// system call to be a multiple of page size, so offset does NOT have to be a
-// multiply of the page size.
-MmapHandle MmapFile(int fd, int64 segment_offset, int64 segment_size);
-
-// Unmaps a file mapped using MmapFile. Returns true on success, false
-// otherwise.
-bool Unmap(MmapHandle mmap_handle);
-
-// Scoped mmapping of a file. Mmaps a file on construction, unmaps it on
-// destruction.
-class ScopedMmap {
- public:
- explicit ScopedMmap(const std::string &filename)
- : handle_(MmapFile(filename)) {}
-
- explicit ScopedMmap(int fd) : handle_(MmapFile(fd)) {}
-
- ScopedMmap(int fd, int segment_offset, int segment_size)
- : handle_(MmapFile(fd, segment_offset, segment_size)) {}
-
- ~ScopedMmap() {
- if (handle_.ok()) {
- Unmap(handle_);
- }
- }
-
- const MmapHandle &handle() { return handle_; }
-
- private:
- MmapHandle handle_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_MEMORY_MMAP_H_
diff --git a/utils/regex-match.cc b/utils/regex-match.cc
deleted file mode 100644
index 8c55e6b..0000000
--- a/utils/regex-match.cc
+++ /dev/null
@@ -1,180 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/regex-match.h"
-
-#include <memory>
-
-#include "annotator/types.h"
-#include "utils/lua-utils.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include "lauxlib.h"
-#include "lualib.h"
-#ifdef __cplusplus
-}
-#endif
-
-namespace libtextclassifier3 {
-namespace {
-
-// Provide a lua environment for running regex match post verification.
-// It sets up and exposes the match data as well as the context.
-class LuaVerifier : private LuaEnvironment {
- public:
- static std::unique_ptr<LuaVerifier> Create(
- const std::string& context, const std::string& verifier_code,
- const UniLib::RegexMatcher* matcher);
-
- bool Verify(bool* result);
-
- private:
- explicit LuaVerifier(const std::string& context,
- const std::string& verifier_code,
- const UniLib::RegexMatcher* matcher)
- : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
- bool Initialize();
-
- // Provides details of a capturing group to lua.
- int GetCapturingGroup();
-
- const std::string& context_;
- const std::string& verifier_code_;
- const UniLib::RegexMatcher* matcher_;
-};
-
-bool LuaVerifier::Initialize() {
- // Run protected to not lua panic in case of setup failure.
- return RunProtected([this] {
- LoadDefaultLibraries();
-
- // Expose context of the match as `context` global variable.
- PushString(context_);
- lua_setglobal(state_, "context");
-
- // Expose match array as `match` global variable.
- // Each entry `match[i]` exposes the ith capturing group as:
- // * `begin`: span start
- // * `end`: span end
- // * `text`: the text
- BindTable<LuaVerifier, &LuaVerifier::GetCapturingGroup>("match");
- lua_setglobal(state_, "match");
- return LUA_OK;
- }) == LUA_OK;
-}
-
-std::unique_ptr<LuaVerifier> LuaVerifier::Create(
- const std::string& context, const std::string& verifier_code,
- const UniLib::RegexMatcher* matcher) {
- auto verifier = std::unique_ptr<LuaVerifier>(
- new LuaVerifier(context, verifier_code, matcher));
- if (!verifier->Initialize()) {
- TC3_LOG(ERROR) << "Could not initialize lua environment.";
- return nullptr;
- }
- return verifier;
-}
-
-int LuaVerifier::GetCapturingGroup() {
- if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
- TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
- << lua_type(state_, /*idx=*/-1);
- lua_error(state_);
- return 0;
- }
- const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
- int status = UniLib::RegexMatcher::kNoError;
- const CodepointSpan span = {matcher_->Start(group_id, &status),
- matcher_->End(group_id, &status)};
- std::string text = matcher_->Group(group_id, &status).ToUTF8String();
- if (status != UniLib::RegexMatcher::kNoError) {
- TC3_LOG(ERROR) << "Could not extract span from capturing group.";
- lua_error(state_);
- return 0;
- }
- lua_newtable(state_);
- lua_pushinteger(state_, span.first);
- lua_setfield(state_, /*idx=*/-2, "begin");
- lua_pushinteger(state_, span.second);
- lua_setfield(state_, /*idx=*/-2, "end");
- PushString(text);
- lua_setfield(state_, /*idx=*/-2, "text");
- return 1;
-}
-
-bool LuaVerifier::Verify(bool* result) {
- if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
- /*name=*/nullptr) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not load verifier snippet.";
- return false;
- }
-
- if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not run verifier snippet.";
- return false;
- }
-
- if (RunProtected(
- [this, result] {
- if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
- TC3_LOG(ERROR) << "Unexpected verification result type: "
- << lua_type(state_, /*idx=*/-1);
- lua_error(state_);
- return LUA_ERRRUN;
- }
- *result = lua_toboolean(state_, /*idx=*/-1);
- return LUA_OK;
- },
- /*num_args=*/1) != LUA_OK) {
- TC3_LOG(ERROR) << "Could not read lua result.";
- return false;
- }
- return true;
-}
-
-} // namespace
-
-bool SetFieldFromCapturingGroup(const int group_id,
- const FlatbufferFieldPath* field_path,
- const UniLib::RegexMatcher* matcher,
- ReflectiveFlatbuffer* flatbuffer) {
- int status = UniLib::RegexMatcher::kNoError;
- std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
- if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
- return false;
- }
- return flatbuffer->ParseAndSet(field_path, group_text);
-}
-
-bool VerifyMatch(const std::string& context,
- const UniLib::RegexMatcher* matcher,
- const std::string& lua_verifier_code) {
- bool status = false;
- auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
- if (verifier == nullptr) {
- TC3_LOG(ERROR) << "Could not create verifier.";
- return false;
- }
- if (!verifier->Verify(&status)) {
- TC3_LOG(ERROR) << "Could not create verifier.";
- return false;
- }
- return status;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/regex-match.h b/utils/regex-match.h
deleted file mode 100644
index f77f6b1..0000000
--- a/utils/regex-match.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
-
-#include "utils/flatbuffers.h"
-#include "utils/flatbuffers_generated.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-// Sets a field in the flatbuffer from a regex match group.
-// Returns true if successful, and false if the field couldn't be set.
-bool SetFieldFromCapturingGroup(const int group_id,
- const FlatbufferFieldPath* field_path,
- const UniLib::RegexMatcher* matcher,
- ReflectiveFlatbuffer* flatbuffer);
-
-// Post-checks a regular expression match with a lua verifier script.
-// The verifier can access:
-// * `context`: The context as a string.
-// * `match`: The groups of the regex match as an array, each group gives
-// * `begin`: span start
-// * `end`: span end
-// * `text`: the text
-// The verifier is expected to return a boolean, indicating whether the
-// verification succeeded or not.
-// Returns true if the verification was successful, false if not.
-bool VerifyMatch(const std::string& context,
- const UniLib::RegexMatcher* matcher,
- const std::string& lua_verifier_code);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_REGEX_MATCH_H_
diff --git a/utils/regex-match_test.cc b/utils/regex-match_test.cc
deleted file mode 100644
index ef86d65..0000000
--- a/utils/regex-match_test.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/regex-match.h"
-
-#include <memory>
-
-#include "utils/utf8/unilib.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class LuaVerifierTest : public testing::Test {
- protected:
- LuaVerifierTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-#ifdef TC3_UNILIB_ICU
-TEST_F(LuaVerifierTest, HandlesSimpleVerification) {
- EXPECT_TRUE(VerifyMatch(/*context=*/"", /*matcher=*/nullptr, "return true;"));
-}
-
-TEST_F(LuaVerifierTest, HandlesCustomVerification) {
- UnicodeText pattern = UTF8ToUnicodeText("(\\d{16})",
- /*do_copy=*/true);
- UnicodeText message = UTF8ToUnicodeText("cc: 4012888888881881",
- /*do_copy=*/true);
- const std::string verifier = R"(
-function luhn(candidate)
- local sum = 0
- local num_digits = string.len(candidate)
- local parity = num_digits % 2
- for pos = 1,num_digits do
- d = tonumber(string.sub(candidate, pos, pos))
- if pos % 2 ~= parity then
- d = d * 2
- end
- if d > 9 then
- d = d - 9
- end
- sum = sum + d
- end
- return (sum % 10) == 0
-end
-return luhn(match[1].text);
- )";
- auto regex_pattern = unilib_.CreateRegexPattern(pattern);
- ASSERT_TRUE(regex_pattern != nullptr);
- auto matcher = regex_pattern->Matcher(message);
- ASSERT_TRUE(matcher != nullptr);
- int status = UniLib::RegexMatcher::kNoError;
- ASSERT_TRUE(matcher->Find(&status) &&
- status == UniLib::RegexMatcher::kNoError);
-
- EXPECT_TRUE(VerifyMatch(message.ToUTF8String(), matcher.get(), verifier));
-}
-#endif
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/resources.cc b/utils/resources.cc
deleted file mode 100644
index ddfa499..0000000
--- a/utils/resources.cc
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/resources.h"
-#include "utils/base/logging.h"
-#include "utils/zlib/buffer_generated.h"
-#include "utils/zlib/zlib.h"
-
-namespace libtextclassifier3 {
-namespace {
-bool isWildcardMatch(const flatbuffers::String* left,
- const std::string& right) {
- return (left == nullptr || right.empty());
-}
-
-bool isExactMatch(const flatbuffers::String* left, const std::string& right) {
- if (left == nullptr) {
- return right.empty();
- }
- return left->str() == right;
-}
-
-} // namespace
-
-int Resources::LocaleMatch(const Locale& locale,
- const LanguageTag* entry_locale) const {
- int match = LOCALE_NO_MATCH;
- if (isExactMatch(entry_locale->language(), locale.Language())) {
- match |= LOCALE_LANGUAGE_MATCH;
- } else if (isWildcardMatch(entry_locale->language(), locale.Language())) {
- match |= LOCALE_LANGUAGE_WILDCARD_MATCH;
- }
-
- if (isExactMatch(entry_locale->script(), locale.Script())) {
- match |= LOCALE_SCRIPT_MATCH;
- } else if (isWildcardMatch(entry_locale->script(), locale.Script())) {
- match |= LOCALE_SCRIPT_WILDCARD_MATCH;
- }
-
- if (isExactMatch(entry_locale->region(), locale.Region())) {
- match |= LOCALE_REGION_MATCH;
- } else if (isWildcardMatch(entry_locale->region(), locale.Region())) {
- match |= LOCALE_REGION_WILDCARD_MATCH;
- }
-
- return match;
-}
-
-const ResourceEntry* Resources::FindResource(
- const StringPiece resource_name) const {
- if (resources_ == nullptr || resources_->resource_entry() == nullptr) {
- TC3_LOG(ERROR) << "No resources defined.";
- return nullptr;
- }
- const ResourceEntry* entry =
- resources_->resource_entry()->LookupByKey(resource_name.data());
- if (entry == nullptr) {
- TC3_LOG(ERROR) << "Resource " << resource_name.ToString() << " not found";
- return nullptr;
- }
- return entry;
-}
-
-int Resources::BestResourceForLocales(
- const ResourceEntry* resource, const std::vector<Locale>& locales) const {
- // Find best match based on locale.
- int resource_id = -1;
- int locale_match = LOCALE_NO_MATCH;
- const auto* resources = resource->resource();
- for (int user_locale = 0; user_locale < locales.size(); user_locale++) {
- if (!locales[user_locale].IsValid()) {
- continue;
- }
- for (int i = 0; i < resources->size(); i++) {
- for (const int locale_id : *resources->Get(i)->locale()) {
- const int candidate_match = LocaleMatch(
- locales[user_locale], resources_->locale()->Get(locale_id));
-
- // Only consider if at least the language matches.
- if ((candidate_match & LOCALE_LANGUAGE_MATCH) == 0 &&
- (candidate_match & LOCALE_LANGUAGE_WILDCARD_MATCH) == 0) {
- continue;
- }
-
- if (candidate_match > locale_match) {
- locale_match = candidate_match;
- resource_id = i;
- }
- }
- }
-
- // If the language matches exactly, we are already finished.
- // We found an exact language match.
- if (locale_match & LOCALE_LANGUAGE_MATCH) {
- return resource_id;
- }
- }
- return resource_id;
-}
-
-bool Resources::GetResourceContent(const std::vector<Locale>& locales,
- const StringPiece resource_name,
- std::string* result) const {
- const ResourceEntry* entry = FindResource(resource_name);
- if (entry == nullptr || entry->resource() == nullptr) {
- return false;
- }
-
- int resource_id = BestResourceForLocales(entry, locales);
- if (resource_id < 0) {
- return false;
- }
- const auto* resource = entry->resource()->Get(resource_id);
- if (resource->content() != nullptr) {
- *result = resource->content()->str();
- return true;
- } else if (resource->compressed_content() != nullptr) {
- std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(
- resources_->compression_dictionary()->data(),
- resources_->compression_dictionary()->size());
- if (decompressor != nullptr &&
- decompressor->MaybeDecompress(resource->compressed_content(), result)) {
- return true;
- }
- }
- return false;
-}
-
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary,
- const int dictionary_sample_every) {
- std::vector<unsigned char> dictionary;
- if (build_compression_dictionary) {
- {
- // Build up a compression dictionary.
- std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
- int i = 0;
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- i++;
-
- // Use a sample of the entries to build up a custom compression
- // dictionary. Using all entries will generally not give a benefit
- // for small data sizes, so we subsample here.
- if (i % dictionary_sample_every != 0) {
- continue;
- }
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
- }
- }
- compressor->GetDictionary(&dictionary);
- resources->compression_dictionary.assign(
- dictionary.data(), dictionary.data() + dictionary.size());
- }
- }
-
- for (auto& entry : resources->resource_entry) {
- for (auto& resource : entry->resource) {
- if (resource->content.empty()) {
- continue;
- }
- // Try compressing the data.
- std::unique_ptr<ZlibCompressor> compressor =
- build_compression_dictionary
- ? ZlibCompressor::Instance(dictionary.data(), dictionary.size())
- : ZlibCompressor::Instance();
- if (!compressor) {
- TC3_LOG(ERROR) << "Cannot create zlib compressor.";
- return false;
- }
-
- CompressedBufferT compressed_content;
- compressor->Compress(resource->content, &compressed_content);
-
- // Only keep compressed version if smaller.
- if (compressed_content.uncompressed_size >
- compressed_content.buffer.size()) {
- resource->content.clear();
- resource->compressed_content.reset(new CompressedBufferT);
- *resource->compressed_content = compressed_content;
- }
- }
- }
- return true;
-}
-
-std::string CompressSerializedResources(const std::string& resources,
- const int dictionary_sample_every) {
- std::unique_ptr<ResourcePoolT> unpacked_resources(
- flatbuffers::GetRoot<ResourcePool>(resources.data())->UnPack());
- TC3_CHECK(unpacked_resources != nullptr);
- TC3_CHECK(
- CompressResources(unpacked_resources.get(), dictionary_sample_every));
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(ResourcePool::Pack(builder, unpacked_resources.get()));
- return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/resources.fbs b/utils/resources.fbs
deleted file mode 100755
index a88c56d..0000000
--- a/utils/resources.fbs
+++ /dev/null
@@ -1,46 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-include "utils/zlib/buffer.fbs";
-
-namespace libtextclassifier3;
-table Resource {
- locale:[int];
- content:string;
- compressed_content:CompressedBuffer;
-}
-
-namespace libtextclassifier3;
-table ResourceEntry {
- name:string (key);
- resource:[Resource];
-}
-
-// BCP 47 tag for the supported locale.
-namespace libtextclassifier3;
-table LanguageTag {
- language:string;
- script:string;
- region:string;
-}
-
-namespace libtextclassifier3;
-table ResourcePool {
- locale:[LanguageTag];
- resource_entry:[ResourceEntry];
- compression_dictionary:[ubyte];
-}
-
diff --git a/utils/resources.h b/utils/resources.h
deleted file mode 100644
index 28db0cc..0000000
--- a/utils/resources.h
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
-#define LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
-
-#include <vector>
-
-#include "utils/i18n/locale.h"
-#include "utils/resources_generated.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Class for accessing localized model resources.
-class Resources {
- public:
- explicit Resources(const ResourcePool* resources) : resources_(resources) {}
-
- // Returns the string value associated with the particular resource.
- // `locales` are locales in preference order.
- bool GetResourceContent(const std::vector<Locale>& locales,
- const StringPiece resource_name,
- std::string* result) const;
-
- private:
- // Match priorities: language > script > region with wildcard matches being
- // weaker than an exact match.
- // For a resource lookup, at least language needs to (weakly) match.
- // c.f. developer.android.com/guide/topics/resources/multilingual-support
- enum LocaleMatch {
- LOCALE_NO_MATCH = 0,
- LOCALE_REGION_WILDCARD_MATCH = 1 << 0,
- LOCALE_REGION_MATCH = 1 << 1,
- LOCALE_SCRIPT_WILDCARD_MATCH = 1 << 2,
- LOCALE_SCRIPT_MATCH = 1 << 3,
- LOCALE_LANGUAGE_WILDCARD_MATCH = 1 << 4,
- LOCALE_LANGUAGE_MATCH = 1 << 5
- };
- int LocaleMatch(const Locale& locale, const LanguageTag* entry_locale) const;
-
- // Finds a resource entry by name.
- const ResourceEntry* FindResource(const StringPiece resource_name) const;
-
- // Finds the best locale matching resource from a resource entry.
- int BestResourceForLocales(const ResourceEntry* resource,
- const std::vector<Locale>& locales) const;
-
- const ResourcePool* resources_;
-};
-
-// Compresses resources in place.
-bool CompressResources(ResourcePoolT* resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-std::string CompressSerializedResources(
- const std::string& resources,
- const bool build_compression_dictionary = false,
- const int dictionary_sample_every = 1);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/utils/sentencepiece/double_array_trie.cc b/utils/sentencepiece/double_array_trie.cc
deleted file mode 100644
index a2b66ea..0000000
--- a/utils/sentencepiece/double_array_trie.cc
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/double_array_trie.h"
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-bool DoubleArrayTrie::GatherPrefixMatches(
- StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
- uint32 pos = 0;
- if (nodes_length_ == 0) {
- TC3_LOG(WARNING) << "Trie is empty. Skipping.";
- return true;
- }
- pos = offset(0);
- for (int i = 0; i < input.size(); i++) {
- if (input[i] == 0) {
- break;
- }
- pos ^= static_cast<unsigned char>(input[i]);
- // We exhausted the trie, no more matches possible.
- if (pos < 0 || pos >= nodes_length_) {
- break;
- }
- if (label(pos) != input[i]) {
- break;
- }
- const bool node_has_leaf = has_leaf(pos);
- pos ^= offset(pos);
- if (pos < 0 || pos > nodes_length_) {
- TC3_LOG(ERROR) << "Out-of-bounds trie search position.";
- return false;
- }
- if (node_has_leaf) {
- update_fn(TrieMatch(/*id=*/value(pos), /*match_length=*/i + 1));
- }
- }
- return true;
-}
-
-bool DoubleArrayTrie::FindAllPrefixMatches(
- StringPiece input, std::vector<TrieMatch>* matches) const {
- return GatherPrefixMatches(
- input, [matches](const TrieMatch match) { matches->push_back(match); });
-}
-
-bool DoubleArrayTrie::LongestPrefixMatch(StringPiece input,
- TrieMatch* longest_match) const {
- *longest_match = TrieMatch();
- return GatherPrefixMatches(input, [longest_match](const TrieMatch match) {
- *longest_match = match;
- });
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
deleted file mode 100644
index 0614fb4..0000000
--- a/utils/sentencepiece/double_array_trie.h
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
-
-#include <functional>
-#include <vector>
-
-#include "utils/base/endian.h"
-#include "utils/base/integral_types.h"
-#include "utils/sentencepiece/matcher.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// A trie node specifies a node in the tree, either an intermediate node or
-// a leaf node.
-// A leaf node contains the id as an int of the string match. This id is encoded
-// in the lower 30 bits, thus the number of distinct ids is 2^30.
-// An intermediate node has an associated label and an offset to it's children.
-// The label is encoded in the least significant byte and must match the input
-// character during matching.
-// We account for endianness when using the node values, as they are serialized
-// (in little endian) as bytes in the flatbuffer model.
-typedef uint32 TrieNode;
-
-// A memory mappable trie, compatible with Darts::DoubleArray.
-class DoubleArrayTrie : public SentencePieceMatcher {
- public:
- // nodes and nodes_length specify the array of the nodes of the trie.
- DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
- : nodes_(nodes), nodes_length_(nodes_length) {}
-
- // Find matches that are prefixes of a string.
- bool FindAllPrefixMatches(StringPiece input,
- std::vector<TrieMatch>* matches) const override;
- // Find the longest prefix match of a string.
- bool LongestPrefixMatch(StringPiece input,
- TrieMatch* longest_match) const override;
-
- private:
- // Returns whether a node as a leaf as a child.
- bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; }
-
- // Available when a node is a leaf.
- int value(uint32 i) const {
- return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff);
- }
-
- // Label associated with a node.
- // A leaf node will have the MSB set and thus return an invalid label.
- uint32 label(uint32 i) const {
- return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff;
- }
-
- // Returns offset to children.
- uint32 offset(uint32 i) const {
- const uint32 node = LittleEndian::ToHost32(nodes_[i]);
- return (node >> 10) << ((node & 0x200) >> 6);
- }
-
- bool GatherPrefixMatches(
- StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
-
- const TrieNode* nodes_;
- const int nodes_length_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
diff --git a/utils/sentencepiece/double_array_trie_test.cc b/utils/sentencepiece/double_array_trie_test.cc
deleted file mode 100644
index d7fc44b..0000000
--- a/utils/sentencepiece/double_array_trie_test.cc
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <fstream>
-#include <string>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/sentencepiece/double_array_trie.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-TEST(DoubleArrayTest, Lookup) {
- // Test trie that contains pieces "hell", "hello", "o", "there".
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- DoubleArrayTrie trie(reinterpret_cast<const TrieNode*>(config.data()),
- config.size() / sizeof(TrieNode));
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("hello there", &matches));
- EXPECT_EQ(matches.size(), 2);
- EXPECT_EQ(matches[0].id, 0 /*hell*/);
- EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
- EXPECT_EQ(matches[1].id, 1 /*hello*/);
- EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("abcd", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches("hi there", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(trie.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(
- trie.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(trie.LongestPrefixMatch("hella there", &match));
- EXPECT_EQ(match.id, 0 /*hell*/);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(trie.LongestPrefixMatch("hello there", &match));
- EXPECT_EQ(match.id, 1 /*hello*/);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(trie.LongestPrefixMatch("abcd", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(trie.LongestPrefixMatch("", &match));
- EXPECT_EQ(match.id, -1);
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
deleted file mode 100644
index 51cda30..0000000
--- a/utils/sentencepiece/encoder.cc
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/encoder.h"
-
-namespace libtextclassifier3 {
-
-bool Encoder::Encode(StringPiece normalized_text,
- std::vector<int>* encoded_text) const {
- const int len = normalized_text.size();
- if (len <= 0) {
- *encoded_text = {start_code_, end_code_};
- return true;
- }
- // We use `previous_pos` to indicate whether a dynamic programming state was
- // reachable.
- std::vector<SegmentationEntry> segmentation(
- len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
- /*num_pieces=*/0});
- for (int i = 0; i < len; i++) {
- // State couldn't be reached.
- if (i > 0 && segmentation[i].previous_pos < 0) {
- // Advance position.
- normalized_text.RemovePrefix(1);
- continue;
- }
- // Check whether we can use the unknown token.
- if (unknown_code_ >= 0) {
- const int pos = i + 1;
- const float unknown_penalty = segmentation[i].score + unknown_score_;
- if (segmentation[pos].previous_pos < 0 ||
- segmentation[pos].score < unknown_penalty) {
- // Merge multiple unknown tokens into one.
- if (segmentation[i].piece_id == unknown_code_) {
- segmentation[pos] = {/*score=*/unknown_penalty,
- /*previous_pos=*/segmentation[i].previous_pos,
- /*piece_id=*/unknown_code_,
- /*num_pieces=*/segmentation[i].num_pieces};
- } else {
- segmentation[pos] = {/*score=*/unknown_penalty,
- /*previous_pos=*/i,
- /*piece_id=*/unknown_code_,
- /*num_pieces=*/segmentation[i].num_pieces + 1};
- }
- }
- }
- std::vector<TrieMatch> matches;
- if (!matcher_->FindAllPrefixMatches(normalized_text, &matches)) {
- TC3_LOG(ERROR)
- << "Couldn't successfully gather prefix sentence piece matches.";
- return false;
- }
- for (const auto& match : matches) {
- TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
- const int pos = i + match.match_length;
- const float candidate_score = segmentation[i].score + scores_[match.id];
- if (segmentation[pos].previous_pos < 0 ||
- segmentation[pos].score < candidate_score) {
- segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
- /*piece_id=*/match.id + encoding_offset_,
- /*num_pieces=*/segmentation[i].num_pieces + 1};
- }
- }
- // Advance position.
- normalized_text.RemovePrefix(1);
- }
- if (segmentation[len].num_pieces <= 0) {
- *encoded_text = {start_code_, end_code_};
- return true;
- }
- const int num_pieces = segmentation[len].num_pieces;
- encoded_text->resize(num_pieces + 2);
- (*encoded_text)[num_pieces + 1] = end_code_;
- int pos = len;
- for (int i = num_pieces; i > 0; i--) {
- (*encoded_text)[i] = segmentation[pos].piece_id;
- pos = segmentation[pos].previous_pos;
- }
- (*encoded_text)[0] = start_code_;
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
deleted file mode 100644
index 6c69077..0000000
--- a/utils/sentencepiece/encoder.h
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
-
-#include <vector>
-
-#include "utils/base/logging.h"
-#include "utils/sentencepiece/matcher.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Encoder to segment/tokenize strings into pieces such that the sum of the
-// scores of the pieces used is maximized.
-class Encoder {
- public:
- // matcher: the list of valid sentence pieces represented as a matcher, e.g.
- // a trie.
- // num_pieces: the number of pieces in the trie.
- // pieces_scores: the scores of the individual pieces.
- // start_code: code that is used as encoding of the start of input.
- // end_code: code that is used as encoding of the end of input.
- // encoding_offset: value added to the sentence piece ids to make them
- // not interesecting with start_code and end_code.
- // unknown_code: code that is used for out-of-dictionary characters.
- // unknown_score: the penality score associated with the unknown code.
- Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
- const float* pieces_scores, int start_code = 0, int end_code = 1,
- int encoding_offset = 2, int unknown_code = -1,
- float unknown_score = 0.f)
- : num_pieces_(num_pieces),
- scores_(pieces_scores),
- matcher_(matcher),
- start_code_(start_code),
- end_code_(end_code),
- encoding_offset_(encoding_offset),
- unknown_code_(unknown_code),
- unknown_score_(unknown_score) {}
-
- // Segment the input so that the total score of the pieces used is maximized.
- // This is a simplified implementation of the general Viterbi algorithm,
- // assuming independence between individual pieces.
- bool Encode(StringPiece normalized_text,
- std::vector<int>* encoded_text) const;
-
- private:
- // State in the dynamic programming algorithm.
- struct SegmentationEntry {
- // Accumulated score.
- float score;
-
- // Position before last piece.
- int previous_pos;
-
- // Last piece used.
- int piece_id;
-
- // Total number of pieces used.
- int num_pieces;
- };
-
- const int num_pieces_;
- const float* scores_;
- const SentencePieceMatcher* matcher_;
- const int start_code_;
- const int end_code_;
- const int encoding_offset_;
- const int unknown_code_;
- const int unknown_score_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
deleted file mode 100644
index 9082cca..0000000
--- a/utils/sentencepiece/encoder_test.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <memory>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/base/integral_types.h"
-#include "utils/sentencepiece/encoder.h"
-#include "utils/sentencepiece/sorted_strings_table.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAre;
-
-TEST(EncoderTest, SimpleTokenization) {
- const char pieces[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
- const Encoder encoder(matcher.get(),
- /*num_pieces=*/4, scores);
-
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 5, 1));
- }
-
- // Make probability of hello very low:
- // hello gets now tokenized as hell + o.
- scores[1] = -100.0;
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 2, 4, 5, 1));
- }
-}
-
-TEST(EncoderTest, HandlesEdgeCases) {
- const char pieces[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
- const Encoder encoder(matcher.get(),
- /*num_pieces=*/4, scores);
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 2, 3, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 2, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
-}
-
-TEST(EncoderTest, HandlesOutOfDictionary) {
- const char pieces[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
- float scores[] = {-0.5, -1.0, -10.0, -1.0};
- std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
- /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
- const Encoder encoder(matcher.get(),
- /*num_pieces=*/4, scores,
- /*start_code=*/0, /*end_code=*/1,
- /*encoding_offset=*/3, /*unknown_code=*/2,
- /*unknown_score=*/-100.0);
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellhello", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 3, 4, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellohell", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 4, 3, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("", &encoded_text));
- EXPECT_THAT(encoded_text, ElementsAre(0, 1));
- }
- {
- std::vector<int> encoded_text;
- EXPECT_TRUE(encoder.Encode("hellathere", &encoded_text));
- EXPECT_THAT(encoded_text,
- ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/matcher.h b/utils/sentencepiece/matcher.h
deleted file mode 100644
index 47e6560..0000000
--- a/utils/sentencepiece/matcher.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
-
-#include <vector>
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-struct TrieMatch {
- TrieMatch() {}
- TrieMatch(int id, int match_length) : id(id), match_length(match_length) {}
- int id = -1;
- int match_length = -1;
-};
-
-class SentencePieceMatcher {
- public:
- virtual ~SentencePieceMatcher() {}
-
- // Find matches that are prefixes of a string.
- virtual bool FindAllPrefixMatches(StringPiece input,
- std::vector<TrieMatch>* matches) const = 0;
-
- // Find the longest prefix match of a string.
- virtual bool LongestPrefixMatch(StringPiece input,
- TrieMatch* longest_match) const = 0;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/normalizer.cc b/utils/sentencepiece/normalizer.cc
deleted file mode 100644
index 9d893fd..0000000
--- a/utils/sentencepiece/normalizer.cc
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/normalizer.h"
-
-#include "utils/base/logging.h"
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-
-bool SentencePieceNormalizer::Normalize(StringPiece input,
- std::string* normalized_input) const {
- // Ignores heading space.
- if (remove_extra_whitespaces_) {
- while (!input.empty()) {
- std::pair<StringPiece, int> suffix_and_length;
- if (!NormalizePrefix(input, &suffix_and_length)) {
- TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
- return false;
- }
- if (suffix_and_length.second <= 0) {
- TC3_LOG(ERROR) << "Consumed string is empty.";
- return false;
- }
- if (suffix_and_length.first.size() != 1 ||
- suffix_and_length.first[0] != ' ') {
- break;
- }
- input.RemovePrefix(suffix_and_length.second);
- }
- }
-
- if (input.empty()) {
- *normalized_input = "";
- return true;
- }
-
- // Reserves the output buffer to avoid re-allocations.
- const int kReservedSize = input.size() * 3;
- normalized_input->reserve(kReservedSize);
-
- // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
- // if escape_whitespaces() is set (default = true).
- const StringPiece kSpaceSymbol = "\xe2\x96\x81";
-
- // Adds a space symbol as a prefix (default is true)
- // With this prefix, "world" and "hello world" are converted into
- // "_world" and "_hello_world", which help the trainer to extract
- // "_world" as one symbol.
- if (add_dummy_prefix_) {
- if (escape_whitespaces_) {
- normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
- } else {
- normalized_input->append(" ");
- }
- }
-
- bool is_prev_space = remove_extra_whitespaces_;
- while (!input.empty()) {
- std::pair<StringPiece, int> p;
- if (!NormalizePrefix(input, &p)) {
- TC3_LOG(ERROR) << "Couldn't normalize string.";
- return false;
- }
- if (p.second <= 0) {
- TC3_LOG(ERROR) << "Consumed string is empty.";
- return false;
- }
-
- StringPiece sp = p.first;
-
- // Removes heading spaces in sentence piece,
- // if the previous sentence piece ends with whitespace.
- while (is_prev_space && ConsumePrefix(&sp, " ")) {
- }
-
- if (!sp.empty()) {
- const char* data = sp.data();
- for (int n = 0; n < sp.size(); ++n) {
- if (escape_whitespaces_ && data[n] == ' ') {
- normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size());
- } else {
- *normalized_input += data[n];
- }
- }
- // Checks whether the last character of sp is whitespace.
- is_prev_space = EndsWith(sp, " ");
- }
- input.RemovePrefix(p.second);
- is_prev_space = is_prev_space && remove_extra_whitespaces_;
- }
-
- // Ignores tailing space.
- if (remove_extra_whitespaces_) {
- const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " ";
- while (EndsWith(*normalized_input, space)) {
- const int length = normalized_input->size() - space.size();
- normalized_input->resize(length);
- }
- }
- return true;
-}
-
-bool SentencePieceNormalizer::NormalizePrefix(
- StringPiece input, std::pair<StringPiece, int>* prefix) const {
- if (input.empty()) return true;
- TrieMatch match;
- if (!charsmap_trie_.LongestPrefixMatch(input, &match)) {
- TC3_LOG(ERROR) << "Couldn't find match in normalization table.";
- return false;
- }
- const bool no_match = match.match_length <= 0;
- if (no_match) {
- const int char_length = ValidUTF8CharLength(input.data(), input.size());
- if (char_length <= 0) {
- // Found a malformed utf8.
- // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER),
- // which is a valid Unicode of three bytes in utf8,
- // but here we only consume one byte.
- static const char kReplacementChar[] = "\xEF\xBF\xBD";
- prefix->first = StringPiece(kReplacementChar, 3);
- prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD.
- } else {
- prefix->first = StringPiece(input.data(), char_length);
- prefix->second = char_length;
- }
- } else {
- if (match.id < 0 || match.id >= charsmap_normalized_.size()) {
- TC3_LOG(ERROR) << "Invalid entry in normalization table.";
- return false;
- }
- prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]);
- prefix->second = match.match_length;
- }
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/normalizer.h b/utils/sentencepiece/normalizer.h
deleted file mode 100644
index 1d3aeb5..0000000
--- a/utils/sentencepiece/normalizer.h
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
-
-#include <memory>
-#include <string>
-
-#include "utils/sentencepiece/double_array_trie.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Normalizer implements a simple text normalizer with user-defined
-// string-to-string rules and leftmost longest matching.
-class SentencePieceNormalizer {
- public:
- // charsmap_trie and charsmap_normalized specify the normalization/replacement
- // string-to-string rules in the following way:
- // A match in the trie for a string will return the offset in
- // charsmap_normalized that contains the replacement string.
- //
- // add_dummy_prefix: Whether to add dummy whitespace at the beginning of the
- // text in order to treat "world" in "world" and "hello world" uniformly.
- //
- // remove_extra_whitespaces: Whether to remove leading, trailing and duplicate
- // internal whitespace.
- //
- // escape_whitespaces: Whether to replace whitespace with a meta symbol.
- SentencePieceNormalizer(const DoubleArrayTrie& charsmap_trie,
- StringPiece charsmap_normalized,
- bool add_dummy_prefix = true,
- bool remove_extra_whitespaces = true,
- bool escape_whitespaces = true)
- : charsmap_trie_(charsmap_trie),
- charsmap_normalized_(charsmap_normalized),
- add_dummy_prefix_(add_dummy_prefix),
- remove_extra_whitespaces_(remove_extra_whitespaces),
- escape_whitespaces_(escape_whitespaces) {}
-
- // Normalizes a plain utf8 string into an internal representation for
- // Sentencepiece model.
- bool Normalize(StringPiece input, std::string* normalized_input) const;
-
- private:
- // Normalizes the prefix of `input` and returns the pair of
- // normalized prefix and the length of the prefix of `input` processed in the
- // normalization.
- bool NormalizePrefix(StringPiece input,
- std::pair<StringPiece, int>* prefix) const;
-
- // Internal trie for efficient longest prefix string matching.
- DoubleArrayTrie charsmap_trie_;
-
- // "\0" delimitered concatenated normalized strings.
- // the value of `charsmap_trie_` stores offsets into this string.
- StringPiece charsmap_normalized_;
-
- const bool add_dummy_prefix_;
- const bool remove_extra_whitespaces_;
- const bool escape_whitespaces_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_NORMALIZER_H_
diff --git a/utils/sentencepiece/normalizer_test.cc b/utils/sentencepiece/normalizer_test.cc
deleted file mode 100644
index a5d6bf9..0000000
--- a/utils/sentencepiece/normalizer_test.cc
+++ /dev/null
@@ -1,198 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <fstream>
-#include <string>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/sentencepiece/double_array_trie.h"
-#include "utils/sentencepiece/normalizer.h"
-#include "utils/sentencepiece/test_utils.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-TEST(NormalizerTest, NormalizesAsReferenceNormalizer) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "▁hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "▁when▁is▁the▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "▁general▁kenobi");
- }
-
- // NFKC char to multi-char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
- EXPECT_EQ(normalized, "▁株式会社");
- }
-
- // Half width katakana, character composition happens.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
- EXPECT_EQ(normalized, "▁グーグル");
- }
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
- EXPECT_EQ(normalized, "▁123");
- }
-}
-
-TEST(NormalizerTest, NoDummyPrefix) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when▁is▁the▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general▁kenobi");
- }
-
- // NFKC char to multi-char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("㍿", &normalized));
- EXPECT_EQ(normalized, "株式会社");
- }
-
- // Half width katakana, character composition happens.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize(" グーグル ", &normalized));
- EXPECT_EQ(normalized, "グーグル");
- }
-
- // NFKC char to char normalization.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("①②③", &normalized));
- EXPECT_EQ(normalized, "123");
- }
-}
-
-TEST(NormalizerTest, NoRemoveExtraWhitespace) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/true);
-
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello▁there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when▁is▁▁the▁▁world▁cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general▁kenobi");
- }
-}
-
-TEST(NormalizerTest, NoEscapeWhitespaces) {
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- SentencePieceNormalizer normalizer =
- NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/false);
-
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("hello there", &normalized));
- EXPECT_EQ(normalized, "hello there");
- }
-
- // Redundant whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("when is the world cup?", &normalized));
- EXPECT_EQ(normalized, "when is the world cup?");
- }
-
- // Different whitespace.
- {
- std::string normalized;
- EXPECT_TRUE(normalizer.Normalize("general\tkenobi", &normalized));
- EXPECT_EQ(normalized, "general kenobi");
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.cc b/utils/sentencepiece/sorted_strings_table.cc
deleted file mode 100644
index 8e7e9ba..0000000
--- a/utils/sentencepiece/sorted_strings_table.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/sorted_strings_table.h"
-
-#include <algorithm>
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-void SortedStringsTable::GatherPrefixMatches(
- StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
- int left = 0;
- int right = num_pieces_;
- int span_size = right - left;
- int match_length = 0;
-
- // Loop invariant:
- // at the ith iteration, all strings from `left` ... `right` match the input
- // on the first `match_length` characters.
- while (span_size > use_linear_scan_threshold_) {
- if (match_length >= input.length()) {
- return;
- }
-
- // We find the possible range of pieces in `left` ... `right` matching the
- // `match_length` + 1 character with two binary searches:
- // `lower_bound` to find the start of the range of matching pieces.
- // `upper_bound` to find the non-inclusive end of the range.
- left = (std::lower_bound(
- offsets_ + left, offsets_ + right,
- static_cast<unsigned char>(input[match_length]),
- [this, match_length](uint32 piece_offset, uint32 c) -> bool {
- return static_cast<unsigned char>(
- pieces_[piece_offset + match_length]) < c;
- }) -
- offsets_);
- right = (std::upper_bound(
- offsets_ + left, offsets_ + right,
- static_cast<unsigned char>(input[match_length]),
- [this, match_length](uint32 c, uint32 piece_offset) -> bool {
- return c < static_cast<unsigned char>(
- pieces_[piece_offset + match_length]);
- }) -
- offsets_);
- span_size = right - left;
- if (span_size <= 0) {
- return;
- }
- ++match_length;
-
- // Due to the loop invariant and the fact that the strings are sorted, there
- // can only be one piece matching completely now, namely at left.
- if (pieces_[offsets_[left] + match_length] == 0) {
- update_fn(TrieMatch(/*id=*/left,
- /*match_length=*/match_length));
- left++;
- }
- }
-
- // Use linear scan for small problem instances.
- // By the loop invariant characters 0...`match_length` of all pieces in
- // in `left`...`right` match the input on 0...`match_length`.
- for (int i = left; i < right; i++) {
- bool matches = true;
- int piece_match_length = match_length;
- for (int k = offsets_[i] + piece_match_length; pieces_[k] != 0; k++) {
- if (match_length >= input.size() ||
- input[piece_match_length] != pieces_[k]) {
- matches = false;
- break;
- }
- piece_match_length++;
- }
- if (matches) {
- update_fn(TrieMatch(/*id=*/i,
- /*match_length=*/piece_match_length));
- }
- }
-}
-
-bool SortedStringsTable::FindAllPrefixMatches(
- StringPiece input, std::vector<TrieMatch>* matches) const {
- GatherPrefixMatches(
- input, [matches](const TrieMatch match) { matches->push_back(match); });
- return true;
-}
-
-bool SortedStringsTable::LongestPrefixMatch(StringPiece input,
- TrieMatch* longest_match) const {
- *longest_match = TrieMatch();
- GatherPrefixMatches(input, [longest_match](const TrieMatch match) {
- *longest_match = match;
- });
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
deleted file mode 100644
index 69f638a..0000000
--- a/utils/sentencepiece/sorted_strings_table.h
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
-
-#include <functional>
-#include <vector>
-
-#include "utils/base/integral_types.h"
-#include "utils/sentencepiece/matcher.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// A matcher to find string pieces matching prefixes of an input string.
-// The list of reference strings are kept in sorted order in a zero separated
-// string.
-// binary search is used to find all prefix matches.
-// num_pieces: Number of sentence pieces.
-// offsets: Offsets into `pieces` where a string starts.
-// pieces: String pieces, concatenated in sorted order and zero byte separated.
-// use_linear_scan_threshold: Minimum size of binary search range before
-// switching to a linear sweep for prefix match testing.
-class SortedStringsTable : public SentencePieceMatcher {
- public:
- SortedStringsTable(const int num_pieces, const uint32* offsets,
- StringPiece pieces,
- const int use_linear_scan_threshold = 10)
- : num_pieces_(num_pieces),
- offsets_(offsets),
- pieces_(pieces),
- use_linear_scan_threshold_(use_linear_scan_threshold) {}
-
- // Find matches that are prefixes of a string.
- bool FindAllPrefixMatches(StringPiece input,
- std::vector<TrieMatch>* matches) const override;
- // Find the longest prefix match of a string.
- bool LongestPrefixMatch(StringPiece input,
- TrieMatch* longest_match) const override;
-
- private:
- void GatherPrefixMatches(
- StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
-
- const int num_pieces_;
- const uint32* offsets_;
- const StringPiece pieces_;
- const int use_linear_scan_threshold_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_SORTED_STRINGS_TABLE_H_
diff --git a/utils/sentencepiece/sorted_strings_table_test.cc b/utils/sentencepiece/sorted_strings_table_test.cc
deleted file mode 100644
index 4dff29d..0000000
--- a/utils/sentencepiece/sorted_strings_table_test.cc
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/base/integral_types.h"
-#include "utils/sentencepiece/sorted_strings_table.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(SortedStringsTest, Lookup) {
- const char pieces[] = "hell\0hello\0o\0there\0";
- const uint32 offsets[] = {0, 5, 11, 13};
-
- SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
- /*use_linear_scan_threshold=*/1);
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("hello there", &matches));
- EXPECT_EQ(matches.size(), 2);
- EXPECT_EQ(matches[0].id, 0 /*hell*/);
- EXPECT_EQ(matches[0].match_length, 4 /*hell*/);
- EXPECT_EQ(matches[1].id, 1 /*hello*/);
- EXPECT_EQ(matches[1].match_length, 5 /*hello*/);
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("he", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("abcd", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches("hi there", &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(table.FindAllPrefixMatches(StringPiece("\0", 1), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- std::vector<TrieMatch> matches;
- EXPECT_TRUE(
- table.FindAllPrefixMatches(StringPiece("\xff, \xfe", 2), &matches));
- EXPECT_THAT(matches, testing::IsEmpty());
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(table.LongestPrefixMatch("hella there", &match));
- EXPECT_EQ(match.id, 0 /*hell*/);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(table.LongestPrefixMatch("hello there", &match));
- EXPECT_EQ(match.id, 1 /*hello*/);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(table.LongestPrefixMatch("abcd", &match));
- EXPECT_EQ(match.id, -1);
- }
-
- {
- TrieMatch match;
- EXPECT_TRUE(table.LongestPrefixMatch("", &match));
- EXPECT_EQ(match.id, -1);
- }
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.cc b/utils/sentencepiece/test_utils.cc
deleted file mode 100644
index 1ed2bf3..0000000
--- a/utils/sentencepiece/test_utils.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/sentencepiece/test_utils.h"
-
-#include <memory>
-
-#include "utils/base/integral_types.h"
-#include "utils/sentencepiece/double_array_trie.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
- bool add_dummy_prefix,
- bool remove_extra_whitespaces,
- bool escape_whitespaces) {
- const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
- spec.RemovePrefix(sizeof(trie_blob_size));
- const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
- spec.RemovePrefix(trie_blob_size);
- const int num_nodes = trie_blob_size / sizeof(TrieNode);
- return SentencePieceNormalizer(
- DoubleArrayTrie(trie_blob, num_nodes),
- /*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
- add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/test_utils.h b/utils/sentencepiece/test_utils.h
deleted file mode 100644
index 0c833da..0000000
--- a/utils/sentencepiece/test_utils.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/sentencepiece/normalizer.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
- bool add_dummy_prefix,
- bool remove_extra_whitespaces,
- bool escape_whitespaces);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_TEST_UTILS_H_
diff --git a/utils/strings/numbers.cc b/utils/strings/numbers.cc
deleted file mode 100644
index 3028c69..0000000
--- a/utils/strings/numbers.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/numbers.h"
-
-#ifdef COMPILER_MSVC
-#include <sstream>
-#endif // COMPILER_MSVC
-
-#include <stdlib.h>
-
-namespace libtextclassifier3 {
-
-bool ParseInt32(const char *c_str, int32 *value) {
- char *temp;
-
- // Short version of man strtol:
- //
- // strtol parses some optional whitespaces, an optional +/- sign, and next a
- // succession of digits. If it finds some digits, it sets temp to point to
- // the first character after that succession of digits and returns the parsed
- // integer.
- //
- // If there were no digits at all, strtol() sets temp to be c_str (the start
- // address) and returns 0.
- *value = strtol(c_str, &temp, 0); // NOLINT
-
- // temp != c_str means that the input string contained at least one digit (see
- // above). *temp == '\0' means the input string does not contain any random
- // chars after the number.
- return (temp != c_str) && (*temp == '\0');
-}
-
-bool ParseInt64(const char *c_str, int64 *value) {
- char *temp;
- *value = strtoll(c_str, &temp, 0); // NOLINT
-
- // See comments inside ParseInt32.
- return (temp != c_str) && (*temp == '\0');
-}
-
-bool ParseDouble(const char *c_str, double *value) {
- char *temp;
- *value = strtod(c_str, &temp);
-
- // See comments inside ParseInt32.
- return (temp != c_str) && (*temp == '\0');
-}
-
-#ifdef COMPILER_MSVC
-std::string IntToString(int64 input) {
- std::stringstream stream;
- stream << input;
- return stream.str();
-}
-#else
-std::string IntToString(int64 input) {
- return std::to_string(input);
-}
-#endif // COMPILER_MSVC
-
-} // namespace libtextclassifier3
diff --git a/utils/strings/numbers_test.cc b/utils/strings/numbers_test.cc
deleted file mode 100644
index 57e812f..0000000
--- a/utils/strings/numbers_test.cc
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/numbers.h"
-
-#include "utils/base/integral_types.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-void TestParseInt32(const char *c_str, bool expected_parsing_success,
- int32 expected_parsed_value = 0) {
- int32 parsed_value = 0;
- EXPECT_EQ(expected_parsing_success, ParseInt32(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_EQ(expected_parsed_value, parsed_value);
- }
-}
-
-TEST(ParseInt32Test, Normal) {
- TestParseInt32("2", true, 2);
- TestParseInt32("-357", true, -357);
- TestParseInt32("7", true, 7);
- TestParseInt32("+7", true, 7);
- TestParseInt32(" +7", true, 7);
- TestParseInt32("-23", true, -23);
- TestParseInt32(" -23", true, -23);
-}
-
-TEST(ParseInt32Test, ErrorCases) {
- TestParseInt32("", false);
- TestParseInt32(" ", false);
- TestParseInt32("not-a-number", false);
- TestParseInt32("123a", false);
-}
-
-void TestParseInt64(const char *c_str, bool expected_parsing_success,
- int64 expected_parsed_value = 0) {
- int64 parsed_value = 0;
- EXPECT_EQ(expected_parsing_success, ParseInt64(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_EQ(expected_parsed_value, parsed_value);
- }
-}
-
-TEST(ParseInt64Test, Normal) {
- TestParseInt64("2", true, 2);
- TestParseInt64("-357", true, -357);
- TestParseInt64("7", true, 7);
- TestParseInt64("+7", true, 7);
- TestParseInt64(" +7", true, 7);
- TestParseInt64("-23", true, -23);
- TestParseInt64(" -23", true, -23);
-}
-
-TEST(ParseInt64Test, ErrorCases) {
- TestParseInt64("", false);
- TestParseInt64(" ", false);
- TestParseInt64("not-a-number", false);
- TestParseInt64("23z", false);
-}
-
-void TestParseDouble(const char *c_str, bool expected_parsing_success,
- double expected_parsed_value = 0.0) {
- double parsed_value = 0.0;
- EXPECT_EQ(expected_parsing_success, ParseDouble(c_str, &parsed_value));
- if (expected_parsing_success) {
- EXPECT_NEAR(expected_parsed_value, parsed_value, 0.00001);
- }
-}
-
-TEST(ParseDoubleTest, Normal) {
- TestParseDouble("2", true, 2.0);
- TestParseDouble("-357.023", true, -357.023);
- TestParseDouble("7.04", true, 7.04);
- TestParseDouble("+7.2", true, 7.2);
- TestParseDouble(" +7.236", true, 7.236);
- TestParseDouble("-23.4", true, -23.4);
- TestParseDouble(" -23.4", true, -23.4);
-}
-
-TEST(ParseDoubleTest, ErrorCases) {
- TestParseDouble("", false);
- TestParseDouble(" ", false);
- TestParseDouble("not-a-number", false);
- TestParseDouble("23.5a", false);
-}
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/strings/split.h b/utils/strings/split.h
deleted file mode 100644
index b565258..0000000
--- a/utils/strings/split.h
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
-#define LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
-
-#include <string>
-#include <vector>
-
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace strings {
-
-std::vector<StringPiece> Split(const StringPiece &text, char delim);
-
-} // namespace strings
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_SPLIT_H_
diff --git a/utils/strings/stringpiece.h b/utils/strings/stringpiece.h
deleted file mode 100644
index 0dec1b8..0000000
--- a/utils/strings/stringpiece.h
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
-#define LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
-
-#include <stddef.h>
-#include <string>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-// Read-only "view" of a piece of data. Does not own the underlying data.
-class StringPiece {
- public:
- StringPiece() : StringPiece(nullptr, 0) {}
-
- StringPiece(const char *str) // NOLINT(runtime/explicit)
- : start_(str), size_(str == nullptr ? 0 : strlen(str)) {}
-
- StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
-
- // Intentionally no "explicit" keyword: in function calls, we want strings to
- // be converted to StringPiece implicitly.
- StringPiece(const std::string &s) // NOLINT(runtime/explicit)
- : StringPiece(s.data(), s.size()) {}
-
- StringPiece(const std::string &s, int offset, int len)
- : StringPiece(s.data() + offset, len) {}
-
- char operator[](size_t i) const { return start_[i]; }
-
- // Returns start address of underlying data.
- const char *data() const { return start_; }
-
- // Returns number of bytes of underlying data.
- size_t size() const { return size_; }
- size_t length() const { return size_; }
-
- bool empty() const { return size_ == 0; }
-
- // Returns a std::string containing a copy of the underlying data.
- std::string ToString() const { return std::string(data(), size()); }
-
- // Returns whether string ends with a given suffix.
- bool EndsWith(StringPiece suffix) const {
- return suffix.empty() || (size_ >= suffix.size() &&
- memcmp(start_ + (size_ - suffix.size()),
- suffix.data(), suffix.size()) == 0);
- }
-
- // Returns whether the string begins with a given prefix.
- bool StartsWith(StringPiece prefix) const {
- return prefix.empty() ||
- (size_ >= prefix.size() &&
- memcmp(start_, prefix.data(), prefix.size()) == 0);
- }
-
- bool Equals(StringPiece other) const {
- return size() == other.size() && memcmp(start_, other.data(), size_) == 0;
- }
-
- // Removes the first `n` characters from the string piece. Note that the
- // underlying string is not changed, only the view.
- void RemovePrefix(int n) {
- TC3_CHECK_LE(n, size_);
- start_ += n;
- size_ -= n;
- }
-
- private:
- const char *start_; // Not owned.
- size_t size_;
-};
-
-inline bool EndsWith(StringPiece text, StringPiece suffix) {
- return text.EndsWith(suffix);
-}
-
-inline bool StartsWith(StringPiece text, StringPiece prefix) {
- return text.StartsWith(prefix);
-}
-
-inline bool ConsumePrefix(StringPiece *text, StringPiece prefix) {
- if (!text->StartsWith(prefix)) {
- return false;
- }
- text->RemovePrefix(prefix.size());
- return true;
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
diff --git a/utils/strings/stringpiece_test.cc b/utils/strings/stringpiece_test.cc
deleted file mode 100644
index 713a7f9..0000000
--- a/utils/strings/stringpiece_test.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(StringPieceTest, EndsWith) {
- EXPECT_TRUE(EndsWith("hello there!", "there!"));
- EXPECT_TRUE(EndsWith("hello there!", "!"));
- EXPECT_FALSE(EndsWith("hello there!", "there"));
- EXPECT_FALSE(EndsWith("hello there!", " hello there!"));
- EXPECT_TRUE(EndsWith("hello there!", ""));
- EXPECT_FALSE(EndsWith("", "hello there!"));
-}
-
-TEST(StringPieceTest, StartsWith) {
- EXPECT_TRUE(StartsWith("hello there!", "hello"));
- EXPECT_TRUE(StartsWith("hello there!", "hello "));
- EXPECT_FALSE(StartsWith("hello there!", "there!"));
- EXPECT_FALSE(StartsWith("hello there!", " hello there! "));
- EXPECT_TRUE(StartsWith("hello there!", ""));
- EXPECT_FALSE(StartsWith("", "hello there!"));
-}
-
-TEST(StringPieceTest, ConsumePrefix) {
- StringPiece str("hello there!");
- EXPECT_TRUE(ConsumePrefix(&str, "hello "));
- EXPECT_EQ(str.ToString(), "there!");
- EXPECT_TRUE(ConsumePrefix(&str, "there"));
- EXPECT_EQ(str.ToString(), "!");
- EXPECT_FALSE(ConsumePrefix(&str, "!!"));
- EXPECT_TRUE(ConsumePrefix(&str, ""));
- EXPECT_TRUE(ConsumePrefix(&str, "!"));
- EXPECT_EQ(str.ToString(), "");
- EXPECT_TRUE(ConsumePrefix(&str, ""));
- EXPECT_FALSE(ConsumePrefix(&str, "!"));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/strings/utf8.cc b/utils/strings/utf8.cc
deleted file mode 100644
index faaf854..0000000
--- a/utils/strings/utf8.cc
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-bool IsValidUTF8(const char *src, int size) {
- for (int i = 0; i < size;) {
- const int char_length = ValidUTF8CharLength(src + i, size - i);
- if (char_length <= 0) {
- return false;
- }
- i += char_length;
- }
- return true;
-}
-
-int ValidUTF8CharLength(const char *src, int size) {
- // Unexpected trail byte.
- if (IsTrailByte(src[0])) {
- return -1;
- }
-
- const int num_codepoint_bytes = GetNumBytesForUTF8Char(&src[0]);
- if (num_codepoint_bytes <= 0 || num_codepoint_bytes > size) {
- return -1;
- }
-
- // Check that remaining bytes in the codepoint are trailing bytes.
- for (int k = 1; k < num_codepoint_bytes; k++) {
- if (!IsTrailByte(src[k])) {
- return -1;
- }
- }
-
- return num_codepoint_bytes;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/strings/utf8.h b/utils/strings/utf8.h
deleted file mode 100644
index 6c4c8a0..0000000
--- a/utils/strings/utf8.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
-#define LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
-
-namespace libtextclassifier3 {
-
-// Returns the length (number of bytes) of the Unicode code point starting at
-// src, based on inspecting just that one byte. Preconditions: src != NULL,
-// *src can be read, and *src is not '\0', and src points to a well-formed UTF-8
-// std::string.
-static inline int GetNumBytesForNonZeroUTF8Char(const char *src) {
- // On most platforms, char is unsigned by default, but iOS is an exception.
- // The cast below makes sure we always interpret *src as an unsigned char.
- return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"
- [(*(reinterpret_cast<const unsigned char *>(src)) & 0xFF) >> 4];
-}
-
-// Like GetNumBytesForNonZeroUTF8Char, but *src may be '\0'; returns 0 in that
-// case.
-static inline int GetNumBytesForUTF8Char(const char *src) {
- if (*src == '\0') return 0;
- return GetNumBytesForNonZeroUTF8Char(src);
-}
-
-// Returns true if this byte is a trailing UTF-8 byte (10xx xxxx)
-static inline bool IsTrailByte(char x) {
- // return (x & 0xC0) == 0x80;
- // Since trail bytes are always in [0x80, 0xBF], we can optimize:
- return static_cast<signed char>(x) < -0x40;
-}
-
-// Returns true iff src points to a well-formed UTF-8 string.
-bool IsValidUTF8(const char *src, int size);
-
-// Returns byte length of the first valid codepoint in the string, otherwise -1
-// if pointing to an ill-formed UTF-8 character.
-int ValidUTF8CharLength(const char *src, int size);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_UTF8_H_
diff --git a/utils/strings/utf8_test.cc b/utils/strings/utf8_test.cc
deleted file mode 100644
index a71d4f2..0000000
--- a/utils/strings/utf8_test.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(Utf8Test, GetNumBytesForUTF8Char) {
- EXPECT_EQ(GetNumBytesForUTF8Char("\x00"), 0);
- EXPECT_EQ(GetNumBytesForUTF8Char("h"), 1);
- EXPECT_EQ(GetNumBytesForUTF8Char("😋"), 4);
- EXPECT_EQ(GetNumBytesForUTF8Char("㍿"), 3);
-}
-
-TEST(Utf8Test, IsValidUTF8) {
- EXPECT_TRUE(IsValidUTF8("1234😋hello", 13));
- EXPECT_TRUE(IsValidUTF8("\u304A\u00B0\u106B", 8));
- EXPECT_TRUE(IsValidUTF8("this is a test😋😋😋", 26));
- EXPECT_TRUE(IsValidUTF8("\xf0\x9f\x98\x8b", 4));
- // Too short (string is too short).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f", 2));
- // Too long (too many trailing bytes).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x8b\x8b", 5));
- // Too short (too few trailing bytes).
- EXPECT_FALSE(IsValidUTF8("\xf0\x9f\x98\x61\x61", 5));
-}
-
-TEST(Utf8Test, ValidUTF8CharLength) {
- EXPECT_EQ(ValidUTF8CharLength("1234😋hello", 13), 1);
- EXPECT_EQ(ValidUTF8CharLength("\u304A\u00B0\u106B", 8), 3);
- EXPECT_EQ(ValidUTF8CharLength("this is a test😋😋😋", 26), 1);
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b", 4), 4);
- // Too short (string is too short).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f", 2), -1);
- // Too long (too many trailing bytes). First character is valid.
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x8b\x8b", 5), 4);
- // Too short (too few trailing bytes).
- EXPECT_EQ(ValidUTF8CharLength("\xf0\x9f\x98\x61\x61", 5), -1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/test-utils.cc b/utils/test-utils.cc
deleted file mode 100644
index e37105a..0000000
--- a/utils/test-utils.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/test-utils.h"
-
-#include <iterator>
-
-#include "utils/strings/split.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-using libtextclassifier3::Token;
-
-// Returns a list of Tokens for given input string. Can't handle non-ASCII
-// input.
-std::vector<Token> TokenizeAsciiOnSpace(const std::string& text) {
- std::vector<Token> result;
- for (const StringPiece token : strings::Split(text, ' ')) {
- const int start_offset = std::distance(text.data(), token.data());
- const int token_length = token.length();
- result.push_back(
- Token{token.ToString(), start_offset, start_offset + token_length});
- }
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/test-utils.h b/utils/test-utils.h
deleted file mode 100644
index 7e227dc..0000000
--- a/utils/test-utils.h
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Utilities for tests.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
-
-#include <string>
-
-#include "annotator/types.h"
-
-namespace libtextclassifier3 {
-
-// Returns a list of Tokens for given input string. Can't handle non-ASCII
-// input.
-std::vector<Token> TokenizeAsciiOnSpace(const std::string& text);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/utils/testing/annotator.h b/utils/testing/annotator.h
deleted file mode 100644
index b988d0b..0000000
--- a/utils/testing/annotator.h
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Helper utilities for testing Annotator.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
-#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
-
-#include <memory>
-#include <string>
-
-#include "annotator/model_generated.h"
-#include "annotator/types.h"
-#include "flatbuffers/flatbuffers.h"
-
-namespace libtextclassifier3 {
-
-// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
-// can modify it. Afterwards the modified unpacked model is serialized back to a
-// flatbuffer.
-template <typename Fn>
-std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
- Fn visitor_fn) {
- std::unique_ptr<ModelT> unpacked_model =
- UnPackModel(model_flatbuffer.c_str());
-
- visitor_fn(unpacked_model.get());
-
- flatbuffers::FlatBufferBuilder builder;
- FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
-
- return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
- builder.GetSize());
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/utils/testing/logging_event_listener.h b/utils/testing/logging_event_listener.h
deleted file mode 100644
index 2663a9c..0000000
--- a/utils/testing/logging_event_listener.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
-#define LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-
-// TestEventListener that writes test results to the log so that they will be
-// visible in the logcat output in Sponge.
-// The formatting of the output is patterend after the output produced by the
-// standard PrettyUnitTestResultPrinter.
-class LoggingEventListener : public ::testing::TestEventListener {
- public:
- void OnTestProgramStart(const testing::UnitTest& unit_test) override;
-
- void OnTestIterationStart(const testing::UnitTest& unit_test,
- int iteration) override;
-
- void OnEnvironmentsSetUpStart(const testing::UnitTest& unit_test) override;
-
- void OnEnvironmentsSetUpEnd(const testing::UnitTest& unit_test) override;
-
- void OnTestCaseStart(const testing::TestCase& test_case) override;
-
- void OnTestStart(const testing::TestInfo& test_info) override;
-
- void OnTestPartResult(
- const testing::TestPartResult& test_part_result) override;
-
- void OnTestEnd(const testing::TestInfo& test_info) override;
-
- void OnTestCaseEnd(const testing::TestCase& test_case) override;
-
- void OnEnvironmentsTearDownStart(const testing::UnitTest& unit_test) override;
-
- void OnEnvironmentsTearDownEnd(const testing::UnitTest& unit_test) override;
-
- void OnTestIterationEnd(const testing::UnitTest& unit_test,
- int iteration) override;
-
- void OnTestProgramEnd(const testing::UnitTest& unit_test) override;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_LOGGING_EVENT_LISTENER_H_
diff --git a/utils/tflite-model-executor.cc b/utils/tflite-model-executor.cc
deleted file mode 100644
index 9ba232e..0000000
--- a/utils/tflite-model-executor.cc
+++ /dev/null
@@ -1,248 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tflite-model-executor.h"
-
-#include "utils/base/logging.h"
-#include "tensorflow/lite/kernels/register.h"
-
-// Forward declaration of custom TensorFlow Lite ops for registration.
-namespace tflite {
-namespace ops {
-namespace builtin {
-TfLiteRegistration* Register_ADD();
-TfLiteRegistration* Register_CONCATENATION();
-TfLiteRegistration* Register_CONV_2D();
-TfLiteRegistration* Register_FULLY_CONNECTED();
-TfLiteRegistration* Register_L2_NORMALIZATION();
-TfLiteRegistration* Register_MUL();
-TfLiteRegistration* Register_RESHAPE();
-TfLiteRegistration* Register_SOFTMAX();
-TfLiteRegistration* Register_GATHER();
-TfLiteRegistration* Register_TRANSPOSE();
-TfLiteRegistration* Register_SUB();
-TfLiteRegistration* Register_DIV();
-TfLiteRegistration* Register_STRIDED_SLICE();
-TfLiteRegistration* Register_EXP();
-TfLiteRegistration* Register_TOPK_V2();
-TfLiteRegistration* Register_SPLIT();
-TfLiteRegistration* Register_CAST();
-TfLiteRegistration* Register_MAXIMUM();
-TfLiteRegistration* Register_MINIMUM();
-TfLiteRegistration* Register_NEG();
-TfLiteRegistration* Register_SLICE();
-TfLiteRegistration* Register_LOG();
-TfLiteRegistration* Register_SUM();
-TfLiteRegistration* Register_PACK();
-TfLiteRegistration* Register_DEQUANTIZE();
-TfLiteRegistration* Register_MEAN();
-} // namespace builtin
-} // namespace ops
-} // namespace tflite
-
-#ifdef TC3_WITH_ACTIONS_OPS
-#include "utils/tflite/dist_diversification.h"
-#include "utils/tflite/text_encoder.h"
-#include "utils/tflite/token_encoder.h"
-
-void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
- resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
- tflite::ops::builtin::Register_ADD(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
- tflite::ops::builtin::Register_CONCATENATION(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
- tflite::ops::builtin::Register_CONV_2D(),
- /*min_version=*/1,
- /*max_version=*/3);
- resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
- tflite::ops::builtin::Register_FULLY_CONNECTED(),
- /*min_version=*/1,
- /*max_version=*/4);
- resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
- tflite::ops::builtin::Register_L2_NORMALIZATION(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
- tflite::ops::builtin::Register_MUL());
- resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
- tflite::ops::builtin::Register_RESHAPE());
- resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
- tflite::ops::builtin::Register_SOFTMAX(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
- tflite::ops::builtin::Register_GATHER(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
- tflite::ops::builtin::Register_TRANSPOSE(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
- tflite::ops::builtin::Register_SUB(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
- tflite::ops::builtin::Register_DIV());
- resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
- tflite::ops::builtin::Register_STRIDED_SLICE(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
- tflite::ops::builtin::Register_EXP());
- resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
- tflite::ops::builtin::Register_TOPK_V2(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
- tflite::ops::builtin::Register_SPLIT(),
- /*min_version=*/1,
- /*max_version=*/3);
- resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
- tflite::ops::builtin::Register_CAST());
- resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
- tflite::ops::builtin::Register_MAXIMUM(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
- tflite::ops::builtin::Register_MINIMUM(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
- tflite::ops::builtin::Register_NEG());
- resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
- tflite::ops::builtin::Register_SLICE(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
- tflite::ops::builtin::Register_LOG());
- resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
- tflite::ops::builtin::Register_SUM());
- resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
- tflite::ops::builtin::Register_PACK(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
- tflite::ops::builtin::Register_DEQUANTIZE(),
- /*min_version=*/1,
- /*max_version=*/2);
- resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
- tflite::ops::builtin::Register_MEAN());
-}
-#else
-void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
- resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
- tflite::ops::builtin::Register_FULLY_CONNECTED());
-}
-#endif // TC3_WITH_ACTIONS_OPS
-
-namespace libtextclassifier3 {
-
-inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
-#ifdef TC3_USE_SELECTIVE_REGISTRATION
- std::unique_ptr<tflite::MutableOpResolver> resolver(
- new tflite::MutableOpResolver);
- RegisterSelectedOps(resolver.get());
-#else
- std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
- new tflite::ops::builtin::BuiltinOpResolver);
-#endif
-#ifdef TC3_WITH_ACTIONS_OPS
- resolver->AddCustom("DistanceDiversification",
- tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
- resolver->AddCustom("TextEncoder",
- tflite::ops::custom::Register_TEXT_ENCODER());
- resolver->AddCustom("TokenEncoder",
- tflite::ops::custom::Register_TOKEN_ENCODER());
-#endif // TC3_WITH_ACTIONS_OPS
- return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
-}
-
-std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
- const tflite::Model* model_spec) {
- std::unique_ptr<const tflite::FlatBufferModel> model(
- tflite::FlatBufferModel::BuildFromModel(model_spec));
- if (!model || !model->initialized()) {
- TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
- return nullptr;
- }
- return model;
-}
-
-std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
- const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
- const tflite::Model* model =
- flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
- flatbuffers::Verifier verifier(model_spec_buffer->data(),
- model_spec_buffer->Length());
- if (!model->Verify(verifier)) {
- return nullptr;
- }
- return TfLiteModelFromModelSpec(model);
-}
-
-TfLiteModelExecutor::TfLiteModelExecutor(
- std::unique_ptr<const tflite::FlatBufferModel> model)
- : model_(std::move(model)), resolver_(BuildOpResolver()) {}
-
-std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
- const {
- std::unique_ptr<tflite::Interpreter> interpreter;
- tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
- return interpreter;
-}
-
-template <>
-void TfLiteModelExecutor::SetInput(const int input_index,
- const std::vector<std::string>& input_data,
- tflite::Interpreter* interpreter) const {
- tflite::DynamicBuffer buf;
- for (const std::string& s : input_data) {
- buf.AddString(s.data(), s.length());
- }
- buf.WriteToTensorAsVector(
- interpreter->tensor(interpreter->inputs()[input_index]));
-}
-
-template <>
-std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
- const int output_index, const tflite::Interpreter* interpreter) const {
- const TfLiteTensor* output_tensor =
- interpreter->tensor(interpreter->outputs()[output_index]);
- const int num_strings = tflite::GetStringCount(output_tensor);
- std::vector<tflite::StringRef> output(num_strings);
- for (int i = 0; i < num_strings; i++) {
- output[i] = tflite::GetString(output_tensor, i);
- }
- return output;
-}
-
-template <>
-std::vector<std::string> TfLiteModelExecutor::Output(
- const int output_index, const tflite::Interpreter* interpreter) const {
- std::vector<std::string> output;
- for (const tflite::StringRef& s :
- Output<tflite::StringRef>(output_index, interpreter)) {
- output.push_back(std::string(s.str, s.len));
- }
- return output;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/tflite-model-executor.h b/utils/tflite-model-executor.h
deleted file mode 100644
index 10d4233..0000000
--- a/utils/tflite-model-executor.h
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// Contains classes that can execute different models/parts of a model.
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
-#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
-
-#include <memory>
-
-#include "utils/base/logging.h"
-#include "utils/tensor-view.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/model.h"
-#include "tensorflow/lite/op_resolver.h"
-#include "tensorflow/lite/string_util.h"
-
-namespace libtextclassifier3 {
-
-std::unique_ptr<tflite::OpResolver> BuildOpResolver();
-std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
- const tflite::Model*);
-std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
- const flatbuffers::Vector<uint8_t>*);
-
-// Executor for the text selection prediction and classification models.
-class TfLiteModelExecutor {
- public:
- static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
- const tflite::Model* model_spec) {
- auto model = TfLiteModelFromModelSpec(model_spec);
- if (!model) {
- return nullptr;
- }
- return std::unique_ptr<TfLiteModelExecutor>(
- new TfLiteModelExecutor(std::move(model)));
- }
-
- static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
- const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
- auto model = TfLiteModelFromBuffer(model_spec_buffer);
- if (!model) {
- return nullptr;
- }
- return std::unique_ptr<TfLiteModelExecutor>(
- new TfLiteModelExecutor(std::move(model)));
- }
-
- // Creates an Interpreter for the model that serves as a scratch-pad for the
- // inference. The Interpreter is NOT thread-safe.
- std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
-
- template <typename T>
- void SetInput(const int input_index, const TensorView<T>& input_data,
- tflite::Interpreter* interpreter) const {
- input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
- input_data.size());
- }
-
- template <typename T>
- void SetInput(const int input_index, const std::vector<T>& input_data,
- tflite::Interpreter* interpreter) const {
- std::copy(input_data.begin(), input_data.end(),
- interpreter->typed_input_tensor<T>(input_index));
- }
-
- template <typename T>
- void SetInput(const int input_index, const T input_value,
- tflite::Interpreter* interpreter) const {
- TfLiteTensor* input_tensor =
- interpreter->tensor(interpreter->inputs()[input_index]);
- switch (input_tensor->type) {
- case kTfLiteFloat32:
- *(input_tensor->data.f) = input_value;
- break;
- case kTfLiteInt32:
- *(input_tensor->data.i32) = input_value;
- break;
- case kTfLiteUInt8:
- *(input_tensor->data.uint8) = input_value;
- break;
- case kTfLiteInt64:
- *(input_tensor->data.i64) = input_value;
- break;
- case kTfLiteBool:
- *(input_tensor->data.b) = input_value;
- break;
- case kTfLiteInt16:
- *(input_tensor->data.i16) = input_value;
- break;
- case kTfLiteInt8:
- *(input_tensor->data.int8) = input_value;
- break;
- default:
- break;
- }
- }
-
- template <typename T>
- TensorView<T> OutputView(const int output_index,
- const tflite::Interpreter* interpreter) const {
- const TfLiteTensor* output_tensor =
- interpreter->tensor(interpreter->outputs()[output_index]);
- return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
- std::vector<int>(output_tensor->dims->data,
- output_tensor->dims->data +
- output_tensor->dims->size));
- }
-
- template <typename T>
- std::vector<T> Output(const int output_index,
- const tflite::Interpreter* interpreter) const {
- TensorView<T> output_view = OutputView<T>(output_index, interpreter);
- return std::vector<T>(output_view.data(),
- output_view.data() + output_view.size());
- }
-
- protected:
- explicit TfLiteModelExecutor(
- std::unique_ptr<const tflite::FlatBufferModel> model);
-
- std::unique_ptr<const tflite::FlatBufferModel> model_;
- std::unique_ptr<tflite::OpResolver> resolver_;
-};
-
-template <>
-void TfLiteModelExecutor::SetInput(const int input_index,
- const std::vector<std::string>& input_data,
- tflite::Interpreter* interpreter) const;
-
-template <>
-std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
- const int output_index, const tflite::Interpreter* interpreter) const;
-
-template <>
-std::vector<std::string> TfLiteModelExecutor::Output(
- const int output_index, const tflite::Interpreter* interpreter) const;
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
diff --git a/utils/tflite/dist_diversification_test.cc b/utils/tflite/dist_diversification_test.cc
deleted file mode 100644
index 2380116..0000000
--- a/utils/tflite/dist_diversification_test.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tflite/dist_diversification.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class DistanceDiversificationOpModel : public tflite::SingleOpModel {
- public:
- explicit DistanceDiversificationOpModel(int matrix_rows);
- void SetDistanceMatrix(const std::initializer_list<float>& values) {
- PopulateTensor(distance_matrix_, values);
- }
- void SetNumOutput(int length) { PopulateTensor(num_results_, {length}); }
- void SetMinDistance(float min_distance) {
- PopulateTensor(min_distance_, {min_distance});
- }
- int GetOutputLen() { return ExtractVector<int>(output_len_).front(); }
- std::vector<int> GetOutputIndexes(int output_length) {
- auto res = ExtractVector<int>(output_indexes_);
- res.resize(output_length);
- return res;
- }
-
- private:
- int distance_matrix_;
- int num_results_;
- int min_distance_;
-
- int output_len_;
- int output_indexes_;
-};
-
-DistanceDiversificationOpModel::DistanceDiversificationOpModel(
- int matrix_rows) {
- distance_matrix_ = AddInput(tflite::TensorType_FLOAT32);
- min_distance_ = AddInput(tflite::TensorType_FLOAT32);
- num_results_ = AddInput(tflite::TensorType_INT32);
-
- output_indexes_ = AddOutput(tflite::TensorType_INT32);
- output_len_ = AddOutput(tflite::TensorType_INT32);
- SetCustomOp("DistanceDiversification", {},
- tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION);
- BuildInterpreter({{matrix_rows, matrix_rows}, {1}, {1}});
-}
-
-// Tests
-TEST(DistanceDiversificationOp, Simple) {
- DistanceDiversificationOpModel m(5);
- m.SetDistanceMatrix({0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.0, 0.1, 0.2,
- 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.2, 0.1,
- 0.0, 0.1, 0.4, 0.3, 0.2, 0.1, 0.0});
- m.SetMinDistance(0.21);
- m.SetNumOutput(3);
- m.Invoke();
- const int output_length = m.GetOutputLen();
- EXPECT_EQ(output_length, 2);
- EXPECT_THAT(m.GetOutputIndexes(output_length), testing::ElementsAre(0, 3));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/tflite/encoder_common_test.cc b/utils/tflite/encoder_common_test.cc
deleted file mode 100644
index 247689f..0000000
--- a/utils/tflite/encoder_common_test.cc
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tflite/encoder_common.h"
-
-#include "gtest/gtest.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-TEST(EncoderUtilsTest, CreateIntArray) {
- TfLiteIntArray* a = CreateIntArray({1, 2, 3});
- EXPECT_EQ(a->data[0], 1);
- EXPECT_EQ(a->data[1], 2);
- EXPECT_EQ(a->data[2], 3);
- TfLiteIntArrayFree(a);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
deleted file mode 100644
index c7811ea..0000000
--- a/utils/tflite/text_encoder.cc
+++ /dev/null
@@ -1,298 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <memory>
-#include <vector>
-
-#include "utils/base/logging.h"
-#include "utils/sentencepiece/double_array_trie.h"
-#include "utils/sentencepiece/encoder.h"
-#include "utils/sentencepiece/normalizer.h"
-#include "utils/sentencepiece/sorted_strings_table.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/tflite/encoder_common.h"
-#include "utils/tflite/text_encoder.h"
-#include "utils/tflite/text_encoder_config_generated.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-#include "tensorflow/lite/model.h"
-#include "tensorflow/lite/string_util.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-struct TextEncoderOp {
- std::unique_ptr<SentencePieceNormalizer> normalizer;
- std::unique_ptr<Encoder> encoder;
- std::unique_ptr<SentencePieceMatcher> matcher;
-};
-
-// Input parameters for the op.
-// The conversation message as a (1, conversation length) string tensor.
-constexpr const int kInputTexts = 0;
-
-// The number of messages, the conversation length, int scalar.
-constexpr const int kInputNumInputs = 1;
-
-// Maximum output length of the encoding, int scalar.
-constexpr const int kInputMaxLength = 2;
-
-// Additional attributes to align to the sentence pieces, e.g. user ids per
-// message.
-constexpr const int kInputAttr = 3;
-
-// Output parameters for the op.
-// The text sentence piece encodings as ids, (1, max output length) int tensor.
-constexpr const int kOutputEncoded = 0;
-
-// Relative position of each sentence piece in the input text,
-// (1, max output length) int tensor.
-constexpr const int kOutputPosition = 1;
-
-// Output length after trimming to the maximum output length specified.
-// int scalar.
-constexpr const int kOutputLengths = 2;
-
-// Padded and sentence piece aligned provided attributes, e.g. user id per
-// sentence piece.
-constexpr const int kOutputAttr = 3;
-
-const char kTextEncoderConfigAttr[] = "text_encoder_config";
-
-// Initializes text encoder object from serialized options:
-// The options are a flexbuffers attribute map that contain the op config
-// with the key `text_encoder_config` as `TextEncoderConfig`.
-void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
- const flexbuffers::Map& attr_map =
- flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
- .AsMap();
- const flexbuffers::Blob serialized_config =
- attr_map[kTextEncoderConfigAttr].AsBlob();
- const TextEncoderConfig* config =
- flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
-
- std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
-
- // Create normalizer from options.
- const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
- config->normalization_charsmap()->Data());
- const int charsmap_trie_nodes_length =
- config->normalization_charsmap()->Length() / sizeof(TrieNode);
- encoder_op->normalizer.reset(new SentencePieceNormalizer(
- DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
- StringPiece(config->normalization_charsmap_values()->data(),
- config->normalization_charsmap_values()->size()),
- config->add_dummy_prefix(), config->remove_extra_whitespaces(),
- config->escape_whitespaces()));
-
- const int num_pieces = config->pieces_scores()->Length();
-
- switch (config->matcher_type()) {
- case SentencePieceMatcherType_MAPPED_TRIE: {
- const TrieNode* pieces_trie_nodes =
- reinterpret_cast<const TrieNode*>(config->pieces()->Data());
- const int pieces_trie_nodes_length =
- config->pieces()->Length() / sizeof(TrieNode);
- encoder_op->matcher.reset(
- new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
- break;
- }
- case SentencePieceMatcherType_SORTED_STRING_TABLE: {
- encoder_op->matcher.reset(new SortedStringsTable(
- num_pieces, config->pieces_offsets()->data(),
- StringPiece(config->pieces()->data(), config->pieces()->Length())));
- break;
- }
- default: {
- TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
- return nullptr;
- }
- }
- encoder_op->encoder.reset(new Encoder(
- encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
- config->start_code(), config->end_code(), config->encoding_offset(),
- config->unknown_code(), config->unknown_score()));
- return encoder_op.release();
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<TextEncoderOp*>(buffer);
-}
-
-TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
- int max_output_length) {
- TF_LITE_ENSURE_OK(
- context,
- ResizeOutputTensor(max_output_length,
- &context->tensors[node->outputs->data[kOutputEncoded]],
- context));
-
- TF_LITE_ENSURE_OK(
- context,
- ResizeOutputTensor(
- max_output_length,
- &context->tensors[node->outputs->data[kOutputPosition]], context));
-
- const int num_output_attrs = node->outputs->size - kOutputAttr;
- for (int i = 0; i < num_output_attrs; ++i) {
- TF_LITE_ENSURE_OK(
- context,
- ResizeOutputTensor(
- max_output_length,
- &context->tensors[node->outputs->data[kOutputAttr + i]], context));
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- // Check that the batch dimension is kBatchSize.
- const TfLiteTensor& input_text =
- context->tensors[node->inputs->data[kInputTexts]];
- TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
- TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
-
- TfLiteTensor& output_lengths =
- context->tensors[node->outputs->data[kOutputLengths]];
- TfLiteTensor& output_encoded =
- context->tensors[node->outputs->data[kOutputEncoded]];
- TfLiteTensor& output_positions =
- context->tensors[node->outputs->data[kOutputPosition]];
-
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, &output_lengths,
- CreateIntArray({kEncoderBatchSize})));
-
- // Check that there are enough outputs for attributes.
- const int num_output_attrs = node->outputs->size - kOutputAttr;
- TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
-
- // Copy attribute types from input to output tensors.
- for (int i = 0; i < num_output_attrs; ++i) {
- TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
- TfLiteTensor& output =
- context->tensors[node->outputs->data[kOutputAttr + i]];
- output.type = input.type;
- }
-
- const TfLiteTensor& output_length =
- context->tensors[node->inputs->data[kInputMaxLength]];
-
- if (tflite::IsConstantTensor(&output_length)) {
- return ResizeOutputTensors(context, node, output_length.data.i64[0]);
- } else {
- tflite::SetTensorToDynamic(&output_encoded);
- tflite::SetTensorToDynamic(&output_positions);
- for (int i = 0; i < num_output_attrs; ++i) {
- TfLiteTensor& output_attr =
- context->tensors[node->outputs->data[kOutputAttr + i]];
- tflite::SetTensorToDynamic(&output_attr);
- }
- }
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- if (node->user_data == nullptr) {
- return kTfLiteError;
- }
- const TextEncoderOp* encoder_op =
- reinterpret_cast<TextEncoderOp*>(node->user_data);
- const TfLiteTensor& input_text =
- context->tensors[node->inputs->data[kInputTexts]];
- const int num_strings = tflite::GetStringCount(&input_text);
- // Check that the number of strings matches the length parameter.
- const int num_strings_param =
- context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
- TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
-
- TfLiteTensor& output_encoded =
- context->tensors[node->outputs->data[kOutputEncoded]];
- if (tflite::IsDynamicTensor(&output_encoded)) {
- const TfLiteTensor& output_length =
- context->tensors[node->inputs->data[kInputMaxLength]];
- TF_LITE_ENSURE_OK(
- context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
- }
- TfLiteTensor& output_positions =
- context->tensors[node->outputs->data[kOutputPosition]];
-
- std::vector<int> encoded_total;
- std::vector<int> encoded_offsets;
- std::vector<int> encoded_positions;
- encoded_offsets.reserve(num_strings);
- const int max_output_length = output_encoded.dims->data[1];
- const int max_encoded_position = max_output_length;
-
- for (int i = 0; i < num_strings; ++i) {
- const auto& strref = tflite::GetString(&input_text, i);
- std::string normalized;
- TF_LITE_ENSURE(context,
- encoder_op->normalizer->Normalize(
- StringPiece(strref.str, strref.len), &normalized));
- std::vector<int> encoded;
- TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
- encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
- encoded_offsets.push_back(encoded_total.size());
- for (int i = 0; i < encoded.size(); i++) {
- encoded_positions.push_back(std::min(i, max_encoded_position - 1));
- }
- }
-
- const int num_skip = CopyDataToTensorAndPadOrTruncate(
- max_output_length, encoded_total,
- /*padding_value=*/encoded_total.back(), &output_encoded);
- TfLiteTensor& output_lengths =
- context->tensors[node->outputs->data[kOutputLengths]];
- output_lengths.data.i32[0] = encoded_total.size() - num_skip;
- CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
- /*padding_value=*/max_encoded_position,
- &output_positions);
-
- // Process attributes, all checks of sizes and types are done in Prepare.
- const int num_output_attrs = node->outputs->size - kOutputAttr;
- TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
- for (int i = 0; i < num_output_attrs; ++i) {
- TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
- context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
- num_skip, context,
- &context->tensors[node->outputs->data[kOutputAttr + i]]);
- if (attr_status != kTfLiteOk) {
- return attr_status;
- }
- }
-
- return kTfLiteOk;
-}
-
-} // namespace
-} // namespace libtextclassifier3
-
-namespace tflite {
-namespace ops {
-namespace custom {
-
-TfLiteRegistration* Register_TEXT_ENCODER() {
- static TfLiteRegistration registration = {
- libtextclassifier3::Initialize, libtextclassifier3::Free,
- libtextclassifier3::Prepare, libtextclassifier3::Eval};
- return ®istration;
-}
-
-} // namespace custom
-} // namespace ops
-} // namespace tflite
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
deleted file mode 100644
index ae752f5..0000000
--- a/utils/tflite/text_encoder_test.cc
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <fstream>
-#include <string>
-#include <vector>
-
-#include "utils/tflite/text_encoder.h"
-#include "gtest/gtest.h"
-#include "third_party/absl/flags/flag.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-#include "tensorflow/lite/string_util.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-std::string GetTestConfigPath() {
- return "";
-}
-
-class TextEncoderOpModel : public tflite::SingleOpModel {
- public:
- TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
- std::initializer_list<int> attribute_shape);
- void SetInputText(const std::initializer_list<string>& strings) {
- PopulateStringTensor(input_string_, strings);
- PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
- }
- void SetMaxOutputLength(int length) {
- PopulateTensor(input_output_maxlength_, {length});
- }
- void SetInt32Attribute(const std::initializer_list<int>& attribute) {
- PopulateTensor(input_attributes_int32_, attribute);
- }
- void SetFloatAttribute(const std::initializer_list<float>& attribute) {
- PopulateTensor(input_attributes_float_, attribute);
- }
-
- std::vector<int> GetOutputEncoding() {
- return ExtractVector<int>(output_encoding_);
- }
- std::vector<int> GetOutputPositions() {
- return ExtractVector<int>(output_positions_);
- }
- std::vector<int> GetOutputAttributeInt32() {
- return ExtractVector<int>(output_attributes_int32_);
- }
- std::vector<float> GetOutputAttributeFloat() {
- return ExtractVector<float>(output_attributes_float_);
- }
- int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }
-
- private:
- int input_string_;
- int input_length_;
- int input_output_maxlength_;
- int input_attributes_int32_;
- int input_attributes_float_;
-
- int output_encoding_;
- int output_positions_;
- int output_length_;
- int output_attributes_int32_;
- int output_attributes_float_;
-};
-
-TextEncoderOpModel::TextEncoderOpModel(
- std::initializer_list<int> input_strings_shape,
- std::initializer_list<int> attribute_shape) {
- input_string_ = AddInput(tflite::TensorType_STRING);
- input_length_ = AddInput(tflite::TensorType_INT32);
- input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
- input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
- input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
-
- output_encoding_ = AddOutput(tflite::TensorType_INT32);
- output_positions_ = AddOutput(tflite::TensorType_INT32);
- output_length_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
-
- std::ifstream test_config_stream(GetTestConfigPath());
- std::string config((std::istreambuf_iterator<char>(test_config_stream)),
- (std::istreambuf_iterator<char>()));
- flexbuffers::Builder builder;
- builder.Map([&]() { builder.String("text_encoder_config", config); });
- builder.Finish();
- SetCustomOp("TextEncoder", builder.GetBuffer(),
- tflite::ops::custom::Register_TEXT_ENCODER);
- BuildInterpreter(
- {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
-}
-
-// Tests
-TEST(TextEncoderTest, SimpleEncoder) {
- TextEncoderOpModel m({1, 1}, {1, 1});
- m.SetInputText({"Hello"});
- m.SetMaxOutputLength(10);
- m.SetInt32Attribute({7});
- m.SetFloatAttribute({3.f});
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 5);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TextEncoderTest, ManyStrings) {
- TextEncoderOpModel m({1, 3}, {1, 3});
- m.SetInt32Attribute({1, 2, 3});
- m.SetFloatAttribute({5.f, 4.f, 3.f});
- m.SetInputText({"Hello", "Hi", "Bye"});
- m.SetMaxOutputLength(10);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 10);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TextEncoderTest, LongStrings) {
- TextEncoderOpModel m({1, 4}, {1, 4});
- m.SetInt32Attribute({1, 2, 3, 4});
- m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
- m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
- m.SetMaxOutputLength(9);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetEncodedLength(), 9);
- EXPECT_THAT(m.GetOutputEncoding(),
- testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/tflite/token_encoder_test.cc b/utils/tflite/token_encoder_test.cc
deleted file mode 100644
index c7f51a1..0000000
--- a/utils/tflite/token_encoder_test.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <vector>
-
-#include "utils/tflite/token_encoder.h"
-#include "gtest/gtest.h"
-#include "third_party/absl/flags/flag.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class TokenEncoderOpModel : public tflite::SingleOpModel {
- public:
- TokenEncoderOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> attribute_shape);
- void SetNumTokens(const std::initializer_list<int>& num_tokens) {
- PopulateTensor(input_num_tokens_, num_tokens);
- PopulateTensor(input_length_, {static_cast<int32_t>(num_tokens.size())});
- }
- void SetMaxOutputLength(int length) {
- PopulateTensor(input_output_maxlength_, {length});
- }
- void SetInt32Attribute(const std::initializer_list<int>& attribute) {
- PopulateTensor(input_attributes_int32_, attribute);
- }
- void SetFloatAttribute(const std::initializer_list<float>& attribute) {
- PopulateTensor(input_attributes_float_, attribute);
- }
- std::vector<int> GetOutputPositions() {
- return ExtractVector<int>(output_positions_);
- }
- std::vector<int> GetOutputAttributeInt32() {
- return ExtractVector<int>(output_attributes_int32_);
- }
- std::vector<float> GetOutputAttributeFloat() {
- return ExtractVector<float>(output_attributes_float_);
- }
- int GetOutputLength() { return ExtractVector<int>(output_length_)[0]; }
-
- private:
- int input_num_tokens_;
- int input_length_;
- int input_output_maxlength_;
- int input_attributes_int32_;
- int input_attributes_float_;
-
- int output_positions_;
- int output_length_;
- int output_attributes_int32_;
- int output_attributes_float_;
-};
-
-TokenEncoderOpModel::TokenEncoderOpModel(
- std::initializer_list<int> input_shape,
- std::initializer_list<int> attribute_shape) {
- input_num_tokens_ = AddInput(tflite::TensorType_INT32);
- input_length_ = AddInput(tflite::TensorType_INT32);
- input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
- input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
- input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
-
- output_positions_ = AddOutput(tflite::TensorType_INT32);
- output_length_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
- output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
-
- SetCustomOp("TokenEncoder", {}, tflite::ops::custom::Register_TOKEN_ENCODER);
- BuildInterpreter({input_shape, {1}, {1}, attribute_shape, attribute_shape});
-}
-
-// Tests
-TEST(TokenEncoderTest, SimpleEncoder) {
- TokenEncoderOpModel m({1, 1}, {1, 1});
- m.SetNumTokens({1});
- m.SetMaxOutputLength(10);
- m.SetInt32Attribute({7});
- m.SetFloatAttribute({3.f});
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 3);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 10, 10, 10, 10, 10, 10, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TokenEncoderTest, ManyMessages) {
- TokenEncoderOpModel m({1, 3}, {1, 3});
- m.SetInt32Attribute({1, 2, 3});
- m.SetFloatAttribute({5.f, 4.f, 3.f});
- m.SetNumTokens({1, 1, 1});
- m.SetMaxOutputLength(10);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 9);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2, 10));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
-}
-
-TEST(TokenEncoderTest, ManyMessagesMultipleTokens) {
- TokenEncoderOpModel m({1, 4}, {1, 4});
- m.SetInt32Attribute({1, 2, 3, 4});
- m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
- m.SetNumTokens({1, 2, 3, 4});
- m.SetMaxOutputLength(9);
-
- m.Invoke();
-
- EXPECT_EQ(m.GetOutputLength(), 9);
- EXPECT_THAT(m.GetOutputPositions(),
- testing::ElementsAre(2, 3, 4, 0, 1, 2, 3, 4, 5));
- EXPECT_THAT(m.GetOutputAttributeInt32(),
- testing::ElementsAre(3, 3, 3, 4, 4, 4, 4, 4, 4));
- EXPECT_THAT(
- m.GetOutputAttributeFloat(),
- testing::ElementsAre(3.f, 3.f, 3.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f));
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/token-feature-extractor.cc b/utils/token-feature-extractor.cc
deleted file mode 100644
index 9faebca..0000000
--- a/utils/token-feature-extractor.cc
+++ /dev/null
@@ -1,311 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/token-feature-extractor.h"
-
-#include <cctype>
-#include <string>
-
-#include "utils/base/logging.h"
-#include "utils/hash/farmhash.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-namespace {
-
-std::string RemapTokenAscii(const std::string& token,
- const TokenFeatureExtractorOptions& options) {
- if (!options.remap_digits && !options.lowercase_tokens) {
- return token;
- }
-
- std::string copy = token;
- for (int i = 0; i < token.size(); ++i) {
- if (options.remap_digits && isdigit(copy[i])) {
- copy[i] = '0';
- }
- if (options.lowercase_tokens) {
- copy[i] = tolower(copy[i]);
- }
- }
- return copy;
-}
-
-void RemapTokenUnicode(const std::string& token,
- const TokenFeatureExtractorOptions& options,
- const UniLib& unilib, UnicodeText* remapped) {
- if (!options.remap_digits && !options.lowercase_tokens) {
- // Leave remapped untouched.
- return;
- }
-
- UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
- remapped->clear();
- for (auto it = word.begin(); it != word.end(); ++it) {
- if (options.remap_digits && unilib.IsDigit(*it)) {
- remapped->push_back('0');
- } else if (options.lowercase_tokens) {
- remapped->push_back(unilib.ToLower(*it));
- } else {
- remapped->push_back(*it);
- }
- }
-}
-
-} // namespace
-
-TokenFeatureExtractor::TokenFeatureExtractor(
- const TokenFeatureExtractorOptions& options, const UniLib& unilib)
- : options_(options), unilib_(unilib) {
- for (const std::string& pattern : options.regexp_features) {
- regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
- unilib_.CreateRegexPattern(UTF8ToUnicodeText(
- pattern.c_str(), pattern.size(), /*do_copy=*/false))));
- }
-}
-
-bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
- std::vector<int>* sparse_features,
- std::vector<float>* dense_features) const {
- if (!dense_features) {
- return false;
- }
- if (sparse_features) {
- *sparse_features = ExtractCharactergramFeatures(token);
- }
- *dense_features = ExtractDenseFeatures(token, is_in_span);
- return true;
-}
-
-std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
- const Token& token) const {
- if (options_.unicode_aware_features) {
- return ExtractCharactergramFeaturesUnicode(token);
- } else {
- return ExtractCharactergramFeaturesAscii(token);
- }
-}
-
-std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
- const Token& token, bool is_in_span) const {
- std::vector<float> dense_features;
-
- if (options_.extract_case_feature) {
- if (options_.unicode_aware_features) {
- UnicodeText token_unicode =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
- if (!token.value.empty() && is_upper) {
- dense_features.push_back(1.0);
- } else {
- dense_features.push_back(-1.0);
- }
- } else {
- if (!token.value.empty() && isupper(*token.value.begin())) {
- dense_features.push_back(1.0);
- } else {
- dense_features.push_back(-1.0);
- }
- }
- }
-
- if (options_.extract_selection_mask_feature) {
- if (is_in_span) {
- dense_features.push_back(1.0);
- } else {
- if (options_.unicode_aware_features) {
- dense_features.push_back(-1.0);
- } else {
- dense_features.push_back(0.0);
- }
- }
- }
-
- // Add regexp features.
- if (!regex_patterns_.empty()) {
- UnicodeText token_unicode =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- for (int i = 0; i < regex_patterns_.size(); ++i) {
- if (!regex_patterns_[i].get()) {
- dense_features.push_back(-1.0);
- continue;
- }
- auto matcher = regex_patterns_[i]->Matcher(token_unicode);
- int status;
- if (matcher->Matches(&status)) {
- dense_features.push_back(1.0);
- } else {
- dense_features.push_back(-1.0);
- }
- }
- }
-
- return dense_features;
-}
-
-int TokenFeatureExtractor::HashToken(StringPiece token) const {
- if (options_.allowed_chargrams.empty()) {
- return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
- } else {
- // Padding and out-of-vocabulary tokens have extra buckets reserved because
- // they are special and important tokens, and we don't want them to share
- // embedding with other charactergrams.
- // TODO(zilka): Experimentally verify.
- const int kNumExtraBuckets = 2;
- const std::string token_string = token.ToString();
- if (token_string == "<PAD>") {
- return 1;
- } else if (options_.allowed_chargrams.find(token_string) ==
- options_.allowed_chargrams.end()) {
- return 0; // Out-of-vocabulary.
- } else {
- return (tc3farmhash::Fingerprint64(token) %
- (options_.num_buckets - kNumExtraBuckets)) +
- kNumExtraBuckets;
- }
- }
-}
-
-std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
- const Token& token) const {
- std::vector<int> result;
- if (token.is_padding || token.value.empty()) {
- result.push_back(HashToken("<PAD>"));
- } else {
- const std::string word = RemapTokenAscii(token.value, options_);
-
- // Trim words that are over max_word_length characters.
- const int max_word_length = options_.max_word_length;
- std::string feature_word;
- if (word.size() > max_word_length) {
- feature_word =
- "^" + word.substr(0, max_word_length / 2) + "\1" +
- word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
- "$";
- } else {
- // Add a prefix and suffix to the word.
- feature_word = "^" + word + "$";
- }
-
- // Upper-bound the number of charactergram extracted to avoid resizing.
- result.reserve(options_.chargram_orders.size() * feature_word.size());
-
- if (options_.chargram_orders.empty()) {
- result.push_back(HashToken(feature_word));
- } else {
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- if (chargram_order == 1) {
- for (int i = 1; i < feature_word.size() - 1; ++i) {
- result.push_back(
- HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
- }
- } else {
- for (int i = 0;
- i < static_cast<int>(feature_word.size()) - chargram_order + 1;
- ++i) {
- result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
- /*len=*/chargram_order)));
- }
- }
- }
- }
- }
- return result;
-}
-
-std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
- const Token& token) const {
- std::vector<int> result;
- if (token.is_padding || token.value.empty()) {
- result.push_back(HashToken("<PAD>"));
- } else {
- UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- RemapTokenUnicode(token.value, options_, unilib_, &word);
-
- // Trim the word if needed by finding a left-cut point and right-cut point.
- auto left_cut = word.begin();
- auto right_cut = word.end();
- for (int i = 0; i < options_.max_word_length / 2; i++) {
- if (left_cut < right_cut) {
- ++left_cut;
- }
- if (left_cut < right_cut) {
- --right_cut;
- }
- }
-
- std::string feature_word;
- if (left_cut == right_cut) {
- feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
- } else {
- // clang-format off
- feature_word = "^" +
- word.UTF8Substring(word.begin(), left_cut) +
- "\1" +
- word.UTF8Substring(right_cut, word.end()) +
- "$";
- // clang-format on
- }
-
- const UnicodeText feature_word_unicode =
- UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
-
- // Upper-bound the number of charactergram extracted to avoid resizing.
- result.reserve(options_.chargram_orders.size() * feature_word.size());
-
- if (options_.chargram_orders.empty()) {
- result.push_back(HashToken(feature_word));
- } else {
- // Generate the character-grams.
- for (int chargram_order : options_.chargram_orders) {
- UnicodeText::const_iterator it_start = feature_word_unicode.begin();
- UnicodeText::const_iterator it_end = feature_word_unicode.end();
- if (chargram_order == 1) {
- ++it_start;
- --it_end;
- }
-
- UnicodeText::const_iterator it_chargram_start = it_start;
- UnicodeText::const_iterator it_chargram_end = it_start;
- bool chargram_is_complete = true;
- for (int i = 0; i < chargram_order; ++i) {
- if (it_chargram_end == it_end) {
- chargram_is_complete = false;
- break;
- }
- ++it_chargram_end;
- }
- if (!chargram_is_complete) {
- continue;
- }
-
- for (; it_chargram_end <= it_end;
- ++it_chargram_start, ++it_chargram_end) {
- const int length_bytes =
- it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
- result.push_back(HashToken(
- StringPiece(it_chargram_start.utf8_data(), length_bytes)));
- }
- }
- }
- }
- return result;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/token-feature-extractor.h b/utils/token-feature-extractor.h
deleted file mode 100644
index fed113b..0000000
--- a/utils/token-feature-extractor.h
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
-#define LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
-
-#include <memory>
-#include <unordered_set>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/strings/stringpiece.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-struct TokenFeatureExtractorOptions {
- // Number of buckets used for hashing charactergrams.
- int num_buckets = 0;
-
- // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
- // character trigrams etc.
- std::vector<int> chargram_orders;
-
- // Whether to extract the token case feature.
- bool extract_case_feature = false;
-
- // If true, will use the unicode-aware functionality for extracting features.
- bool unicode_aware_features = false;
-
- // Whether to extract the selection mask feature.
- bool extract_selection_mask_feature = false;
-
- // Regexp features to extract.
- std::vector<std::string> regexp_features;
-
- // Whether to remap digits to a single number.
- bool remap_digits = false;
-
- // Whether to lowercase all tokens.
- bool lowercase_tokens = false;
-
- // Maximum length of a word.
- int max_word_length = 20;
-
- // List of allowed charactergrams. The extracted charactergrams are filtered
- // using this list, and charactergrams that are not present are interpreted as
- // out-of-vocabulary.
- // If no allowed_chargrams are specified, all charactergrams are allowed.
- std::unordered_set<std::string> allowed_chargrams;
-};
-
-class TokenFeatureExtractor {
- public:
- TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
- const UniLib& unilib);
-
- // Extracts both the sparse (charactergram) and the dense features from a
- // token. is_in_span is a bool indicator whether the token is a part of the
- // selection span (true) or not (false).
- // The sparse_features output is optional. Fails and returns false if
- // dense_fatures in a nullptr.
- bool Extract(const Token& token, bool is_in_span,
- std::vector<int>* sparse_features,
- std::vector<float>* dense_features) const;
-
- // Extracts the sparse (charactergram) features from the token.
- std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
-
- // Extracts the dense features from the token. is_in_span is a bool indicator
- // whether the token is a part of the selection span (true) or not (false).
- std::vector<float> ExtractDenseFeatures(const Token& token,
- bool is_in_span) const;
-
- int DenseFeaturesCount() const {
- int feature_count =
- options_.extract_case_feature + options_.extract_selection_mask_feature;
- feature_count += regex_patterns_.size();
- return feature_count;
- }
-
- protected:
- // Hashes given token to given number of buckets.
- int HashToken(StringPiece token) const;
-
- // Extracts the charactergram features from the token in a non-unicode-aware
- // way.
- std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const;
-
- // Extracts the charactergram features from the token in a unicode-aware way.
- std::vector<int> ExtractCharactergramFeaturesUnicode(
- const Token& token) const;
-
- private:
- TokenFeatureExtractorOptions options_;
- std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_;
- const UniLib& unilib_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/utils/token-feature-extractor_test.cc b/utils/token-feature-extractor_test.cc
deleted file mode 100644
index 9a97e42..0000000
--- a/utils/token-feature-extractor_test.cc
+++ /dev/null
@@ -1,556 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/token-feature-extractor.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class TokenFeatureExtractorTest : public ::testing::Test {
- protected:
- TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
- UniLib unilib_;
-};
-
-class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
- public:
- using TokenFeatureExtractor::HashToken;
- using TokenFeatureExtractor::TokenFeatureExtractor;
-};
-
-TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("H"),
- extractor.HashToken("e"),
- extractor.HashToken("l"),
- extractor.HashToken("l"),
- extractor.HashToken("o"),
- extractor.HashToken("^H"),
- extractor.HashToken("He"),
- extractor.HashToken("el"),
- extractor.HashToken("ll"),
- extractor.HashToken("lo"),
- extractor.HashToken("o$"),
- extractor.HashToken("^He"),
- extractor.HashToken("Hel"),
- extractor.HashToken("ell"),
- extractor.HashToken("llo"),
- extractor.HashToken("lo$")
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- extractor.HashToken("o"),
- extractor.HashToken("r"),
- extractor.HashToken("l"),
- extractor.HashToken("d"),
- extractor.HashToken("!"),
- extractor.HashToken("^w"),
- extractor.HashToken("wo"),
- extractor.HashToken("or"),
- extractor.HashToken("rl"),
- extractor.HashToken("ld"),
- extractor.HashToken("d!"),
- extractor.HashToken("!$"),
- extractor.HashToken("^wo"),
- extractor.HashToken("wor"),
- extractor.HashToken("orl"),
- extractor.HashToken("rld"),
- extractor.HashToken("ld!"),
- extractor.HashToken("d!$"),
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^world!$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("H"),
- extractor.HashToken("ě"),
- extractor.HashToken("l"),
- extractor.HashToken("l"),
- extractor.HashToken("ó"),
- extractor.HashToken("^H"),
- extractor.HashToken("Hě"),
- extractor.HashToken("ěl"),
- extractor.HashToken("ll"),
- extractor.HashToken("ló"),
- extractor.HashToken("ó$"),
- extractor.HashToken("^Hě"),
- extractor.HashToken("Hěl"),
- extractor.HashToken("ěll"),
- extractor.HashToken("lló"),
- extractor.HashToken("ló$")
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- extractor.HashToken("o"),
- extractor.HashToken("r"),
- extractor.HashToken("l"),
- extractor.HashToken("d"),
- extractor.HashToken("!"),
- extractor.HashToken("^w"),
- extractor.HashToken("wo"),
- extractor.HashToken("or"),
- extractor.HashToken("rl"),
- extractor.HashToken("ld"),
- extractor.HashToken("d!"),
- extractor.HashToken("!$"),
- extractor.HashToken("^wo"),
- extractor.HashToken("wor"),
- extractor.HashToken("orl"),
- extractor.HashToken("rld"),
- extractor.HashToken("ld!"),
- extractor.HashToken("d!$"),
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features, testing::ElementsAreArray({
- extractor.HashToken("^world!$"),
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-}
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
-}
-#endif
-
-TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = true;
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features,
- testing::Not(testing::ElementsAreArray(sparse_features2)));
-}
-
-TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = true;
- options.unicode_aware_features = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features,
- testing::Not(testing::ElementsAreArray(sparse_features2)));
-}
-
-TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.lowercase_tokens = true;
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
- &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-
- extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-}
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.lowercase_tokens = true;
- options.unicode_aware_features = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
-
- std::vector<int> sparse_features2;
- extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
- &dense_features);
- EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
-}
-#endif
-
-#ifdef TC3_TEST_ICU
-TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.remap_digits = false;
- options.unicode_aware_features = false;
- options.regexp_features.push_back("^[a-z]+$"); // all lower case.
- options.regexp_features.push_back("^[0-9]+$"); // all digits.
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
-
- dense_features.clear();
- extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
- &dense_features);
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
-}
-#endif
-
-TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{22};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- // Test that this runs. ASAN should catch problems.
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
- extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
- &sparse_features, &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
- extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
- // clang-format on
- }));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
- options.extract_case_feature = true;
- options.unicode_aware_features = true;
- options.extract_selection_mask_feature = true;
-
- TestingTokenFeatureExtractor extractor_unicode(options, unilib_);
-
- options.unicode_aware_features = false;
- TestingTokenFeatureExtractor extractor_ascii(options, unilib_);
-
- for (const std::string& input :
- {"https://www.abcdefgh.com/in/xxxkkkvayio",
- "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
- "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
- "x", "Hello", "Hey,", "Hi", ""}) {
- std::vector<int> sparse_features_unicode;
- std::vector<float> dense_features_unicode;
- extractor_unicode.Extract(Token{input, 0, 0}, true,
- &sparse_features_unicode,
- &dense_features_unicode);
-
- std::vector<int> sparse_features_ascii;
- std::vector<float> dense_features_ascii;
- extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
- &dense_features_ascii);
-
- EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
- EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
- }
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
-
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token(), false, &sparse_features, &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
-}
-
-TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
- TokenFeatureExtractorOptions options;
- options.num_buckets = 1000;
- options.chargram_orders = std::vector<int>{1, 2, 3};
- options.extract_case_feature = true;
- options.unicode_aware_features = false;
- options.extract_selection_mask_feature = true;
- options.allowed_chargrams.insert("^H");
- options.allowed_chargrams.insert("ll");
- options.allowed_chargrams.insert("llo");
- options.allowed_chargrams.insert("w");
- options.allowed_chargrams.insert("!");
- options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
-
- TestingTokenFeatureExtractor extractor(options, unilib_);
-
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features,
- testing::ElementsAreArray({
- // clang-format off
- 0,
- extractor.HashToken("\xc4"),
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("^H"),
- 0,
- 0,
- 0,
- extractor.HashToken("ll"),
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("llo"),
- 0
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
-
- sparse_features.clear();
- dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
- &dense_features);
-
- EXPECT_THAT(sparse_features, testing::ElementsAreArray({
- // clang-format off
- extractor.HashToken("w"),
- 0,
- 0,
- 0,
- 0,
- extractor.HashToken("!"),
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- // clang-format on
- }));
- EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
- EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/tokenizer.cc b/utils/tokenizer.cc
deleted file mode 100644
index 87a5c8d..0000000
--- a/utils/tokenizer.cc
+++ /dev/null
@@ -1,261 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tokenizer.h"
-
-#include <algorithm>
-
-#include "utils/base/logging.h"
-#include "utils/base/macros.h"
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-
-Tokenizer::Tokenizer(
- const TokenizationType type, const UniLib* unilib,
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- const std::vector<const CodepointRange*>&
- internal_tokenizer_codepoint_ranges,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens)
- : type_(type),
- unilib_(unilib),
- split_on_script_change_(split_on_script_change),
- icu_preserve_whitespace_tokens_(icu_preserve_whitespace_tokens) {
- for (const TokenizationCodepointRange* range : codepoint_ranges) {
- codepoint_ranges_.emplace_back(range->UnPack());
- }
-
- std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
- const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
- return a->start < b->start;
- });
-
- SortCodepointRanges(internal_tokenizer_codepoint_ranges,
- &internal_tokenizer_codepoint_ranges_);
-}
-
-const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
- int codepoint) const {
- auto it = std::lower_bound(
- codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& range,
- int codepoint) {
- // This function compares range with the codepoint for the purpose of
- // finding the first greater or equal range. Because of the use of
- // std::lower_bound it needs to return true when range < codepoint;
- // the first time it will return false the lower bound is found and
- // returned.
- //
- // It might seem weird that the condition is range.end <= codepoint
- // here but when codepoint == range.end it means it's actually just
- // outside of the range, thus the range is less than the codepoint.
- return range->end <= codepoint;
- });
- if (it != codepoint_ranges_.end() && (*it)->start <= codepoint &&
- (*it)->end > codepoint) {
- return it->get();
- } else {
- return nullptr;
- }
-}
-
-void Tokenizer::GetScriptAndRole(char32 codepoint,
- TokenizationCodepointRange_::Role* role,
- int* script) const {
- const TokenizationCodepointRangeT* range = FindTokenizationRange(codepoint);
- if (range) {
- *role = range->role;
- *script = range->script_id;
- } else {
- *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- *script = kUnknownScript;
- }
-}
-
-std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
- UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
- return Tokenize(text_unicode);
-}
-
-std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
- switch (type_) {
- case TokenizationType_INTERNAL_TOKENIZER:
- return InternalTokenize(text_unicode);
- case TokenizationType_ICU:
- TC3_FALLTHROUGH_INTENDED;
- case TokenizationType_MIXED: {
- std::vector<Token> result;
- if (!ICUTokenize(text_unicode, &result)) {
- return {};
- }
- if (type_ == TokenizationType_MIXED) {
- InternalRetokenize(text_unicode, &result);
- }
- return result;
- }
- default:
- TC3_LOG(ERROR) << "Unknown tokenization type specified. Using internal.";
- return InternalTokenize(text_unicode);
- }
-}
-
-std::vector<Token> Tokenizer::InternalTokenize(
- const UnicodeText& text_unicode) const {
- std::vector<Token> result;
- Token new_token("", 0, 0);
- int codepoint_index = 0;
-
- int last_script = kInvalidScript;
- for (auto it = text_unicode.begin(); it != text_unicode.end();
- ++it, ++codepoint_index) {
- TokenizationCodepointRange_::Role role;
- int script;
- GetScriptAndRole(*it, &role, &script);
-
- if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
- (split_on_script_change_ && last_script != kInvalidScript &&
- last_script != script)) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index, codepoint_index);
- }
- if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
- new_token.value += std::string(
- it.utf8_data(),
- it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
- ++new_token.end;
- }
- if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
- new_token = Token("", codepoint_index + 1, codepoint_index + 1);
- }
-
- last_script = script;
- }
- if (!new_token.value.empty()) {
- result.push_back(new_token);
- }
-
- return result;
-}
-
-void Tokenizer::TokenizeSubstring(const UnicodeText& unicode_text,
- CodepointSpan span,
- std::vector<Token>* result) const {
- if (span.first < 0) {
- // There is no span to tokenize.
- return;
- }
-
- // Extract the substring.
- UnicodeText text = UnicodeText::Substring(unicode_text, span.first,
- span.second, /*do_copy=*/false);
-
- // Run the tokenizer and update the token bounds to reflect the offset of the
- // substring.
- std::vector<Token> tokens = InternalTokenize(text);
-
- // Avoids progressive capacity increases in the for loop.
- result->reserve(result->size() + tokens.size());
- for (Token& token : tokens) {
- token.start += span.first;
- token.end += span.first;
- result->emplace_back(std::move(token));
- }
-}
-
-void Tokenizer::InternalRetokenize(const UnicodeText& unicode_text,
- std::vector<Token>* tokens) const {
- std::vector<Token> result;
- CodepointSpan span(-1, -1);
- for (Token& token : *tokens) {
- const UnicodeText unicode_token_value =
- UTF8ToUnicodeText(token.value, /*do_copy=*/false);
- bool should_retokenize = true;
- for (const int codepoint : unicode_token_value) {
- if (!IsCodepointInRanges(codepoint,
- internal_tokenizer_codepoint_ranges_)) {
- should_retokenize = false;
- break;
- }
- }
-
- if (should_retokenize) {
- if (span.first < 0) {
- span.first = token.start;
- }
- span.second = token.end;
- } else {
- TokenizeSubstring(unicode_text, span, &result);
- span.first = -1;
- result.emplace_back(std::move(token));
- }
- }
- TokenizeSubstring(unicode_text, span, &result);
-
- *tokens = std::move(result);
-}
-
-bool Tokenizer::ICUTokenize(const UnicodeText& context_unicode,
- std::vector<Token>* result) const {
- std::unique_ptr<UniLib::BreakIterator> break_iterator =
- unilib_->CreateBreakIterator(context_unicode);
- if (!break_iterator) {
- return false;
- }
- int last_break_index = 0;
- int break_index = 0;
- int last_unicode_index = 0;
- int unicode_index = 0;
- auto token_begin_it = context_unicode.begin();
- while ((break_index = break_iterator->Next()) !=
- UniLib::BreakIterator::kDone) {
- const int token_length = break_index - last_break_index;
- unicode_index = last_unicode_index + token_length;
-
- auto token_end_it = token_begin_it;
- std::advance(token_end_it, token_length);
-
- // Determine if the whole token is whitespace.
- bool is_whitespace = true;
- for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
- if (!unilib_->IsWhitespace(*char_it)) {
- is_whitespace = false;
- break;
- }
- }
-
- const std::string token =
- context_unicode.UTF8Substring(token_begin_it, token_end_it);
-
- if (!is_whitespace || icu_preserve_whitespace_tokens_) {
- result->push_back(Token(token, last_unicode_index, unicode_index));
- }
-
- last_break_index = break_index;
- last_unicode_index = unicode_index;
- token_begin_it = token_end_it;
- }
-
- return true;
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/tokenizer.fbs b/utils/tokenizer.fbs
deleted file mode 100755
index 2a19999..0000000
--- a/utils/tokenizer.fbs
+++ /dev/null
@@ -1,70 +0,0 @@
-//
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-// Controls the type of tokenization the model will use for the input text.
-namespace libtextclassifier3;
-enum TokenizationType : int {
- INVALID_TOKENIZATION_TYPE = 0,
-
- // Use the internal tokenizer for tokenization.
- INTERNAL_TOKENIZER = 1,
-
- // Use ICU for tokenization.
- ICU = 2,
-
- // First apply ICU tokenization. Then identify stretches of tokens
- // consisting only of codepoints in internal_tokenizer_codepoint_ranges
- // and re-tokenize them using the internal tokenizer.
- MIXED = 3,
-}
-
-// Role of the codepoints in the range.
-namespace libtextclassifier3.TokenizationCodepointRange_;
-enum Role : int {
- // Concatenates the codepoint to the current run of codepoints.
- DEFAULT_ROLE = 0,
-
- // Splits a run of codepoints before the current codepoint.
- SPLIT_BEFORE = 1,
-
- // Splits a run of codepoints after the current codepoint.
- SPLIT_AFTER = 2,
-
- // Each codepoint will be a separate token. Good e.g. for Chinese
- // characters.
- TOKEN_SEPARATOR = 3,
-
- // Discards the codepoint.
- DISCARD_CODEPOINT = 4,
-
- // Common values:
- // Splits on the characters and discards them. Good e.g. for the space
- // character.
- WHITESPACE_SEPARATOR = 7,
-}
-
-// Represents a codepoint range [start, end) with its role for tokenization.
-namespace libtextclassifier3;
-table TokenizationCodepointRange {
- start:int;
- end:int;
- role:TokenizationCodepointRange_.Role;
-
- // Integer identifier of the script this range denotes. Negative values are
- // reserved for Tokenizer's internal use.
- script_id:int;
-}
-
diff --git a/utils/tokenizer.h b/utils/tokenizer.h
deleted file mode 100644
index 3a9ef6c..0000000
--- a/utils/tokenizer.h
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
-#define LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
-
-#include <string>
-#include <vector>
-
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/codepoint-range.h"
-#include "utils/tokenizer_generated.h"
-#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-
-namespace libtextclassifier3 {
-
-const int kInvalidScript = -1;
-const int kUnknownScript = -2;
-
-// Tokenizer splits the input string into a sequence of tokens, according to
-// the configuration.
-class Tokenizer {
- public:
- // `codepoint_ranges`: Codepoint ranges that determine how different
- // codepoints are tokenized. The ranges must not overlap.
- // `internal_tokenizer_codepoint_ranges`: Codepoint ranges that define which
- // tokens should be re-tokenized with the internal tokenizer in the mixed
- // tokenization mode.
- // `split_on_script_change`: Whether to consider a change of codepoint script
- // in a sequence of characters as a token boundary. If True, will treat
- // script change as a token boundary.
- // `icu_preserve_whitespace_tokens`: If true, will include empty tokens in the
- // output (in the ICU tokenization mode).
- Tokenizer(
- const TokenizationType type, const UniLib* unilib,
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- const std::vector<const CodepointRange*>&
- internal_tokenizer_codepoint_ranges,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens);
-
- Tokenizer(
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- const bool split_on_script_change)
- : Tokenizer(TokenizationType_INTERNAL_TOKENIZER, /*unilib=*/nullptr,
- codepoint_ranges, /*internal_tokenizer_codepoint_ranges=*/{},
- split_on_script_change,
- /*icu_preserve_whitespace_tokens=*/false) {}
-
- // Tokenizes the input string using the selected tokenization method.
- std::vector<Token> Tokenize(const std::string& text) const;
-
- // Same as above but takes UnicodeText.
- std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
-
- protected:
- // Finds the tokenization codepoint range config for given codepoint.
- // Internally uses binary search so should be O(log(# of codepoint_ranges)).
- const TokenizationCodepointRangeT* FindTokenizationRange(int codepoint) const;
-
- // Finds the role and script for given codepoint. If not found, DEFAULT_ROLE
- // and kUnknownScript are assigned.
- void GetScriptAndRole(char32 codepoint,
- TokenizationCodepointRange_::Role* role,
- int* script) const;
-
- // Tokenizes a substring of the unicode string, appending the resulting tokens
- // to the output vector. The resulting tokens have bounds relative to the full
- // string. Does nothing if the start of the span is negative.
- void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
- std::vector<Token>* result) const;
-
- std::vector<Token> InternalTokenize(const UnicodeText& text_unicode) const;
-
- // Takes the result of ICU tokenization and retokenizes stretches of tokens
- // made of a specific subset of characters using the internal tokenizer.
- void InternalRetokenize(const UnicodeText& unicode_text,
- std::vector<Token>* tokens) const;
-
- // Tokenizes the input text using ICU tokenizer.
- bool ICUTokenize(const UnicodeText& context_unicode,
- std::vector<Token>* result) const;
-
- private:
- const TokenizationType type_;
-
- const UniLib* unilib_;
-
- // Codepoint ranges that determine how different codepoints are tokenized.
- // The ranges must not overlap.
- std::vector<std::unique_ptr<const TokenizationCodepointRangeT>>
- codepoint_ranges_;
-
- // Codepoint ranges that define which tokens (consisting of which codepoints)
- // should be re-tokenized with the internal tokenizer in the mixed
- // tokenization mode.
- // NOTE: Must be sorted.
- std::vector<CodepointRangeStruct> internal_tokenizer_codepoint_ranges_;
-
- // If true, tokens will be additionally split when the codepoint's script_id
- // changes.
- const bool split_on_script_change_;
-
- const bool icu_preserve_whitespace_tokens_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_TOKENIZER_H_
diff --git a/utils/tokenizer_test.cc b/utils/tokenizer_test.cc
deleted file mode 100644
index 4f4f763..0000000
--- a/utils/tokenizer_test.cc
+++ /dev/null
@@ -1,485 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/tokenizer.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-using testing::ElementsAreArray;
-
-class TestingTokenizer : public Tokenizer {
- public:
- TestingTokenizer(
- const TokenizationType type, const UniLib* unilib,
- const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
- const std::vector<const CodepointRange*>&
- internal_tokenizer_codepoint_ranges,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens)
- : Tokenizer(type, unilib, codepoint_ranges,
- internal_tokenizer_codepoint_ranges, split_on_script_change,
- icu_preserve_whitespace_tokens) {}
-
- using Tokenizer::FindTokenizationRange;
-};
-
-class TestingTokenizerProxy {
- public:
- TestingTokenizerProxy(
- TokenizationType type,
- const std::vector<TokenizationCodepointRangeT>& codepoint_range_configs,
- const std::vector<CodepointRangeT>& internal_codepoint_range_configs,
- const bool split_on_script_change,
- const bool icu_preserve_whitespace_tokens)
- : INIT_UNILIB_FOR_TESTING(unilib_) {
- const int num_configs = codepoint_range_configs.size();
- std::vector<const TokenizationCodepointRange*> configs_fb;
- configs_fb.reserve(num_configs);
- const int num_internal_configs = internal_codepoint_range_configs.size();
- std::vector<const CodepointRange*> internal_configs_fb;
- internal_configs_fb.reserve(num_internal_configs);
- buffers_.reserve(num_configs + num_internal_configs);
- for (int i = 0; i < num_configs; i++) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateTokenizationCodepointRange(
- builder, &codepoint_range_configs[i]));
- buffers_.push_back(builder.Release());
- configs_fb.push_back(flatbuffers::GetRoot<TokenizationCodepointRange>(
- buffers_.back().data()));
- }
- for (int i = 0; i < num_internal_configs; i++) {
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(
- CreateCodepointRange(builder, &internal_codepoint_range_configs[i]));
- buffers_.push_back(builder.Release());
- internal_configs_fb.push_back(
- flatbuffers::GetRoot<CodepointRange>(buffers_.back().data()));
- }
- tokenizer_ = std::unique_ptr<TestingTokenizer>(new TestingTokenizer(
- type, &unilib_, configs_fb, internal_configs_fb, split_on_script_change,
- icu_preserve_whitespace_tokens));
- }
-
- TokenizationCodepointRange_::Role TestFindTokenizationRole(int c) const {
- const TokenizationCodepointRangeT* range =
- tokenizer_->FindTokenizationRange(c);
- if (range != nullptr) {
- return range->role;
- } else {
- return TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- }
- }
-
- std::vector<Token> Tokenize(const std::string& utf8_text) const {
- return tokenizer_->Tokenize(utf8_text);
- }
-
- private:
- UniLib unilib_;
- std::vector<flatbuffers::DetachedBuffer> buffers_;
- std::unique_ptr<TestingTokenizer> tokenizer_;
-};
-
-TEST(TokenizerTest, FindTokenizationRange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 10;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 1234;
- config->end = 12345;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {}, /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
-
- // Test hits to the first group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(0),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(5),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(10),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit to the second group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(31),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(32),
- TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(33),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test hits to the third group.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1233),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(1234),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12344),
- TokenizationCodepointRange_::Role_TOKEN_SEPARATOR);
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(12345),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-
- // Test a hit outside.
- EXPECT_EQ(tokenizer.TestFindTokenizationRole(99),
- TokenizationCodepointRange_::Role_DEFAULT_ROLE);
-}
-
-TEST(TokenizerTest, TokenizeOnSpace) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- // Space character.
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
-
- EXPECT_THAT(tokens,
- ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
-}
-
-TEST(TokenizerTest, TokenizeOnSpaceAndScriptChange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Latin.
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- config->script_id = 1;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- config->script_id = 1;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/true,
- /*icu_preserve_whitespace_tokens=*/false);
- EXPECT_THAT(tokenizer.Tokenize("앨라배마 주 전화(123) 456-789웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("주", 5, 6),
- Token("전화", 7, 10), Token("(123)", 10, 15),
- Token("456-789", 16, 23),
- Token("웹사이트", 23, 28)}));
-} // namespace
-
-TEST(TokenizerTest, TokenizeComplex) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- // Source: http://www.unicode.org/Public/10.0.0/ucd/Blocks-10.0.0d1.txt
- // Latin - cyrilic.
- // 0000..007F; Basic Latin
- // 0080..00FF; Latin-1 Supplement
- // 0100..017F; Latin Extended-A
- // 0180..024F; Latin Extended-B
- // 0250..02AF; IPA Extensions
- // 02B0..02FF; Spacing Modifier Letters
- // 0300..036F; Combining Diacritical Marks
- // 0370..03FF; Greek and Coptic
- // 0400..04FF; Cyrillic
- // 0500..052F; Cyrillic Supplement
- // 0530..058F; Armenian
- // 0590..05FF; Hebrew
- // 0600..06FF; Arabic
- // 0700..074F; Syriac
- // 0750..077F; Arabic Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 32;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 33;
- config->end = 0x77F + 1;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
-
- // CJK
- // 2E80..2EFF; CJK Radicals Supplement
- // 3000..303F; CJK Symbols and Punctuation
- // 3040..309F; Hiragana
- // 30A0..30FF; Katakana
- // 3100..312F; Bopomofo
- // 3130..318F; Hangul Compatibility Jamo
- // 3190..319F; Kanbun
- // 31A0..31BF; Bopomofo Extended
- // 31C0..31EF; CJK Strokes
- // 31F0..31FF; Katakana Phonetic Extensions
- // 3200..32FF; Enclosed CJK Letters and Months
- // 3300..33FF; CJK Compatibility
- // 3400..4DBF; CJK Unified Ideographs Extension A
- // 4DC0..4DFF; Yijing Hexagram Symbols
- // 4E00..9FFF; CJK Unified Ideographs
- // A000..A48F; Yi Syllables
- // A490..A4CF; Yi Radicals
- // A4D0..A4FF; Lisu
- // A500..A63F; Vai
- // F900..FAFF; CJK Compatibility Ideographs
- // FE30..FE4F; CJK Compatibility Forms
- // 20000..2A6DF; CJK Unified Ideographs Extension B
- // 2A700..2B73F; CJK Unified Ideographs Extension C
- // 2B740..2B81F; CJK Unified Ideographs Extension D
- // 2B820..2CEAF; CJK Unified Ideographs Extension E
- // 2CEB0..2EBEF; CJK Unified Ideographs Extension F
- // 2F800..2FA1F; CJK Compatibility Ideographs Supplement
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2E80;
- config->end = 0x2EFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x3000;
- config->end = 0xA63F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xF900;
- config->end = 0xFAFF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0xFE30;
- config->end = 0xFE4F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x20000;
- config->end = 0x2A6DF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2A700;
- config->end = 0x2B73F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B740;
- config->end = 0x2B81F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2B820;
- config->end = 0x2CEAF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2CEB0;
- config->end = 0x2EBEF + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x2F800;
- config->end = 0x2FA1F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- // Thai.
- // 0E00..0E7F; Thai
- configs.emplace_back();
- config = &configs.back();
- config->start = 0x0E00;
- config->end = 0x0E7F + 1;
- config->role = TokenizationCodepointRange_::Role_TOKEN_SEPARATOR;
-
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER, configs,
- {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
- std::vector<Token> tokens;
-
- tokens = tokenizer.Tokenize(
- "問少目木輸走猶術権自京門録球変。細開括省用掲情結傍走愛明氷。");
- EXPECT_EQ(tokens.size(), 30);
-
- tokens = tokenizer.Tokenize("問少目 hello 木輸ยามきゃ");
- // clang-format off
- EXPECT_THAT(
- tokens,
- ElementsAreArray({Token("問", 0, 1),
- Token("少", 1, 2),
- Token("目", 2, 3),
- Token("hello", 4, 9),
- Token("木", 10, 11),
- Token("輸", 11, 12),
- Token("ย", 12, 13),
- Token("า", 13, 14),
- Token("ม", 14, 15),
- Token("き", 15, 16),
- Token("ゃ", 16, 17)}));
- // clang-format on
-}
-
-#ifdef TC3_TEST_ICU
-TEST(TokenizerTest, ICUTokenize) {
- TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
- std::vector<Token> tokens = tokenizer.Tokenize("พระบาทสมเด็จพระปรมิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token("สมเด็จ", 6, 12),
- Token("พระ", 12, 15),
- Token("ปร", 15, 17),
- Token("มิ", 17, 19)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, ICUTokenizeWithWhitespaces) {
- TestingTokenizerProxy tokenizer(TokenizationType_ICU, {}, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/true);
- std::vector<Token> tokens = tokenizer.Tokenize("พระบาท สมเด็จ พระ ปร มิ");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("พระบาท", 0, 6),
- Token(" ", 6, 7),
- Token("สมเด็จ", 7, 13),
- Token(" ", 13, 14),
- Token("พระ", 14, 17),
- Token(" ", 17, 18),
- Token("ปร", 18, 20),
- Token(" ", 20, 21),
- Token("มิ", 21, 23)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, MixedTokenize) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 32;
- config->end = 33;
- config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
-
- std::vector<CodepointRangeT> internal_configs;
- CodepointRangeT* interal_config;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 0;
- interal_config->end = 128;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 128;
- interal_config->end = 256;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 256;
- interal_config->end = 384;
-
- internal_configs.emplace_back();
- interal_config = &internal_configs.back();
- interal_config->start = 384;
- interal_config->end = 592;
-
- TestingTokenizerProxy tokenizer(TokenizationType_MIXED, configs,
- internal_configs,
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
-
- std::vector<Token> tokens = tokenizer.Tokenize(
- "こんにちはJapanese-ląnguagę text 世界 http://www.google.com/");
- ASSERT_EQ(tokens,
- // clang-format off
- std::vector<Token>({Token("こんにちは", 0, 5),
- Token("Japanese-ląnguagę", 5, 22),
- Token("text", 23, 27),
- Token("世界", 28, 30),
- Token("http://www.google.com/", 31, 53)}));
- // clang-format on
-}
-
-TEST(TokenizerTest, InternalTokenizeOnScriptChange) {
- std::vector<TokenizationCodepointRangeT> configs;
- TokenizationCodepointRangeT* config;
-
- configs.emplace_back();
- config = &configs.back();
- config->start = 0;
- config->end = 256;
- config->role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
-
- {
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
- configs, {},
- /*split_on_script_change=*/false,
- /*icu_preserve_whitespace_tokens=*/false);
-
- EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마123웹사이트", 0, 11)}));
- }
-
- {
- TestingTokenizerProxy tokenizer(TokenizationType_INTERNAL_TOKENIZER,
- configs, {},
- /*split_on_script_change=*/true,
- /*icu_preserve_whitespace_tokens=*/false);
- EXPECT_EQ(tokenizer.Tokenize("앨라배마123웹사이트"),
- std::vector<Token>({Token("앨라배마", 0, 4), Token("123", 4, 7),
- Token("웹사이트", 7, 11)}));
- }
-}
-#endif
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/utf8/UniLibJavaIcuTest.java b/utils/utf8/UniLibJavaIcuTest.java
deleted file mode 100644
index d6a0a06..0000000
--- a/utils/utf8/UniLibJavaIcuTest.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.android.textclassifier.utils.utf8;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
-@RunWith(JUnit4.class)
-public class UniLibJavaIcuTest {
-
- @Before
- public void setUp() throws Exception {
- System.loadLibrary("unilib-javaicu_test-jni");
- }
-
- private native boolean testsMain();
-
- @Test
- public void testNative() {
- assertThat(testsMain()).isTrue();
- }
-}
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
deleted file mode 100644
index b3b092e..0000000
--- a/utils/utf8/unicodetext.cc
+++ /dev/null
@@ -1,327 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unicodetext.h"
-
-#include <string.h>
-
-#include <algorithm>
-
-#include "utils/base/logging.h"
-#include "utils/strings/utf8.h"
-
-namespace libtextclassifier3 {
-
-// *************** Data representation **********
-// Note: the copy constructor is undefined.
-
-UnicodeText::Repr& UnicodeText::Repr::operator=(Repr&& src) {
- if (ours_ && data_) delete[] data_;
- data_ = src.data_;
- size_ = src.size_;
- capacity_ = src.capacity_;
- ours_ = src.ours_;
- src.ours_ = false;
- return *this;
-}
-
-void UnicodeText::Repr::PointTo(const char* data, int size) {
- if (ours_ && data_) delete[] data_; // If we owned the old buffer, free it.
- data_ = const_cast<char*>(data);
- size_ = size;
- capacity_ = size;
- ours_ = false;
-}
-
-void UnicodeText::Repr::Copy(const char* data, int size) {
- resize(size);
- memcpy(data_, data, size);
-}
-
-void UnicodeText::Repr::resize(int new_size) {
- if (new_size == 0) {
- clear();
- } else {
- if (!ours_ || new_size > capacity_) reserve(new_size);
- // Clear the memory in the expanded part.
- if (size_ < new_size) memset(data_ + size_, 0, new_size - size_);
- size_ = new_size;
- ours_ = true;
- }
-}
-
-void UnicodeText::Repr::reserve(int new_capacity) {
- // If there's already enough capacity, and we're an owner, do nothing.
- if (capacity_ >= new_capacity && ours_) return;
-
- // Otherwise, allocate a new buffer.
- capacity_ = std::max(new_capacity, (3 * capacity_) / 2 + 20);
- char* new_data = new char[capacity_];
-
- // If there is an old buffer, copy it into the new buffer.
- if (data_) {
- memcpy(new_data, data_, size_);
- if (ours_) delete[] data_; // If we owned the old buffer, free it.
- }
- data_ = new_data;
- ours_ = true; // We own the new buffer.
- // size_ is unchanged.
-}
-
-void UnicodeText::Repr::append(const char* bytes, int byte_length) {
- reserve(size_ + byte_length);
- memcpy(data_ + size_, bytes, byte_length);
- size_ += byte_length;
-}
-
-void UnicodeText::Repr::clear() {
- if (ours_) delete[] data_;
- data_ = nullptr;
- size_ = capacity_ = 0;
- ours_ = true;
-}
-
-// *************** UnicodeText ******************
-
-UnicodeText::UnicodeText() {}
-
-UnicodeText::UnicodeText(const UnicodeText& src) { Copy(src); }
-
-UnicodeText& UnicodeText::operator=(UnicodeText&& src) {
- this->repr_ = std::move(src.repr_);
- return *this;
-}
-
-UnicodeText& UnicodeText::Copy(const UnicodeText& src) {
- repr_.Copy(src.repr_.data_, src.repr_.size_);
- return *this;
-}
-
-UnicodeText& UnicodeText::PointToUTF8(const char* buffer, int byte_length) {
- repr_.PointTo(buffer, byte_length);
- return *this;
-}
-
-UnicodeText& UnicodeText::CopyUTF8(const char* buffer, int byte_length) {
- repr_.Copy(buffer, byte_length);
- return *this;
-}
-
-UnicodeText& UnicodeText::AppendUTF8(const char* utf8, int len) {
- repr_.append(utf8, len);
- return *this;
-}
-
-const char* UnicodeText::data() const { return repr_.data_; }
-
-int UnicodeText::size_bytes() const { return repr_.size_; }
-
-namespace {
-
-enum {
- RuneError = 0xFFFD, // Decoding error in UTF.
- RuneMax = 0x10FFFF, // Maximum rune value.
-};
-
-int runetochar(const char32 rune, char* dest) {
- // Convert to unsigned for range check.
- uint32 c;
-
- // 1 char 00-7F
- c = rune;
- if (c <= 0x7F) {
- dest[0] = static_cast<char>(c);
- return 1;
- }
-
- // 2 char 0080-07FF
- if (c <= 0x07FF) {
- dest[0] = 0xC0 | static_cast<char>(c >> 1 * 6);
- dest[1] = 0x80 | (c & 0x3F);
- return 2;
- }
-
- // Range check
- if (c > RuneMax) {
- c = RuneError;
- }
-
- // 3 char 0800-FFFF
- if (c <= 0xFFFF) {
- dest[0] = 0xE0 | static_cast<char>(c >> 2 * 6);
- dest[1] = 0x80 | ((c >> 1 * 6) & 0x3F);
- dest[2] = 0x80 | (c & 0x3F);
- return 3;
- }
-
- // 4 char 10000-1FFFFF
- dest[0] = 0xF0 | static_cast<char>(c >> 3 * 6);
- dest[1] = 0x80 | ((c >> 2 * 6) & 0x3F);
- dest[2] = 0x80 | ((c >> 1 * 6) & 0x3F);
- dest[3] = 0x80 | (c & 0x3F);
- return 4;
-}
-
-} // namespace
-
-UnicodeText& UnicodeText::push_back(char32 ch) {
- char str[4];
- int char_len = runetochar(ch, str);
- repr_.append(str, char_len);
- return *this;
-}
-
-void UnicodeText::clear() { repr_.clear(); }
-
-int UnicodeText::size_codepoints() const {
- return std::distance(begin(), end());
-}
-
-bool UnicodeText::empty() const { return size_bytes() == 0; }
-
-bool UnicodeText::is_valid() const {
- return IsValidUTF8(repr_.data_, repr_.size_);
-}
-
-bool UnicodeText::operator==(const UnicodeText& other) const {
- if (repr_.size_ != other.repr_.size_) {
- return false;
- }
- return memcmp(repr_.data_, other.repr_.data_, repr_.size_) == 0;
-}
-
-std::string UnicodeText::ToUTF8String() const {
- return UTF8Substring(begin(), end());
-}
-
-std::string UnicodeText::UTF8Substring(int begin_codepoint,
- int end_codepoint) const {
- auto span_begin = begin();
- std::advance(span_begin, begin_codepoint);
- auto span_end = begin();
- std::advance(span_end, end_codepoint);
- return UTF8Substring(span_begin, span_end);
-}
-
-std::string UnicodeText::UTF8Substring(const const_iterator& it_begin,
- const const_iterator& it_end) {
- return std::string(it_begin.it_, it_end.it_ - it_begin.it_);
-}
-
-UnicodeText UnicodeText::Substring(const UnicodeText& text, int begin_codepoint,
- int end_codepoint, bool do_copy) {
- auto it_begin = text.begin();
- std::advance(it_begin, begin_codepoint);
- auto it_end = text.begin();
- std::advance(it_end, end_codepoint);
-
- if (do_copy) {
- UnicodeText result;
- result.repr_.Copy(it_begin.it_, it_end.it_ - it_begin.it_);
- return result;
- } else {
- UnicodeText result;
- result.repr_.PointTo(it_begin.it_, it_end.it_ - it_begin.it_);
- return result;
- }
-}
-
-UnicodeText::~UnicodeText() {}
-
-// ******************* UnicodeText::const_iterator *********************
-
-// The implementation of const_iterator would be nicer if it
-// inherited from boost::iterator_facade
-// (http://boost.org/libs/iterator/doc/iterator_facade.html).
-
-UnicodeText::const_iterator::const_iterator() : it_(0) {}
-
-UnicodeText::const_iterator& UnicodeText::const_iterator::operator=(
- const const_iterator& other) {
- if (&other != this) it_ = other.it_;
- return *this;
-}
-
-UnicodeText::const_iterator UnicodeText::begin() const {
- return const_iterator(repr_.data_);
-}
-
-UnicodeText::const_iterator UnicodeText::end() const {
- return const_iterator(repr_.data_ + repr_.size_);
-}
-
-bool operator<(const UnicodeText::const_iterator& lhs,
- const UnicodeText::const_iterator& rhs) {
- return lhs.it_ < rhs.it_;
-}
-
-char32 UnicodeText::const_iterator::operator*() const {
- // (We could call chartorune here, but that does some
- // error-checking, and we're guaranteed that our data is valid
- // UTF-8. Also, we expect this routine to be called very often. So
- // for speed, we do the calculation ourselves.)
-
- // Convert from UTF-8
- unsigned char byte1 = static_cast<unsigned char>(it_[0]);
- if (byte1 < 0x80) return byte1;
-
- unsigned char byte2 = static_cast<unsigned char>(it_[1]);
- if (byte1 < 0xE0) return ((byte1 & 0x1F) << 6) | (byte2 & 0x3F);
-
- unsigned char byte3 = static_cast<unsigned char>(it_[2]);
- if (byte1 < 0xF0) {
- return ((byte1 & 0x0F) << 12) | ((byte2 & 0x3F) << 6) | (byte3 & 0x3F);
- }
-
- unsigned char byte4 = static_cast<unsigned char>(it_[3]);
- return ((byte1 & 0x07) << 18) | ((byte2 & 0x3F) << 12) |
- ((byte3 & 0x3F) << 6) | (byte4 & 0x3F);
-}
-
-UnicodeText::const_iterator& UnicodeText::const_iterator::operator++() {
- it_ += GetNumBytesForNonZeroUTF8Char(it_);
- return *this;
-}
-
-UnicodeText::const_iterator& UnicodeText::const_iterator::operator--() {
- while (IsTrailByte(*--it_)) {
- }
- return *this;
-}
-
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len, bool do_copy) {
- UnicodeText t;
- if (do_copy) {
- t.CopyUTF8(utf8_buf, len);
- } else {
- t.PointToUTF8(utf8_buf, len);
- }
- return t;
-}
-
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy) {
- return UTF8ToUnicodeText(utf8_buf, strlen(utf8_buf), do_copy);
-}
-
-UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy) {
- return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
-}
-
-UnicodeText UTF8ToUnicodeText(const std::string& str) {
- return UTF8ToUnicodeText(str, /*do_copy=*/true);
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
deleted file mode 100644
index 310fd38..0000000
--- a/utils/utf8/unicodetext.h
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
-#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
-
-#include <iterator>
-#include <string>
-#include <utility>
-
-#include "utils/base/integral_types.h"
-
-namespace libtextclassifier3 {
-
-// ***************************** UnicodeText **************************
-//
-// A UnicodeText object is a wrapper around a sequence of Unicode
-// codepoint values that allows iteration over these values.
-//
-// The internal representation of the text is UTF-8. Since UTF-8 is a
-// variable-width format, UnicodeText does not provide random access
-// to the text, and changes to the text are permitted only at the end.
-//
-// The UnicodeText class defines a const_iterator. The dereferencing
-// operator (*) returns a codepoint (int32). The iterator is a
-// read-only iterator. It becomes invalid if the text is changed.
-//
-// Codepoints are integers in the range [0, 0xD7FF] or [0xE000,
-// 0x10FFFF], but UnicodeText has the additional restriction that it
-// can contain only those characters that are valid for interchange on
-// the Web. This excludes all of the control codes except for carriage
-// return, line feed, and horizontal tab. It also excludes
-// non-characters, but codepoints that are in the Private Use regions
-// are allowed, as are codepoints that are unassigned. (See the
-// Unicode reference for details.)
-//
-// MEMORY MANAGEMENT:
-//
-// PointToUTF8(buffer, size) creates an alias pointing to buffer.
-//
-// The purpose of an alias is to avoid making an unnecessary copy of a
-// UTF-8 buffer while still providing access to the Unicode values
-// within that text through iterators. The lifetime of an alias must not
-// exceed the lifetime of the buffer from which it was constructed.
-//
-// Aliases should be used with care. If the source from which an alias
-// was created is freed, or if the contents are changed, while the
-// alias is still in use, fatal errors could result. But it can be
-// quite useful to have a UnicodeText "window" through which to see a
-// UTF-8 buffer without having to pay the price of making a copy.
-
-class UnicodeText {
- public:
- class const_iterator;
-
- UnicodeText(); // Create an empty text.
- UnicodeText(const UnicodeText& src);
- UnicodeText& operator=(UnicodeText&& src);
- ~UnicodeText();
-
- class const_iterator {
- typedef const_iterator CI;
-
- public:
- typedef std::bidirectional_iterator_tag iterator_category;
- typedef char32 value_type;
- typedef int difference_type;
- typedef void pointer; // (Not needed.)
- typedef const char32 reference; // (Needed for const_reverse_iterator)
-
- // Iterators are default-constructible.
- const_iterator();
-
- // It's safe to make multiple passes over a UnicodeText.
- const_iterator& operator=(const const_iterator& other);
-
- char32 operator*() const; // Dereference
-
- const_iterator& operator++(); // Advance (++iter)
- const_iterator operator++(int) { // (iter++)
- const_iterator result(*this);
- ++*this;
- return result;
- }
-
- const_iterator& operator--(); // Retreat (--iter)
- const_iterator operator--(int) { // (iter--)
- const_iterator result(*this);
- --*this;
- return result;
- }
-
- friend bool operator==(const CI& lhs, const CI& rhs) {
- return lhs.it_ == rhs.it_;
- }
- friend bool operator!=(const CI& lhs, const CI& rhs) {
- return !(lhs == rhs);
- }
- friend bool operator<(const CI& lhs, const CI& rhs);
- friend bool operator>(const CI& lhs, const CI& rhs) { return rhs < lhs; }
- friend bool operator<=(const CI& lhs, const CI& rhs) {
- return !(rhs < lhs);
- }
- friend bool operator>=(const CI& lhs, const CI& rhs) {
- return !(lhs < rhs);
- }
-
- int utf8_length() const {
- if (it_[0] < 0x80) {
- return 1;
- } else if (it_[0] < 0xE0) {
- return 2;
- } else if (it_[0] < 0xF0) {
- return 3;
- } else {
- return 4;
- }
- }
- const char* utf8_data() const { return it_; }
-
- private:
- friend class UnicodeText;
- explicit const_iterator(const char* it) : it_(it) {}
-
- const char* it_;
- };
-
- const_iterator begin() const;
- const_iterator end() const;
-
- // Gets pointer to the underlying utf8 data.
- const char* data() const;
-
- // Gets length (in bytes) of the underlying utf8 data.
- int size_bytes() const;
-
- // Computes length (in number of Unicode codepoints) of the underlying utf8
- // data.
- // NOTE: Complexity O(n).
- int size_codepoints() const;
-
- bool empty() const;
-
- // Checks whether the underlying data is valid utf8 data.
- bool is_valid() const;
-
- bool operator==(const UnicodeText& other) const;
-
- // x.PointToUTF8(buf,len) changes x so that it points to buf
- // ("becomes an alias"). It does not take ownership or copy buf.
- // This function assumes that the input is interchange valid UTF8.
- UnicodeText& Copy(const UnicodeText& src);
- UnicodeText& PointToUTF8(const char* utf8_buffer, int byte_length);
- UnicodeText& CopyUTF8(const char* utf8_buffer, int byte_length);
-
- // Calling this may invalidate pointers to underlying data.
- UnicodeText& AppendUTF8(const char* utf8, int len);
- UnicodeText& push_back(char32 ch);
- void clear();
-
- std::string ToUTF8String() const;
- std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
- static std::string UTF8Substring(const const_iterator& it_begin,
- const const_iterator& it_end);
- static UnicodeText Substring(const UnicodeText& text, int begin_codepoint,
- int end_codepoint, bool do_copy = true);
-
- private:
- friend class const_iterator;
-
- class Repr { // A byte-string.
- public:
- char* data_;
- int size_;
- int capacity_;
- bool ours_; // Do we own data_?
-
- Repr() : data_(nullptr), size_(0), capacity_(0), ours_(true) {}
- Repr& operator=(Repr&& src);
- ~Repr() {
- if (ours_) delete[] data_;
- }
-
- void clear();
- void reserve(int capacity);
- void resize(int size);
-
- void append(const char* bytes, int byte_length);
- void Copy(const char* data, int size);
- void PointTo(const char* data, int size);
-
- private:
- Repr& operator=(const Repr&);
- Repr(const Repr& other);
- };
-
- Repr repr_;
-};
-
-typedef std::pair<UnicodeText::const_iterator, UnicodeText::const_iterator>
- UnicodeTextRange;
-
-// NOTE: The following are needed to avoid implicit conversion from char* to
-// std::string, or from ::string to std::string, because if this happens it
-// often results in invalid memory access to a temporary object created during
-// such conversion (if do_copy == false).
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, int len,
- bool do_copy = true);
-UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
-UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy = true);
-UnicodeText UTF8ToUnicodeText(const std::string& str);
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNICODETEXT_H_
diff --git a/utils/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc
deleted file mode 100644
index e6926ce..0000000
--- a/utils/utf8/unicodetext_test.cc
+++ /dev/null
@@ -1,198 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unicodetext.h"
-
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-class UnicodeTextTest : public testing::Test {
- protected:
- UnicodeTextTest() : empty_text_() {
- text_.push_back(0x1C0);
- text_.push_back(0x4E8C);
- text_.push_back(0xD7DB);
- text_.push_back(0x34);
- text_.push_back(0x1D11E);
- }
-
- UnicodeText empty_text_;
- UnicodeText text_;
-};
-
-// Tests for our modifications of UnicodeText.
-TEST(UnicodeTextTest, Custom) {
- UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
- EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
- EXPECT_EQ(text.size_codepoints(), 10);
- EXPECT_EQ(text.size_bytes(), 13);
-
- auto it_begin = text.begin();
- std::advance(it_begin, 4);
- auto it_end = text.begin();
- std::advance(it_end, 6);
- EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
-}
-
-TEST(UnicodeTextTest, Substring) {
- UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
-
- EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/true),
- UTF8ToUnicodeText("😋h"));
- EXPECT_EQ(UnicodeText::Substring(text, 4, 6, /*do_copy=*/false),
- UTF8ToUnicodeText("😋h"));
-}
-
-TEST(UnicodeTextTest, Ownership) {
- const std::string src = "\u304A\u00B0\u106B";
-
- UnicodeText alias;
- alias.PointToUTF8(src.data(), src.size());
- EXPECT_EQ(alias.data(), src.data());
- UnicodeText::const_iterator it = alias.begin();
- EXPECT_EQ(*it++, 0x304A);
- EXPECT_EQ(*it++, 0x00B0);
- EXPECT_EQ(*it++, 0x106B);
- EXPECT_EQ(it, alias.end());
-
- UnicodeText t = alias; // Copy initialization copies the data.
- EXPECT_NE(t.data(), alias.data());
-}
-
-TEST(UnicodeTextTest, Validation) {
- EXPECT_TRUE(UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("\u304A\u00B0\u106B", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("this is a test😋😋😋", /*do_copy=*/false).is_valid());
- EXPECT_TRUE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x8b", /*do_copy=*/false).is_valid());
- // Too short (string is too short).
- EXPECT_FALSE(UTF8ToUnicodeText("\xf0\x9f", /*do_copy=*/false).is_valid());
- // Too long (too many trailing bytes).
- EXPECT_FALSE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x8b\x8b", /*do_copy=*/false).is_valid());
- // Too short (too few trailing bytes).
- EXPECT_FALSE(
- UTF8ToUnicodeText("\xf0\x9f\x98\x61\x61", /*do_copy=*/false).is_valid());
- // Invalid with context.
- EXPECT_FALSE(
- UTF8ToUnicodeText("hello \xf0\x9f\x98\x61\x61 world1", /*do_copy=*/false)
- .is_valid());
-}
-
-class IteratorTest : public UnicodeTextTest {};
-
-TEST_F(IteratorTest, Iterates) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0x1C0, *iter);
- EXPECT_EQ(&iter, &++iter); // operator++ returns *this.
- EXPECT_EQ(0x4E8C, *iter++);
- EXPECT_EQ(0xD7DB, *iter);
- // Make sure you can dereference more than once.
- EXPECT_EQ(0xD7DB, *iter);
- EXPECT_EQ(0x34, *++iter);
- EXPECT_EQ(0x1D11E, *++iter);
- ASSERT_TRUE(iter != text_.end());
- iter++;
- EXPECT_TRUE(iter == text_.end());
-}
-
-TEST_F(IteratorTest, MultiPass) {
- // Also tests Default Constructible and Assignable.
- UnicodeText::const_iterator i1, i2;
- i1 = text_.begin();
- i2 = i1;
- EXPECT_EQ(0x4E8C, *++i1);
- EXPECT_TRUE(i1 != i2);
- EXPECT_EQ(0x1C0, *i2);
- ++i2;
- EXPECT_TRUE(i1 == i2);
- EXPECT_EQ(0x4E8C, *i2);
-}
-
-TEST_F(IteratorTest, ReverseIterates) {
- UnicodeText::const_iterator iter = text_.end();
- EXPECT_TRUE(iter == text_.end());
- iter--;
- ASSERT_TRUE(iter != text_.end());
- EXPECT_EQ(0x1D11E, *iter--);
- EXPECT_EQ(0x34, *iter);
- EXPECT_EQ(0xD7DB, *--iter);
- // Make sure you can dereference more than once.
- EXPECT_EQ(0xD7DB, *iter);
- --iter;
- EXPECT_EQ(0x4E8C, *iter--);
- EXPECT_EQ(0x1C0, *iter);
- EXPECT_TRUE(iter == text_.begin());
-}
-
-TEST_F(IteratorTest, Comparable) {
- UnicodeText::const_iterator i1, i2;
- i1 = text_.begin();
- i2 = i1;
- ++i2;
-
- EXPECT_TRUE(i1 < i2);
- EXPECT_TRUE(text_.begin() <= i1);
- EXPECT_FALSE(i1 >= i2);
- EXPECT_FALSE(i1 > text_.end());
-}
-
-TEST_F(IteratorTest, Advance) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0x1C0, *iter);
- std::advance(iter, 4);
- EXPECT_EQ(0x1D11E, *iter);
- ++iter;
- EXPECT_TRUE(iter == text_.end());
-}
-
-TEST_F(IteratorTest, Distance) {
- UnicodeText::const_iterator iter = text_.begin();
- EXPECT_EQ(0, std::distance(text_.begin(), iter));
- EXPECT_EQ(5, std::distance(iter, text_.end()));
- ++iter;
- ++iter;
- EXPECT_EQ(2, std::distance(text_.begin(), iter));
- EXPECT_EQ(3, std::distance(iter, text_.end()));
- ++iter;
- ++iter;
- EXPECT_EQ(4, std::distance(text_.begin(), iter));
- ++iter;
- EXPECT_EQ(0, std::distance(iter, text_.end()));
-}
-
-class OperatorTest : public UnicodeTextTest {};
-
-TEST_F(OperatorTest, Clear) {
- UnicodeText empty_text(UTF8ToUnicodeText("", /*do_copy=*/false));
- EXPECT_FALSE(text_ == empty_text);
- text_.clear();
- EXPECT_TRUE(text_ == empty_text);
-}
-
-TEST_F(OperatorTest, Empty) {
- EXPECT_TRUE(empty_text_.empty());
- EXPECT_FALSE(text_.empty());
- text_.clear();
- EXPECT_TRUE(text_.empty());
-}
-
-} // namespace
-} // namespace libtextclassifier3
diff --git a/utils/utf8/unilib-javaicu.cc b/utils/utf8/unilib-javaicu.cc
deleted file mode 100644
index 8cddddd..0000000
--- a/utils/utf8/unilib-javaicu.cc
+++ /dev/null
@@ -1,728 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unilib-javaicu.h"
-
-#include <algorithm>
-#include <cassert>
-#include <cctype>
-#include <map>
-
-#include "utils/java/string_utils.h"
-
-namespace libtextclassifier3 {
-namespace {
-
-// -----------------------------------------------------------------------------
-// Native implementations.
-// -----------------------------------------------------------------------------
-
-#define ARRAYSIZE(a) sizeof(a) / sizeof(*a)
-
-// Derived from http://www.unicode.org/Public/UNIDATA/UnicodeData.txt
-// grep -E "Ps" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
-// IMPORTANT: entries with the same offsets in kOpeningBrackets and
-// kClosingBrackets must be counterparts.
-constexpr char32 kOpeningBrackets[] = {
- 0x0028, 0x005B, 0x007B, 0x0F3C, 0x2045, 0x207D, 0x208D, 0x2329, 0x2768,
- 0x276A, 0x276C, 0x2770, 0x2772, 0x2774, 0x27E6, 0x27E8, 0x27EA, 0x27EC,
- 0x27EE, 0x2983, 0x2985, 0x2987, 0x2989, 0x298B, 0x298D, 0x298F, 0x2991,
- 0x2993, 0x2995, 0x2997, 0x29FC, 0x2E22, 0x2E24, 0x2E26, 0x2E28, 0x3008,
- 0x300A, 0x300C, 0x300E, 0x3010, 0x3014, 0x3016, 0x3018, 0x301A, 0xFD3F,
- 0xFE17, 0xFE35, 0xFE37, 0xFE39, 0xFE3B, 0xFE3D, 0xFE3F, 0xFE41, 0xFE43,
- 0xFE47, 0xFE59, 0xFE5B, 0xFE5D, 0xFF08, 0xFF3B, 0xFF5B, 0xFF5F, 0xFF62};
-constexpr int kNumOpeningBrackets = ARRAYSIZE(kOpeningBrackets);
-
-// grep -E "Pe" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]{4});.*(PAREN|BRACKET|BRAKCET|BRACE).*/0x\1, /p"
-constexpr char32 kClosingBrackets[] = {
- 0x0029, 0x005D, 0x007D, 0x0F3D, 0x2046, 0x207E, 0x208E, 0x232A, 0x2769,
- 0x276B, 0x276D, 0x2771, 0x2773, 0x2775, 0x27E7, 0x27E9, 0x27EB, 0x27ED,
- 0x27EF, 0x2984, 0x2986, 0x2988, 0x298A, 0x298C, 0x298E, 0x2990, 0x2992,
- 0x2994, 0x2996, 0x2998, 0x29FD, 0x2E23, 0x2E25, 0x2E27, 0x2E29, 0x3009,
- 0x300B, 0x300D, 0x300F, 0x3011, 0x3015, 0x3017, 0x3019, 0x301B, 0xFD3E,
- 0xFE18, 0xFE36, 0xFE38, 0xFE3A, 0xFE3C, 0xFE3E, 0xFE40, 0xFE42, 0xFE44,
- 0xFE48, 0xFE5A, 0xFE5C, 0xFE5E, 0xFF09, 0xFF3D, 0xFF5D, 0xFF60, 0xFF63};
-constexpr int kNumClosingBrackets = ARRAYSIZE(kClosingBrackets);
-
-// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-constexpr char32 kWhitespaces[] = {
- 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004,
- 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x205F,
- 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85,
- 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A,
- 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500,
- 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
-constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
-
-// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-// As the name suggests, these ranges are always 10 codepoints long, so we just
-// store the end of the range.
-constexpr char32 kDecimalDigitRangesEnd[] = {
- 0x0039, 0x0669, 0x06f9, 0x07c9, 0x096f, 0x09ef, 0x0a6f, 0x0aef,
- 0x0b6f, 0x0bef, 0x0c6f, 0x0cef, 0x0d6f, 0x0def, 0x0e59, 0x0ed9,
- 0x0f29, 0x1049, 0x1099, 0x17e9, 0x1819, 0x194f, 0x19d9, 0x1a89,
- 0x1a99, 0x1b59, 0x1bb9, 0x1c49, 0x1c59, 0xa629, 0xa8d9, 0xa909,
- 0xa9d9, 0xa9f9, 0xaa59, 0xabf9, 0xff19, 0x104a9, 0x1106f, 0x110f9,
- 0x1113f, 0x111d9, 0x112f9, 0x11459, 0x114d9, 0x11659, 0x116c9, 0x11739,
- 0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff};
-constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd);
-
-// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
-// There are three common ways in which upper/lower case codepoint ranges
-// were introduced: one offs, dense ranges, and ranges that alternate between
-// lower and upper case. For the sake of keeping out binary size down, we
-// treat each independently.
-constexpr char32 kUpperSingles[] = {
- 0x01b8, 0x01bc, 0x01c4, 0x01c7, 0x01ca, 0x01f1, 0x0376, 0x037f,
- 0x03cf, 0x03f4, 0x03fa, 0x10c7, 0x10cd, 0x2102, 0x2107, 0x2115,
- 0x2145, 0x2183, 0x2c72, 0x2c75, 0x2cf2, 0xa7b6};
-constexpr int kNumUpperSingles = ARRAYSIZE(kUpperSingles);
-constexpr char32 kUpperRanges1Start[] = {
- 0x0041, 0x00c0, 0x00d8, 0x0181, 0x018a, 0x018e, 0x0193, 0x0196,
- 0x019c, 0x019f, 0x01b2, 0x01f7, 0x023a, 0x023d, 0x0244, 0x0389,
- 0x0392, 0x03a3, 0x03d2, 0x03fd, 0x0531, 0x10a0, 0x13a0, 0x1f08,
- 0x1f18, 0x1f28, 0x1f38, 0x1f48, 0x1f68, 0x1fb8, 0x1fc8, 0x1fd8,
- 0x1fe8, 0x1ff8, 0x210b, 0x2110, 0x2119, 0x212b, 0x2130, 0x213e,
- 0x2c00, 0x2c63, 0x2c6e, 0x2c7e, 0xa7ab, 0xa7b0};
-constexpr int kNumUpperRanges1Start = ARRAYSIZE(kUpperRanges1Start);
-constexpr char32 kUpperRanges1End[] = {
- 0x005a, 0x00d6, 0x00de, 0x0182, 0x018b, 0x0191, 0x0194, 0x0198,
- 0x019d, 0x01a0, 0x01b3, 0x01f8, 0x023b, 0x023e, 0x0246, 0x038a,
- 0x03a1, 0x03ab, 0x03d4, 0x042f, 0x0556, 0x10c5, 0x13f5, 0x1f0f,
- 0x1f1d, 0x1f2f, 0x1f3f, 0x1f4d, 0x1f6f, 0x1fbb, 0x1fcb, 0x1fdb,
- 0x1fec, 0x1ffb, 0x210d, 0x2112, 0x211d, 0x212d, 0x2133, 0x213f,
- 0x2c2e, 0x2c64, 0x2c70, 0x2c80, 0xa7ae, 0xa7b4};
-constexpr int kNumUpperRanges1End = ARRAYSIZE(kUpperRanges1End);
-constexpr char32 kUpperRanges2Start[] = {
- 0x0100, 0x0139, 0x014a, 0x0179, 0x0184, 0x0187, 0x01a2, 0x01a7, 0x01ac,
- 0x01af, 0x01b5, 0x01cd, 0x01de, 0x01f4, 0x01fa, 0x0241, 0x0248, 0x0370,
- 0x0386, 0x038c, 0x038f, 0x03d8, 0x03f7, 0x0460, 0x048a, 0x04c1, 0x04d0,
- 0x1e00, 0x1e9e, 0x1f59, 0x2124, 0x2c60, 0x2c67, 0x2c82, 0x2ceb, 0xa640,
- 0xa680, 0xa722, 0xa732, 0xa779, 0xa77e, 0xa78b, 0xa790, 0xa796};
-constexpr int kNumUpperRanges2Start = ARRAYSIZE(kUpperRanges2Start);
-constexpr char32 kUpperRanges2End[] = {
- 0x0136, 0x0147, 0x0178, 0x017d, 0x0186, 0x0189, 0x01a6, 0x01a9, 0x01ae,
- 0x01b1, 0x01b7, 0x01db, 0x01ee, 0x01f6, 0x0232, 0x0243, 0x024e, 0x0372,
- 0x0388, 0x038e, 0x0391, 0x03ee, 0x03f9, 0x0480, 0x04c0, 0x04cd, 0x052e,
- 0x1e94, 0x1efe, 0x1f5f, 0x212a, 0x2c62, 0x2c6d, 0x2ce2, 0x2ced, 0xa66c,
- 0xa69a, 0xa72e, 0xa76e, 0xa77d, 0xa786, 0xa78d, 0xa792, 0xa7aa};
-constexpr int kNumUpperRanges2End = ARRAYSIZE(kUpperRanges2End);
-
-// grep -E "Lu" UnicodeData.txt | \
-// sed -rne "s/^([0-9A-Z]+);.*;([0-9A-Z]+);$/(0x\1, 0x\2), /p"
-// We have two strategies for mapping from upper to lower case. We have single
-// character lookups that do not follow a pattern, and ranges for which there
-// is a constant codepoint shift.
-// Note that these ranges ignore anything that's not an upper case character,
-// so when applied to a non-uppercase character the result is incorrect.
-constexpr int kToLowerSingles[] = {
- 0x0130, 0x0178, 0x0181, 0x0186, 0x018b, 0x018e, 0x018f, 0x0190, 0x0191,
- 0x0194, 0x0196, 0x0197, 0x0198, 0x019c, 0x019d, 0x019f, 0x01a6, 0x01a9,
- 0x01ae, 0x01b7, 0x01f6, 0x01f7, 0x0220, 0x023a, 0x023d, 0x023e, 0x0243,
- 0x0244, 0x0245, 0x037f, 0x0386, 0x038c, 0x03cf, 0x03f4, 0x03f9, 0x04c0,
- 0x1e9e, 0x1fec, 0x2126, 0x212a, 0x212b, 0x2132, 0x2183, 0x2c60, 0x2c62,
- 0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa,
- 0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3};
-constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles);
-constexpr int kToLowerSinglesOffsets[] = {
- -199, -121, 210, 206, 1, 79, 202, 203, 1,
- 207, 211, 209, 1, 211, 213, 214, 218, 218,
- 218, 219, -97, -56, -130, 10795, -163, 10792, -195,
- 69, 71, 116, 38, 64, 8, -60, -7, 15,
- -7615, -7, -7517, -8383, -8262, 28, 1, 1, -10743,
- -3814, -10727, -10780, -10749, -10783, -10782, -35332, -42280, -42308,
- -42319, -42315, -42305, -42308, -42258, -42282, -42261, 928};
-constexpr int kNumToLowerSinglesOffsets = ARRAYSIZE(kToLowerSinglesOffsets);
-constexpr int kToLowerRangesStart[] = {
- 0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391,
- 0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0,
- 0x1e00, 0x1f08, 0x1fba, 0x1fc8, 0x1fd8, 0x1fda, 0x1fe8, 0x1fea, 0x1ff8,
- 0x1ffa, 0x2c00, 0x2c67, 0x2c7e, 0x2c80, 0xff21, 0x10400, 0x10c80, 0x118a0};
-constexpr int kNumToLowerRangesStart = ARRAYSIZE(kToLowerRangesStart);
-constexpr int kToLowerRangesEnd[] = {
- 0x00de, 0x0187, 0x019f, 0x01af, 0x01b2, 0x0386, 0x038c, 0x038f, 0x03cf,
- 0x03fa, 0x03ff, 0x040f, 0x042f, 0x052e, 0x0556, 0x10cd, 0x13ef, 0x13f5,
- 0x1efe, 0x1fb9, 0x1fbb, 0x1fcb, 0x1fd9, 0x1fdb, 0x1fe9, 0x1fec, 0x1ff9,
- 0x2183, 0x2c64, 0x2c75, 0x2c7f, 0xa7b6, 0xff3a, 0x104d3, 0x10cb2, 0x118bf};
-constexpr int kNumToLowerRangesEnd = ARRAYSIZE(kToLowerRangesEnd);
-constexpr int kToLowerRangesOffsets[] = {
- 32, 1, 205, 1, 217, 1, 37, 63, 32, 1, -130, 80,
- 32, 1, 48, 7264, 38864, 8, 1, -8, -74, -86, -8, -100,
- -8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32};
-constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets);
-
-#undef ARRAYSIZE
-
-static_assert(kNumOpeningBrackets == kNumClosingBrackets,
- "mismatching number of opening and closing brackets");
-static_assert(kNumUpperRanges1Start == kNumUpperRanges1End,
- "number of uppercase stride 1 range starts/ends doesn't match");
-static_assert(kNumUpperRanges2Start == kNumUpperRanges2End,
- "number of uppercase stride 2 range starts/ends doesn't match");
-static_assert(kNumToLowerSingles == kNumToLowerSinglesOffsets,
- "number of to lower singles and offsets doesn't match");
-static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd,
- "mismatching number of range starts/ends for to lower ranges");
-static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets,
- "number of to lower ranges and offsets doesn't match");
-
-constexpr int kNoMatch = -1;
-
-// Returns the index of the element in the array that matched the given
-// codepoint, or kNoMatch if the element didn't exist.
-// The input array must be in sorted order.
-int GetMatchIndex(const char32* array, int array_length, char32 c) {
- const char32* end = array + array_length;
- const auto find_it = std::lower_bound(array, end, c);
- if (find_it != end && *find_it == c) {
- return find_it - array;
- } else {
- return kNoMatch;
- }
-}
-
-// Returns the index of the range in the array that overlapped the given
-// codepoint, or kNoMatch if no such range existed.
-// The input array must be in sorted order.
-int GetOverlappingRangeIndex(const char32* arr, int arr_length,
- int range_length, char32 c) {
- const char32* end = arr + arr_length;
- const auto find_it = std::lower_bound(arr, end, c);
- if (find_it == end) {
- return kNoMatch;
- }
- // The end is inclusive, we so subtract one less than the range length.
- const char32 range_end = *find_it;
- const char32 range_start = range_end - (range_length - 1);
- if (c < range_start || range_end < c) {
- return kNoMatch;
- } else {
- return find_it - arr;
- }
-}
-
-// As above, but with explicit codepoint start and end indices for the range.
-// The input array must be in sorted order.
-int GetOverlappingRangeIndex(const char32* start_arr, const char32* end_arr,
- int arr_length, int stride, char32 c) {
- const char32* end_arr_end = end_arr + arr_length;
- const auto find_it = std::lower_bound(end_arr, end_arr_end, c);
- if (find_it == end_arr_end) {
- return kNoMatch;
- }
- // Find the corresponding start.
- const int range_index = find_it - end_arr;
- const char32 range_start = start_arr[range_index];
- const char32 range_end = *find_it;
- if (c < range_start || range_end < c) {
- return kNoMatch;
- }
- if ((c - range_start) % stride == 0) {
- return range_index;
- } else {
- return kNoMatch;
- }
-}
-
-} // anonymous namespace
-
-UniLib::UniLib() {
- TC3_LOG(FATAL) << "Java ICU UniLib must be initialized with a JniCache.";
-}
-
-UniLib::UniLib(const std::shared_ptr<JniCache>& jni_cache)
- : jni_cache_(jni_cache) {}
-
-bool UniLib::IsOpeningBracket(char32 codepoint) const {
- return GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint) >= 0;
-}
-
-bool UniLib::IsClosingBracket(char32 codepoint) const {
- return GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint) >= 0;
-}
-
-bool UniLib::IsWhitespace(char32 codepoint) const {
- return GetMatchIndex(kWhitespaces, kNumWhitespaces, codepoint) >= 0;
-}
-
-bool UniLib::IsDigit(char32 codepoint) const {
- return GetOverlappingRangeIndex(kDecimalDigitRangesEnd,
- kNumDecimalDigitRangesEnd,
- /*range_length=*/10, codepoint) >= 0;
-}
-
-bool UniLib::IsUpper(char32 codepoint) const {
- if (GetMatchIndex(kUpperSingles, kNumUpperSingles, codepoint) >= 0) {
- return true;
- } else if (GetOverlappingRangeIndex(kUpperRanges1Start, kUpperRanges1End,
- kNumUpperRanges1Start, /*stride=*/1,
- codepoint) >= 0) {
- return true;
- } else if (GetOverlappingRangeIndex(kUpperRanges2Start, kUpperRanges2End,
- kNumUpperRanges2Start, /*stride=*/2,
- codepoint) >= 0) {
- return true;
- } else {
- return false;
- }
-}
-
-char32 UniLib::ToLower(char32 codepoint) const {
- // Make sure we still produce output even if the method is called for a
- // codepoint that's not an uppercase character.
- if (!IsUpper(codepoint)) {
- return codepoint;
- }
- const int singles_idx =
- GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint);
- if (singles_idx >= 0) {
- return codepoint + kToLowerSinglesOffsets[singles_idx];
- }
- const int ranges_idx =
- GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd,
- kNumToLowerRangesStart, /*stride=*/1, codepoint);
- if (ranges_idx >= 0) {
- return codepoint + kToLowerRangesOffsets[ranges_idx];
- }
- return codepoint;
-}
-
-char32 UniLib::GetPairedBracket(char32 codepoint) const {
- const int open_offset =
- GetMatchIndex(kOpeningBrackets, kNumOpeningBrackets, codepoint);
- if (open_offset >= 0) {
- return kClosingBrackets[open_offset];
- }
- const int close_offset =
- GetMatchIndex(kClosingBrackets, kNumClosingBrackets, codepoint);
- if (close_offset >= 0) {
- return kOpeningBrackets[close_offset];
- }
- return codepoint;
-}
-
-// -----------------------------------------------------------------------------
-// Implementations that call out to JVM. Behold the beauty.
-// -----------------------------------------------------------------------------
-
-bool UniLib::ParseInt32(const UnicodeText& text, int* result) const {
- if (jni_cache_) {
- JNIEnv* env = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> text_java =
- jni_cache_->ConvertToJavaString(text);
- jint res = env->CallStaticIntMethod(jni_cache_->integer_class.get(),
- jni_cache_->integer_parse_int,
- text_java.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- return false;
- }
- *result = res;
- return true;
- }
- return false;
-}
-
-std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
- const UnicodeText& regex) const {
- return std::unique_ptr<UniLib::RegexPattern>(
- new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/false));
-}
-
-std::unique_ptr<UniLib::RegexPattern> UniLib::CreateLazyRegexPattern(
- const UnicodeText& regex) const {
- return std::unique_ptr<UniLib::RegexPattern>(
- new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/true));
-}
-
-UniLib::RegexPattern::RegexPattern(const JniCache* jni_cache,
- const UnicodeText& pattern, bool lazy)
- : jni_cache_(jni_cache),
- pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
- initialized_(false),
- initialization_failure_(false),
- pattern_text_(pattern) {
- if (!lazy) {
- LockedInitializeIfNotAlready();
- }
-}
-
-void UniLib::RegexPattern::LockedInitializeIfNotAlready() const {
- std::lock_guard<std::mutex> guard(mutex_);
- if (initialized_ || initialization_failure_) {
- return;
- }
-
- if (jni_cache_) {
- JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> regex_java =
- jni_cache_->ConvertToJavaString(pattern_text_);
- pattern_ = MakeGlobalRef(jenv->CallStaticObjectMethod(
- jni_cache_->pattern_class.get(),
- jni_cache_->pattern_compile, regex_java.get()),
- jenv, jni_cache_->jvm);
-
- if (jni_cache_->ExceptionCheckAndClear() || pattern_ == nullptr) {
- initialization_failure_ = true;
- pattern_.reset();
- return;
- }
-
- initialized_ = true;
- pattern_text_.clear(); // We don't need this anymore.
- }
-}
-
-constexpr int UniLib::RegexMatcher::kError;
-constexpr int UniLib::RegexMatcher::kNoError;
-
-std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher(
- const UnicodeText& context) const {
- LockedInitializeIfNotAlready(); // Possibly lazy initialization.
- if (initialization_failure_) {
- return nullptr;
- }
-
- if (jni_cache_) {
- JNIEnv* env = jni_cache_->GetEnv();
- const jstring context_java =
- jni_cache_->ConvertToJavaString(context).release();
- if (!context_java) {
- return nullptr;
- }
- const jobject matcher = env->CallObjectMethod(
- pattern_.get(), jni_cache_->pattern_matcher, context_java);
- if (jni_cache_->ExceptionCheckAndClear() || !matcher) {
- return nullptr;
- }
- return std::unique_ptr<UniLib::RegexMatcher>(new RegexMatcher(
- jni_cache_, MakeGlobalRef(matcher, env, jni_cache_->jvm),
- MakeGlobalRef(context_java, env, jni_cache_->jvm)));
- } else {
- // NOTE: A valid object needs to be created here to pass the interface
- // tests.
- return std::unique_ptr<UniLib::RegexMatcher>(
- new RegexMatcher(jni_cache_, nullptr, nullptr));
- }
-}
-
-UniLib::RegexMatcher::RegexMatcher(const JniCache* jni_cache,
- ScopedGlobalRef<jobject> matcher,
- ScopedGlobalRef<jstring> text)
- : jni_cache_(jni_cache),
- matcher_(std::move(matcher)),
- text_(std::move(text)) {}
-
-bool UniLib::RegexMatcher::Matches(int* status) const {
- if (jni_cache_) {
- *status = kNoError;
- const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
- matcher_.get(), jni_cache_->matcher_matches);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return false;
- }
- return result;
- } else {
- *status = kError;
- return false;
- }
-}
-
-bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) {
- *status = kNoError;
-
- jni_cache_->GetEnv()->CallObjectMethod(matcher_.get(),
- jni_cache_->matcher_reset);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- if (!Find(status) || *status != kNoError) {
- return false;
- }
-
- const int found_start = jni_cache_->GetEnv()->CallIntMethod(
- matcher_.get(), jni_cache_->matcher_start_idx, 0);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- const int found_end = jni_cache_->GetEnv()->CallIntMethod(
- matcher_.get(), jni_cache_->matcher_end_idx, 0);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- int context_length_bmp = jni_cache_->GetEnv()->CallIntMethod(
- text_.get(), jni_cache_->string_length);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return false;
- }
-
- if (found_start != 0 || found_end != context_length_bmp) {
- return false;
- }
-
- return true;
-}
-
-bool UniLib::RegexMatcher::UpdateLastFindOffset() const {
- if (!last_find_offset_dirty_) {
- return true;
- }
-
- const int find_offset = jni_cache_->GetEnv()->CallIntMethod(
- matcher_.get(), jni_cache_->matcher_start_idx, 0);
- if (jni_cache_->ExceptionCheckAndClear()) {
- return false;
- }
-
- const int codepoint_count = jni_cache_->GetEnv()->CallIntMethod(
- text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
- find_offset);
- if (jni_cache_->ExceptionCheckAndClear()) {
- return false;
- }
-
- last_find_offset_codepoints_ += codepoint_count;
- last_find_offset_ = find_offset;
- last_find_offset_dirty_ = false;
-
- return true;
-}
-
-bool UniLib::RegexMatcher::Find(int* status) {
- if (jni_cache_) {
- const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
- matcher_.get(), jni_cache_->matcher_find);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return false;
- }
-
- last_find_offset_dirty_ = true;
- *status = kNoError;
- return result;
- } else {
- *status = kError;
- return false;
- }
-}
-
-int UniLib::RegexMatcher::Start(int* status) const {
- return Start(/*group_idx=*/0, status);
-}
-
-int UniLib::RegexMatcher::Start(int group_idx, int* status) const {
- if (jni_cache_) {
- *status = kNoError;
-
- if (!UpdateLastFindOffset()) {
- *status = kError;
- return kError;
- }
-
- const int java_index = jni_cache_->GetEnv()->CallIntMethod(
- matcher_.get(), jni_cache_->matcher_start_idx, group_idx);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- // If the group didn't participate in the match the index is -1.
- if (java_index == -1) {
- return -1;
- }
-
- const int unicode_index = jni_cache_->GetEnv()->CallIntMethod(
- text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
- java_index);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- return unicode_index + last_find_offset_codepoints_;
- } else {
- *status = kError;
- return kError;
- }
-}
-
-int UniLib::RegexMatcher::End(int* status) const {
- return End(/*group_idx=*/0, status);
-}
-
-int UniLib::RegexMatcher::End(int group_idx, int* status) const {
- if (jni_cache_) {
- *status = kNoError;
-
- if (!UpdateLastFindOffset()) {
- *status = kError;
- return kError;
- }
-
- const int java_index = jni_cache_->GetEnv()->CallIntMethod(
- matcher_.get(), jni_cache_->matcher_end_idx, group_idx);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- // If the group didn't participate in the match the index is -1.
- if (java_index == -1) {
- return -1;
- }
-
- const int unicode_index = jni_cache_->GetEnv()->CallIntMethod(
- text_.get(), jni_cache_->string_code_point_count, last_find_offset_,
- java_index);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- return kError;
- }
-
- return unicode_index + last_find_offset_codepoints_;
- } else {
- *status = kError;
- return kError;
- }
-}
-
-UnicodeText UniLib::RegexMatcher::Group(int* status) const {
- if (jni_cache_) {
- JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> java_result(
- reinterpret_cast<jstring>(
- jenv->CallObjectMethod(matcher_.get(), jni_cache_->matcher_group)),
- jenv);
- if (jni_cache_->ExceptionCheckAndClear() || !java_result) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
-
- std::string result;
- if (!JStringToUtf8String(jenv, java_result.get(), &result)) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
- *status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
- } else {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
-}
-
-UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const {
- if (jni_cache_) {
- JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> java_result(
- reinterpret_cast<jstring>(jenv->CallObjectMethod(
- matcher_.get(), jni_cache_->matcher_group_idx, group_idx)),
- jenv);
- if (jni_cache_->ExceptionCheckAndClear()) {
- *status = kError;
- TC3_LOG(ERROR) << "Exception occurred";
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
-
- // java_result is nullptr when the group did not participate in the match.
- // For these cases other UniLib implementations return empty string, and
- // the participation can be checked by checking if Start() == -1.
- if (!java_result) {
- *status = kNoError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
-
- std::string result;
- if (!JStringToUtf8String(jenv, java_result.get(), &result)) {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
- *status = kNoError;
- return UTF8ToUnicodeText(result, /*do_copy=*/true);
- } else {
- *status = kError;
- return UTF8ToUnicodeText("", /*do_copy=*/false);
- }
-}
-
-constexpr int UniLib::BreakIterator::kDone;
-
-UniLib::BreakIterator::BreakIterator(const JniCache* jni_cache,
- const UnicodeText& text)
- : jni_cache_(jni_cache),
- text_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
- iterator_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
- last_break_index_(0),
- last_unicode_index_(0) {
- if (jni_cache_) {
- JNIEnv* jenv = jni_cache_->GetEnv();
- text_ = MakeGlobalRef(jni_cache_->ConvertToJavaString(text).release(), jenv,
- jni_cache->jvm);
- if (!text_) {
- return;
- }
-
- iterator_ = MakeGlobalRef(
- jenv->CallStaticObjectMethod(jni_cache->breakiterator_class.get(),
- jni_cache->breakiterator_getwordinstance,
- jni_cache->locale_us.get()),
- jenv, jni_cache->jvm);
- if (!iterator_) {
- return;
- }
- jenv->CallVoidMethod(iterator_.get(), jni_cache->breakiterator_settext,
- text_.get());
- }
-}
-
-int UniLib::BreakIterator::Next() {
- if (jni_cache_) {
- const int break_index = jni_cache_->GetEnv()->CallIntMethod(
- iterator_.get(), jni_cache_->breakiterator_next);
- if (jni_cache_->ExceptionCheckAndClear() ||
- break_index == BreakIterator::kDone) {
- return BreakIterator::kDone;
- }
-
- const int token_unicode_length = jni_cache_->GetEnv()->CallIntMethod(
- text_.get(), jni_cache_->string_code_point_count, last_break_index_,
- break_index);
- if (jni_cache_->ExceptionCheckAndClear()) {
- return BreakIterator::kDone;
- }
-
- last_break_index_ = break_index;
- return last_unicode_index_ += token_unicode_length;
- }
- return BreakIterator::kDone;
-}
-
-std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator(
- const UnicodeText& text) const {
- return std::unique_ptr<UniLib::BreakIterator>(
- new UniLib::BreakIterator(jni_cache_.get(), text));
-}
-
-} // namespace libtextclassifier3
diff --git a/utils/utf8/unilib-javaicu.h b/utils/utf8/unilib-javaicu.h
deleted file mode 100644
index 0a5d339..0000000
--- a/utils/utf8/unilib-javaicu.h
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// An implementation of Unilib that uses Android Java interfaces via JNI. The
-// performance critical ops have been re-implemented in C++.
-// Specifically, this class must be compatible with API level 14 (ICS).
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
-#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
-
-#include <jni.h>
-#include <memory>
-#include <mutex> // NOLINT
-#include <string>
-
-#include "utils/base/integral_types.h"
-#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_global_ref.h"
-#include "utils/java/scoped_local_ref.h"
-#include "utils/java/string_utils.h"
-#include "utils/utf8/unicodetext.h"
-
-namespace libtextclassifier3 {
-
-class UniLib {
- public:
- UniLib();
- explicit UniLib(const std::shared_ptr<JniCache>& jni_cache);
-
- bool ParseInt32(const UnicodeText& text, int* result) const;
- bool IsOpeningBracket(char32 codepoint) const;
- bool IsClosingBracket(char32 codepoint) const;
- bool IsWhitespace(char32 codepoint) const;
- bool IsDigit(char32 codepoint) const;
- bool IsUpper(char32 codepoint) const;
-
- char32 ToLower(char32 codepoint) const;
- char32 GetPairedBracket(char32 codepoint) const;
-
- // Forward declaration for friend.
- class RegexPattern;
-
- class RegexMatcher {
- public:
- static constexpr int kError = -1;
- static constexpr int kNoError = 0;
-
- // Checks whether the input text matches the pattern exactly.
- bool Matches(int* status) const;
-
- // Approximate Matches() implementation implemented using Find(). It uses
- // the first Find() result and then checks that it spans the whole input.
- // NOTE: Unlike Matches() it can result in false negatives.
- // NOTE: Resets the matcher, so the current Find() state will be lost.
- bool ApproximatelyMatches(int* status);
-
- // Finds occurrences of the pattern in the input text.
- // Can be called repeatedly to find all occurences. A call will update
- // internal state, so that 'Start', 'End' and 'Group' can be called to get
- // information about the match.
- // NOTE: Any call to ApproximatelyMatches() in between Find() calls will
- // modify the state.
- bool Find(int* status);
-
- // Gets the start offset of the last match (from 'Find').
- // Sets status to 'kError' if 'Find'
- // was not called previously.
- int Start(int* status) const;
-
- // Gets the start offset of the specified group of the last match.
- // (from 'Find').
- // Sets status to 'kError' if an invalid group was specified or if 'Find'
- // was not called previously.
- int Start(int group_idx, int* status) const;
-
- // Gets the end offset of the last match (from 'Find').
- // Sets status to 'kError' if 'Find'
- // was not called previously.
- int End(int* status) const;
-
- // Gets the end offset of the specified group of the last match.
- // (from 'Find').
- // Sets status to 'kError' if an invalid group was specified or if 'Find'
- // was not called previously.
- int End(int group_idx, int* status) const;
-
- // Gets the text of the last match (from 'Find').
- // Sets status to 'kError' if 'Find' was not called previously.
- UnicodeText Group(int* status) const;
-
- // Gets the text of the specified group of the last match (from 'Find').
- // Sets status to 'kError' if an invalid group was specified or if 'Find'
- // was not called previously.
- UnicodeText Group(int group_idx, int* status) const;
-
- // Returns the matched text (the 0th capturing group).
- std::string Text() const {
- ScopedStringChars text_str =
- GetScopedStringChars(jni_cache_->GetEnv(), text_.get());
- return text_str.get();
- }
-
- private:
- friend class RegexPattern;
- RegexMatcher(const JniCache* jni_cache, ScopedGlobalRef<jobject> matcher,
- ScopedGlobalRef<jstring> text);
- bool UpdateLastFindOffset() const;
-
- const JniCache* jni_cache_;
- ScopedGlobalRef<jobject> matcher_;
- ScopedGlobalRef<jstring> text_;
- mutable int last_find_offset_ = 0;
- mutable int last_find_offset_codepoints_ = 0;
- mutable bool last_find_offset_dirty_ = true;
- };
-
- class RegexPattern {
- public:
- std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& context) const;
-
- private:
- friend class UniLib;
- RegexPattern(const JniCache* jni_cache, const UnicodeText& pattern,
- bool lazy);
- void LockedInitializeIfNotAlready() const;
-
- const JniCache* jni_cache_;
-
- // These members need to be mutable because of the lazy initialization.
- // NOTE: The Matcher method first ensures (using a lock) that the
- // initialization was attempted (by using LockedInitializeIfNotAlready) and
- // then can access them without locking.
- mutable std::mutex mutex_;
- mutable ScopedGlobalRef<jobject> pattern_;
- mutable bool initialized_;
- mutable bool initialization_failure_;
- mutable UnicodeText pattern_text_;
- };
-
- class BreakIterator {
- public:
- int Next();
-
- static constexpr int kDone = -1;
-
- private:
- friend class UniLib;
- BreakIterator(const JniCache* jni_cache, const UnicodeText& text);
-
- const JniCache* jni_cache_;
- ScopedGlobalRef<jstring> text_;
- ScopedGlobalRef<jobject> iterator_;
- int last_break_index_;
- int last_unicode_index_;
- };
-
- std::unique_ptr<RegexPattern> CreateRegexPattern(
- const UnicodeText& regex) const;
- std::unique_ptr<RegexPattern> CreateLazyRegexPattern(
- const UnicodeText& regex) const;
- std::unique_ptr<BreakIterator> CreateBreakIterator(
- const UnicodeText& text) const;
-
- private:
- std::shared_ptr<JniCache> jni_cache_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
diff --git a/utils/utf8/unilib.h b/utils/utf8/unilib.h
deleted file mode 100644
index ec1f329..0000000
--- a/utils/utf8/unilib.h
+++ /dev/null
@@ -1,23 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
-#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
-
-#include "utils/utf8/unilib-javaicu.h"
-#define INIT_UNILIB_FOR_TESTING(VAR) VAR(nullptr)
-
-#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
diff --git a/utils/utf8/unilib_test-include.cc b/utils/utf8/unilib_test-include.cc
deleted file mode 100644
index bd53208..0000000
--- a/utils/utf8/unilib_test-include.cc
+++ /dev/null
@@ -1,221 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/utf8/unilib_test-include.h"
-
-#include "gmock/gmock.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-using ::testing::ElementsAre;
-
-TEST_F(UniLibTest, CharacterClassesAscii) {
- EXPECT_TRUE(unilib_.IsOpeningBracket('('));
- EXPECT_TRUE(unilib_.IsClosingBracket(')'));
- EXPECT_FALSE(unilib_.IsWhitespace(')'));
- EXPECT_TRUE(unilib_.IsWhitespace(' '));
- EXPECT_FALSE(unilib_.IsDigit(')'));
- EXPECT_TRUE(unilib_.IsDigit('0'));
- EXPECT_TRUE(unilib_.IsDigit('9'));
- EXPECT_FALSE(unilib_.IsUpper(')'));
- EXPECT_TRUE(unilib_.IsUpper('A'));
- EXPECT_TRUE(unilib_.IsUpper('Z'));
- EXPECT_EQ(unilib_.ToLower('A'), 'a');
- EXPECT_EQ(unilib_.ToLower('Z'), 'z');
- EXPECT_EQ(unilib_.ToLower(')'), ')');
- EXPECT_EQ(unilib_.GetPairedBracket(')'), '(');
- EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
-}
-
-TEST_F(UniLibTest, CharacterClassesUnicode) {
- EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
- EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
- EXPECT_FALSE(unilib_.IsWhitespace(0x23F0)); // ALARM CLOCK
- EXPECT_TRUE(unilib_.IsWhitespace(0x2003)); // EM SPACE
- EXPECT_FALSE(unilib_.IsDigit(0xA619)); // VAI SYMBOL JONG
- EXPECT_TRUE(unilib_.IsDigit(0xA620)); // VAI DIGIT ZERO
- EXPECT_TRUE(unilib_.IsDigit(0xA629)); // VAI DIGIT NINE
- EXPECT_FALSE(unilib_.IsDigit(0xA62A)); // VAI SYLLABLE NDOLE MA
- EXPECT_FALSE(unilib_.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
- EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
-
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
- EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
-}
-
-TEST_F(UniLibTest, RegexInterface) {
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("[0-9]+", /*do_copy=*/true);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- const UnicodeText input = UTF8ToUnicodeText("hello 0123", /*do_copy=*/false);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher = pattern->Matcher(input);
- TC3_LOG(INFO) << matcher->Matches(&status);
- TC3_LOG(INFO) << matcher->Find(&status);
- TC3_LOG(INFO) << matcher->Start(0, &status);
- TC3_LOG(INFO) << matcher->End(0, &status);
- TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
-}
-
-TEST_F(UniLibTest, Regex) {
- // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
- // test the regex functionality with it to verify we are handling the indices
- // correctly.
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("[0-9]+😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("0123😋", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Matches(&status));
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_FALSE(matcher->Matches(&status));
- EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(0, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(0, &status), 13);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexLazy) {
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateLazyRegexPattern(
- UTF8ToUnicodeText("[a-z][0-9]", /*do_copy=*/false));
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("a3", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Matches(&status));
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_TRUE(matcher->Matches(&status)); // Check that the state is reset.
- EXPECT_TRUE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-
- matcher = pattern->Matcher(UTF8ToUnicodeText("3a", /*do_copy=*/false));
- EXPECT_FALSE(matcher->Matches(&status));
- EXPECT_FALSE(matcher->ApproximatelyMatches(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, RegexGroups) {
- // The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
- // test the regex functionality with it to verify we are handling the indices
- // correctly.
- const UnicodeText regex_pattern =
- UTF8ToUnicodeText("([0-9])([0-9]+)😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::RegexPattern> pattern =
- unilib_.CreateRegexPattern(regex_pattern);
- int status;
- std::unique_ptr<UniLib::RegexMatcher> matcher;
-
- matcher = pattern->Matcher(
- UTF8ToUnicodeText("hello😋😋 0123😋 world", /*do_copy=*/false));
- EXPECT_TRUE(matcher->Find(&status));
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(0, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(1, &status), 8);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Start(2, &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(0, &status), 13);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(1, &status), 9);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->End(2, &status), 12);
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(1, &status).ToUTF8String(), "0");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
- EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
- EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
-}
-
-TEST_F(UniLibTest, BreakIterator) {
- const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
- std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
- std::vector<int> break_indices;
- int break_index = 0;
- while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
- break_indices.push_back(break_index);
- }
- EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
-}
-
-TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
- const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
- std::unique_ptr<UniLib::BreakIterator> iterator =
- unilib_.CreateBreakIterator(text);
- std::vector<int> break_indices;
- int break_index = 0;
- while ((break_index = iterator->Next()) != UniLib::BreakIterator::kDone) {
- break_indices.push_back(break_index);
- }
- EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
-}
-
-TEST_F(UniLibTest, IntegerParse) {
- int result;
- EXPECT_TRUE(
- unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, IntegerParseFullWidth) {
- int result;
- // The input string here is full width
- EXPECT_TRUE(unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false),
- &result));
- EXPECT_EQ(result, 123);
-}
-
-TEST_F(UniLibTest, IntegerParseFullWidthWithAlpha) {
- int result;
- // The input string here is full width
- EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
- &result));
-}
-
-} // namespace test_internal
-} // namespace libtextclassifier3
diff --git a/utils/utf8/unilib_test-include.h b/utils/utf8/unilib_test-include.h
deleted file mode 100644
index 151a6f0..0000000
--- a/utils/utf8/unilib_test-include.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
-#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
-
-// Include the version of UniLib depending on the macro.
-#if defined TC3_UNILIB_ICU
-#include "utils/utf8/unilib-icu.h"
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#elif defined TC3_UNILIB_JAVAICU
-#include <jni.h>
-extern JNIEnv* g_jenv;
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
-#include "utils/utf8/unilib-javaicu.h"
-#elif defined TC3_UNILIB_DUMMY
-#include "utils/utf8/unilib-dummy.h"
-#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
-#endif
-
-#include "utils/base/logging.h"
-#include "gtest/gtest.h"
-
-namespace libtextclassifier3 {
-namespace test_internal {
-
-class UniLibTest : public ::testing::Test {
- protected:
- UniLibTest() : TC3_TESTING_CREATE_UNILIB_INSTANCE(unilib_) {}
- UniLib unilib_;
-};
-
-} // namespace test_internal
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
diff --git a/utils/variant.h b/utils/variant.h
deleted file mode 100644
index 68bb04b..0000000
--- a/utils/variant.h
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
-#define LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
-
-#include <map>
-#include <string>
-
-#include "utils/base/integral_types.h"
-#include "utils/base/logging.h"
-#include "utils/strings/stringpiece.h"
-
-namespace libtextclassifier3 {
-
-// Represents a type-tagged union of different basic types.
-class Variant {
- public:
- enum Type {
- TYPE_EMPTY = 0,
- TYPE_INT_VALUE = 1,
- TYPE_INT64_VALUE = 2,
- TYPE_FLOAT_VALUE = 3,
- TYPE_DOUBLE_VALUE = 4,
- TYPE_BOOL_VALUE = 5,
- TYPE_STRING_VALUE = 6,
- };
-
- Variant() : type_(TYPE_EMPTY) {}
- explicit Variant(const int value)
- : type_(TYPE_INT_VALUE), int_value_(value) {}
- explicit Variant(const int64 value)
- : type_(TYPE_INT64_VALUE), long_value_(value) {}
- explicit Variant(const float value)
- : type_(TYPE_FLOAT_VALUE), float_value_(value) {}
- explicit Variant(const double value)
- : type_(TYPE_DOUBLE_VALUE), double_value_(value) {}
- explicit Variant(const StringPiece value)
- : type_(TYPE_STRING_VALUE), string_value_(value.ToString()) {}
- explicit Variant(const std::string value)
- : type_(TYPE_STRING_VALUE), string_value_(value) {}
- explicit Variant(const char* value)
- : type_(TYPE_STRING_VALUE), string_value_(value) {}
- explicit Variant(const bool value)
- : type_(TYPE_BOOL_VALUE), bool_value_(value) {}
-
- Variant& operator=(const Variant&) = default;
-
- int IntValue() const {
- TC3_CHECK(HasInt());
- return int_value_;
- }
-
- int64 Int64Value() const {
- TC3_CHECK(HasInt64());
- return long_value_;
- }
-
- float FloatValue() const {
- TC3_CHECK(HasFloat());
- return float_value_;
- }
-
- double DoubleValue() const {
- TC3_CHECK(HasDouble());
- return double_value_;
- }
-
- bool BoolValue() const {
- TC3_CHECK(HasBool());
- return bool_value_;
- }
-
- const std::string& StringValue() const {
- TC3_CHECK(HasString());
- return string_value_;
- }
-
- bool HasInt() const { return type_ == TYPE_INT_VALUE; }
-
- bool HasInt64() const { return type_ == TYPE_INT64_VALUE; }
-
- bool HasFloat() const { return type_ == TYPE_FLOAT_VALUE; }
-
- bool HasDouble() const { return type_ == TYPE_DOUBLE_VALUE; }
-
- bool HasBool() const { return type_ == TYPE_BOOL_VALUE; }
-
- bool HasString() const { return type_ == TYPE_STRING_VALUE; }
-
- Type GetType() const { return type_; }
-
- bool HasValue() const { return type_ != TYPE_EMPTY; }
-
- private:
- Type type_;
- union {
- int int_value_;
- int64 long_value_;
- float float_value_;
- double double_value_;
- bool bool_value_;
- };
- std::string string_value_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
diff --git a/utils/zlib/zlib_regex.cc b/utils/zlib/zlib_regex.cc
deleted file mode 100644
index bfe3f5b..0000000
--- a/utils/zlib/zlib_regex.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "utils/zlib/zlib_regex.h"
-
-#include <memory>
-
-#include "utils/base/logging.h"
-#include "utils/flatbuffers.h"
-
-namespace libtextclassifier3 {
-
-std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern(
- const UniLib& unilib, const flatbuffers::String* uncompressed_pattern,
- const CompressedBuffer* compressed_pattern, bool lazy_compile_regex,
- ZlibDecompressor* decompressor, std::string* result_pattern_text) {
- UnicodeText unicode_regex_pattern;
- std::string decompressed_pattern;
- if (compressed_pattern != nullptr &&
- compressed_pattern->buffer() != nullptr) {
- if (decompressor == nullptr ||
- !decompressor->MaybeDecompress(compressed_pattern,
- &decompressed_pattern)) {
- TC3_LOG(ERROR) << "Cannot decompress pattern.";
- return nullptr;
- }
- unicode_regex_pattern =
- UTF8ToUnicodeText(decompressed_pattern.data(),
- decompressed_pattern.size(), /*do_copy=*/false);
- } else {
- if (uncompressed_pattern == nullptr) {
- TC3_LOG(ERROR) << "Cannot load uncompressed pattern.";
- return nullptr;
- }
- unicode_regex_pattern =
- UTF8ToUnicodeText(uncompressed_pattern->c_str(),
- uncompressed_pattern->Length(), /*do_copy=*/false);
- }
-
- if (result_pattern_text != nullptr) {
- *result_pattern_text = unicode_regex_pattern.ToUTF8String();
- }
-
- std::unique_ptr<UniLib::RegexPattern> regex_pattern;
- if (lazy_compile_regex) {
- regex_pattern = unilib.CreateLazyRegexPattern(unicode_regex_pattern);
- } else {
- regex_pattern = unilib.CreateRegexPattern(unicode_regex_pattern);
- }
-
- if (!regex_pattern) {
- TC3_LOG(ERROR) << "Could not create pattern: "
- << unicode_regex_pattern.ToUTF8String();
- }
- return regex_pattern;
-}
-
-} // namespace libtextclassifier3