Export libtextclassifier to Android
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Id7a31dc60c8f6625ff8f2a9c85689e13b121a5a4
diff --git a/Android.mk b/Android.mk
index acc01b3..c89e0b9 100644
--- a/Android.mk
+++ b/Android.mk
@@ -85,8 +85,8 @@
LOCAL_ADDITIONAL_DEPENDENCIES += $(LOCAL_PATH)/jni.lds
LOCAL_LDFLAGS += -Wl,-version-script=$(LOCAL_PATH)/jni.lds
-LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
-LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_32 += -DTC3_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
include $(BUILD_SHARED_LIBRARY)
@@ -109,8 +109,8 @@
LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, annotator/test_data, actions/test_data)
-LOCAL_CPPFLAGS_32 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
-LOCAL_CPPFLAGS_64 += -DLIBTEXTCLASSIFIER_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_32 += -DTC3_TEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/test_data/\""
+LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
# TODO: Do not filter out tflite test once the dependency issue is resolved.
LOCAL_SRC_FILES := $(filter-out utils/tflite/%_test.cc,$(call all-subdir-cpp-files))
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
index eacf991..7b6e6e2 100644
--- a/actions/actions-suggestions.cc
+++ b/actions/actions-suggestions.cc
@@ -90,6 +90,10 @@
return FromScopedMmap(std::move(mmap));
}
+void ActionsSuggestions::SetAnnotator(const Annotator* annotator) {
+ annotator_ = annotator;
+}
+
bool ActionsSuggestions::ValidateAndInitialize() {
if (model_ == nullptr) {
TC3_LOG(ERROR) << "No model specified.";
@@ -110,7 +114,8 @@
void ActionsSuggestions::SetupModelInput(
const std::vector<std::string>& context, const std::vector<int>& user_ids,
- const int num_suggestions, tflite::Interpreter* interpreter) const {
+ const std::vector<float>& time_diffs, const int num_suggestions,
+ tflite::Interpreter* interpreter) const {
if (model_->tflite_model_spec()->input_context() >= 0) {
model_executor_->SetInput<std::string>(
model_->tflite_model_spec()->input_context(), context, interpreter);
@@ -121,18 +126,21 @@
->input_context_length()])
->data.i64 = context.size();
}
-
if (model_->tflite_model_spec()->input_user_id() >= 0) {
model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
user_ids, interpreter);
}
-
if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
*interpreter
->tensor(interpreter->inputs()[model_->tflite_model_spec()
->input_num_suggestions()])
->data.i64 = num_suggestions;
}
+ if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
+ model_executor_->SetInput<float>(
+ model_->tflite_model_spec()->input_time_diffs(), time_diffs,
+ interpreter);
+ }
}
bool ActionsSuggestions::ShouldSuppressPredictions(
@@ -220,9 +228,31 @@
return;
}
- // Use only last message for now.
- SetupModelInput({conversation.messages.back().text},
- {conversation.messages.back().user_id},
+ int num_messages = conversation.messages.size();
+ if (model_->max_conversation_history_length() >= 0 &&
+ num_messages > model_->max_conversation_history_length()) {
+ num_messages = model_->max_conversation_history_length();
+ }
+
+ if (num_messages <= 0) {
+ TC3_LOG(INFO) << "No messages provided for actions suggestions.";
+ return;
+ }
+
+ std::vector<std::string> context;
+ std::vector<int> user_ids;
+ std::vector<float> time_diffs;
+
+ // Gather last `num__messages` messages from the conversation.
+ for (int i = conversation.messages.size() - num_messages;
+ i < conversation.messages.size(); i++) {
+ const ConversationMessage& message = conversation.messages[i];
+ context.push_back(message.text);
+ user_ids.push_back(message.user_id);
+ time_diffs.push_back(message.time_diff_secs);
+ }
+
+ SetupModelInput(context, user_ids, time_diffs,
/*num_suggestions=*/model_->num_smart_replies(),
interpreter.get());
@@ -239,12 +269,17 @@
}
void ActionsSuggestions::SuggestActionsFromAnnotations(
- const Conversation& conversation,
+ const Conversation& conversation, const ActionSuggestionOptions& options,
std::vector<ActionSuggestion>* suggestions) const {
// Create actions based on the annotations present in the last message.
// TODO(smillius): Make this configurable.
- for (const AnnotatedSpan& annotation :
- conversation.messages.back().annotations) {
+ std::vector<AnnotatedSpan> annotations =
+ conversation.messages.back().annotations;
+ if (annotations.empty() && annotator_ != nullptr) {
+ annotations = annotator_->Annotate(conversation.messages.back().text,
+ options.annotation_options);
+ }
+ for (const AnnotatedSpan& annotation : annotations) {
if (annotation.classification.empty() ||
annotation.classification[0].collection.empty()) {
continue;
@@ -266,7 +301,7 @@
}
SuggestActionsFromModel(conversation, &suggestions);
- SuggestActionsFromAnnotations(conversation, &suggestions);
+ SuggestActionsFromAnnotations(conversation, options, &suggestions);
// TODO(smillius): Properly rank the actions.
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
index 187ff9e..67ef9cc 100644
--- a/actions/actions-suggestions.h
+++ b/actions/actions-suggestions.h
@@ -22,6 +22,7 @@
#include <vector>
#include "actions/actions_model_generated.h"
+#include "annotator/annotator.h"
#include "annotator/types.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
@@ -44,6 +45,8 @@
int user_id;
// Text of the message.
std::string text;
+ // Relative time to previous message.
+ float time_diff_secs;
// Annotations on the text.
std::vector<AnnotatedSpan> annotations;
};
@@ -55,7 +58,12 @@
};
// Options for suggesting actions.
-struct ActionSuggestionOptions {};
+struct ActionSuggestionOptions {
+ // Options for annotation of the messages.
+ AnnotationOptions annotation_options = AnnotationOptions::Default();
+
+ static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
+};
// Class for predicting actions following a conversation.
class ActionsSuggestions {
@@ -72,7 +80,11 @@
std::vector<ActionSuggestion> SuggestActions(
const Conversation& conversation,
- const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
+ const ActionSuggestionOptions& options =
+ ActionSuggestionOptions::Default()) const;
+
+ // Provide an annotator.
+ void SetAnnotator(const Annotator* annotator);
private:
// Checks that model contains all required fields, and initializes internal
@@ -81,6 +93,7 @@
void SetupModelInput(const std::vector<std::string>& context,
const std::vector<int>& user_ids,
+ const std::vector<float>& time_diffs,
const int num_suggestions,
tflite::Interpreter* interpreter) const;
void ReadModelOutput(tflite::Interpreter* interpreter,
@@ -91,7 +104,7 @@
std::vector<ActionSuggestion>* suggestions) const;
void SuggestActionsFromAnnotations(
- const Conversation& conversation,
+ const Conversation& conversation, const ActionSuggestionOptions& options,
std::vector<ActionSuggestion>* suggestions) const;
// Check whether we shouldn't produce any predictions.
@@ -102,6 +115,9 @@
// Tensorflow Lite models.
std::unique_ptr<const TfLiteModelExecutor> model_executor_;
+
+ // Annotator.
+ const Annotator* annotator_ = nullptr;
};
// Interprets the buffer as a Model flatbuffer and returns it for reading.
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
index 5696b25..5f97af6 100644
--- a/actions/actions-suggestions_test.cc
+++ b/actions/actions-suggestions_test.cc
@@ -63,6 +63,7 @@
annotation.classification = {ClassificationResult("address", 1.0)};
const std::vector<ActionSuggestion>& actions =
actions_suggestions->SuggestActions({{{/*user_id=*/1, "are you at home?",
+ /*time_diff_secs=*/0,
/*annotations=*/{annotation}}}});
EXPECT_EQ(actions.size(), 7);
EXPECT_EQ(actions.back().type, "address");
@@ -102,5 +103,35 @@
});
}
+TEST(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+
+ // Allow a larger conversation context.
+ actions_model->max_conversation_history_length = 10;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize());
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const std::vector<ActionSuggestion>& actions =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/0, "hi, how are you?", /*time_diff_secs=*/0},
+ {/*user_id=*/1, "good! are you at home?",
+ /*time_diff_secs=*/60,
+ /*annotations=*/{annotation}}}});
+ EXPECT_EQ(actions.size(), 7);
+ EXPECT_EQ(actions.back().type, "address");
+ EXPECT_EQ(actions.back().score, 1.0);
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
index a789883..33710cd 100644
--- a/actions/actions_jni.cc
+++ b/actions/actions_jni.cc
@@ -23,6 +23,7 @@
#include <vector>
#include "actions/actions-suggestions.h"
+#include "annotator/annotator.h"
#include "utils/base/integral_types.h"
#include "utils/java/scoped_local_ref.h"
#include "utils/memory/mmap.h"
@@ -30,6 +31,7 @@
using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionSuggestion;
using libtextclassifier3::ActionSuggestionOptions;
+using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
using libtextclassifier3::ScopedLocalRef;
using libtextclassifier3::ToStlString;
@@ -234,3 +236,13 @@
new libtextclassifier3::ScopedMmap(fd));
return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
}
+
+TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeSetAnnotator)
+(JNIEnv* env, jobject clazz, jlong ptr, jlong annotatorPtr) {
+ if (!ptr) {
+ return;
+ }
+ ActionsSuggestions* action_model = reinterpret_cast<ActionsSuggestions*>(ptr);
+ Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
+ action_model->SetAnnotator(annotator);
+}
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
index f0e8b3d..fb6565d 100755
--- a/actions/actions_model.fbs
+++ b/actions/actions_model.fbs
@@ -81,6 +81,9 @@
// Default number of smart reply predictions.
num_smart_replies:int = 3;
+
+ // Length of message history to consider, -1 if unbounded.
+ max_conversation_history_length:int = 1;
}
root_type libtextclassifier3.ActionsModel;
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
index 893dd84..4809a7b 100644
--- a/actions/test_data/actions_suggestions_test.model
+++ b/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 562d58e..a8a7d8b 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -111,6 +111,7 @@
return classifier;
}
+
std::unique_ptr<Annotator> Annotator::FromScopedMmap(
std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
const CalendarLib* calendarlib) {
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
index 57580fa..6907398 100644
--- a/annotator/annotator_jni.cc
+++ b/annotator/annotator_jni.cc
@@ -30,8 +30,8 @@
#include "utils/memory/mmap.h"
#include "utils/utf8/unilib.h"
-#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
-#ifndef LIBTEXTCLASSIFIER_CALENDAR_JAVAICU
+#ifdef TC3_UNILIB_JAVAICU
+#ifndef TC3_CALENDAR_JAVAICU
#error Inconsistent usage of Java ICU components
#else
#define TC3_USE_JAVAICU
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
index b6290d5..fbaf039 100644
--- a/annotator/annotator_test.cc
+++ b/annotator/annotator_test.cc
@@ -52,7 +52,7 @@
}
std::string GetModelPath() {
- return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
+ return TC3_TEST_DATA_DIR;
}
class AnnotatorTest : public ::testing::TestWithParam<const char*> {
@@ -205,7 +205,7 @@
return result;
}
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, ClassifyTextRegularExpression) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -266,9 +266,9 @@
"www.google.com every today!|Call me at (800) 123-456 today.",
{51, 65})));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, SuggestSelectionRegularExpression) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -308,9 +308,9 @@
EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
std::make_pair(4, 23));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -339,9 +339,9 @@
{55, 57}),
std::make_pair(26, 62));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -370,9 +370,9 @@
{55, 57}),
std::make_pair(55, 62));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, AnnotateRegex) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -409,7 +409,7 @@
IsAnnotatedSpan(79, 91, "phone"),
IsAnnotatedSpan(107, 126, "payment_card")}));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
TEST_P(AnnotatorTest, PhoneFiltering) {
std::unique_ptr<Annotator> classifier =
@@ -739,6 +739,7 @@
.empty());
}
+
TEST_P(AnnotatorTest, AnnotateSmallBatches) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -768,7 +769,7 @@
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, AnnotateFilteringDiscardAll) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -791,7 +792,7 @@
EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
TEST_P(AnnotatorTest, AnnotateFilteringKeepAll) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -873,7 +874,7 @@
}));
}
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
@@ -917,9 +918,9 @@
IsAnnotatedSpan(28, 55, "address"),
}));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+#ifdef TC3_CALENDAR_ICU
TEST_P(AnnotatorTest, ClassifyTextDate) {
std::unique_ptr<Annotator> classifier =
Annotator::FromPath(GetModelPath() + GetParam());
@@ -968,9 +969,9 @@
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+#ifdef TC3_CALENDAR_ICU
TEST_P(AnnotatorTest, ClassifyTextDatePriorities) {
std::unique_ptr<Annotator> classifier =
Annotator::FromPath(GetModelPath() + GetParam());
@@ -1001,9 +1002,9 @@
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+#ifdef TC3_CALENDAR_ICU
TEST_P(AnnotatorTest, SuggestTextDateDisabled) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -1027,7 +1028,7 @@
EXPECT_THAT(classifier->Annotate("january 1, 2017"),
ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
class TestingAnnotator : public Annotator {
public:
@@ -1123,7 +1124,7 @@
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, LongInput) {
std::unique_ptr<Annotator> classifier =
Annotator::FromPath(GetModelPath() + GetParam(), &unilib_, &calendarlib_);
@@ -1152,9 +1153,9 @@
input_100k, {50000, 50000 + value_length})));
}
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
// These coarse tests are there only to make sure the execution happens in
// reasonable amount of time.
TEST_P(AnnotatorTest, LongInputNoResultCheck) {
@@ -1173,9 +1174,9 @@
classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
}
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, MaxTokenLength) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -1210,9 +1211,9 @@
"I live at 350 Third Street, Cambridge.", {10, 37})),
"other");
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_P(AnnotatorTest, MinAddressTokenLength) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
@@ -1247,7 +1248,7 @@
"I live at 350 Third Street, Cambridge.", {10, 37})),
"other");
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
} // namespace
} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
index efe7306..d46accf 100644
--- a/annotator/datetime/parser_test.cc
+++ b/annotator/datetime/parser_test.cc
@@ -34,7 +34,7 @@
namespace {
std::string GetModelPath() {
- return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
+ return TC3_TEST_DATA_DIR;
}
std::string ReadFile(const std::string& file_name) {
diff --git a/annotator/feature-processor_test.cc b/annotator/feature-processor_test.cc
index 1788906..c9f0e0d 100644
--- a/annotator/feature-processor_test.cc
+++ b/annotator/feature-processor_test.cc
@@ -867,7 +867,7 @@
Token("웹사이트", 7, 11)}));
}
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(FeatureProcessorTest, ICUTokenize) {
FeatureProcessorOptionsT options;
options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
@@ -889,7 +889,7 @@
}
#endif
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(FeatureProcessorTest, ICUTokenizeWithWhitespaces) {
FeatureProcessorOptionsT options;
options.tokenization_type = FeatureProcessorOptions_::TokenizationType_ICU;
@@ -917,7 +917,7 @@
}
#endif
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(FeatureProcessorTest, MixedTokenize) {
FeatureProcessorOptionsT options;
options.tokenization_type = FeatureProcessorOptions_::TokenizationType_MIXED;
diff --git a/annotator/model.fbs b/annotator/model.fbs
index a3d26f8..1f0a292 100755
--- a/annotator/model.fbs
+++ b/annotator/model.fbs
@@ -14,6 +14,8 @@
// limitations under the License.
//
+include "utils/intents/intent-config.fbs";
+
file_identifier "TC2 ";
// The possible model modes, represents a bit field.
@@ -129,59 +131,6 @@
GROUP_DUMMY2 = 13,
}
-// The type of variable to fetch.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorVariableType : int {
- INVALID_VARIABLE = 0,
-
- // The raw text that was classified.
- RAW_TEXT = 1,
-
- // Text as a URL with explicit protocol. If no protocol was specified, http
- // is prepended.
- URL_TEXT = 2,
-
- // The raw text, but URL encoded.
- URL_ENCODED_TEXT = 3,
-
- // For dates/times: the instant of the event in UTC millis.
- EVENT_TIME_MS_UTC = 4,
-
- // For dates/times: the start of the event in UTC millis.
- EVENT_START_MS_UTC = 5,
-
- // For dates/times: the end of the event in UTC millis.
- EVENT_END_MS_UTC = 6,
-
- // Name of the package that's running the classifier.
- PACKAGE_NAME = 7,
-}
-
-// Enumerates the possible extra types for the simple intent generator.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorExtraType : int {
- INVALID_EXTRA_TYPE = 0,
- STRING = 1,
- BOOL = 2,
- VARIABLE_AS_LONG = 3,
-}
-
-// Enumerates the possible condition types for the simple intent generator.
-namespace libtextclassifier3;
-enum AndroidSimpleIntentGeneratorConditionType : int {
- INVALID_CONDITION_TYPE = 0,
-
- // Queries the UserManager for the given boolean restriction. The condition
- // passes if the result is of getBoolean is false. The name of the
- // restriction to check is in the string_ field.
- USER_RESTRICTION_NOT_SET = 1,
-
- // Checks that the parsed event start time is at least a give number of
- // milliseconds in the future. (Only valid if there is a parsed event
- // time) The offset is stored in the int64_ field.
- EVENT_START_IN_FUTURE_MS = 2,
-}
-
namespace libtextclassifier3;
table CompressedBuffer {
buffer:[ubyte];
@@ -633,109 +582,4 @@
tokenize_on_script_change:bool = false;
}
-// Describes how intents for the various entity types should be generated on
-// Android. This is distributed through the model, but not used by
-// libtextclassifier yet - rather, it's passed to the calling Java code, which
-// implements the Intent generation logic.
-namespace libtextclassifier3;
-table AndroidIntentFactoryOptions {
- entity:[libtextclassifier3.AndroidIntentFactoryEntityOptions];
-}
-
-// Describes how intents should be generated for a particular entity type.
-namespace libtextclassifier3;
-table AndroidIntentFactoryEntityOptions {
- // The entity type as defined by one of the TextClassifier ENTITY_TYPE
- // constants. (e.g. "address", "phone", etc.)
- entity_type:string;
-
- // List of generators for all the different types of intents that should
- // be made available for the entity type.
- generator:[libtextclassifier3.AndroidIntentGeneratorOptions];
-}
-
-// Configures a single Android Intent generator.
-namespace libtextclassifier3;
-table AndroidIntentGeneratorOptions {
- // Strings for UI elements.
- strings:[libtextclassifier3.AndroidIntentGeneratorStrings];
-
- // Generator specific configuration.
- simple:libtextclassifier3.AndroidSimpleIntentGeneratorOptions;
-}
-
-// Language dependent configuration for an Android Intent generator.
-namespace libtextclassifier3;
-table AndroidIntentGeneratorStrings {
- // BCP 47 tag for the supported locale. Note that because of API level
- // restrictions, this must /not/ use wildcards. To e.g. match all English
- // locales, use only "en" and not "en_*". Reference the java.util.Locale
- // constructor for details.
- language_tag:string;
-
- // Title shown for the action (see RemoteAction.getTitle).
- title:string;
-
- // Description shown for the action (see
- // RemoteAction.getContentDescription).
- description:string;
-}
-
-// An extra to set on a simple intent generator Intent.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorExtra {
- // The name of the extra to set.
- name:string;
-
- // The type of the extra to set.
- type:libtextclassifier3.AndroidSimpleIntentGeneratorExtraType;
-
- string_:string;
-
- bool_:bool;
- int32_:int;
-}
-
-// A condition that needs to be fulfilled for an Intent to get generated.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorCondition {
- type:libtextclassifier3.AndroidSimpleIntentGeneratorConditionType;
-
- string_:string;
-
- int32_:int;
- int64_:long;
-}
-
-// Configures an intent generator where the logic is simple to be expressed with
-// basic rules - which covers the vast majority of use cases and is analogous
-// to Android Actions.
-// Most strings (action, data, type, ...) may contain variable references. To
-// use them, the generator must first declare all the variables it wishes to use
-// in the variables field. The values then become available as numbered
-// arguments (using the normal java.util.Formatter syntax) in the order they
-// were specified.
-namespace libtextclassifier3;
-table AndroidSimpleIntentGeneratorOptions {
- // The action to set on the Intent (see Intent.setAction). Supports variables.
- action:string;
-
- // The data to set on the Intent (see Intent.setData). Supports variables.
- data:string;
-
- // The type to set on the Intent (see Intent.setType). Supports variables.
- type:string;
-
- // The list of all the extras to add to the Intent.
- extra:[libtextclassifier3.AndroidSimpleIntentGeneratorExtra];
-
- // The list of all the variables that become available for substitution in
- // the action, data, type and extra strings. To e.g. set a field to the value
- // of the first variable, use "%0$s".
- variable:[libtextclassifier3.AndroidSimpleIntentGeneratorVariableType];
-
- // The list of all conditions that need to be fulfilled for Intent generation.
- condition:[libtextclassifier3.AndroidSimpleIntentGeneratorCondition];
-}
-
root_type libtextclassifier3.Model;
diff --git a/annotator/token-feature-extractor_test.cc b/annotator/token-feature-extractor_test.cc
index d669129..32383a9 100644
--- a/annotator/token-feature-extractor_test.cc
+++ b/annotator/token-feature-extractor_test.cc
@@ -233,7 +233,7 @@
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
}
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -340,7 +340,7 @@
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
}
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
@@ -360,7 +360,7 @@
}
#endif
-#ifdef LIBTEXTCLASSIFIER_TEST_ICU
+#ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
diff --git a/generate_flatbuffers.mk b/generate_flatbuffers.mk
index 963dfef..ca24807 100644
--- a/generate_flatbuffers.mk
+++ b/generate_flatbuffers.mk
@@ -4,22 +4,34 @@
@echo "Flatc: $@ <= $(PRIVATE_INPUT_FBS)"
@rm -f $@
@mkdir -p $(dir $@)
-$(hide) $(FLATC) \
+$(FLATC) \
--cpp \
--no-union-value-namespacing \
--gen-object-api \
+ --keep-prefix \
+ -I $(INPUT_DIR) \
-o $(dir $@) \
$(PRIVATE_INPUT_FBS) \
|| exit 33
-$(hide) [ -f $@ ] || exit 33
+[ -f $@ ] || exit 33
endef
intermediates := $(call local-generated-sources-dir)
+# Generate utils/intent/intent-config_generated.h using FlatBuffer schema compiler.
+INTENT_CONFIG_FBS := $(LOCAL_PATH)/utils/intents/intent-config.fbs
+INTENT_CONFIG_H := $(intermediates)/utils/intents/intent-config_generated.h
+$(INTENT_CONFIG_H): PRIVATE_INPUT_FBS := $(INTENT_CONFIG_FBS)
+$(INTENT_CONFIG_H): INPUT_DIR := $(LOCAL_PATH)
+$(INTENT_CONFIG_H): $(FLATC)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(INTENT_CONFIG_H)
+
# Generate annotator/model_generated.h using FlatBuffer schema compiler.
ANNOTATOR_MODEL_FBS := $(LOCAL_PATH)/annotator/model.fbs
ANNOTATOR_MODEL_H := $(intermediates)/annotator/model_generated.h
$(ANNOTATOR_MODEL_H): PRIVATE_INPUT_FBS := $(ANNOTATOR_MODEL_FBS)
+$(ANNOTATOR_MODEL_H): INPUT_DIR := $(LOCAL_PATH)
$(ANNOTATOR_MODEL_H): $(FLATC)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(ANNOTATOR_MODEL_H)
@@ -28,6 +40,7 @@
ACTIONS_MODEL_FBS := $(LOCAL_PATH)/actions/actions_model.fbs
ACTIONS_MODEL_H := $(intermediates)/actions/actions_model_generated.h
$(ACTIONS_MODEL_H): PRIVATE_INPUT_FBS := $(ACTIONS_MODEL_FBS)
+$(ACTIONS_MODEL_H): INPUT_DIR := $(LOCAL_PATH)
$(ACTIONS_MODEL_H): $(FLATC)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(ACTIONS_MODEL_H)
@@ -36,6 +49,7 @@
UTILS_TFLITE_TEXT_ENCODER_CONFIG_FBS := $(LOCAL_PATH)/utils/tflite/text_encoder_config.fbs
UTILS_TFLITE_TEXT_ENCODER_CONFIG_H := $(intermediates)/utils/tflite/text_encoder_config_generated.h
$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): PRIVATE_INPUT_FBS := $(UTILS_TFLITE_TEXT_ENCODER_CONFIG_FBS)
+$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): INPUT_DIR := $(LOCAL_PATH)
$(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H): $(FLATC)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(UTILS_TFLITE_TEXT_ENCODER_CONFIG_H)
@@ -44,6 +58,7 @@
LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_FBS := $(LOCAL_PATH)/lang_id/common/flatbuffers/embedding-network.fbs
LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H := $(intermediates)/lang_id/common/flatbuffers/embedding-network_generated.h
$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): PRIVATE_INPUT_FBS := $(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_FBS)
+$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): INPUT_DIR := $(LOCAL_PATH)
$(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H): $(FLATC)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(LANG_ID_COMMON_FLATBUFFERS_EMBEDDING_NETWORK_H)
@@ -52,7 +67,7 @@
LANG_ID_COMMON_FLATBUFFERS_MODEL_FBS := $(LOCAL_PATH)/lang_id/common/flatbuffers/model.fbs
LANG_ID_COMMON_FLATBUFFERS_MODEL_H := $(intermediates)/lang_id/common/flatbuffers/model_generated.h
$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): PRIVATE_INPUT_FBS := $(LANG_ID_COMMON_FLATBUFFERS_MODEL_FBS)
+$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): INPUT_DIR := $(LOCAL_PATH)
$(LANG_ID_COMMON_FLATBUFFERS_MODEL_H): $(FLATC)
$(transform-fbs-to-cpp)
LOCAL_GENERATED_SOURCES += $(LANG_ID_COMMON_FLATBUFFERS_MODEL_H)
-
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
index b3599ea..3b6d033 100644
--- a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -32,16 +32,22 @@
}
private long actionsModelPtr;
+ private AnnotatorModel annotator;
/**
* Creates a new instance of Actions predictor, using the provided model image, given as a file
* descriptor.
*/
public ActionsSuggestionsModel(int fileDescriptor) {
+ this(fileDescriptor, null);
+ }
+
+ public ActionsSuggestionsModel(int fileDescriptor, AnnotatorModel annotator) {
actionsModelPtr = nativeNewActionsModel(fileDescriptor);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
}
+ setAnnotator(annotator);
}
/**
@@ -49,10 +55,15 @@
* path.
*/
public ActionsSuggestionsModel(String path) {
+ this(path, null);
+ }
+
+ public ActionsSuggestionsModel(String path, AnnotatorModel annotator) {
actionsModelPtr = nativeNewActionsModelFromPath(path);
if (actionsModelPtr == 0L) {
throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
}
+ setAnnotator(annotator);
}
/** Suggests actions / replies to the given conversation. */
@@ -156,6 +167,14 @@
/** Represents options for the SuggestActions call. */
public static final class ActionSuggestionOptions {}
+ /** Sets and annotator to use for actions suggestions. */
+ private void setAnnotator(AnnotatorModel annotator) {
+ this.annotator = annotator;
+ if (annotator != null) {
+ nativeSetAnnotator(annotator.getNativeAnnotator());
+ }
+ }
+
private static native long nativeNewActionsModel(int fd);
private static native long nativeNewActionsModelFromPath(String path);
@@ -170,4 +189,6 @@
long context, Conversation conversation, ActionSuggestionOptions options);
private native void nativeCloseActionsModel(long context);
+
+ private native void nativeSetAnnotator(long annotatorPtr);
}
diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java
index b268a28..08a4455 100644
--- a/java/com/google/android/textclassifier/AnnotatorModel.java
+++ b/java/com/google/android/textclassifier/AnnotatorModel.java
@@ -305,6 +305,14 @@
}
}
+ /**
+ * Retrieves the pointer to the native object. Note: Need to keep the AnnotatorModel alive as long
+ * as the pointer is used.
+ */
+ long getNativeAnnotator() {
+ return annotatorPtr;
+ }
+
private static native long nativeNewAnnotator(int fd);
private static native long nativeNewAnnotatorFromPath(String path);
diff --git a/lang_id/common/fel/task-context.cc b/lang_id/common/fel/task-context.cc
index 75aeaad..f8b0701 100644
--- a/lang_id/common/fel/task-context.cc
+++ b/lang_id/common/fel/task-context.cc
@@ -43,10 +43,6 @@
return defval;
}
-string TaskContext::Get(const string &name, const string &defval) const {
- return Get(name, defval.c_str());
-}
-
int TaskContext::Get(const string &name, int defval) const {
const string s = Get(name, "");
int value = defval;
diff --git a/lang_id/common/fel/task-context.h b/lang_id/common/fel/task-context.h
index f271095..ddc8cfe 100644
--- a/lang_id/common/fel/task-context.h
+++ b/lang_id/common/fel/task-context.h
@@ -51,7 +51,6 @@
// Returns parameter value. If the parameter is not specified in this
// context, the default value is returned.
string Get(const string &name, const char *defval) const;
- string Get(const string &name, const string &defval) const;
int Get(const string &name, int defval) const;
float Get(const string &name, float defval) const;
bool Get(const string &name, bool defval) const;
diff --git a/models/actions_suggestions.model b/models/actions_suggestions.model
index 893dd84..4809a7b 100644
--- a/models/actions_suggestions.model
+++ b/models/actions_suggestions.model
Binary files differ
diff --git a/utils/base/logging.h b/utils/base/logging.h
index e197780..e8bde39 100644
--- a/utils/base/logging.h
+++ b/utils/base/logging.h
@@ -155,7 +155,7 @@
#endif // NDEBUG
-#ifdef LIBTEXTCLASSIFIER_VLOG
+#ifdef TC3_VLOG
#define TC3_VLOG(severity) \
::libtextclassifier3::logging::LogMessage( \
::libtextclassifier3::logging::INFO, __FILE__, __LINE__) \
diff --git a/utils/calendar/calendar_test.cc b/utils/calendar/calendar_test.cc
index 02ce63f..a8c3af8 100644
--- a/utils/calendar/calendar_test.cc
+++ b/utils/calendar/calendar_test.cc
@@ -45,7 +45,7 @@
TC3_LOG(INFO) << result;
}
-#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
+#ifdef TC3_CALENDAR_ICU
TEST_F(CalendarTest, RoundingToGranularity) {
int64 time;
DateParseData data;
@@ -238,7 +238,7 @@
/*granularity=*/GRANULARITY_DAY, &time));
EXPECT_EQ(time, 1523397600000L /* 11 April 2018 00:00:00 */);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_DUMMY
+#endif // TC3_UNILIB_DUMMY
} // namespace
} // namespace libtextclassifier3
diff --git a/utils/intents/intent-config.fbs b/utils/intents/intent-config.fbs
new file mode 100755
index 0000000..d350ae4
--- /dev/null
+++ b/utils/intents/intent-config.fbs
@@ -0,0 +1,174 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// The type of variable to fetch.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorVariableType : int {
+ INVALID_VARIABLE = 0,
+
+ // The raw text that was classified.
+ RAW_TEXT = 1,
+
+ // Text as a URL with explicit protocol. If no protocol was specified, http
+ // is prepended.
+ URL_TEXT = 2,
+
+ // The raw text, but URL encoded.
+ URL_ENCODED_TEXT = 3,
+
+ // For dates/times: the instant of the event in UTC millis.
+ EVENT_TIME_MS_UTC = 4,
+
+ // For dates/times: the start of the event in UTC millis.
+ EVENT_START_MS_UTC = 5,
+
+ // For dates/times: the end of the event in UTC millis.
+ EVENT_END_MS_UTC = 6,
+
+ // Name of the package that's running the classifier.
+ PACKAGE_NAME = 7,
+}
+
+// Enumerates the possible extra types for the simple intent generator.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorExtraType : int {
+ INVALID_EXTRA_TYPE = 0,
+ STRING = 1,
+ BOOL = 2,
+ VARIABLE_AS_LONG = 3,
+}
+
+// Enumerates the possible condition types for the simple intent generator.
+namespace libtextclassifier3;
+enum AndroidSimpleIntentGeneratorConditionType : int {
+ INVALID_CONDITION_TYPE = 0,
+
+ // Queries the UserManager for the given boolean restriction. The condition
+ // passes if the result is of getBoolean is false. The name of the
+ // restriction to check is in the string_ field.
+ USER_RESTRICTION_NOT_SET = 1,
+
+ // Checks that the parsed event start time is at least a give number of
+ // milliseconds in the future. (Only valid if there is a parsed event
+ // time) The offset is stored in the int64_ field.
+ EVENT_START_IN_FUTURE_MS = 2,
+}
+
+// Describes how intents for the various entity types should be generated on
+// Android. This is distributed through the model, but not used by
+// libtextclassifier yet - rather, it's passed to the calling Java code, which
+// implements the Intent generation logic.
+namespace libtextclassifier3;
+table AndroidIntentFactoryOptions {
+ entity:[libtextclassifier3.AndroidIntentFactoryEntityOptions];
+}
+
+// Describes how intents should be generated for a particular entity type.
+namespace libtextclassifier3;
+table AndroidIntentFactoryEntityOptions {
+ // The entity type as defined by one of the TextClassifier ENTITY_TYPE
+ // constants. (e.g. "address", "phone", etc.)
+ entity_type:string;
+
+ // List of generators for all the different types of intents that should
+ // be made available for the entity type.
+ generator:[libtextclassifier3.AndroidIntentGeneratorOptions];
+}
+
+// Configures a single Android Intent generator.
+namespace libtextclassifier3;
+table AndroidIntentGeneratorOptions {
+ // Strings for UI elements.
+ strings:[libtextclassifier3.AndroidIntentGeneratorStrings];
+
+ // Generator specific configuration.
+ simple:libtextclassifier3.AndroidSimpleIntentGeneratorOptions;
+}
+
+// Language dependent configuration for an Android Intent generator.
+namespace libtextclassifier3;
+table AndroidIntentGeneratorStrings {
+ // BCP 47 tag for the supported locale. Note that because of API level
+ // restrictions, this must /not/ use wildcards. To e.g. match all English
+ // locales, use only "en" and not "en_*". Reference the java.util.Locale
+ // constructor for details.
+ language_tag:string;
+
+ // Title shown for the action (see RemoteAction.getTitle).
+ title:string;
+
+ // Description shown for the action (see
+ // RemoteAction.getContentDescription).
+ description:string;
+}
+
+// An extra to set on a simple intent generator Intent.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorExtra {
+ // The name of the extra to set.
+ name:string;
+
+ // The type of the extra to set.
+ type:libtextclassifier3.AndroidSimpleIntentGeneratorExtraType;
+
+ string_:string;
+
+ bool_:bool;
+ int32_:int;
+}
+
+// A condition that needs to be fulfilled for an Intent to get generated.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorCondition {
+ type:libtextclassifier3.AndroidSimpleIntentGeneratorConditionType;
+
+ string_:string;
+
+ int32_:int;
+ int64_:long;
+}
+
+// Configures an intent generator where the logic is simple to be expressed with
+// basic rules - which covers the vast majority of use cases and is analogous
+// to Android Actions.
+// Most strings (action, data, type, ...) may contain variable references. To
+// use them, the generator must first declare all the variables it wishes to use
+// in the variables field. The values then become available as numbered
+// arguments (using the normal java.util.Formatter syntax) in the order they
+// were specified.
+namespace libtextclassifier3;
+table AndroidSimpleIntentGeneratorOptions {
+ // The action to set on the Intent (see Intent.setAction). Supports variables.
+ action:string;
+
+ // The data to set on the Intent (see Intent.setData). Supports variables.
+ data:string;
+
+ // The type to set on the Intent (see Intent.setType). Supports variables.
+ type:string;
+
+ // The list of all the extras to add to the Intent.
+ extra:[libtextclassifier3.AndroidSimpleIntentGeneratorExtra];
+
+ // The list of all the variables that become available for substitution in
+ // the action, data, type and extra strings. To e.g. set a field to the value
+ // of the first variable, use "%0$s".
+ variable:[libtextclassifier3.AndroidSimpleIntentGeneratorVariableType];
+
+ // The list of all conditions that need to be fulfilled for Intent generation.
+ condition:[libtextclassifier3.AndroidSimpleIntentGeneratorCondition];
+}
+
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
index 96fb868..6ffb0c7 100644
--- a/utils/sentencepiece/encoder.cc
+++ b/utils/sentencepiece/encoder.cc
@@ -35,6 +35,17 @@
normalized_text.RemovePrefix(1);
continue;
}
+ // Check whether we can use the unknown token.
+ if (unknown_code_ >= 0) {
+ const int pos = i + 1;
+ const float unknown_penalty = segmentation[i].score + unknown_score_;
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < unknown_penalty) {
+ segmentation[pos] = {/*score=*/unknown_penalty, /*previous_pos=*/i,
+ /*piece_id=*/unknown_code_,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
for (const auto& match : matcher_->FindAllPrefixMatches(normalized_text)) {
TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
const int pos = i + match.match_length;
@@ -42,7 +53,7 @@
if (segmentation[pos].previous_pos < 0 ||
segmentation[pos].score < candidate_score) {
segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
- /*piece_id=*/match.id,
+ /*piece_id=*/match.id + encoding_offset_,
/*num_pieces=*/segmentation[i].num_pieces + 1};
}
}
@@ -57,7 +68,7 @@
result[num_pieces + 1] = end_code_;
int pos = len;
for (int i = num_pieces; i > 0; i--) {
- result[i] = segmentation[pos].piece_id + encoding_offset_;
+ result[i] = segmentation[pos].piece_id;
pos = segmentation[pos].previous_pos;
}
result[0] = start_code_;
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
index fffd86f..0f1bfd3 100644
--- a/utils/sentencepiece/encoder.h
+++ b/utils/sentencepiece/encoder.h
@@ -33,19 +33,24 @@
// a trie.
// num_pieces: the number of pieces in the trie.
// pieces_scores: the scores of the individual pieces.
- // start_code: Code that is used as encoding of the start of input.
- // end_code: Code that is used as encoding of the end of input.
- // encoding_offset: Value added to the sentence piece ids to make them
+ // start_code: code that is used as encoding of the start of input.
+ // end_code: code that is used as encoding of the end of input.
+ // encoding_offset: value added to the sentence piece ids to make them
// not interesecting with start_code and end_code.
+ // unknown_code: code that is used for out-of-dictionary characters.
+ // unknown_score: the penality score associated with the unknown code.
Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
const float* pieces_scores, int start_code = 0, int end_code = 1,
- int encoding_offset = 2)
+ int encoding_offset = 2, int unknown_code = -1,
+ float unknown_score = 0.f)
: num_pieces_(num_pieces),
scores_(pieces_scores),
matcher_(matcher),
start_code_(start_code),
end_code_(end_code),
- encoding_offset_(encoding_offset) {}
+ encoding_offset_(encoding_offset),
+ unknown_code_(unknown_code),
+ unknown_score_(unknown_score) {}
// Segment the input so that the total score of the pieces used is maximized.
// This is a simplified implementation of the general Viterbi algorithm,
@@ -74,6 +79,8 @@
const int start_code_;
const int end_code_;
const int encoding_offset_;
+ const int unknown_code_;
+ const int unknown_score_;
};
} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
index 59c12ad..6bc9aeb 100644
--- a/utils/sentencepiece/encoder_test.cc
+++ b/utils/sentencepiece/encoder_test.cc
@@ -26,7 +26,7 @@
namespace libtextclassifier3 {
namespace {
-using testing::ElementsAreArray;
+using testing::ElementsAre;
using testing::IsEmpty;
TEST(EncoderTest, SimpleTokenization) {
@@ -38,12 +38,12 @@
const Encoder encoder(matcher.get(),
/*num_pieces=*/4, scores);
- EXPECT_THAT(encoder.Encode("hellothere"), ElementsAreArray({0, 3, 5, 1}));
+ EXPECT_THAT(encoder.Encode("hellothere"), ElementsAre(0, 3, 5, 1));
// Make probability of hello very low:
// hello gets now tokenized as hell + o.
scores[1] = -100.0;
- EXPECT_THAT(encoder.Encode("hellothere"), ElementsAreArray({0, 2, 4, 5, 1}));
+ EXPECT_THAT(encoder.Encode("hellothere"), ElementsAre(0, 2, 4, 5, 1));
}
TEST(EncoderTest, HandlesEdgeCases) {
@@ -54,10 +54,28 @@
/*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
const Encoder encoder(matcher.get(),
/*num_pieces=*/4, scores);
- EXPECT_THAT(encoder.Encode("hellhello"), ElementsAreArray({0, 2, 3, 1}));
- EXPECT_THAT(encoder.Encode("hellohell"), ElementsAreArray({0, 3, 2, 1}));
- EXPECT_THAT(encoder.Encode(""), ElementsAreArray({0, 1}));
- EXPECT_THAT(encoder.Encode("hellathere"), ElementsAreArray({0, 1}));
+ EXPECT_THAT(encoder.Encode("hellhello"), ElementsAre(0, 2, 3, 1));
+ EXPECT_THAT(encoder.Encode("hellohell"), ElementsAre(0, 3, 2, 1));
+ EXPECT_THAT(encoder.Encode(""), ElementsAre(0, 1));
+ EXPECT_THAT(encoder.Encode("hellathere"), ElementsAre(0, 1));
+}
+
+TEST(EncoderTest, HandlesOutOfDictionary) {
+ const char pieces[] = "hell\0hello\0o\0there\0";
+ const int offsets[] = {0, 5, 11, 13};
+ float scores[] = {-0.5, -1.0, -10.0, -1.0};
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores,
+ /*start_code=*/0, /*end_code=*/1,
+ /*encoding_offset=*/3, /*unknown_code=*/2,
+ /*unknown_score=*/-100.0);
+ EXPECT_THAT(encoder.Encode("hellhello"), ElementsAre(0, 3, 4, 1));
+ EXPECT_THAT(encoder.Encode("hellohell"), ElementsAre(0, 4, 3, 1));
+ EXPECT_THAT(encoder.Encode(""), ElementsAre(0, 1));
+ EXPECT_THAT(encoder.Encode("hellathere"),
+ ElementsAre(0, /*hell*/ 3, /*unknown*/ 2, /*there*/ 6, 1));
}
} // namespace
diff --git a/utils/sentencepiece/normalizer.cc b/utils/sentencepiece/normalizer.cc
index 9fcc1e5..1dd20da 100644
--- a/utils/sentencepiece/normalizer.cc
+++ b/utils/sentencepiece/normalizer.cc
@@ -21,7 +21,7 @@
namespace libtextclassifier3 {
-std::string Normalizer::Normalize(StringPiece input) const {
+std::string SentencePieceNormalizer::Normalize(StringPiece input) const {
std::string normalized;
// Ignores heading space.
@@ -106,7 +106,7 @@
return normalized;
}
-std::pair<StringPiece, int> Normalizer::NormalizePrefix(
+std::pair<StringPiece, int> SentencePieceNormalizer::NormalizePrefix(
StringPiece input) const {
std::pair<StringPiece, int> result;
if (input.empty()) return result;
diff --git a/utils/sentencepiece/normalizer.h b/utils/sentencepiece/normalizer.h
index 582d563..227e09b 100644
--- a/utils/sentencepiece/normalizer.h
+++ b/utils/sentencepiece/normalizer.h
@@ -27,7 +27,7 @@
// Normalizer implements a simple text normalizer with user-defined
// string-to-string rules and leftmost longest matching.
-class Normalizer {
+class SentencePieceNormalizer {
public:
// charsmap_trie and charsmap_normalized specify the normalization/replacement
// string-to-string rules in the following way:
@@ -41,10 +41,11 @@
// internal whitespace.
//
// escape_whitespaces: Whether to replace whitespace with a meta symbol.
- Normalizer(const DoubleArrayTrie &charsmap_trie,
- StringPiece charsmap_normalized, bool add_dummy_prefix = true,
- bool remove_extra_whitespaces = true,
- bool escape_whitespaces = true)
+ SentencePieceNormalizer(const DoubleArrayTrie &charsmap_trie,
+ StringPiece charsmap_normalized,
+ bool add_dummy_prefix = true,
+ bool remove_extra_whitespaces = true,
+ bool escape_whitespaces = true)
: charsmap_trie_(charsmap_trie),
charsmap_normalized_(charsmap_normalized),
add_dummy_prefix_(add_dummy_prefix),
diff --git a/utils/sentencepiece/normalizer_test.cc b/utils/sentencepiece/normalizer_test.cc
index 143e795..f6018ab 100644
--- a/utils/sentencepiece/normalizer_test.cc
+++ b/utils/sentencepiece/normalizer_test.cc
@@ -36,9 +36,10 @@
std::ifstream test_config_stream(GetTestConfigPath());
std::string config((std::istreambuf_iterator<char>(test_config_stream)),
(std::istreambuf_iterator<char>()));
- Normalizer normalizer = NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/true,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
EXPECT_EQ(normalizer.Normalize("hello there"), "▁hello▁there");
@@ -63,9 +64,10 @@
std::ifstream test_config_stream(GetTestConfigPath());
std::string config((std::istreambuf_iterator<char>(test_config_stream)),
(std::istreambuf_iterator<char>()));
- Normalizer normalizer = NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/true,
- /*escape_whitespaces=*/true);
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/true,
+ /*escape_whitespaces=*/true);
EXPECT_EQ(normalizer.Normalize("hello there"), "hello▁there");
@@ -90,9 +92,10 @@
std::ifstream test_config_stream(GetTestConfigPath());
std::string config((std::istreambuf_iterator<char>(test_config_stream)),
(std::istreambuf_iterator<char>()));
- Normalizer normalizer = NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/true);
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/true);
EXPECT_EQ(normalizer.Normalize("hello there"), "hello▁there");
@@ -108,9 +111,10 @@
std::ifstream test_config_stream(GetTestConfigPath());
std::string config((std::istreambuf_iterator<char>(test_config_stream)),
(std::istreambuf_iterator<char>()));
- Normalizer normalizer = NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
- /*remove_extra_whitespaces=*/false,
- /*escape_whitespaces=*/false);
+ SentencePieceNormalizer normalizer =
+ NormalizerFromSpec(config, /*add_dummy_prefix=*/false,
+ /*remove_extra_whitespaces=*/false,
+ /*escape_whitespaces=*/false);
EXPECT_EQ(normalizer.Normalize("hello there"), "hello there");
diff --git a/utils/sentencepiece/test_utils.cc b/utils/sentencepiece/test_utils.cc
index 1b766ac..1ed2bf3 100644
--- a/utils/sentencepiece/test_utils.cc
+++ b/utils/sentencepiece/test_utils.cc
@@ -24,15 +24,16 @@
namespace libtextclassifier3 {
-Normalizer NormalizerFromSpec(StringPiece spec, bool add_dummy_prefix,
- bool remove_extra_whitespaces,
- bool escape_whitespaces) {
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces) {
const uint32 trie_blob_size = reinterpret_cast<const uint32*>(spec.data())[0];
spec.RemovePrefix(sizeof(trie_blob_size));
const TrieNode* trie_blob = reinterpret_cast<const TrieNode*>(spec.data());
spec.RemovePrefix(trie_blob_size);
const int num_nodes = trie_blob_size / sizeof(TrieNode);
- return Normalizer(
+ return SentencePieceNormalizer(
DoubleArrayTrie(trie_blob, num_nodes),
/*charsmap_normalized=*/StringPiece(spec.data(), spec.size()),
add_dummy_prefix, remove_extra_whitespaces, escape_whitespaces);
diff --git a/utils/sentencepiece/test_utils.h b/utils/sentencepiece/test_utils.h
index 71a4994..0c833da 100644
--- a/utils/sentencepiece/test_utils.h
+++ b/utils/sentencepiece/test_utils.h
@@ -25,9 +25,10 @@
namespace libtextclassifier3 {
-Normalizer NormalizerFromSpec(StringPiece spec, bool add_dummy_prefix,
- bool remove_extra_whitespaces,
- bool escape_whitespaces);
+SentencePieceNormalizer NormalizerFromSpec(StringPiece spec,
+ bool add_dummy_prefix,
+ bool remove_extra_whitespaces,
+ bool escape_whitespaces);
} // namespace libtextclassifier3
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
index 9554283..734b5b0 100644
--- a/utils/tflite/text_encoder.cc
+++ b/utils/tflite/text_encoder.cc
@@ -35,7 +35,7 @@
namespace {
struct TextEncoderOp {
- std::unique_ptr<Normalizer> normalizer;
+ std::unique_ptr<SentencePieceNormalizer> normalizer;
std::unique_ptr<Encoder> encoder;
std::unique_ptr<SentencePieceMatcher> matcher;
};
@@ -81,7 +81,7 @@
config->normalization_charsmap()->Data());
const int charsmap_trie_nodes_length =
config->normalization_charsmap()->Length() / sizeof(TrieNode);
- encoder_op->normalizer.reset(new Normalizer(
+ encoder_op->normalizer.reset(new SentencePieceNormalizer(
DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
StringPiece(config->normalization_charsmap_values()->data(),
config->normalization_charsmap_values()->size()),
@@ -113,7 +113,8 @@
}
encoder_op->encoder.reset(new Encoder(
encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
- config->start_code(), config->end_code(), config->encoding_offset()));
+ config->start_code(), config->end_code(), config->encoding_offset(),
+ config->unknown_code(), config->unknown_score()));
return encoder_op.release();
}
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 462da21..8ae8fc5 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -34,6 +34,12 @@
// `start_code` and `end_code`.
encoding_offset:int32 = 2;
+ // Code that is used for out-of-dictionary characters.
+ unknown_code:int32 = -1;
+
+ // Penalty associated with the unknown code.
+ unknown_score:float;
+
// Normalization options.
// Serialized normalization charsmap.
normalization_charsmap:string;
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
index 0b6ff71..0cd67ce 100644
--- a/utils/tflite/text_encoder_test.cc
+++ b/utils/tflite/text_encoder_test.cc
@@ -20,6 +20,7 @@
#include "utils/tflite/text_encoder.h"
#include "gtest/gtest.h"
+#include "third_party/absl/flags/flag.h"
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
diff --git a/utils/utf8/unilib_test.cc b/utils/utf8/unilib_test.cc
index e2ad26b..96b2c2d 100644
--- a/utils/utf8/unilib_test.cc
+++ b/utils/utf8/unilib_test.cc
@@ -50,7 +50,7 @@
EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
}
-#ifndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
+#ifndef TC3_UNILIB_DUMMY
TEST_F(UniLibTest, CharacterClassesUnicode) {
EXPECT_TRUE(unilib_.IsOpeningBracket(0x0F3C)); // TIBET ANG KHANG GYON
EXPECT_TRUE(unilib_.IsClosingBracket(0x0F3D)); // TIBET ANG KHANG GYAS
@@ -72,7 +72,7 @@
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
}
-#endif // ndef LIBTEXTCLASSIFIER_UNILIB_DUMMY
+#endif // ndef TC3_UNILIB_DUMMY
TEST_F(UniLibTest, RegexInterface) {
const UnicodeText regex_pattern =
@@ -89,7 +89,7 @@
TC3_LOG(INFO) << matcher->Group(0, &status).size_codepoints();
}
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, Regex) {
// The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
// test the regex functionality with it to verify we are handling the indices
@@ -126,9 +126,9 @@
EXPECT_EQ(matcher->Group(0, &status).ToUTF8String(), "0123😋");
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, RegexGroups) {
// The smiley face is a 4-byte UTF8 codepoint 0x1F60B, and it's important to
// test the regex functionality with it to verify we are handling the indices
@@ -163,9 +163,9 @@
EXPECT_EQ(matcher->Group(2, &status).ToUTF8String(), "123");
EXPECT_EQ(status, UniLib::RegexMatcher::kNoError);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, BreakIterator) {
const UnicodeText text = UTF8ToUnicodeText("some text", /*do_copy=*/false);
@@ -178,9 +178,9 @@
}
EXPECT_THAT(break_indices, ElementsAre(4, 5, 9));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, BreakIterator4ByteUTF8) {
const UnicodeText text = UTF8ToUnicodeText("😀😂😋", /*do_copy=*/false);
std::unique_ptr<UniLib::BreakIterator> iterator =
@@ -192,18 +192,18 @@
}
EXPECT_THAT(break_indices, ElementsAre(1, 2, 3));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+#ifndef TC3_UNILIB_JAVAICU
TEST_F(UniLibTest, IntegerParse) {
int result;
EXPECT_TRUE(
unilib_.ParseInt32(UTF8ToUnicodeText("123", /*do_copy=*/false), &result));
EXPECT_EQ(result, 123);
}
-#endif // ndef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
+#endif // ndef TC3_UNILIB_JAVAICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, IntegerParseFullWidth) {
int result;
// The input string here is full width
@@ -211,16 +211,16 @@
&result));
EXPECT_EQ(result, 123);
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
-#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
+#ifdef TC3_UNILIB_ICU
TEST_F(UniLibTest, IntegerParseFullWidthWithAlpha) {
int result;
// The input string here is full width
EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
&result));
}
-#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
+#endif // TC3_UNILIB_ICU
} // namespace
} // namespace libtextclassifier3