Export libtextclassifier to Android (generated by the export script)
Test: Compile and boot
Change-Id: I0433e6fb549ba0b32bc55933b3c11562e61a0b4d
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
index f79b58e..eacf991 100644
--- a/actions/actions-suggestions.cc
+++ b/actions/actions-suggestions.cc
@@ -22,45 +22,14 @@
namespace {
const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
- const ActionsModel* model = GetActionsModel(addr);
flatbuffers::Verifier verifier(addr, size);
- if (model->Verify(verifier)) {
- return model;
+ if (VerifyActionsModelBuffer(verifier)) {
+ return GetActionsModel(addr);
} else {
return nullptr;
}
}
-// Indices of the TensorFlow Lite model inputs.
-enum SmartReplyModelInputs {
- SMART_REPLY_MODEL_INPUT_USER_ID = 0,
- SMART_REPLY_MODEL_INPUT_CONTEXT = 1,
- SMART_REPLY_MODEL_INPUT_CONTEXT_LENGTH = 2,
- SMART_REPLY_MODEL_INPUT_TIME_DIFFS = 3,
- SMART_REPLY_MODEL_INPUT_NUM_SUGGESTIONS = 4
-};
-
-// Indices of the TensorFlow Lite model outputs.
-enum SmartReplyModelOutputs {
- SMART_REPLY_MODEL_OUTPUT_REPLIES = 0,
- SMART_REPLY_MODEL_OUTPUT_SCORES = 1,
- SMART_REPLY_MODEL_OUTPUT_EMBEDDINGS = 2,
- SMART_REPLY_MODEL_OUTPUT_SENSITIVE_TOPIC_SCORE = 3,
- SMART_REPLY_MODEL_OUTPUT_TRIGGERING_SCORE = 4,
-};
-
-// Indices of the TensorFlow Lite actions suggestion model inputs.
-enum ActionsSuggestionsModelInputs {
- ACTIONS_SUGGESTIONS_MODEL_INPUT_EMBEDDINGS = 0,
-};
-
-// Indices of the TensorFlow Lite actions suggestion model outputss.
-enum ActionsSuggestionsModelOutputs {
- ACTIONS_SUGGESTIONS_MODEL_OUTPUT_SCORES = 0,
-};
-
-const char* kOtherCategory = "other";
-
} // namespace
std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
@@ -127,149 +96,189 @@
return false;
}
- smart_reply_executor_ = TfLiteModelExecutor::FromBuffer(
- model_->smart_reply_model()->tflite_model());
- if (!smart_reply_executor_) {
- TC3_LOG(ERROR) << "Could not initialize smart reply model executor.";
- return false;
- }
-
- actions_suggestions_executor_ = TfLiteModelExecutor::FromBuffer(
- model_->actions_suggestions_model()->tflite_model());
- if (!actions_suggestions_executor_) {
- TC3_LOG(ERROR)
- << "Could not initialize actions suggestions model executor.";
- return false;
+ if (model_->tflite_model_spec()) {
+ model_executor_ = TfLiteModelExecutor::FromBuffer(
+ model_->tflite_model_spec()->tflite_model());
+ if (!model_executor_) {
+ TC3_LOG(ERROR) << "Could not initialize model executor.";
+ return false;
+ }
}
return true;
}
-void ActionsSuggestions::SetupSmartReplyModelInput(
+void ActionsSuggestions::SetupModelInput(
const std::vector<std::string>& context, const std::vector<int>& user_ids,
- const int num_suggestions, tflite::Interpreter* interpreter) {
- smart_reply_executor_->SetInput<std::string>(SMART_REPLY_MODEL_INPUT_CONTEXT,
- context, interpreter);
- *interpreter
- ->tensor(interpreter->inputs()[SMART_REPLY_MODEL_INPUT_CONTEXT_LENGTH])
- ->data.i64 = context.size();
+ const int num_suggestions, tflite::Interpreter* interpreter) const {
+ if (model_->tflite_model_spec()->input_context() >= 0) {
+ model_executor_->SetInput<std::string>(
+ model_->tflite_model_spec()->input_context(), context, interpreter);
+ }
+ if (model_->tflite_model_spec()->input_context_length() >= 0) {
+ *interpreter
+ ->tensor(interpreter->inputs()[model_->tflite_model_spec()
+ ->input_context_length()])
+ ->data.i64 = context.size();
+ }
- smart_reply_executor_->SetInput<int>(SMART_REPLY_MODEL_INPUT_USER_ID,
- user_ids, interpreter);
+ if (model_->tflite_model_spec()->input_user_id() >= 0) {
+ model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
+ user_ids, interpreter);
+ }
- *interpreter
- ->tensor(interpreter->inputs()[SMART_REPLY_MODEL_INPUT_NUM_SUGGESTIONS])
- ->data.i64 = num_suggestions;
+ if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
+ *interpreter
+ ->tensor(interpreter->inputs()[model_->tflite_model_spec()
+ ->input_num_suggestions()])
+ ->data.i64 = num_suggestions;
+ }
}
-void ActionsSuggestions::ReadSmartReplyModelOutput(
+bool ActionsSuggestions::ShouldSuppressPredictions(
+ tflite::Interpreter* interpreter) const {
+ const TensorView<float>& triggering_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_triggering_score(), interpreter);
+ if (!triggering_score.is_valid() || triggering_score.dim(0) != 1) {
+ TC3_LOG(ERROR) << "Could not compute triggering score.";
+ return true;
+ }
+ if (triggering_score.data()[0] <= model_->min_triggering_confidence()) {
+ return true;
+ }
+
+ const TensorView<float>& sensitive_topic_score =
+ model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_sensitive_topic_score(),
+ interpreter);
+ if (!sensitive_topic_score.is_valid() || sensitive_topic_score.dim(0) != 1) {
+ TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
+ return true;
+ }
+ if (sensitive_topic_score.data()[0] > model_->max_sensitive_topic_score()) {
+ return true;
+ }
+ return false;
+}
+
+void ActionsSuggestions::ReadModelOutput(
tflite::Interpreter* interpreter,
- std::vector<ActionSuggestion>* suggestions) {
+ std::vector<ActionSuggestion>* suggestions) const {
+ // Read smart reply predictions.
const std::vector<tflite::StringRef> replies =
- smart_reply_executor_->Output<tflite::StringRef>(
- SMART_REPLY_MODEL_OUTPUT_REPLIES, interpreter);
- TensorView<float> scores = smart_reply_executor_->OutputView<float>(
- SMART_REPLY_MODEL_OUTPUT_SCORES, interpreter);
+ model_executor_->Output<tflite::StringRef>(
+ model_->tflite_model_spec()->output_replies(), interpreter);
+ TensorView<float> scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_replies_scores(), interpreter);
std::vector<ActionSuggestion> text_replies;
for (int i = 0; i < replies.size(); i++) {
suggestions->push_back({std::string(replies[i].str, replies[i].len),
- model_->smart_reply_model()->action_type()->str(),
+ model_->smart_reply_action_type()->str(),
scores.data()[i]});
}
+
+ // Read actions suggestions.
+ const TensorView<float> actions_scores = model_executor_->OutputView<float>(
+ model_->tflite_model_spec()->output_actions_scores(), interpreter);
+ for (int i = 0; i < model_->action_type()->Length(); i++) {
+ // Skip disabled action classes, such as the default other category.
+ if (!(*model_->action_type())[i]->enabled()) {
+ continue;
+ }
+ const float score = actions_scores.data()[i];
+ if (score < (*model_->action_type())[i]->min_triggering_score()) {
+ continue;
+ }
+ const std::string& output_class =
+ (*model_->action_type())[i]->name()->str();
+ if (score >= model_->min_actions_confidence()) {
+ suggestions->push_back({/*response_text=*/"", output_class, score});
+ }
+ }
}
-void ActionsSuggestions::SuggestActionsFromConversationEmbedding(
- const TensorView<float>& conversation_embedding,
- const ActionSuggestionOptions& options,
- std::vector<ActionSuggestion>* actions) {
- std::unique_ptr<tflite::Interpreter> actions_suggestions_interpreter =
- actions_suggestions_executor_->CreateInterpreter();
- if (!actions_suggestions_interpreter) {
+void ActionsSuggestions::SuggestActionsFromModel(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* suggestions) const {
+ if (!model_executor_) {
+ return;
+ }
+ std::unique_ptr<tflite::Interpreter> interpreter =
+ model_executor_->CreateInterpreter();
+
+ if (!interpreter) {
TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
- "action suggestions model.";
+ "actions suggestions model.";
return;
}
- const int embedding_size = conversation_embedding.shape().back();
- actions_suggestions_interpreter->ResizeInputTensor(
- ACTIONS_SUGGESTIONS_MODEL_INPUT_EMBEDDINGS, {1, embedding_size});
- if (actions_suggestions_interpreter->AllocateTensors() != kTfLiteOk) {
- TC3_LOG(ERROR) << "Failed to allocate TensorFlow Lite tensors for the "
- "action suggestions model.";
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TC3_LOG(ERROR)
+ << "Failed to allocate TensorFlow Lite tensors for the actions "
+ "suggestions model.";
return;
}
- actions_suggestions_executor_->SetInput(
- ACTIONS_SUGGESTIONS_MODEL_INPUT_EMBEDDINGS, conversation_embedding,
- actions_suggestions_interpreter.get());
+ // Use only last message for now.
+ SetupModelInput({conversation.messages.back().text},
+ {conversation.messages.back().user_id},
+ /*num_suggestions=*/model_->num_smart_replies(),
+ interpreter.get());
- if (actions_suggestions_interpreter->Invoke() != kTfLiteOk) {
+ if (interpreter->Invoke() != kTfLiteOk) {
TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
return;
}
- const TensorView<float> output =
- actions_suggestions_executor_->OutputView<float>(
- ACTIONS_SUGGESTIONS_MODEL_OUTPUT_SCORES,
- actions_suggestions_interpreter.get());
- for (int i = 0; i < model_->actions_suggestions_model()->classes()->Length();
- i++) {
- const std::string& output_class =
- (*model_->actions_suggestions_model()->classes())[i]->str();
- if (output_class == kOtherCategory) {
+ if (ShouldSuppressPredictions(interpreter.get())) {
+ return;
+ }
+
+ ReadModelOutput(interpreter.get(), suggestions);
+}
+
+void ActionsSuggestions::SuggestActionsFromAnnotations(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* suggestions) const {
+ // Create actions based on the annotations present in the last message.
+ // TODO(smillius): Make this configurable.
+ for (const AnnotatedSpan& annotation :
+ conversation.messages.back().annotations) {
+ if (annotation.classification.empty() ||
+ annotation.classification[0].collection.empty()) {
continue;
}
- const float score = output.data()[i];
- if (score >= model_->actions_suggestions_model()->min_confidence()) {
- actions->push_back({/*response_text=*/"", output_class, score});
- }
+ const ClassificationResult& classification_result =
+ annotation.classification[0];
+ suggestions->push_back({/*response_text=*/"",
+ /*type=*/classification_result.collection,
+ /*score=*/classification_result.score});
}
}
std::vector<ActionSuggestion> ActionsSuggestions::SuggestActions(
- const Conversation& conversation, const ActionSuggestionOptions& options) {
+ const Conversation& conversation,
+ const ActionSuggestionOptions& options) const {
std::vector<ActionSuggestion> suggestions;
if (conversation.messages.empty()) {
return suggestions;
}
- std::unique_ptr<tflite::Interpreter> smart_reply_interpreter =
- smart_reply_executor_->CreateInterpreter();
+ SuggestActionsFromModel(conversation, &suggestions);
+ SuggestActionsFromAnnotations(conversation, &suggestions);
- if (!smart_reply_interpreter) {
- TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
- "smart reply model.";
- return suggestions;
- }
-
- if (smart_reply_interpreter->AllocateTensors() != kTfLiteOk) {
- TC3_LOG(ERROR)
- << "Failed to allocate TensorFlow Lite tensors for the smart "
- "reply model.";
- return suggestions;
- }
-
- // Use only last message for now.
- SetupSmartReplyModelInput({conversation.messages.back().text},
- {conversation.messages.back().user_id},
- /*num_suggestions=*/3,
- smart_reply_interpreter.get());
-
- if (smart_reply_interpreter->Invoke() != kTfLiteOk) {
- TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
- return suggestions;
- }
-
- ReadSmartReplyModelOutput(smart_reply_interpreter.get(), &suggestions);
-
- // Add action predictions.
- const TensorView<float> conversation_embedding =
- smart_reply_executor_->OutputView<float>(
- SMART_REPLY_MODEL_OUTPUT_EMBEDDINGS, smart_reply_interpreter.get());
- SuggestActionsFromConversationEmbedding(conversation_embedding, options,
- &suggestions);
+ // TODO(smillius): Properly rank the actions.
return suggestions;
}
+const ActionsModel* ViewActionsModel(const void* buffer, int size) {
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+
+ return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
+}
+
} // namespace libtextclassifier3
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
index 531d131..187ff9e 100644
--- a/actions/actions-suggestions.h
+++ b/actions/actions-suggestions.h
@@ -22,6 +22,7 @@
#include <vector>
#include "actions/actions_model_generated.h"
+#include "annotator/types.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
@@ -43,6 +44,8 @@
int user_id;
// Text of the message.
std::string text;
+ // Annotations on the text.
+ std::vector<AnnotatedSpan> annotations;
};
// Conversation between multiple users.
@@ -69,33 +72,41 @@
std::vector<ActionSuggestion> SuggestActions(
const Conversation& conversation,
- const ActionSuggestionOptions& options = ActionSuggestionOptions());
+ const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
private:
// Checks that model contains all required fields, and initializes internal
// datastructures.
bool ValidateAndInitialize();
- void SetupSmartReplyModelInput(const std::vector<std::string>& context,
- const std::vector<int>& user_ids,
- const int num_suggestions,
- tflite::Interpreter* interpreter);
- void ReadSmartReplyModelOutput(tflite::Interpreter* interpreter,
- std::vector<ActionSuggestion>* suggestions);
+ void SetupModelInput(const std::vector<std::string>& context,
+ const std::vector<int>& user_ids,
+ const int num_suggestions,
+ tflite::Interpreter* interpreter) const;
+ void ReadModelOutput(tflite::Interpreter* interpreter,
+ std::vector<ActionSuggestion>* suggestions) const;
- void SuggestActionsFromConversationEmbedding(
- const TensorView<float>& conversation_embedding,
- const ActionSuggestionOptions& options,
- std::vector<ActionSuggestion>* actions);
+ void SuggestActionsFromModel(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* suggestions) const;
+
+ void SuggestActionsFromAnnotations(
+ const Conversation& conversation,
+ std::vector<ActionSuggestion>* suggestions) const;
+
+ // Check whether we shouldn't produce any predictions.
+ bool ShouldSuppressPredictions(tflite::Interpreter* interpreter) const;
const ActionsModel* model_;
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
// Tensorflow Lite models.
- std::unique_ptr<const TfLiteModelExecutor> smart_reply_executor_;
- std::unique_ptr<const TfLiteModelExecutor> actions_suggestions_executor_;
+ std::unique_ptr<const TfLiteModelExecutor> model_executor_;
};
+// Interprets the buffer as a Model flatbuffer and returns it for reading.
+const ActionsModel* ViewActionsModel(const void* buffer, int size);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
index 37f82f2..5696b25 100644
--- a/actions/actions-suggestions_test.cc
+++ b/actions/actions-suggestions_test.cc
@@ -16,21 +16,32 @@
#include "actions/actions-suggestions.h"
+#include <fstream>
+#include <iterator>
#include <memory>
+#include "actions/actions_model_generated.h"
+#include "annotator/types.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "flatbuffers/flatbuffers.h"
namespace libtextclassifier3 {
namespace {
+constexpr char kModelFileName[] = "actions_suggestions_test.model";
+
+std::string ReadFile(const std::string& file_name) {
+ std::ifstream file_stream(file_name);
+ return std::string(std::istreambuf_iterator<char>(file_stream), {});
+}
+
std::string GetModelPath() {
return "";
}
std::unique_ptr<ActionsSuggestions> LoadTestModel() {
- return ActionsSuggestions::FromPath(GetModelPath() +
- "actions_suggestions_test.model");
+ return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName);
}
TEST(ActionsSuggestionsTest, InstantiateActionSuggestions) {
@@ -39,10 +50,57 @@
TEST(ActionsSuggestionsTest, SuggestActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
- const auto actions = actions_suggestions->SuggestActions(
- {{{/*user_id=*/1, "where are you?"}}});
+ const std::vector<ActionSuggestion>& actions =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "where are you?"}}});
EXPECT_EQ(actions.size(), 6);
}
+TEST(ActionsSuggestionsTest, SuggestActionsFromAnnotations) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const std::vector<ActionSuggestion>& actions =
+ actions_suggestions->SuggestActions({{{/*user_id=*/1, "are you at home?",
+ /*annotations=*/{annotation}}}});
+ EXPECT_EQ(actions.size(), 7);
+ EXPECT_EQ(actions.back().type, "address");
+ EXPECT_EQ(actions.back().score, 1.0);
+}
+
+void TestSuggestActionsWithThreshold(
+ const std::function<void(ActionsModelT*)>& set_value_fn) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ set_value_fn(actions_model.get());
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ ASSERT_TRUE(actions_suggestions);
+ const std::vector<ActionSuggestion>& actions =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "where are you?"}}});
+ EXPECT_THAT(actions, testing::IsEmpty());
+}
+
+TEST(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) {
+ TestSuggestActionsWithThreshold([](ActionsModelT* actions_model) {
+ actions_model->min_triggering_confidence = 1.0;
+ });
+}
+
+TEST(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) {
+ TestSuggestActionsWithThreshold([](ActionsModelT* actions_model) {
+ actions_model->max_sensitive_topic_score = 0.0;
+ });
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
index 920a3ea..a789883 100644
--- a/actions/actions_jni.cc
+++ b/actions/actions_jni.cc
@@ -25,6 +25,7 @@
#include "actions/actions-suggestions.h"
#include "utils/base/integral_types.h"
#include "utils/java/scoped_local_ref.h"
+#include "utils/memory/mmap.h"
using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionSuggestion;
@@ -127,6 +128,42 @@
conversation.messages = messages;
return conversation;
}
+
+jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->locales()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->locales()->c_str());
+}
+
+jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return 0;
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model) {
+ return 0;
+ }
+ return model->version();
+}
+
+jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+ if (!mmap->handle().ok()) {
+ return env->NewStringUTF("");
+ }
+ const ActionsModel* model = libtextclassifier3::ViewActionsModel(
+ mmap->handle().start(), mmap->handle().num_bytes());
+ if (!model || !model->name()) {
+ return env->NewStringUTF("");
+ }
+ return env->NewStringUTF(model->name()->c_str());
+}
} // namespace
} // namespace libtextclassifier3
@@ -156,7 +193,7 @@
}
TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
-(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation,
+(JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation,
jobject joptions) {
if (!ptr) {
return nullptr;
@@ -172,7 +209,28 @@
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
-(JNIEnv* env, jobject thiz, jlong ptr) {
+(JNIEnv* env, jobject clazz, jlong ptr) {
ActionsSuggestions* model = reinterpret_cast<ActionsSuggestions*>(ptr);
delete model;
}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+}
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd) {
+ const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
+ new libtextclassifier3::ScopedMmap(fd));
+ return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
+}
diff --git a/actions/actions_jni.h b/actions/actions_jni.h
index 8c7fcb5..48d50db 100644
--- a/actions/actions_jni.h
+++ b/actions/actions_jni.h
@@ -47,6 +47,15 @@
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
(JNIEnv* env, jobject thiz, jlong ptr);
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
+(JNIEnv* env, jobject clazz, jint fd);
+
+TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
+(JNIEnv* env, jobject clazz, jint fd);
+
#ifdef __cplusplus
}
#endif
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
index a4dc0a1..f0e8b3d 100755
--- a/actions/actions_model.fbs
+++ b/actions/actions_model.fbs
@@ -16,28 +16,41 @@
file_identifier "TC3A";
+// Options to specify triggering behaviour per action class.
namespace libtextclassifier3;
-table ActionsSuggestionsModel {
+table ActionTypeOptions {
+ // The name of the predicted action.
+ name:string;
+
+ // Triggering behaviour.
+ // Whether the action class is considered in the model output or not.
+ enabled:bool = true;
+
+ // Minimal output score threshold.
+ min_triggering_score:float = 0;
+}
+
+// TensorFlow Lite model for suggesting actions.
+namespace libtextclassifier3;
+table TensorflowLiteModelSpec {
// TensorFlow Lite model for suggesting actions.
tflite_model:[ubyte] (force_align: 16);
- // Output classes.
- classes:[string];
+ // Input specification.
+ input_user_id:int = 0;
- // Lower bound threshold for model prediction output.
- min_confidence:float;
-}
+ input_context:int = 1;
+ input_context_length:int = 2;
+ input_time_diffs:int = 3;
+ input_num_suggestions:int = 4;
-namespace libtextclassifier3;
-table SmartReplyModel {
- // TensorFlow Lite model for suggesting smart replies.
- tflite_model:[ubyte] (force_align: 16);
+ // Output specification.
+ output_replies:int = 0;
- // Output type.
- action_type:string;
-
- // Lower bound threshold for model prediction output.
- min_confidence:float;
+ output_replies_scores:int = 1;
+ output_sensitive_topic_score:int = 3;
+ output_triggering_score:int = 4;
+ output_actions_scores:int = 5;
}
namespace libtextclassifier3;
@@ -51,11 +64,23 @@
// A name for the model that can be used e.g. for logging.
name:string;
- // Model for suggesting smart replies.
- smart_reply_model:libtextclassifier3.SmartReplyModel;
+ tflite_model_spec:libtextclassifier3.TensorflowLiteModelSpec;
- // Model for suggesting actions.
- actions_suggestions_model:libtextclassifier3.ActionsSuggestionsModel;
+ // Output classes.
+ smart_reply_action_type:string;
+
+ action_type:[libtextclassifier3.ActionTypeOptions];
+
+ // Lower bound thresholds for model prediction output.
+ min_actions_confidence:float;
+
+ min_triggering_confidence:float;
+
+ // Maximum sensitive score for which actions and smart replies are shown.
+ max_sensitive_topic_score:float = 1;
+
+ // Default number of smart reply predictions.
+ num_smart_replies:int = 3;
}
root_type libtextclassifier3.ActionsModel;
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
index f625a09..893dd84 100644
--- a/actions/test_data/actions_suggestions_test.model
+++ b/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 3c3f16b..562d58e 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -39,11 +39,9 @@
namespace {
const Model* LoadAndVerifyModel(const void* addr, int size) {
- const Model* model = GetModel(addr);
-
flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
- if (model->Verify(verifier)) {
- return model;
+ if (VerifyModelBuffer(verifier)) {
+ return GetModel(addr);
} else {
return nullptr;
}
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
index 8598ea4..b6290d5 100644
--- a/annotator/annotator_test.cc
+++ b/annotator/annotator_test.cc
@@ -123,7 +123,7 @@
unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -142,7 +142,7 @@
ModeFlag_ANNOTATION_AND_SELECTION;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -172,7 +172,7 @@
"phone");
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -227,7 +227,7 @@
unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -291,7 +291,7 @@
unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -325,7 +325,7 @@
unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -356,7 +356,7 @@
unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -393,7 +393,7 @@
verified_pattern->verification_options->verify_luhn_checksum = true;
unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -479,7 +479,7 @@
unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -499,7 +499,7 @@
unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -538,7 +538,7 @@
unpacked_model->selection_options->always_classify_suggested_selection = true;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -746,7 +746,7 @@
// Set the batch size.
unpacked_model->selection_options->batch_size = 4;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -778,7 +778,7 @@
unpacked_model->triggering_options->min_annotate_confidence =
2.f; // Discards all results.
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -803,7 +803,7 @@
0.f; // Keeps all results.
unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -823,7 +823,7 @@
// Disable the model for annotation.
unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -860,7 +860,7 @@
"phone");
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -905,7 +905,7 @@
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -1014,7 +1014,7 @@
ModeFlag_ANNOTATION_AND_CLASSIFICATION;
}
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
@@ -1186,7 +1186,7 @@
unpacked_model->classification_options->max_num_tokens = -1;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib_, &calendarlib_);
@@ -1200,7 +1200,7 @@
unpacked_model->classification_options->max_num_tokens = 3;
flatbuffers::FlatBufferBuilder builder2;
- builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder2.GetBufferPointer()),
builder2.GetSize(), &unilib_, &calendarlib_);
@@ -1223,7 +1223,7 @@
unpacked_model->classification_options->address_min_num_tokens = 0;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(Model::Pack(builder, unpacked_model.get()));
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib_, &calendarlib_);
@@ -1237,7 +1237,7 @@
unpacked_model->classification_options->address_min_num_tokens = 5;
flatbuffers::FlatBufferBuilder builder2;
- builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
+ FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
classifier = Annotator::FromUnownedBuffer(
reinterpret_cast<const char*>(builder2.GetBufferPointer()),
builder2.GetSize(), &unilib_, &calendarlib_);
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
index 6bd6d10..efe7306 100644
--- a/annotator/datetime/parser_test.cc
+++ b/annotator/datetime/parser_test.cc
@@ -124,7 +124,7 @@
{{expected_start_index, expected_end_index},
{expected_ms_utc, expected_granularity},
/*target_classification_score=*/1.0,
- /*priority_score=*/0.0}};
+ /*priority_score=*/0.1}};
const bool matches =
testing::Matches(ElementsAreArray(expected))(filtered_results);
if (!matches) {
diff --git a/annotator/feature-processor.h b/annotator/feature-processor.h
index ce44372..2d04253 100644
--- a/annotator/feature-processor.h
+++ b/annotator/feature-processor.h
@@ -88,10 +88,7 @@
// identical.
typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
- // If unilib is nullptr, will create and own an instance of a UniLib,
- // otherwise will use what's passed in.
- explicit FeatureProcessor(const FeatureProcessorOptions* options,
- const UniLib* unilib)
+ FeatureProcessor(const FeatureProcessorOptions* options, const UniLib* unilib)
: unilib_(unilib),
feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
*unilib_),
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
index ca6d9bf..f25b950 100644
--- a/annotator/test_data/test_model.fb
+++ b/annotator/test_data/test_model.fb
Binary files differ
diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb
index a1b73fe..cfe10cf 100644
--- a/annotator/test_data/test_model_cc.fb
+++ b/annotator/test_data/test_model_cc.fb
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
index 38b6969..7e990ed 100644
--- a/annotator/test_data/wrong_embeddings.fb
+++ b/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/annotator/token-feature-extractor.cc b/annotator/token-feature-extractor.cc
index 86ab03a..77ad7a4 100644
--- a/annotator/token-feature-extractor.cc
+++ b/annotator/token-feature-extractor.cc
@@ -58,11 +58,11 @@
remapped->clear();
for (auto it = word.begin(); it != word.end(); ++it) {
if (options.remap_digits && unilib.IsDigit(*it)) {
- remapped->AppendCodepoint('0');
+ remapped->push_back('0');
} else if (options.lowercase_tokens) {
- remapped->AppendCodepoint(unilib.ToLower(*it));
+ remapped->push_back(unilib.ToLower(*it));
} else {
- remapped->AppendCodepoint(*it);
+ remapped->push_back(*it);
}
}
}
@@ -160,7 +160,7 @@
int TokenFeatureExtractor::HashToken(StringPiece token) const {
if (options_.allowed_chargrams.empty()) {
- return tc2farmhash::Fingerprint64(token) % options_.num_buckets;
+ return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
} else {
// Padding and out-of-vocabulary tokens have extra buckets reserved because
// they are special and important tokens, and we don't want them to share
@@ -174,7 +174,7 @@
options_.allowed_chargrams.end()) {
return 0; // Out-of-vocabulary.
} else {
- return (tc2farmhash::Fingerprint64(token) %
+ return (tc3farmhash::Fingerprint64(token) %
(options_.num_buckets - kNumExtraBuckets)) +
kNumExtraBuckets;
}
diff --git a/annotator/zlib-utils.cc b/annotator/zlib-utils.cc
index f1de08a..d0fb0d0 100644
--- a/annotator/zlib-utils.cc
+++ b/annotator/zlib-utils.cc
@@ -156,6 +156,9 @@
bool DecompressBuffer(const CompressedBufferT* compressed_pattern,
ZlibDecompressor* zlib_decompressor,
std::string* uncompressed_pattern) {
+ if (!compressed_pattern) {
+ return true;
+ }
std::string packed_pattern =
PackFlatbuffer<CompressedBuffer>(compressed_pattern);
if (!zlib_decompressor->Decompress(
diff --git a/generate_flatbuffers.mk b/generate_flatbuffers.mk
index 3bad2bc..f256de3 100644
--- a/generate_flatbuffers.mk
+++ b/generate_flatbuffers.mk
@@ -2,6 +2,7 @@
define transform-fbs-to-cpp
@echo "Flatc: $@ <= $(PRIVATE_INPUT_FBS)"
+@rm $@
@mkdir -p $(dir $@)
$(hide) $(FLATC) \
--cpp \
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
index 89b9c0c..b3599ea 100644
--- a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -16,6 +16,8 @@
package com.google.android.textclassifier;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
* actions and replies in a given conversation.
@@ -23,12 +25,13 @@
* @hide
*/
public final class ActionsSuggestionsModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
static {
System.loadLibrary("textclassifier");
}
- private final long actionsModelPtr;
+ private long actionsModelPtr;
/**
* Creates a new instance of Actions predictor, using the provided model image, given as a file
@@ -61,12 +64,38 @@
/** Frees up the allocated memory. */
@Override
public void close() {
- nativeCloseActionsModel(actionsModelPtr);
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseActionsModel(actionsModelPtr);
+ actionsModelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
+ public static String getLocales(int fd) {
+ return nativeGetLocales(fd);
+ }
+
+ /** Returns the version of the model. */
+ public static int getVersion(int fd) {
+ return nativeGetVersion(fd);
+ }
+
+ /** Returns the name of the model. */
+ public static String getName(int fd) {
+ return nativeGetName(fd);
}
/** Action suggestion that contains a response text and the type of the response. */
public static final class ActionSuggestion {
-
private final String responseText;
private final String actionType;
private final float score;
@@ -131,8 +160,14 @@
private static native long nativeNewActionsModelFromPath(String path);
- private static native ActionSuggestion[] nativeSuggestActions(
+ private static native String nativeGetLocales(int fd);
+
+ private static native int nativeGetVersion(int fd);
+
+ private static native String nativeGetName(int fd);
+
+ private native ActionSuggestion[] nativeSuggestActions(
long context, Conversation conversation, ActionSuggestionOptions options);
- private static native void nativeCloseActionsModel(long context);
+ private native void nativeCloseActionsModel(long context);
}
diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java
index ee8dc50..b268a28 100644
--- a/java/com/google/android/textclassifier/AnnotatorModel.java
+++ b/java/com/google/android/textclassifier/AnnotatorModel.java
@@ -16,6 +16,8 @@
package com.google.android.textclassifier;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* Java wrapper for Annotator native library interface. This library is used for detecting entities
* in text.
@@ -23,6 +25,7 @@
* @hide
*/
public final class AnnotatorModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
static {
System.loadLibrary("textclassifier");
@@ -39,7 +42,7 @@
static final String TYPE_DATE_TIME = "datetime";
static final String TYPE_FLIGHT_NUMBER = "flight";
- private final long annotatorPtr;
+ private long annotatorPtr;
/**
* Creates a new instance of SmartSelect predictor, using the provided model image, given as a
@@ -109,7 +112,19 @@
/** Frees up the allocated memory. */
@Override
public void close() {
- nativeCloseAnnotator(annotatorPtr);
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseAnnotator(annotatorPtr);
+ annotatorPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
}
/** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
@@ -294,27 +309,26 @@
private static native long nativeNewAnnotatorFromPath(String path);
- private static native boolean nativeInitializeKnowledgeEngine(
- long context, byte[] serializedConfig);
+ private static native String nativeGetLocales(int fd);
- private static native int[] nativeSuggestSelection(
+ private static native int nativeGetVersion(int fd);
+
+ private static native String nativeGetName(int fd);
+
+ private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
+
+ private native int[] nativeSuggestSelection(
long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
- private static native ClassificationResult[] nativeClassifyText(
+ private native ClassificationResult[] nativeClassifyText(
long context,
String text,
int selectionBegin,
int selectionEnd,
ClassificationOptions options);
- private static native AnnotatedSpan[] nativeAnnotate(
+ private native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
- private static native void nativeCloseAnnotator(long context);
-
- private static native String nativeGetLocales(int fd);
-
- private static native int nativeGetVersion(int fd);
-
- private static native String nativeGetName(int fd);
+ private native void nativeCloseAnnotator(long context);
}
diff --git a/java/com/google/android/textclassifier/LangIdModel.java b/java/com/google/android/textclassifier/LangIdModel.java
index d016bd5..864535e 100644
--- a/java/com/google/android/textclassifier/LangIdModel.java
+++ b/java/com/google/android/textclassifier/LangIdModel.java
@@ -16,44 +16,59 @@
package com.google.android.textclassifier;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* Java wrapper for LangId native library interface. This class is used to detect languages in text.
*
* @hide
*/
public final class LangIdModel implements AutoCloseable {
+ private final AtomicBoolean isClosed = new AtomicBoolean(false);
static {
System.loadLibrary("textclassifier");
}
- private final long mModelPtr;
+ private long modelPtr;
/** Creates a new instance of LangId predictor, using the provided model image. */
public LangIdModel(int fd) {
- mModelPtr = nativeNewLangIdModel(fd);
- if (mModelPtr == 0L) {
+ modelPtr = nativeNewLangIdModel(fd);
+ if (modelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize LangId from given file descriptor.");
}
}
/** Creates a new instance of LangId predictor, using the provided model image. */
public LangIdModel(String modelPath) {
- mModelPtr = nativeNewLangIdModelFromPath(modelPath);
- if (mModelPtr == 0L) {
+ modelPtr = nativeNewLangIdModelFromPath(modelPath);
+ if (modelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize LangId from given file.");
}
}
/** Detects the languages for given text. */
public LanguageResult[] detectLanguages(String text) {
- return nativeDetectLanguages(mModelPtr, text);
+ return nativeDetectLanguages(modelPtr, text);
}
/** Frees up the allocated memory. */
@Override
public void close() {
- nativeCloseLangIdModel(mModelPtr);
+ if (isClosed.compareAndSet(false, true)) {
+ nativeCloseLangIdModel(modelPtr);
+ modelPtr = 0L;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
}
/** Result for detectLanguages method. */
@@ -77,16 +92,16 @@
/** Returns the version of the LangId model used. */
public int getVersion() {
- return nativeGetLangIdModelVersion(mModelPtr);
+ return nativeGetLangIdModelVersion(modelPtr);
}
private static native long nativeNewLangIdModel(int fd);
private static native long nativeNewLangIdModelFromPath(String path);
- private static native LanguageResult[] nativeDetectLanguages(long context, String text);
+ private native LanguageResult[] nativeDetectLanguages(long context, String text);
- private static native void nativeCloseLangIdModel(long context);
+ private native void nativeCloseLangIdModel(long context);
- private static native int nativeGetLangIdModelVersion(long context);
+ private native int nativeGetLangIdModelVersion(long context);
}
diff --git a/lang_id/common/embedding-feature-extractor.h b/lang_id/common/embedding-feature-extractor.h
index f672cf6..f51b6e5 100644
--- a/lang_id/common/embedding-feature-extractor.h
+++ b/lang_id/common/embedding-feature-extractor.h
@@ -42,11 +42,14 @@
// Read() or updated via UpdateMapsForExample.
class GenericEmbeddingFeatureExtractor {
public:
- virtual ~GenericEmbeddingFeatureExtractor() {}
+ // Constructs this GenericEmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
+ : arg_prefix_(arg_prefix) {}
- // Get the prefix string to put in front of all arguments, so they don't
- // conflict with other embedding models.
- virtual const string ArgPrefix() const = 0;
+ virtual ~GenericEmbeddingFeatureExtractor() {}
// Sets/inits up predicate maps and embedding space names that are common for
// all embedding based feature extractors.
@@ -59,43 +62,23 @@
// implemented in the typed class.
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
- // Number of predicates for the embedding at a given index (vocabulary size.)
- int EmbeddingSize(int index) const {
- return generic_feature_extractor(index).GetDomainSize();
- }
-
// Returns number of embedding spaces.
int NumEmbeddings() const { return embedding_dims_.size(); }
- // Returns the number of features in the embedding space.
- const int FeatureSize(int idx) const {
- return generic_feature_extractor(idx).feature_types();
- }
-
- // Returns the dimensionality of the embedding space.
- int EmbeddingDims(int index) const { return embedding_dims_[index]; }
-
- // Accessor for embedding dims (dimensions of the embedding spaces).
- const std::vector<int> &embedding_dims() const { return embedding_dims_; }
-
const std::vector<string> &embedding_fml() const { return embedding_fml_; }
// Get parameter name by concatenating the prefix and the original name.
string GetParamName(const string ¶m_name) const {
- string full_name = ArgPrefix();
+ string full_name = arg_prefix_;
full_name.push_back('_');
full_name.append(param_name);
return full_name;
}
- protected:
- // Provides the generic class with access to the templated extractors. This is
- // used to get the type information out of the feature extractor without
- // knowing the specific calling arguments of the extractor itself.
- virtual const GenericFeatureExtractor &generic_feature_extractor(
- int idx) const = 0;
-
private:
+ // Prefix for TaskContext parameters.
+ const string arg_prefix_;
+
// Embedding space names for parameter sharing.
std::vector<string> embedding_names_;
@@ -119,6 +102,13 @@
template <class EXTRACTOR, class OBJ, class... ARGS>
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
public:
+ // Constructs this EmbeddingFeatureExtractor.
+ //
+ // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
+ // avoid name clashes. See GetParamName().
+ explicit EmbeddingFeatureExtractor(const string &arg_prefix)
+ : GenericEmbeddingFeatureExtractor(arg_prefix) {}
+
// Sets up all predicate maps, feature extractors, and flags.
SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
@@ -173,15 +163,6 @@
}
}
- protected:
- // Provides generic access to the feature extractors.
- const GenericFeatureExtractor &generic_feature_extractor(
- int idx) const override {
- // DCHECK_LT(idx, feature_extractors_.size());
- // DCHECK_GE(idx, 0);
- return *feature_extractors_[idx];
- }
-
private:
// Templated feature extractor class.
std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
diff --git a/lang_id/common/embedding-feature-interface.h b/lang_id/common/embedding-feature-interface.h
new file mode 100644
index 0000000..87576c6
--- /dev/null
+++ b/lang_id/common/embedding-feature-interface.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "lang_id/common/embedding-feature-extractor.h"
+#include "lang_id/common/fel/feature-extractor.h"
+#include "lang_id/common/fel/task-context.h"
+#include "lang_id/common/fel/workspace.h"
+#include "lang_id/common/lite_base/attributes.h"
+
+namespace libtextclassifier3 {
+namespace mobile {
+
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureInterface {
+ public:
+ // Constructs this EmbeddingFeatureInterface.
+ //
+ // |arg_prefix| is a string prefix for the TaskContext parameters, passed to
+ // |the underlying EmbeddingFeatureExtractor.
+ explicit EmbeddingFeatureInterface(const string &arg_prefix)
+ : feature_extractor_(arg_prefix) {}
+
+ // Sets up feature extractors and flags for processing (inference).
+ SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
+ return feature_extractor_.Setup(context);
+ }
+
+ // Initializes feature extractor resources for processing (inference)
+ // including requesting a workspace for caching extracted features.
+ SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
+ if (!feature_extractor_.Init(context)) return false;
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
+ return true;
+ }
+
+ // Preprocesses *obj using the internal workspace registry.
+ void Preprocess(WorkspaceSet *workspace, OBJ *obj) const {
+ workspace->Reset(workspace_registry_);
+ feature_extractor_.Preprocess(workspace, obj);
+ }
+
+ // Extract features from |obj|. On return, FeatureVector features[i]
+ // contains the features for the embedding space #i.
+ //
+ // This function uses the precomputed info from |workspace|. Usage pattern:
+ //
+ // EmbeddingFeatureInterface<...> feature_interface;
+ // ...
+ // OBJ obj;
+ // WorkspaceSet workspace;
+ // feature_interface.Preprocess(&workspace, &obj);
+ //
+ // // For the same obj, but with different args:
+ // std::vector<FeatureVector> features;
+ // feature_interface.GetFeatures(obj, args, workspace, &features);
+ //
+ // This pattern is useful (more efficient) if you can pre-compute some info
+ // for the entire |obj|, which is reused by the feature extraction performed
+ // for different args. If that is not the case, you can use the simpler
+ // version GetFeaturesNoCaching below.
+ void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace,
+ std::vector<FeatureVector> *features) const {
+ feature_extractor_.ExtractFeatures(workspace, obj, args..., features);
+ }
+
+ // Simpler version of GetFeatures(), for cases when there is no opportunity to
+ // reuse computation between feature extractions for the same |obj|, but with
+ // different |args|. Returns the extracted features. For more info, see the
+ // doc for GetFeatures().
+ std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj,
+ ARGS... args) const {
+ // Technically, we still use a workspace, because
+ // feature_extractor_.ExtractFeatures requires one. But there is no real
+ // caching here, as we start from scratch for each call to ExtractFeatures.
+ WorkspaceSet workspace;
+ Preprocess(&workspace, obj);
+ std::vector<FeatureVector> features(NumEmbeddings());
+ GetFeatures(*obj, args..., workspace, &features);
+ return features;
+ }
+
+ // Returns number of embedding spaces.
+ int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); }
+
+ private:
+ // Typed feature extractor for embeddings.
+ EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_;
+
+ // The registry of shared workspaces in the feature extractor.
+ WorkspaceRegistry workspace_registry_;
+};
+
+} // namespace mobile
+} // namespace nlp_saft
+
+#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_
diff --git a/lang_id/common/fel/feature-descriptors.h b/lang_id/common/fel/feature-descriptors.h
index bd40b6f..a9408c9 100644
--- a/lang_id/common/fel/feature-descriptors.h
+++ b/lang_id/common/fel/feature-descriptors.h
@@ -53,14 +53,12 @@
// Accessors for the feature function type. The function type is the string
// that the feature extractor code is registered under.
void set_type(const string &type) { type_ = type; }
- bool has_type() const { return !type_.empty(); }
const string &type() const { return type_; }
// Accessors for the feature function name. The function name (if available)
// is used for some log messages. Otherwise, a more precise, but also more
// verbose name based on the feature specification is used.
void set_name(const string &name) { name_ = name; }
- bool has_name() const { return !name_.empty(); }
const string &name() const { return name_; }
// Accessors for the default (name-less) parameter.
diff --git a/lang_id/common/fel/feature-extractor.cc b/lang_id/common/fel/feature-extractor.cc
index c9633c5..c256257 100644
--- a/lang_id/common/fel/feature-extractor.cc
+++ b/lang_id/common/fel/feature-extractor.cc
@@ -61,21 +61,6 @@
return true;
}
-FeatureValue GenericFeatureExtractor::GetDomainSize() const {
- // Domain size of the set of features is equal to:
- // [largest domain size of any feature types] * [number of feature types]
- FeatureValue max_feature_type_dsize = 0;
- for (size_t i = 0; i < feature_types_.size(); ++i) {
- FeatureType *ft = feature_types_[i];
- const FeatureValue feature_type_dsize = ft->GetDomainSize();
- if (feature_type_dsize > max_feature_type_dsize) {
- max_feature_type_dsize = feature_type_dsize;
- }
- }
-
- return max_feature_type_dsize;
-}
-
string GenericFeatureFunction::GetParameter(const string &name,
const string &default_value) const {
// Find named parameter in feature descriptor.
diff --git a/lang_id/common/fel/feature-extractor.h b/lang_id/common/fel/feature-extractor.h
index 8e36352..8763852 100644
--- a/lang_id/common/fel/feature-extractor.h
+++ b/lang_id/common/fel/feature-extractor.h
@@ -136,17 +136,6 @@
// before Init() has been called.
int feature_types() const { return feature_types_.size(); }
- // Returns a feature type used in the extractor. Invalid before Init() has
- // been called.
- const FeatureType *feature_type(int index) const {
- return feature_types_[index];
- }
-
- // Returns the feature domain size of this feature extractor.
- // NOTE: The way that domain size is calculated is, for some, unintuitive. It
- // is the largest domain size of any feature type.
- FeatureValue GetDomainSize() const;
-
protected:
// Initializes the feature types used by the extractor. Called from
// FeatureExtractor<>::Init().
@@ -216,11 +205,6 @@
// null. Invalid before Init() has been called.
virtual FeatureType *GetFeatureType() const;
- // Returns the name of the registry used for creating the feature function.
- // This can be used for checking if two feature functions are of the same
- // kind.
- virtual const char *RegistryName() const = 0;
-
// Returns value of parameter |name| from the feature function descriptor.
// If the parameter is not present, returns the indicated |default_value|.
string GetParameter(const string &name, const string &default_value) const;
@@ -366,9 +350,6 @@
return f;
}
- // Returns the name of the registry for the feature function.
- const char *RegistryName() const override { return Self::registry()->name(); }
-
private:
// Special feature function class for resolving variable references. The type
// of the feature function is used for resolving the variable reference. When
diff --git a/lang_id/common/fel/feature-types.h b/lang_id/common/fel/feature-types.h
index fa8f35d..18cf69a 100644
--- a/lang_id/common/fel/feature-types.h
+++ b/lang_id/common/fel/feature-types.h
@@ -77,66 +77,6 @@
bool is_continuous_;
};
-// Templated generic resource based feature type. This feature type delegates
-// look up of feature value names to an unknown resource class, which is not
-// owned. Optionally, this type can also store a mapping of extra values which
-// are not in the resource.
-//
-// Note: this class assumes that Resource->GetFeatureValueName() will return
-// successfully for values ONLY in the range [0, Resource->NumValues()) Any
-// feature value not in the extra value map and not in the above range of
-// Resource will result in a ERROR and return of "<INVALID>".
-template <class Resource>
-class ResourceBasedFeatureType : public FeatureType {
- public:
- // Creates a new type with given name, resource object, and a mapping of
- // special values. The values must be greater or equal to
- // resource->NumValues() so as to avoid collisions; this is verified with
- // SAFTM_CHECK at creation.
- ResourceBasedFeatureType(const string &name, const Resource *resource,
- const std::map<FeatureValue, string> &values)
- : FeatureType(name), resource_(resource), values_(values) {
- max_value_ = resource->NumValues() - 1;
- for (const auto &pair : values) {
- SAFTM_CHECK_GE(pair.first, resource->NumValues())
- << "Invalid extra value: " << pair.first << ", " << pair.second;
- max_value_ = pair.first > max_value_ ? pair.first : max_value_;
- }
- }
-
- // Creates a new type with no special values.
- ResourceBasedFeatureType(const string &name, const Resource *resource)
- : ResourceBasedFeatureType(name, resource, {}) {}
-
- // Returns the feature name for a given feature value. First checks the values
- // map, then checks the resource to look up the name.
- string GetFeatureValueName(FeatureValue value) const override {
- if (values_.find(value) != values_.end()) {
- return values_.find(value)->second;
- }
- if (value >= 0 && value < resource_->NumValues()) {
- return resource_->GetFeatureValueName(value);
- } else {
- // LOG(ERROR) << "Invalid feature value " << value << " for " << name();
- return "<INVALID>";
- }
- }
-
- // Returns the number of possible values for this feature type. This is the
- // based on the largest value that was observed in the extra values.
- FeatureValue GetDomainSize() const override { return max_value_ + 1; }
-
- protected:
- // Shared resource. Not owned.
- const Resource *resource_ = nullptr;
-
- // Maximum possible value this feature could take.
- FeatureValue max_value_;
-
- // Mapping for extra feature values not in the resource.
- std::map<FeatureValue, string> values_;
-};
-
// Feature type that is defined using an explicit map from FeatureValue to
// string values. This can reduce some of the boilerplate when defining
// features that generate enum values. Example usage:
diff --git a/lang_id/common/fel/workspace.cc b/lang_id/common/fel/workspace.cc
index e422776..8cab281 100644
--- a/lang_id/common/fel/workspace.cc
+++ b/lang_id/common/fel/workspace.cc
@@ -54,10 +54,5 @@
string VectorIntWorkspace::TypeName() { return "Vector"; }
-VectorVectorIntWorkspace::VectorVectorIntWorkspace(int size)
- : elements_(size) {}
-
-string VectorVectorIntWorkspace::TypeName() { return "VectorVector"; }
-
} // namespace mobile
} // namespace nlp_saft
diff --git a/lang_id/common/fel/workspace.h b/lang_id/common/fel/workspace.h
index 910abaa..09095e4 100644
--- a/lang_id/common/fel/workspace.h
+++ b/lang_id/common/fel/workspace.h
@@ -168,29 +168,6 @@
std::vector<std::vector<Workspace *> > workspaces_;
};
-// A workspace that wraps around a single int.
-class SingletonIntWorkspace : public Workspace {
- public:
- // Default-initializes the int value.
- SingletonIntWorkspace() {}
-
- // Initializes the int with the given value.
- explicit SingletonIntWorkspace(int value) : value_(value) {}
-
- // Returns the name of this type of workspace.
- static string TypeName() { return "SingletonInt"; }
-
- // Returns the int value.
- int get() const { return value_; }
-
- // Sets the int value.
- void set(int value) { value_ = value; }
-
- private:
- // The enclosed int.
- int value_ = 0;
-};
-
// A workspace that wraps around a vector of int.
class VectorIntWorkspace : public Workspace {
public:
@@ -221,26 +198,6 @@
std::vector<int> elements_;
};
-// A workspace that wraps around a vector of vector of int.
-class VectorVectorIntWorkspace : public Workspace {
- public:
- // Creates a vector of empty vectors of the given size.
- explicit VectorVectorIntWorkspace(int size);
-
- // Returns the name of this type of workspace.
- static string TypeName();
-
- // Returns the i'th vector of elements.
- const std::vector<int> &elements(int i) const { return elements_[i]; }
-
- // Mutable access to the i'th vector of elements.
- std::vector<int> *mutable_elements(int i) { return &(elements_[i]); }
-
- private:
- // The enclosed vector of vector of elements.
- std::vector<std::vector<int> > elements_;
-};
-
} // namespace mobile
} // namespace nlp_saft
diff --git a/lang_id/common/math/fastexp.h b/lang_id/common/math/fastexp.h
index 4d942bc..05b654a 100644
--- a/lang_id/common/math/fastexp.h
+++ b/lang_id/common/math/fastexp.h
@@ -16,24 +16,6 @@
// Fast approximation for exp.
//
-// Note: this file is based on util/math/fastmath.h; we trimmed it down to
-// contain only vfexp() and vfexp2() (plus their transitive dependencies), and
-// renamed those functions to VeryFastExp() / VeryFastExp2().
-//
-// Both functions are based on a table lookup. "vf" stands for "very fast". In
-// terms of precision, VeryFastExp(x) differs from expf(x) by less than 1%
-// (relative to expf(x)); we have tests for that.
-//
-// The functions produce undefined results on overflow/underflow.
-// Bounds checking is only done in debug mode.
-//
-// To microbenchmark, run
-//
-// blaze run -c opt --dynamic_mode=off --run_under=perflab \
-// //lang_id/common:fastexp_benchmark \
-// -- --benchmarks=all --heap_check=
-//
-// You will receive an email when the results are ready.
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_FASTEXP_H_
@@ -80,7 +62,6 @@
extern FastMathClass FastMathInstance;
-inline float VeryFastExp2(float f) { return FastMathInstance.VeryFastExp2(f); }
inline float VeryFastExp(float f) { return FastMathInstance.VeryFastExp(f); }
} // namespace mobile
diff --git a/lang_id/lang-id-brain-interface.h b/lang_id/lang-id-brain-interface.h
deleted file mode 100644
index e247f9c..0000000
--- a/lang_id/lang-id-brain-interface.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_BRAIN_INTERFACE_H_
-#define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_BRAIN_INTERFACE_H_
-
-#include <string>
-#include <vector>
-
-#include "lang_id/common/embedding-feature-extractor.h"
-#include "lang_id/common/fel/feature-extractor.h"
-#include "lang_id/common/fel/task-context.h"
-#include "lang_id/common/fel/workspace.h"
-#include "lang_id/common/lite_base/attributes.h"
-#include "lang_id/features/light-sentence-features.h"
-#include "lang_id/light-sentence.h"
-
-// TODO(abakalov): Add a test.
-namespace libtextclassifier3 {
-namespace mobile {
-namespace lang_id {
-
-// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
-class LangIdEmbeddingFeatureExtractor
- : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
- public:
- const string ArgPrefix() const override { return "language_identifier"; }
-};
-
-// Similar to the inference (processing) part of SaftBrainInterface from
-// nlp/saft/components/common/brain/saft-brain-interface.h
-//
-// Handles sentence -> numeric_features and numeric_prediction -> language
-// conversions.
-class LangIdBrainInterface {
- public:
- // Requests/initializes resources and parameters.
- SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) {
- return feature_extractor_.Setup(context);
- }
-
- SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) {
- if (!feature_extractor_.Init(context)) return false;
- feature_extractor_.RequestWorkspaces(&workspace_registry_);
- return true;
- }
-
- // Extract features from sentence. On return, FeatureVector features[i]
- // contains the features for the embedding space #i.
- void GetFeatures(LightSentence *sentence,
- std::vector<FeatureVector> *features) const {
- WorkspaceSet workspace;
- workspace.Reset(workspace_registry_);
- feature_extractor_.Preprocess(&workspace, sentence);
- return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
- }
-
- int NumEmbeddings() const {
- return feature_extractor_.NumEmbeddings();
- }
-
- private:
- // Typed feature extractor for embeddings.
- LangIdEmbeddingFeatureExtractor feature_extractor_;
-
- // The registry of shared workspaces in the feature extractor.
- WorkspaceRegistry workspace_registry_;
-};
-
-} // namespace lang_id
-} // namespace mobile
-} // namespace nlp_saft
-
-#endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_LANG_ID_BRAIN_INTERFACE_H_
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
index 217662b..ebc88ec 100644
--- a/lang_id/lang-id.cc
+++ b/lang_id/lang-id.cc
@@ -24,6 +24,7 @@
#include <unordered_map>
#include <vector>
+#include "lang_id/common/embedding-feature-interface.h"
#include "lang_id/common/embedding-network-params.h"
#include "lang_id/common/embedding-network.h"
#include "lang_id/common/fel/feature-extractor.h"
@@ -34,7 +35,7 @@
#include "lang_id/common/math/algorithm.h"
#include "lang_id/common/math/softmax.h"
#include "lang_id/custom-tokenizer.h"
-#include "lang_id/lang-id-brain-interface.h"
+#include "lang_id/features/light-sentence-features.h"
#include "lang_id/light-sentence.h"
namespace libtextclassifier3 {
@@ -55,7 +56,8 @@
class LangIdImpl {
public:
explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
- : model_provider_(std::move(model_provider)) {
+ : model_provider_(std::move(model_provider)),
+ lang_id_brain_interface_("language_identifier") {
// Note: in the code below, we set valid_ to true only if all initialization
// steps completed successfully. Otherwise, we return early, leaving valid_
// to its default value false.
@@ -139,18 +141,9 @@
// language code string in ascending order.
std::vector<float> softmax = ComputeSoftmax(scores);
- // We will need to renormalize after removing items from the support.
- // Keep track of the normalization constant.
- float normalization_z = 0.0f;
for (int i = 0; i < softmax.size(); ++i) {
result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
softmax[i]);
- normalization_z += softmax[i];
- }
-
- // Renormalize prediction probabilities.
- for (auto &prediction : result->predictions) {
- prediction.second /= normalization_z;
}
// Sort the resulting language predictions by probability in descending
@@ -207,13 +200,10 @@
void ComputeScores(StringPiece text, std::vector<float> *scores) const {
// Create a Sentence storing the input text.
LightSentence sentence;
-
tokenizer_.Tokenize(text, &sentence);
- // TODO(salcianu): reuse vector<FeatureVector>.
- std::vector<FeatureVector> features(
- lang_id_brain_interface_.NumEmbeddings());
- lang_id_brain_interface_.GetFeatures(&sentence, &features);
+ std::vector<FeatureVector> features =
+ lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
// Run feed-forward neural network to compute scores.
network_->ComputeFinalScores(features, scores);
@@ -235,7 +225,8 @@
TokenizerForLangId tokenizer_;
- LangIdBrainInterface lang_id_brain_interface_;
+ EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
+ lang_id_brain_interface_;
// Neural network to use for scoring.
std::unique_ptr<EmbeddingNetwork> network_;
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
index a907897..3f656f2 100644
--- a/lang_id/lang-id.h
+++ b/lang_id/lang-id.h
@@ -69,20 +69,6 @@
virtual ~LangId();
- // Returns language code for the most likely language for a text.
- //
- // The input text consists of the |num_bytes| bytes that starts at |data|.
- //
- // Note: if this LangId object is not valid (see is_valid()) or if this LangId
- // object can't make a prediction, then this method returns
- // LangId::kUnknownLanguageCode.
- string FindLanguage(const char *data, size_t num_bytes) const;
-
- // Convenience version of FindLanguage(const char *, size_t).
- string FindLanguage(const string &text) const {
- return FindLanguage(text.data(), text.size());
- }
-
// Computes the an n-best list of language codes and probabilities
// corresponding to the most likely languages the given input text is written
// in. The list is sorted in descending order by language probability.
@@ -100,6 +86,28 @@
FindLanguages(text.data(), text.size(), result);
}
+ // Returns language code for the most likely language for a piece of text.
+ //
+ // The input text consists of the |num_bytes| bytes that start at |data|.
+ //
+ // Note: this method reports the most likely (1-best) language only if its
+ // probability is high enough; otherwise, it returns
+ // LangId::kUnknownLanguageCode. The specific probability threshold is tuned
+ // to the needs of an early client. If you need a different threshold, you
+ // can use FindLanguages (plural) to get the full LangIdResult, and apply your
+ // own threshold.
+ //
+ // Note: if this LangId object is not valid (see is_valid()) or if this LangId
+ // object can't make a prediction, then this method returns
+ // LangId::kUnknownLanguageCode.
+ //
+ string FindLanguage(const char *data, size_t num_bytes) const;
+
+ // Convenience version of FindLanguage(const char *, size_t).
+ string FindLanguage(const string &text) const {
+ return FindLanguage(text.data(), text.size());
+ }
+
// Returns true if this object has been correctly initialized and is ready to
// perform predictions. For more info, see doc for LangId
// constructor above.
diff --git a/lang_id/lang-id_jni.cc b/lang_id/lang-id_jni.cc
index b5168dd..a66cf29 100644
--- a/lang_id/lang-id_jni.cc
+++ b/lang_id/lang-id_jni.cc
@@ -83,7 +83,7 @@
}
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
LangId* model = reinterpret_cast<LangId*>(ptr);
if (!model) {
return nullptr;
@@ -97,7 +97,7 @@
}
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeCloseLangIdModel)
-(JNIEnv* env, jobject thiz, jlong ptr) {
+(JNIEnv* env, jobject clazz, jlong ptr) {
if (!ptr) {
TC3_LOG(ERROR) << "Trying to close null LangId.";
return;
@@ -107,7 +107,7 @@
}
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdModelVersion)
-(JNIEnv* env, jobject thiz, jlong ptr) {
+(JNIEnv* env, jobject clazz, jlong ptr) {
if (!ptr) {
return -1;
}
diff --git a/lang_id/lang-id_jni.h b/lang_id/lang-id_jni.h
index f8689f5..d447d65 100644
--- a/lang_id/lang-id_jni.h
+++ b/lang_id/lang-id_jni.h
@@ -40,13 +40,13 @@
(JNIEnv* env, jobject thiz, jstring path);
TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
-(JNIEnv* env, jobject thiz, jlong ptr, jstring text);
+(JNIEnv* env, jobject clazz, jlong ptr, jstring text);
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeCloseLangIdModel)
-(JNIEnv* env, jobject thiz, jlong ptr);
+(JNIEnv* env, jobject clazz, jlong ptr);
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdModelVersion)
-(JNIEnv* env, jobject thiz, jlong ptr);
+(JNIEnv* env, jobject clazz, jlong ptr);
#ifdef __cplusplus
}
diff --git a/models/actions_suggestions.model b/models/actions_suggestions.model
index f625a09..893dd84 100644
--- a/models/actions_suggestions.model
+++ b/models/actions_suggestions.model
Binary files differ
diff --git a/models/lang_id.model b/models/lang_id.model
index fc6d4fe..e577a69 100644
--- a/models/lang_id.model
+++ b/models/lang_id.model
Binary files differ
diff --git a/utils/hash/farmhash.h b/utils/hash/farmhash.h
index 4c9d2fe..f374c0b 100644
--- a/utils/hash/farmhash.h
+++ b/utils/hash/farmhash.h
@@ -24,7 +24,7 @@
#include <utility>
#ifndef NAMESPACE_FOR_HASH_FUNCTIONS
-#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash
+#define NAMESPACE_FOR_HASH_FUNCTIONS tc3farmhash
#endif
namespace NAMESPACE_FOR_HASH_FUNCTIONS {
diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc
index 8073c5a..330732c 100644
--- a/utils/java/jni-base.cc
+++ b/utils/java/jni-base.cc
@@ -36,21 +36,7 @@
return result;
}
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
- // Get system-level file descriptor from AssetFileDescriptor.
- ScopedLocalRef<jclass> afd_class(
- env->FindClass("android/content/res/AssetFileDescriptor"), env);
- if (afd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jmethodID afd_class_getFileDescriptor = env->GetMethodID(
- afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
- if (afd_class_getFileDescriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
-
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) {
ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
env);
if (fd_class == nullptr) {
@@ -63,9 +49,24 @@
TC3_LOG(ERROR) << "Couldn't find descriptor.";
return reinterpret_cast<jlong>(nullptr);
}
+ return env->GetIntField(fd, fd_class_descriptor);
+}
+jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
+ ScopedLocalRef<jclass> afd_class(
+ env->FindClass("android/content/res/AssetFileDescriptor"), env);
+ if (afd_class == nullptr) {
+ TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jmethodID afd_class_getFileDescriptor = env->GetMethodID(
+ afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
+ if (afd_class_getFileDescriptor == nullptr) {
+ TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
- return env->GetIntField(bundle_jfd, fd_class_descriptor);
+ return GetFdFromFileDescriptor(env, bundle_jfd);
}
} // namespace libtextclassifier3
diff --git a/utils/java/jni-base.h b/utils/java/jni-base.h
index 147ae08..23658a3 100644
--- a/utils/java/jni-base.h
+++ b/utils/java/jni-base.h
@@ -72,7 +72,12 @@
}
std::string ToStlString(JNIEnv* env, const jstring& str);
+
+// Get system-level file descriptor from AssetFileDescriptor.
jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd);
+
+// Get system-level file descriptor from FileDescriptor.
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/utils/java/string_utils.cc b/utils/java/string_utils.cc
index 6d4a5d7..ef865a8 100644
--- a/utils/java/string_utils.cc
+++ b/utils/java/string_utils.cc
@@ -20,6 +20,21 @@
namespace libtextclassifier3 {
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+ std::string* result) {
+ jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+ if (array_bytes == nullptr) {
+ return false;
+ }
+
+ const int array_length = env->GetArrayLength(array);
+ *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
+
+ env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+
+ return true;
+}
+
bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
std::string* result) {
if (jstr == nullptr) {
@@ -37,16 +52,13 @@
env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
jstring encoding = env->NewStringUTF("UTF-8");
+
jbyteArray array = reinterpret_cast<jbyteArray>(
env->CallObjectMethod(jstr, get_bytes_id, encoding));
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- int length = env->GetArrayLength(array);
-
- *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+ JByteArrayToString(env, array, result);
// Release the array.
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
env->DeleteLocalRef(array);
env->DeleteLocalRef(string_class);
env->DeleteLocalRef(encoding);
diff --git a/utils/java/string_utils.h b/utils/java/string_utils.h
index c4fd97a..e4f2bd8 100644
--- a/utils/java/string_utils.h
+++ b/utils/java/string_utils.h
@@ -22,6 +22,8 @@
namespace libtextclassifier3 {
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+ std::string* result);
bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
index a0ad4a4..050c466 100644
--- a/utils/sentencepiece/double_array_trie.h
+++ b/utils/sentencepiece/double_array_trie.h
@@ -20,7 +20,7 @@
#include <functional>
#include <vector>
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -35,17 +35,17 @@
typedef unsigned int TrieNode;
// A memory mappable trie, compatible with Darts::DoubleArray.
-class DoubleArrayTrie {
+class DoubleArrayTrie : public SentencePieceMatcher {
public:
// nodes and nodes_length specify the array of the nodes of the trie.
DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
: nodes_(nodes), nodes_length_(nodes_length) {}
// Find matches that are prefixes of a string.
- std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
// Find the longest prefix match of a string.
- TrieMatch LongestPrefixMatch(StringPiece input) const;
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
private:
// Returns whether a node as a leaf as a child.
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..96fb868
--- /dev/null
+++ b/utils/sentencepiece/encoder.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+std::vector<int> Encoder::Encode(StringPiece normalized_text) const {
+ const int len = normalized_text.size();
+ if (len <= 0) {
+ return {start_code_, end_code_};
+ }
+ // We use `previous_pos` to indicate whether a dynamic programming state was
+ // reachable.
+ std::vector<SegmentationEntry> segmentation(
+ len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+ /*num_pieces=*/0});
+ for (int i = 0; i < len; i++) {
+ // State couldn't be reached.
+ if (i > 0 && segmentation[i].previous_pos < 0) {
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ continue;
+ }
+ for (const auto& match : matcher_->FindAllPrefixMatches(normalized_text)) {
+ TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+ const int pos = i + match.match_length;
+ const float candidate_score = segmentation[i].score + scores_[match.id];
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < candidate_score) {
+ segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+ /*piece_id=*/match.id,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ }
+ if (segmentation[len].num_pieces <= 0) {
+ return {start_code_, end_code_};
+ }
+ const int num_pieces = segmentation[len].num_pieces;
+ std::vector<int> result(num_pieces + 2);
+ result[num_pieces + 1] = end_code_;
+ int pos = len;
+ for (int i = num_pieces; i > 0; i--) {
+ result[i] = segmentation[pos].piece_id + encoding_offset_;
+ pos = segmentation[pos].previous_pos;
+ }
+ result[0] = start_code_;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
index 4aa7582..fffd86f 100644
--- a/utils/sentencepiece/encoder.h
+++ b/utils/sentencepiece/encoder.h
@@ -17,19 +17,16 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
-#include <memory>
-#include <string>
#include <vector>
#include "utils/base/logging.h"
-#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
// Encoder to segment/tokenize strings into pieces such that the sum of the
// scores of the pieces used is maximized.
-template <typename TMatcher = DoubleArrayTrie>
class Encoder {
public:
// matcher: the list of valid sentence pieces represented as a matcher, e.g.
@@ -40,7 +37,7 @@
// end_code: Code that is used as encoding of the end of input.
// encoding_offset: Value added to the sentence piece ids to make them
// not interesecting with start_code and end_code.
- Encoder(const TMatcher& matcher, const int num_pieces,
+ Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
const float* pieces_scores, int start_code = 0, int end_code = 1,
int encoding_offset = 2)
: num_pieces_(num_pieces),
@@ -49,6 +46,10 @@
start_code_(start_code),
end_code_(end_code),
encoding_offset_(encoding_offset) {}
+
+ // Segment the input so that the total score of the pieces used is maximized.
+ // This is a simplified implementation of the general Viterbi algorithm,
+ // assuming independence between individual pieces.
std::vector<int> Encode(StringPiece normalized_text) const;
private:
@@ -69,62 +70,12 @@
const int num_pieces_;
const float* scores_;
- TMatcher matcher_;
+ const SentencePieceMatcher* matcher_;
const int start_code_;
const int end_code_;
const int encoding_offset_;
};
-// Segment the input such that the total score of the pieces used is maximized.
-// This is a simplified implementation of the general Viterbi algorithm,
-// assuming independence between individual pieces.
-template <typename TMatcher>
-std::vector<int> Encoder<TMatcher>::Encode(StringPiece normalized_text) const {
- const int len = normalized_text.size();
- if (len <= 0) {
- return {start_code_, end_code_};
- }
- // We use `previous_pos` to indicate whether a dynamic programming state was
- // reachable.
- std::vector<SegmentationEntry> segmentation(
- len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
- /*num_pieces=*/0});
- for (int i = 0; i < len; i++) {
- // State couldn't be reached.
- if (i > 0 && segmentation[i].previous_pos < 0) {
- // Advance position.
- normalized_text.RemovePrefix(1);
- continue;
- }
- for (const auto& match : matcher_.FindAllPrefixMatches(normalized_text)) {
- TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
- const int pos = i + match.match_length;
- const float candidate_score = segmentation[i].score + scores_[match.id];
- if (segmentation[pos].previous_pos < 0 ||
- segmentation[pos].score < candidate_score) {
- segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
- /*piece_id=*/match.id,
- /*num_pieces=*/segmentation[i].num_pieces + 1};
- }
- }
- // Advance position.
- normalized_text.RemovePrefix(1);
- }
- if (segmentation[len].num_pieces <= 0) {
- return {start_code_, end_code_};
- }
- const int num_pieces = segmentation[len].num_pieces;
- std::vector<int> result(num_pieces + 2);
- result[num_pieces + 1] = end_code_;
- int pos = len;
- for (int i = num_pieces; i > 0; i--) {
- result[i] = segmentation[pos].piece_id + encoding_offset_;
- pos = segmentation[pos].previous_pos;
- }
- result[0] = start_code_;
- return result;
-}
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
index 5697758..59c12ad 100644
--- a/utils/sentencepiece/encoder_test.cc
+++ b/utils/sentencepiece/encoder_test.cc
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include <memory>
#include <vector>
#include "gmock/gmock.h"
@@ -32,9 +33,10 @@
const char pieces[] = "hell\0hello\0o\0there\0";
const int offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
- const Encoder<SortedStringsTable> encoder(
- SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
- /*num_pieces=*/4, scores);
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
EXPECT_THAT(encoder.Encode("hellothere"), ElementsAreArray({0, 3, 5, 1}));
@@ -48,9 +50,10 @@
const char pieces[] = "hell\0hello\0o\0there\0";
const int offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
- const Encoder<SortedStringsTable> encoder(
- SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
- /*num_pieces=*/4, scores);
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
EXPECT_THAT(encoder.Encode("hellhello"), ElementsAreArray({0, 2, 3, 1}));
EXPECT_THAT(encoder.Encode("hellohell"), ElementsAreArray({0, 3, 2, 1}));
EXPECT_THAT(encoder.Encode(""), ElementsAreArray({0, 1}));
diff --git a/utils/sentencepiece/match.h b/utils/sentencepiece/matcher.h
similarity index 63%
rename from utils/sentencepiece/match.h
rename to utils/sentencepiece/matcher.h
index c1dc475..b538d69 100644
--- a/utils/sentencepiece/match.h
+++ b/utils/sentencepiece/matcher.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
#include <vector>
#include "utils/strings/stringpiece.h"
@@ -29,6 +29,18 @@
int match_length = -1;
};
+class SentencePieceMatcher {
+ public:
+ virtual ~SentencePieceMatcher() {}
+
+ // Find matches that are prefixes of a string.
+ virtual std::vector<TrieMatch> FindAllPrefixMatches(
+ StringPiece input) const = 0;
+
+ // Find the longest prefix match of a string.
+ virtual TrieMatch LongestPrefixMatch(StringPiece input) const = 0;
+};
+
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
index da4209d..82cda5c 100644
--- a/utils/sentencepiece/sorted_strings_table.h
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -20,7 +20,7 @@
#include <functional>
#include <vector>
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -34,7 +34,7 @@
// pieces: String pieces, concatenated in sorted order and zero byte separated.
// use_linear_scan_threshold: Minimum size of binary search range before
// switching to a linear sweep for prefix match testing.
-class SortedStringsTable {
+class SortedStringsTable : public SentencePieceMatcher {
public:
SortedStringsTable(const int num_pieces, const int* offsets,
StringPiece pieces,
@@ -45,10 +45,10 @@
use_linear_scan_threshold_(use_linear_scan_threshold) {}
// Find matches that are prefixes of a string.
- std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
// Find the longest prefix match of a string.
- TrieMatch LongestPrefixMatch(StringPiece input) const;
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
private:
void GatherPrefixMatches(
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
index 727991f..9554283 100644
--- a/utils/tflite/text_encoder.cc
+++ b/utils/tflite/text_encoder.cc
@@ -21,6 +21,7 @@
#include "utils/sentencepiece/double_array_trie.h"
#include "utils/sentencepiece/encoder.h"
#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
#include "utils/strings/stringpiece.h"
#include "utils/tflite/text_encoder.h"
#include "utils/tflite/text_encoder_config_generated.h"
@@ -35,7 +36,8 @@
struct TextEncoderOp {
std::unique_ptr<Normalizer> normalizer;
- std::unique_ptr<Encoder<DoubleArrayTrie>> encoder;
+ std::unique_ptr<Encoder> encoder;
+ std::unique_ptr<SentencePieceMatcher> matcher;
};
// Input parameters for the op.
@@ -49,12 +51,17 @@
// Output parameters for the op.
enum SmartReplyModelOutputs {
TEXT_ENCODER_OUTPUT_ENCODED = 0,
- TEXT_ENCODER_OUTPUT_LENGTHS = 1,
- TEXT_ENCODER_OUTPUT_ATTR = 2,
+ TEXT_ENCODER_OUTPUT_POSITION = 1,
+ TEXT_ENCODER_OUTPUT_LENGTHS = 2,
+ TEXT_ENCODER_OUTPUT_ATTR = 3,
};
const char kTextEncoderConfigAttr[] = "text_encoder_config";
+// Input rank is 2 since there is a dummy batch dimension of 1.
+const int kInputRank = 2;
+const int kBatchSize = 1;
+
// Initializes text encoder object from serialized options:
// The options are a flexbuffers attribute map that contain the op config
// with the key `text_encoder_config` as `TextEncoderConfig`.
@@ -81,15 +88,32 @@
config->add_dummy_prefix(), config->remove_extra_whitespaces(),
config->escape_whitespaces()));
- const TrieNode* pieces_trie_nodes =
- reinterpret_cast<const TrieNode*>(config->pieces()->Data());
- const int pieces_trie_nodes_length =
- config->pieces()->Length() / sizeof(TrieNode);
const int num_pieces = config->pieces_scores()->Length();
- encoder_op->encoder.reset(new Encoder<DoubleArrayTrie>(
- DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length), num_pieces,
- config->pieces_scores()->data(), config->start_code(), config->end_code(),
- config->encoding_offset()));
+
+ switch (config->matcher_type()) {
+ case SentencePieceMatcherType_MAPPED_TRIE: {
+ const TrieNode* pieces_trie_nodes =
+ reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+ const int pieces_trie_nodes_length =
+ config->pieces()->Length() / sizeof(TrieNode);
+ encoder_op->matcher.reset(
+ new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+ break;
+ }
+ case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+ encoder_op->matcher.reset(new SortedStringsTable(
+ num_pieces, config->pieces_offsets()->data(),
+ StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ break;
+ }
+ default: {
+ TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+ return nullptr;
+ }
+ }
+ encoder_op->encoder.reset(new Encoder(
+ encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+ config->start_code(), config->end_code(), config->encoding_offset()));
return encoder_op.release();
}
@@ -112,8 +136,8 @@
const std::vector<int>& encoding_end_offsets,
int start_offset, TfLiteContext* context,
TfLiteTensor* out) {
- TF_LITE_ENSURE_EQ(context, in.dims->size, 2);
- TF_LITE_ENSURE_EQ(context, in.dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, in.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, in.dims->data[0], kBatchSize);
const int output_size = out->dims->data[1];
int output_offset = 0;
for (int value_index = 0;
@@ -162,7 +186,8 @@
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
std::fill(out->data.f + output_offset, out->data.f + output_size, value);
} break;
- default: {}
+ default:
+ break;
}
return kTfLiteOk;
}
@@ -173,16 +198,26 @@
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, &output_encoded,
- CreateSizeArray({1, max_output_length})));
+ context,
+ context->ResizeTensor(context, &output_encoded,
+ CreateSizeArray({kBatchSize, max_output_length})));
+
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+ TF_LITE_ENSURE_OK(
+ context,
+ context->ResizeTensor(context, &output_positions,
+ CreateSizeArray({kBatchSize, max_output_length})));
const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
for (int i = 0; i < num_output_attrs; ++i) {
TfLiteTensor& output =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(
- context, &output,
- CreateSizeArray({1, max_output_length})));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(
+ context, &output,
+ CreateSizeArray({kBatchSize, max_output_length})));
}
return kTfLiteOk;
}
@@ -190,19 +225,22 @@
} // namespace
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- // Check that the batch dimension is 1.
+ // Check that the batch dimension is kBatchSize.
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_TEXTS]];
- TF_LITE_ENSURE_EQ(context, input_text.dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kBatchSize);
TfLiteTensor& output_lengths =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_LENGTHS]];
TfLiteTensor& output_encoded =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_lengths,
- CreateSizeArray({1})));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateSizeArray({kBatchSize})));
// Check that there are enough outputs for attributes.
const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
@@ -225,6 +263,7 @@
return ResizeOutputTensors(context, node, output_length.data.i64[0]);
} else {
tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
for (int i = 0; i < num_output_attrs; ++i) {
TfLiteTensor& output_attr =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
@@ -253,10 +292,15 @@
TF_LITE_ENSURE_OK(
context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
}
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
std::vector<int> encoded_total;
std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
for (int i = 0; i < num_strings; ++i) {
const auto& strref = tflite::GetString(&input_text, i);
@@ -264,16 +308,20 @@
encoder_op->normalizer->Normalize(StringPiece(strref.str, strref.len)));
encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); i++) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
}
- const int max_output_length = output_encoded.dims->data[1];
// Copy encoding to output tensor.
const int start_offset =
std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
int output_offset = 0;
int32_t* output_buffer = output_encoded.data.i32;
+ int32_t* output_positions_buffer = output_positions.data.i32;
for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
output_buffer[output_offset] = encoded_total[i];
+ output_positions_buffer[output_offset] = encoded_positions[i];
}
// Save output encoded length.
@@ -284,6 +332,7 @@
// Do padding.
for (; output_offset < max_output_length; ++output_offset) {
output_buffer[output_offset] = encoded_total.back();
+ output_positions_buffer[output_offset] = max_encoded_position;
}
// Process attributes, all checks of sizes and types are done in Prepare.
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 9017116..462da21 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -17,6 +17,12 @@
// Configuration for the text encoder op.
namespace libtextclassifier3;
+
+enum SentencePieceMatcherType : byte {
+ MAPPED_TRIE = 0,
+ SORTED_STRING_TABLE = 1,
+}
+
table TextEncoderConfig {
// Code that is used as encoding of the start code.
start_code:int32 = 0;
@@ -46,6 +52,8 @@
// Sentence pieces scores.
pieces_scores:[float];
- // Serialized sentence pieces trie.
+ // Serialized sentence pieces.
pieces:string;
+ pieces_offsets:[int32];
+ matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
}
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
index d1892c7..0b6ff71 100644
--- a/utils/tflite/text_encoder_test.cc
+++ b/utils/tflite/text_encoder_test.cc
@@ -55,6 +55,9 @@
std::vector<int> GetOutputEncoding() {
return ExtractVector<int>(output_encoding_);
}
+ std::vector<int> GetOutputPositions() {
+ return ExtractVector<int>(output_positions_);
+ }
std::vector<int> GetOutputAttributeInt32() {
return ExtractVector<int>(output_attributes_int32_);
}
@@ -71,6 +74,7 @@
int input_attributes_float_;
int output_encoding_;
+ int output_positions_;
int output_length_;
int output_attributes_int32_;
int output_attributes_float_;
@@ -86,6 +90,7 @@
input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
output_encoding_ = AddOutput(tflite::TensorType_INT32);
+ output_positions_ = AddOutput(tflite::TensorType_INT32);
output_length_ = AddOutput(tflite::TensorType_INT32);
output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
@@ -113,6 +118,8 @@
EXPECT_EQ(m.GetEncodedLength(), 5);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
EXPECT_THAT(
@@ -130,6 +137,8 @@
EXPECT_EQ(m.GetEncodedLength(), 10);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
EXPECT_THAT(
@@ -147,6 +156,8 @@
EXPECT_EQ(m.GetEncodedLength(), 9);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
EXPECT_THAT(
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
index 057703a..81492d8 100644
--- a/utils/utf8/unicodetext.cc
+++ b/utils/utf8/unicodetext.cc
@@ -176,7 +176,7 @@
} // namespace
-UnicodeText& UnicodeText::AppendCodepoint(char32 ch) {
+UnicodeText& UnicodeText::push_back(char32 ch) {
char str[4];
int char_len = runetochar(ch, str);
repr_.append(str, char_len);
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
index f7790f5..eb206b8 100644
--- a/utils/utf8/unicodetext.h
+++ b/utils/utf8/unicodetext.h
@@ -168,7 +168,7 @@
// Calling this may invalidate pointers to underlying data.
UnicodeText& AppendUTF8(const char* utf8, int len);
- UnicodeText& AppendCodepoint(char32 ch);
+ UnicodeText& push_back(char32 ch);
void clear();
std::string ToUTF8String() const;
diff --git a/utils/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc
index 9cdc850..7ebb415 100644
--- a/utils/utf8/unicodetext_test.cc
+++ b/utils/utf8/unicodetext_test.cc
@@ -24,11 +24,11 @@
class UnicodeTextTest : public testing::Test {
protected:
UnicodeTextTest() : empty_text_() {
- text_.AppendCodepoint(0x1C0);
- text_.AppendCodepoint(0x4E8C);
- text_.AppendCodepoint(0xD7DB);
- text_.AppendCodepoint(0x34);
- text_.AppendCodepoint(0x1D11E);
+ text_.push_back(0x1C0);
+ text_.push_back(0x4E8C);
+ text_.push_back(0xD7DB);
+ text_.push_back(0x34);
+ text_.push_back(0x1D11E);
}
UnicodeText empty_text_;