Merge "Export libtextclassifier"
diff --git a/Android.mk b/Android.mk
index 17f0373..6b5fccb 100644
--- a/Android.mk
+++ b/Android.mk
@@ -64,7 +64,8 @@
LOCAL_CFLAGS += $(MY_LIBTEXTCLASSIFIER_CFLAGS)
LOCAL_STRIP_MODULE := $(LIBTEXTCLASSIFIER_STRIP_OPTS)
-LOCAL_SRC_FILES := $(filter-out tests/% %_test.cc test-util.%,$(call all-subdir-cpp-files))
+EXCLUDED_FILES := %_test.cc test-util.% utils/testing/% %-test-lib.cc
+LOCAL_SRC_FILES := $(filter-out $(EXCLUDED_FILES),$(call all-subdir-cpp-files))
LOCAL_C_INCLUDES := $(TOP)/external/zlib
LOCAL_C_INCLUDES += $(TOP)/external/tensorflow
@@ -77,6 +78,7 @@
LOCAL_SHARED_LIBRARIES += libz
LOCAL_STATIC_LIBRARIES += libutf
+LOCAL_STATIC_LIBRARIES += liblua
LOCAL_REQUIRED_MODULES := libtextclassifier_annotator_en_model
LOCAL_REQUIRED_MODULES += libtextclassifier_annotator_universal_model
@@ -113,6 +115,8 @@
LOCAL_CPPFLAGS_64 += -DTC3_TEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/test_data/\""
# TODO: Do not filter out tflite test once the dependency issue is resolved.
+EXCLUDED_FILES := utils/tflite/%_test.cc
+
LOCAL_SRC_FILES := $(filter-out utils/tflite/%_test.cc,$(call all-subdir-cpp-files))
LOCAL_C_INCLUDES := $(TOP)/external/zlib
@@ -127,6 +131,7 @@
LOCAL_STATIC_LIBRARIES += libgmock
LOCAL_STATIC_LIBRARIES += libutf
+LOCAL_STATIC_LIBRARIES += liblua
include $(BUILD_NATIVE_TEST)
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
index 281909d..8c8fe3f 100644
--- a/actions/actions-suggestions.cc
+++ b/actions/actions-suggestions.cc
@@ -192,9 +192,7 @@
return true;
}
- const int num_rules = model_->rules()->rule()->size();
- for (int i = 0; i < num_rules; i++) {
- const auto* rule = model_->rules()->rule()->Get(i);
+ for (const RulesModel_::Rule* rule : *model_->rules()->rule()) {
std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
UncompressMakeRegexPattern(*unilib_, rule->pattern(),
rule->compressed_pattern(), decompressor);
@@ -202,7 +200,7 @@
TC3_LOG(ERROR) << "Failed to load rule pattern.";
return false;
}
- rules_.push_back({/*rule_id=*/i, std::move(compiled_pattern)});
+ rules_.push_back({rule, std::move(compiled_pattern)});
}
return true;
@@ -259,7 +257,7 @@
}
void ActionsSuggestions::ReadModelOutput(
- tflite::Interpreter* interpreter,
+ tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const {
// Read sensitivity and triggering score predictions.
if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
@@ -273,6 +271,7 @@
}
response->triggering_score = triggering_score.data()[0];
response->output_filtered_min_triggering_score =
+ !options.ignore_min_replies_triggering_threshold &&
(response->triggering_score <
model_->preconditions()->min_smart_reply_triggering_score());
}
@@ -336,6 +335,7 @@
void ActionsSuggestions::SuggestActionsFromModel(
const Conversation& conversation, const int num_messages,
+ const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const {
TC3_CHECK_LE(num_messages, conversation.messages.size());
@@ -393,7 +393,7 @@
return;
}
- ReadModelOutput(interpreter.get(), response);
+ ReadModelOutput(interpreter.get(), options, response);
}
void ActionsSuggestions::SuggestActionsFromAnnotations(
@@ -442,10 +442,13 @@
const float score =
(mapping->use_annotation_score() ? classification_result.score
: mapping->action()->score());
- suggestions->actions.push_back({/*response_text=*/"",
- /*type=*/mapping->action()->type()->str(),
- /*score=*/score,
- /*annotations=*/{suggestion_annotation}});
+ suggestions->actions.push_back({
+ /*response_text=*/"",
+ /*type=*/mapping->action()->type()->str(),
+ /*score=*/score,
+ /*annotations=*/{suggestion_annotation},
+ /*extra=*/AsVariantMap(mapping->action()->extra()),
+ });
}
}
}
@@ -457,21 +460,21 @@
const std::string& message = conversation.messages.back().text;
const UnicodeText message_unicode(
UTF8ToUnicodeText(message, /*do_copy=*/false));
- for (int i = 0; i < rules_.size(); i++) {
+ for (const CompiledRule& rule : rules_) {
const std::unique_ptr<UniLib::RegexMatcher> matcher =
- rules_[i].pattern->Matcher(message_unicode);
+ rule.pattern->Matcher(message_unicode);
int status = UniLib::RegexMatcher::kNoError;
if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- const auto actions =
- model_->rules()->rule()->Get(rules_[i].rule_id)->actions();
- for (int k = 0; k < actions->size(); k++) {
- const ActionSuggestionSpec* action = actions->Get(k);
- suggestions->actions.push_back(
- {/*response_text=*/(action->response_text() != nullptr
- ? action->response_text()->str()
- : ""),
- /*type=*/action->type()->str(),
- /*score=*/action->score()});
+ for (const ActionSuggestionSpec* action : *rule.rule->actions()) {
+ suggestions->actions.push_back({
+ /*response_text=*/(action->response_text() != nullptr
+ ? action->response_text()->str()
+ : ""),
+ /*type=*/action->type()->str(),
+ /*score=*/action->score(),
+ /*annotations=*/{},
+ /*extra=*/AsVariantMap(action->extra()),
+ });
}
}
}
@@ -515,7 +518,7 @@
SuggestActionsFromRules(conversation, &response);
- SuggestActionsFromModel(conversation, num_messages, &response);
+ SuggestActionsFromModel(conversation, num_messages, options, &response);
// Suppress all predictions if the conversation was deemed sensitive.
if (model_->preconditions()->suppress_on_sensitive_topic() &&
@@ -536,6 +539,10 @@
return SuggestActions(conversation, /*annotator=*/nullptr, options);
}
+float ActionsSuggestions::GetMinRepliesTriggeringThreshold() const {
+ return model_->preconditions()->min_smart_reply_triggering_score();
+}
+
const ActionsModel* ViewActionsModel(const void* buffer, int size) {
if (buffer == nullptr) {
return nullptr;
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
index b5f0c2e..068a4a1 100644
--- a/actions/actions-suggestions.h
+++ b/actions/actions-suggestions.h
@@ -17,6 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
#define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
+#include <map>
#include <memory>
#include <string>
#include <vector>
@@ -27,6 +28,7 @@
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
#include "utils/utf8/unilib.h"
+#include "utils/variant.h"
#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -62,6 +64,9 @@
// The associated annotations.
std::vector<ActionSuggestionAnnotation> annotations;
+
+ // Extras information.
+ std::map<std::string, Variant> extra;
};
// Actions suggestions result containing meta-information and the suggested
@@ -116,6 +121,7 @@
struct ActionSuggestionOptions {
// Options for annotation of the messages.
AnnotationOptions annotation_options = AnnotationOptions::Default();
+ bool ignore_min_replies_triggering_threshold = false;
static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
};
@@ -150,6 +156,8 @@
// Provide an annotator.
void SetAnnotator(const Annotator* annotator);
+ float GetMinRepliesTriggeringThreshold() const;
+
// Should be in sync with those defined in Android.
// android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
static const std::string& kViewCalendarType;
@@ -177,10 +185,12 @@
const int num_suggestions,
tflite::Interpreter* interpreter) const;
void ReadModelOutput(tflite::Interpreter* interpreter,
+ const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const;
void SuggestActionsFromModel(const Conversation& conversation,
const int num_messages,
+ const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* response) const;
void SuggestActionsFromAnnotations(
@@ -206,7 +216,7 @@
// Rules.
struct CompiledRule {
- int rule_id;
+ const RulesModel_::Rule* rule;
std::unique_ptr<UniLib::RegexPattern> pattern;
};
std::vector<CompiledRule> rules_;
@@ -218,6 +228,23 @@
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const ActionsModel* ViewActionsModel(const void* buffer, int size);
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitActionsModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const ActionsModel* model =
+ ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
index c82763d..e04155a 100644
--- a/actions/actions-suggestions_test.cc
+++ b/actions/actions-suggestions_test.cc
@@ -21,6 +21,7 @@
#include <memory>
#include "actions/actions_model_generated.h"
+#include "annotator/collections.h"
#include "annotator/types.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -155,7 +156,7 @@
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
- ClassificationResult(Annotator::kAddressCollection, 1.0)};
+ ClassificationResult(Collections::kAddress, 1.0)};
const ActionsSuggestionsResponse& response =
actions_suggestions->SuggestActions({{{/*user_id=*/1, "are you at home?",
/*time_diff_secs=*/0,
@@ -182,7 +183,7 @@
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
- ClassificationResult(Annotator::kAddressCollection, 1.0)};
+ ClassificationResult(Collections::kAddress, 1.0)};
const ActionsSuggestionsResponse& response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/0, "hi, how are you?", /*reference_time=*/10000},
@@ -198,8 +199,7 @@
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
AnnotatedSpan annotation;
annotation.span = {13, 16};
- annotation.classification = {
- ClassificationResult(Annotator::kPhoneCollection, 1.0)};
+ annotation.classification = {ClassificationResult(Collections::kPhone, 1.0)};
const ActionsSuggestionsResponse& response =
actions_suggestions->SuggestActions({{{/*user_id=*/1, "can you call 911?",
@@ -236,6 +236,18 @@
actions_model->rules->rule.back()->actions.back()->response_text =
"General Kenobi!";
actions_model->rules->rule.back()->actions.back()->score = 1.f;
+ actions_model->rules->rule.back()->actions.back()->extra.emplace_back(
+ new NamedVariantT);
+ actions_model->rules->rule.back()->actions.back()->extra.back()->name =
+ "person";
+ actions_model->rules->rule.back()->actions.back()->extra.back()->value.reset(
+ new VariantValueT);
+ actions_model->rules->rule.back()
+ ->actions.back()
+ ->extra.back()
+ ->value->string_value = "Kenobi";
+ actions_model->rules->rule.back()->actions.back()->extra.back()->value->type =
+ VariantValue_::Type_STRING_VALUE;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
@@ -247,8 +259,10 @@
const ActionsSuggestionsResponse& response =
actions_suggestions->SuggestActions({{{/*user_id=*/1, "hello there"}}});
- EXPECT_EQ(response.actions.size(), 1);
+ EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+ EXPECT_EQ(response.actions[0].extra.size(), 1);
+ EXPECT_EQ(response.actions[0].extra.at("person").StringValue(), "Kenobi");
}
TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
@@ -298,5 +312,22 @@
}
#endif
+TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
+ EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
+ [](const ActionsModel* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
index 88dff79..ede3d59 100755
--- a/actions/actions_model.fbs
+++ b/actions/actions_model.fbs
@@ -14,6 +14,7 @@
// limitations under the License.
//
+include "utils/named-extra.fbs";
include "utils/zlib/buffer.fbs";
file_identifier "TC3A";
@@ -85,6 +86,9 @@
// Score.
score:float;
+
+ // Extra information.
+ extra:[NamedVariant];
}
namespace libtextclassifier3.AnnotationActionsSpec_;
@@ -93,7 +97,7 @@
annotation_collection:string;
// The action name to use.
- action:libtextclassifier3.ActionSuggestionSpec;
+ action:ActionSuggestionSpec;
// Whether to use the score of the annotation as the action score.
use_annotation_score:bool = true;
@@ -105,7 +109,7 @@
// Configuration for actions based on annotatations.
namespace libtextclassifier3;
table AnnotationActionsSpec {
- annotation_mapping:[libtextclassifier3.AnnotationActionsSpec_.AnnotationMapping];
+ annotation_mapping:[AnnotationActionsSpec_.AnnotationMapping];
}
// List of regular expression matchers.
@@ -114,16 +118,33 @@
// The regular expression pattern.
pattern:string;
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
// The actions to produce upon triggering.
- actions:[libtextclassifier3.ActionSuggestionSpec];
+ actions:[ActionSuggestionSpec];
}
// Rule based actions.
namespace libtextclassifier3;
table RulesModel {
- rule:[libtextclassifier3.RulesModel_.Rule];
+ rule:[RulesModel_.Rule];
+}
+
+// Describes how intents should be generated for a particular action type.
+namespace libtextclassifier3.ActionsIntentFactoryModel_;
+table IntentGenerator {
+ // The action type.
+ action_type:string;
+
+ // The template generator lua code, either as text source or precompiled
+ // bytecode.
+ lua_template_generator:[ubyte];
+}
+
+// Describes how intents for the various action types should be generated.
+namespace libtextclassifier3;
+table ActionsIntentFactoryModel {
+ action:[ActionsIntentFactoryModel_.IntentGenerator];
}
namespace libtextclassifier3;
@@ -137,15 +158,15 @@
// A name for the model that can be used e.g. for logging.
name:string;
- tflite_model_spec:libtextclassifier3.TensorflowLiteModelSpec;
+ tflite_model_spec:TensorflowLiteModelSpec;
// Output classes.
smart_reply_action_type:string;
- action_type:[libtextclassifier3.ActionTypeOptions];
+ action_type:[ActionTypeOptions];
// Triggering conditions of the model.
- preconditions:libtextclassifier3.TriggeringPreconditions;
+ preconditions:TriggeringPreconditions;
// Default number of smart reply predictions.
num_smart_replies:int = 3;
@@ -154,10 +175,13 @@
max_conversation_history_length:int = 1;
// Configuration for mapping annotations to action suggestions.
- annotation_actions_spec:libtextclassifier3.AnnotationActionsSpec;
+ annotation_actions_spec:AnnotationActionsSpec;
// Configuration for rules.
- rules:libtextclassifier3.RulesModel;
+ rules:RulesModel;
+
+ // Configuration for intent generation on Android.
+ android_intent_options:ActionsIntentFactoryModel;
}
root_type libtextclassifier3.ActionsModel;
diff --git a/actions/test_data/actions_suggestions_test.model b/actions/test_data/actions_suggestions_test.model
index ee60ce2..051a193 100644
--- a/actions/test_data/actions_suggestions_test.model
+++ b/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 2be9d3c..330cf0b 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -21,15 +21,15 @@
#include <cmath>
#include <iterator>
#include <numeric>
+#include "annotator/types.h"
+#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-const std::string& Annotator::kOtherCollection =
- *[]() { return new std::string("other"); }();
const std::string& Annotator::kPhoneCollection =
*[]() { return new std::string("phone"); }();
const std::string& Annotator::kAddressCollection =
@@ -38,18 +38,8 @@
*[]() { return new std::string("date"); }();
const std::string& Annotator::kUrlCollection =
*[]() { return new std::string("url"); }();
-const std::string& Annotator::kFlightCollection =
- *[]() { return new std::string("flight"); }();
const std::string& Annotator::kEmailCollection =
*[]() { return new std::string("email"); }();
-const std::string& Annotator::kIbanCollection =
- *[]() { return new std::string("iban"); }();
-const std::string& Annotator::kPaymentCardCollection =
- *[]() { return new std::string("payment_card"); }();
-const std::string& Annotator::kIsbnCollection =
- *[]() { return new std::string("isbn"); }();
-const std::string& Annotator::kTrackingNumberCollection =
- *[]() { return new std::string("tracking_number"); }();
namespace {
const Model* LoadAndVerifyModel(const void* addr, int size) {
@@ -373,15 +363,9 @@
selection_regex_patterns_.push_back(regex_pattern_id);
}
regex_patterns_.push_back({
- regex_pattern->collection_name()->str(),
- regex_pattern->target_classification_score(),
- regex_pattern->priority_score(),
+ regex_pattern,
std::move(compiled_pattern),
- regex_pattern->verification_options(),
});
- if (regex_pattern->use_approximate_matching()) {
- regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
- }
++regex_pattern_id;
}
@@ -400,6 +384,16 @@
return true;
}
+bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
+ std::unique_ptr<ContactEngine> contact_engine(new ContactEngine());
+ if (!contact_engine->Initialize(serialized_config)) {
+ TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
+ return false;
+ }
+ contact_engine_ = std::move(contact_engine);
+ return true;
+}
+
namespace {
int CountDigits(const std::string& str, CodepointSpan selection_indices) {
@@ -571,6 +565,11 @@
TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
return original_click_indices;
}
+ if (contact_engine_ &&
+ !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Contact suggest selection failed.";
+ return original_click_indices;
+ }
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
@@ -671,7 +670,7 @@
inline bool ClassifiedAsOther(
const std::vector<ClassificationResult>& classification) {
return !classification.empty() &&
- classification[0].collection == Annotator::kOtherCollection;
+ classification[0].collection == Collections::kOther;
}
float GetPriorityScore(
@@ -937,7 +936,7 @@
if (model_->classification_options()->max_num_tokens() > 0 &&
model_->classification_options()->max_num_tokens() <
selection_num_tokens) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::kOther, 1.0}};
return true;
}
@@ -977,7 +976,7 @@
if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
tokens, extraction_span)) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::kOther, 1.0}};
return true;
}
@@ -1031,22 +1030,22 @@
// Phone class sanity check.
if (!classification_results->empty() &&
- classification_results->begin()->collection == kPhoneCollection) {
+ classification_results->begin()->collection == Collections::kPhone) {
const int digit_count = CountDigits(context, selection_indices);
if (digit_count <
model_->classification_options()->phone_min_num_digits() ||
digit_count >
model_->classification_options()->phone_max_num_digits()) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::kOther, 1.0}};
}
}
// Address class sanity check.
if (!classification_results->empty() &&
- classification_results->begin()->collection == kAddressCollection) {
+ classification_results->begin()->collection == Collections::kAddress) {
if (selection_num_tokens <
model_->classification_options()->address_min_num_tokens()) {
- *classification_results = {{kOtherCollection, 1.0}};
+ *classification_results = {{Collections::kOther, 1.0}};
}
}
@@ -1068,8 +1067,7 @@
regex_pattern.pattern->Matcher(selection_text_unicode);
int status = UniLib::RegexMatcher::kNoError;
bool matches;
- if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
- regex_approximate_match_pattern_ids_.end()) {
+ if (regex_pattern.config->use_approximate_matching()) {
matches = matcher->ApproximatelyMatches(&status);
} else {
matches = matcher->Matches(&status);
@@ -1077,11 +1075,12 @@
if (status != UniLib::RegexMatcher::kNoError) {
return false;
}
- if (matches &&
- VerifyCandidate(regex_pattern.verification_options, selection_text)) {
- *classification_result = {regex_pattern.collection_name,
- regex_pattern.target_classification_score,
- regex_pattern.priority_score};
+ if (matches && VerifyCandidate(regex_pattern.config->verification_options(),
+ selection_text)) {
+ *classification_result = {
+ regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()};
return true;
}
if (status != UniLib::RegexMatcher::kNoError) {
@@ -1095,7 +1094,7 @@
bool Annotator::DatetimeClassifyText(
const std::string& context, CodepointSpan selection_indices,
const ClassificationOptions& options,
- ClassificationResult* classification_result) const {
+ std::vector<ClassificationResult>* classification_results) const {
if (!datetime_parser_) {
return false;
}
@@ -1117,9 +1116,11 @@
if (std::make_pair(datetime_span.span.first + selection_indices.first,
datetime_span.span.second + selection_indices.first) ==
selection_indices) {
- *classification_result = {kDateCollection,
- datetime_span.target_classification_score};
- classification_result->datetime_parse_result = datetime_span.data;
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ classification_results->emplace_back(
+ Collections::kDate, datetime_span.target_classification_score);
+ classification_results->back().datetime_parse_result = parse_result;
+ }
return true;
}
}
@@ -1156,7 +1157,18 @@
if (!FilteredForClassification(knowledge_result)) {
return {knowledge_result};
} else {
- return {{kOtherCollection, 1.0}};
+ return {{Collections::kOther, 1.0}};
+ }
+ }
+
+ // Try the contact engine.
+ ClassificationResult contact_result;
+ if (contact_engine_ && contact_engine_->ClassifyText(
+ context, selection_indices, &contact_result)) {
+ if (!FilteredForClassification(contact_result)) {
+ return {contact_result};
+ } else {
+ return {{Collections::kOther, 1.0}};
}
}
@@ -1166,18 +1178,25 @@
if (!FilteredForClassification(regex_result)) {
return {regex_result};
} else {
- return {{kOtherCollection, 1.0}};
+ return {{Collections::kOther, 1.0}};
}
}
// Try the date model.
- ClassificationResult datetime_result;
+ std::vector<ClassificationResult> datetime_results;
if (DatetimeClassifyText(context, selection_indices, options,
- &datetime_result)) {
- if (!FilteredForClassification(datetime_result)) {
- return {datetime_result};
+ &datetime_results)) {
+ for (int i = 0; i < datetime_results.size(); i++) {
+ if (FilteredForClassification(datetime_results[i])) {
+ datetime_results.erase(datetime_results.begin() + i);
+ i--;
+ }
+ }
+
+ if (!datetime_results.empty()) {
+ return datetime_results;
} else {
- return {{kOtherCollection, 1.0}};
+ return {{Collections::kOther, 1.0}};
}
}
@@ -1192,7 +1211,7 @@
if (!FilteredForClassification(model_result[0])) {
return model_result;
} else {
- return {{kOtherCollection, 1.0}};
+ return {{Collections::kOther, 1.0}};
}
}
@@ -1317,7 +1336,9 @@
return {};
}
- if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!context_unicode.is_valid()) {
return {};
}
@@ -1351,6 +1372,13 @@
return {};
}
+ // Annotate with the contact engine.
+ if (contact_engine_ &&
+ !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
+ TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
+ return {};
+ }
+
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
@@ -1379,6 +1407,47 @@
return result;
}
+CodepointSpan Annotator::ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const {
+ if (config->capturing_group() == nullptr) {
+ // Use first capturing group to specify the selection.
+ int status = UniLib::RegexMatcher::kNoError;
+ const CodepointSpan result = {match->Start(1, &status),
+ match->End(1, &status)};
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ return result;
+ }
+
+ CodepointSpan result = {kInvalidIndex, kInvalidIndex};
+ const int num_groups = config->capturing_group()->size();
+ for (int i = 0; i < num_groups; i++) {
+ if (!config->capturing_group()->Get(i)->extend_selection()) {
+ continue;
+ }
+
+ int status = UniLib::RegexMatcher::kNoError;
+ // Check match and adjust bounds.
+ const int group_start = match->Start(i, &status);
+ const int group_end = match->End(i, &status);
+ if (status != UniLib::RegexMatcher::kNoError) {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+ if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
+ continue;
+ }
+ if (result.first == kInvalidIndex) {
+ result = {group_start, group_end};
+ } else {
+ result.first = std::min(result.first, group_start);
+ result.second = std::max(result.second, group_end);
+ }
+ }
+ return result;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
std::vector<AnnotatedSpan>* result) const {
@@ -1393,21 +1462,23 @@
int status = UniLib::RegexMatcher::kNoError;
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (regex_pattern.verification_options) {
- if (!VerifyCandidate(regex_pattern.verification_options,
+ if (regex_pattern.config->verification_options()) {
+ if (!VerifyCandidate(regex_pattern.config->verification_options(),
matcher->Group(1, &status).ToUTF8String())) {
continue;
}
}
result->emplace_back();
+
// Selection/annotation regular expressions need to specify a capturing
// group specifying the selection.
- result->back().span = {matcher->Start(1, &status),
- matcher->End(1, &status)};
+ result->back().span =
+ ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
+
result->back().classification = {
- {regex_pattern.collection_name,
- regex_pattern.target_classification_score,
- regex_pattern.priority_score}};
+ {regex_pattern.config->collection_name()->str(),
+ regex_pattern.config->target_classification_score(),
+ regex_pattern.config->priority_score()}};
}
}
return true;
@@ -1672,18 +1743,22 @@
return false;
}
for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
- AnnotatedSpan annotated_span;
- annotated_span.span = datetime_span.span;
- annotated_span.classification = {{kDateCollection,
- datetime_span.target_classification_score,
- datetime_span.priority_score}};
- annotated_span.classification[0].datetime_parse_result = datetime_span.data;
+ for (const DatetimeParseResult& parse_result : datetime_span.data) {
+ AnnotatedSpan annotated_span;
+ annotated_span.span = datetime_span.span;
+ annotated_span.classification = {
+ {Collections::kDate, datetime_span.target_classification_score,
+ datetime_span.priority_score}};
+ annotated_span.classification[0].datetime_parse_result = parse_result;
- result->push_back(std::move(annotated_span));
+ result->push_back(std::move(annotated_span));
+ }
}
return true;
}
+const Model* Annotator::ViewModel() const { return model_; }
+
const Model* ViewModel(const void* buffer, int size) {
if (!buffer) {
return nullptr;
diff --git a/annotator/annotator.h b/annotator/annotator.h
index c58c03d..fa3e3d2 100644
--- a/annotator/annotator.h
+++ b/annotator/annotator.h
@@ -24,6 +24,7 @@
#include <string>
#include <vector>
+#include "annotator/contact/contact-engine.h"
#include "annotator/datetime/parser.h"
#include "annotator/feature-processor.h"
#include "annotator/knowledge/knowledge-engine.h"
@@ -133,6 +134,9 @@
// Initializes the knowledge engine with the given config.
bool InitializeKnowledgeEngine(const std::string& serialized_config);
+ // Initializes the contact engine with the given config.
+ bool InitializeContactEngine(const std::string& serialized_config);
+
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
@@ -158,6 +162,8 @@
const std::string& context,
const AnnotationOptions& options = AnnotationOptions::Default()) const;
+ const Model* ViewModel() const;
+
// Exposes the feature processor for tests and evaluations.
const FeatureProcessor* SelectionFeatureProcessorForTests() const;
const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
@@ -165,18 +171,11 @@
// Exposes the date time parser for tests and evaluations.
const DatetimeParser* DatetimeParserForTests() const;
- // String collection names for various classes.
- static const std::string& kOtherCollection;
static const std::string& kPhoneCollection;
static const std::string& kAddressCollection;
static const std::string& kDateCollection;
static const std::string& kUrlCollection;
- static const std::string& kFlightCollection;
static const std::string& kEmailCollection;
- static const std::string& kIbanCollection;
- static const std::string& kPaymentCardCollection;
- static const std::string& kIsbnCollection;
- static const std::string& kTrackingNumberCollection;
protected:
struct ScoredChunk {
@@ -258,10 +257,10 @@
// Classifies the selected text with the date time model.
// Returns true if there was a match and the result was set.
- bool DatetimeClassifyText(const std::string& context,
- CodepointSpan selection_indices,
- const ClassificationOptions& options,
- ClassificationResult* classification_result) const;
+ bool DatetimeClassifyText(
+ const std::string& context, CodepointSpan selection_indices,
+ const ClassificationOptions& options,
+ std::vector<ClassificationResult>* classification_results) const;
// Chunks given input text with the selection model and classifies the spans
// with the classification model.
@@ -324,6 +323,11 @@
const ClassificationResult& classification) const;
bool FilteredForSelection(const AnnotatedSpan& span) const;
+ // Computes the selection boundaries from a regular expression match.
+ CodepointSpan ComputeSelectionBoundaries(
+ const UniLib::RegexMatcher* match,
+ const RegexModel_::Pattern* config) const;
+
const Model* model_;
std::unique_ptr<const ModelExecutor> selection_executor_;
@@ -337,11 +341,8 @@
private:
struct CompiledRegexPattern {
- std::string collection_name;
- float target_classification_score;
- float priority_score;
+ const RegexModel_::Pattern* config;
std::unique_ptr<UniLib::RegexPattern> pattern;
- const VerificationOptions* verification_options;
};
std::unique_ptr<ScopedMmap> mmap_;
@@ -354,7 +355,6 @@
std::unordered_set<std::string> filtered_collections_selection_;
std::vector<CompiledRegexPattern> regex_patterns_;
- std::unordered_set<int> regex_approximate_match_pattern_ids_;
// Indices into regex_patterns_ for the different modes.
std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
@@ -366,6 +366,7 @@
const CalendarLib* calendarlib_;
std::unique_ptr<const KnowledgeEngine> knowledge_engine_;
+ std::unique_ptr<const ContactEngine> contact_engine_;
};
namespace internal {
@@ -388,6 +389,23 @@
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const Model* ViewModel(const void* buffer, int size);
+// Opens model from given path and runs a function, passing the loaded Model
+// flatbuffer as an argument.
+//
+// This is mainly useful if we don't want to pay the cost for the model
+// initialization because we'll be only reading some flatbuffer values from the
+// file.
+template <typename ReturnType, typename Func>
+ReturnType VisitAnnotatorModel(const std::string& path, Func function) {
+ ScopedMmap mmap(path);
+ if (!mmap.handle().ok()) {
+ function(/*model=*/nullptr);
+ }
+ const Model* model =
+ ViewModel(mmap.handle().start(), mmap.handle().num_bytes());
+ return function(model);
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_ANNOTATOR_H_
diff --git a/annotator/annotator_jni.cc b/annotator/annotator_jni.cc
index 9bda35a..690c3ff 100644
--- a/annotator/annotator_jni.cc
+++ b/annotator/annotator_jni.cc
@@ -26,9 +26,13 @@
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
#include "utils/calendar/calendar.h"
+#include "utils/intents/intent-generator.h"
+#include "utils/intents/jni.h"
+#include "utils/java/jni-cache.h"
#include "utils/java/scoped_local_ref.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
+#include "utils/strings/stringpiece.h"
#include "utils/utf8/unilib.h"
#ifdef TC3_UNILIB_JAVAICU
@@ -59,8 +63,9 @@
namespace {
-jobjectArray ClassificationResultsToJObjectArray(
- JNIEnv* env,
+jobjectArray ClassificationResultsWithIntentsToJObjectArray(
+ JNIEnv* env, const IntentGenerator* intent_generator,
+ const ClassificationOptions* options, StringPiece selection_text,
const std::vector<ClassificationResult>& classification_result) {
const ScopedLocalRef<jclass> result_class(
env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
@@ -82,12 +87,20 @@
const jmethodID result_class_constructor = env->GetMethodID(
result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[B)V");
+ "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
+ "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
const jmethodID datetime_parse_class_constructor =
env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
const jobjectArray results = env->NewObjectArray(classification_result.size(),
result_class.get(), nullptr);
+
+ std::unique_ptr<RemoteActionTemplatesHandler> remote_action_templates_handler;
+ if (intent_generator != nullptr) {
+ remote_action_templates_handler = RemoteActionTemplatesHandler::Create(env);
+ }
+
for (int i = 0; i < classification_result.size(); i++) {
jstring row_string =
env->NewStringUTF(classification_result[i].collection.c_str());
@@ -112,16 +125,70 @@
serialized_knowledge_result_string.data()));
}
- jobject result =
- env->NewObject(result_class.get(), result_class_constructor, row_string,
- static_cast<jfloat>(classification_result[i].score),
- row_datetime_parse, serialized_knowledge_result);
+ jstring contact_name = nullptr;
+ if (!classification_result[i].contact_name.empty()) {
+ contact_name =
+ env->NewStringUTF(classification_result[i].contact_name.c_str());
+ }
+
+ jstring contact_given_name = nullptr;
+ if (!classification_result[i].contact_given_name.empty()) {
+ contact_given_name = env->NewStringUTF(
+ classification_result[i].contact_given_name.c_str());
+ }
+
+ jstring contact_nickname = nullptr;
+ if (!classification_result[i].contact_nickname.empty()) {
+ contact_nickname =
+ env->NewStringUTF(classification_result[i].contact_nickname.c_str());
+ }
+
+ jstring contact_email_address = nullptr;
+ if (!classification_result[i].contact_email_address.empty()) {
+ contact_email_address = env->NewStringUTF(
+ classification_result[i].contact_email_address.c_str());
+ }
+
+ jstring contact_phone_number = nullptr;
+ if (!classification_result[i].contact_phone_number.empty()) {
+ contact_phone_number = env->NewStringUTF(
+ classification_result[i].contact_phone_number.c_str());
+ }
+
+ jobject remote_action_templates_result = nullptr;
+
+ if (intent_generator != nullptr &&
+ remote_action_templates_handler != nullptr) {
+ std::vector<RemoteActionTemplate> remote_action_templates =
+ intent_generator->GenerateIntents(classification_result[i],
+ options->reference_time_ms_utc,
+ selection_text);
+ remote_action_templates_result =
+ remote_action_templates_handler->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates);
+ }
+
+ jobject result = env->NewObject(
+ result_class.get(), result_class_constructor, row_string,
+ static_cast<jfloat>(classification_result[i].score), row_datetime_parse,
+ serialized_knowledge_result, contact_name, contact_given_name,
+ contact_nickname, contact_email_address, contact_phone_number,
+ remote_action_templates_result);
env->SetObjectArrayElement(results, i, result);
env->DeleteLocalRef(result);
}
return results;
}
+jobjectArray ClassificationResultsToJObjectArray(
+ JNIEnv* env,
+ const std::vector<ClassificationResult>& classification_result) {
+ return ClassificationResultsWithIntentsToJObjectArray(
+ env, /*(unused) intent_generator=*/nullptr,
+ /*(unusued) options=*/nullptr,
+ /*(unused) selection_text=*/"", classification_result);
+}
+
CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
CodepointSpan orig_indices,
bool from_utf8) {
@@ -218,6 +285,7 @@
} // namespace libtextclassifier3
using libtextclassifier3::ClassificationResultsToJObjectArray;
+using libtextclassifier3::ClassificationResultsWithIntentsToJObjectArray;
using libtextclassifier3::ConvertIndicesBMPToUTF8;
using libtextclassifier3::ConvertIndicesUTF8ToBMP;
using libtextclassifier3::FromJavaAnnotationOptions;
@@ -290,6 +358,25 @@
return model->InitializeKnowledgeEngine(serialized_config_string);
}
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config) {
+ if (!ptr) {
+ return false;
+ }
+
+ Annotator* model = reinterpret_cast<Annotator*>(ptr);
+
+ std::string serialized_config_string;
+ const int length = env->GetArrayLength(serialized_config);
+ serialized_config_string.resize(length);
+ env->GetByteArrayRegion(serialized_config, 0, length,
+ reinterpret_cast<jbyte*>(const_cast<char*>(
+ serialized_config_string.data())));
+
+ return model->InitializeContactEngine(serialized_config_string);
+}
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options) {
@@ -314,7 +401,7 @@
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options) {
+ jint selection_end, jobject options, jobject app_context) {
if (!ptr) {
return nullptr;
}
@@ -323,11 +410,24 @@
const std::string context_utf8 = ToStlString(env, context);
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
+ const libtextclassifier3::ClassificationOptions classification_options =
+ FromJavaClassificationOptions(env, options);
const std::vector<ClassificationResult> classification_result =
ff_model->ClassifyText(context_utf8, input_indices,
- FromJavaClassificationOptions(env, options));
+ classification_options);
- return ClassificationResultsToJObjectArray(env, classification_result);
+ std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
+ libtextclassifier3::JniCache::Create(env));
+ std::unique_ptr<libtextclassifier3::IntentGenerator> intent_generator(
+ new libtextclassifier3::IntentGenerator(
+ ff_model->ViewModel()->intent_options(), jni_cache, app_context));
+
+ libtextclassifier3::StringPiece selection_text(
+ context_utf8.data() + input_indices.first,
+ input_indices.second - input_indices.first);
+ return ClassificationResultsWithIntentsToJObjectArray(
+ env, intent_generator.get(), &classification_options, selection_text,
+ classification_result);
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
diff --git a/annotator/annotator_jni.h b/annotator/annotator_jni.h
index 47715b4..b084b26 100644
--- a/annotator/annotator_jni.h
+++ b/annotator/annotator_jni.h
@@ -42,13 +42,17 @@
nativeInitializeKnowledgeEngine)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+TC3_JNI_METHOD(jboolean, TC3_ANNOTATOR_CLASS_NAME,
+ nativeInitializeContactEngine)
+(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray serialized_config);
+
TC3_JNI_METHOD(jintArray, TC3_ANNOTATOR_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
jint selection_end, jobject options);
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
- jint selection_end, jobject options);
+ jint selection_end, jobject options, jobject app_context);
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options);
diff --git a/annotator/annotator_test.cc b/annotator/annotator_test.cc
index fbaf039..be71f84 100644
--- a/annotator/annotator_test.cc
+++ b/annotator/annotator_test.cc
@@ -308,6 +308,49 @@
EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
std::make_pair(4, 23));
}
+
+TEST_P(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
+ const std::string test_model = ReadFile(GetModelPath() + GetParam());
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
+
+ // Add test regex models.
+ std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
+ MakePattern("date_range",
+ "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
+ "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
+ /*enabled_for_classification=*/false,
+ /*enabled_for_selection=*/true,
+ /*enabled_for_annotation=*/false, 1.0);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group.emplace_back(
+ new RegexModel_::Pattern_::CapturingGroupT);
+ custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
+ custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
+ custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
+ unpacked_model->regex_model->patterns.push_back(
+ std::move(custom_selection_bounds_pattern));
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
+ reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_, &calendarlib_);
+ ASSERT_TRUE(classifier);
+
+ // Check regular expression selection.
+ EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
+ {21, 23}),
+ std::make_pair(10, 34));
+ EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
+ std::make_pair(9, 17));
+}
#endif // TC3_UNILIB_ICU
#ifdef TC3_UNILIB_ICU
@@ -931,7 +974,6 @@
options.reference_timezone = "Europe/Zurich";
result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
-
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
@@ -950,11 +992,24 @@
options.reference_timezone = "America/Los_Angeles";
result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
- ASSERT_EQ(result.size(), 1);
+ ASSERT_EQ(result.size(), 2); // Has 2 interpretations - a.m. or p.m.
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_SECOND);
+ EXPECT_THAT(result[1].collection, "date");
+ EXPECT_EQ(result[1].datetime_parse_result.time_ms_utc, 1514874620000);
+ EXPECT_EQ(result[1].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_SECOND);
+ result.clear();
+
+ options.reference_timezone = "America/Los_Angeles";
+ result = classifier->ClassifyText("2018/01/01 22:00", {0, 16}, options);
+ ASSERT_EQ(result.size(), 1); // Has only 1 interpretation - 10 p.m.
+ EXPECT_THAT(result[0].collection, "date");
+ EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514872800000);
+ EXPECT_EQ(result[0].datetime_parse_result.granularity,
+ DatetimeGranularity::GRANULARITY_MINUTE);
result.clear();
// Date on another line.
@@ -1034,7 +1089,8 @@
public:
TestingAnnotator(const std::string& model, const UniLib* unilib,
const CalendarLib* calendarlib)
- : Annotator(ViewModel(model.data(), model.size()), unilib, calendarlib) {}
+ : Annotator(libtextclassifier3::ViewModel(model.data(), model.size()),
+ unilib, calendarlib) {}
using Annotator::ResolveConflicts;
};
@@ -1250,5 +1306,22 @@
}
#endif // TC3_UNILIB_ICU
+TEST_F(AnnotatorTest, VisitAnnotatorModel) {
+ EXPECT_TRUE(VisitAnnotatorModel<bool>(GetModelPath() + "test_model.fb",
+ [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+ EXPECT_FALSE(VisitAnnotatorModel<bool>(
+ GetModelPath() + "non_existing_model.fb", [](const Model* model) {
+ if (model == nullptr) {
+ return false;
+ }
+ return true;
+ }));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/annotator/collections.cc b/annotator/collections.cc
new file mode 100644
index 0000000..823306f
--- /dev/null
+++ b/annotator/collections.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/collections.h"
+
+namespace libtextclassifier3 {
+
+const std::string& Collections::kOther =
+ *[]() { return new std::string("other"); }();
+const std::string& Collections::kPhone =
+ *[]() { return new std::string("phone"); }();
+const std::string& Collections::kAddress =
+ *[]() { return new std::string("address"); }();
+const std::string& Collections::kDate =
+ *[]() { return new std::string("date"); }();
+const std::string& Collections::kUrl =
+ *[]() { return new std::string("url"); }();
+const std::string& Collections::kFlight =
+ *[]() { return new std::string("flight"); }();
+const std::string& Collections::kEmail =
+ *[]() { return new std::string("email"); }();
+const std::string& Collections::kIban =
+ *[]() { return new std::string("iban"); }();
+const std::string& Collections::kPaymentCard =
+ *[]() { return new std::string("payment_card"); }();
+const std::string& Collections::kIsbn =
+ *[]() { return new std::string("isbn"); }();
+const std::string& Collections::kTrackingNumber =
+ *[]() { return new std::string("tracking_number"); }();
+const std::string& Collections::kContact =
+ *[]() { return new std::string("contact"); }();
+
+} // namespace libtextclassifier3
diff --git a/annotator/collections.h b/annotator/collections.h
new file mode 100644
index 0000000..5a35231
--- /dev/null
+++ b/annotator/collections.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+// String collection names for various classes.
+class Collections {
+ public:
+ static const std::string& kOther;
+ static const std::string& kPhone;
+ static const std::string& kAddress;
+ static const std::string& kDate;
+ static const std::string& kUrl;
+ static const std::string& kFlight;
+ static const std::string& kEmail;
+ static const std::string& kIban;
+ static const std::string& kPaymentCard;
+ static const std::string& kIsbn;
+ static const std::string& kTrackingNumber;
+ static const std::string& kContact;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_COLLECTIONS_H_
diff --git a/annotator/contact/contact-engine-dummy.h b/annotator/contact/contact-engine-dummy.h
new file mode 100644
index 0000000..cdf1ac3
--- /dev/null
+++ b/annotator/contact/contact-engine-dummy.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
+
+namespace libtextclassifier3 {
+
+// A dummy implementation of the contact engine.
+class ContactEngine {
+ public:
+ bool Initialize(const std::string& serialized_config) {
+ TC3_LOG(ERROR) << "No contact engine to initialize.";
+ return false;
+ }
+
+ bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
+ ClassificationResult* classification_result) const {
+ return false;
+ }
+
+ bool Chunk(const UnicodeText& context_unicode,
+ const std::vector<Token>& tokens,
+ std::vector<AnnotatedSpan>* result) const {
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_DUMMY_H_
diff --git a/annotator/contact/contact-engine.h b/annotator/contact/contact-engine.h
new file mode 100644
index 0000000..01d3323
--- /dev/null
+++ b/annotator/contact/contact-engine.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
+
+#include "annotator/contact/contact-engine-dummy.h"
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_CONTACT_CONTACT_ENGINE_H_
diff --git a/annotator/datetime/extractor.cc b/annotator/datetime/extractor.cc
index 31229dd..0328dcf 100644
--- a/annotator/datetime/extractor.cc
+++ b/annotator/datetime/extractor.cc
@@ -376,7 +376,7 @@
}
bool DatetimeExtractor::ParseAMPM(const UnicodeText& input,
- int* parsed_ampm) const {
+ DateParseData::AMPM* parsed_ampm) const {
return MapInput(input,
{
{DatetimeExtractorType_AM, DateParseData::AMPM::AM},
@@ -420,48 +420,56 @@
return MapInput(
input,
{
- {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
- {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
- {DatetimeExtractorType_DAY, DateParseData::DAY},
- {DatetimeExtractorType_WEEK, DateParseData::WEEK},
- {DatetimeExtractorType_MONTH, DateParseData::MONTH},
- {DatetimeExtractorType_YEAR, DateParseData::YEAR},
+ {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY,
+ DateParseData::RelationType::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY,
+ DateParseData::RelationType::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
+ {DatetimeExtractorType_SATURDAY,
+ DateParseData::RelationType::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
+ {DatetimeExtractorType_DAY, DateParseData::RelationType::DAY},
+ {DatetimeExtractorType_WEEK, DateParseData::RelationType::WEEK},
+ {DatetimeExtractorType_MONTH, DateParseData::RelationType::MONTH},
+ {DatetimeExtractorType_YEAR, DateParseData::RelationType::YEAR},
},
parsed_relation_type);
}
-bool DatetimeExtractor::ParseTimeUnit(const UnicodeText& input,
- int* parsed_time_unit) const {
- return MapInput(input,
- {
- {DatetimeExtractorType_DAYS, DateParseData::DAYS},
- {DatetimeExtractorType_WEEKS, DateParseData::WEEKS},
- {DatetimeExtractorType_MONTHS, DateParseData::MONTHS},
- {DatetimeExtractorType_HOURS, DateParseData::HOURS},
- {DatetimeExtractorType_MINUTES, DateParseData::MINUTES},
- {DatetimeExtractorType_SECONDS, DateParseData::SECONDS},
- {DatetimeExtractorType_YEARS, DateParseData::YEARS},
- },
- parsed_time_unit);
-}
-
-bool DatetimeExtractor::ParseWeekday(const UnicodeText& input,
- int* parsed_weekday) const {
+bool DatetimeExtractor::ParseTimeUnit(
+ const UnicodeText& input, DateParseData::TimeUnit* parsed_time_unit) const {
return MapInput(
input,
{
- {DatetimeExtractorType_MONDAY, DateParseData::MONDAY},
- {DatetimeExtractorType_TUESDAY, DateParseData::TUESDAY},
- {DatetimeExtractorType_WEDNESDAY, DateParseData::WEDNESDAY},
- {DatetimeExtractorType_THURSDAY, DateParseData::THURSDAY},
- {DatetimeExtractorType_FRIDAY, DateParseData::FRIDAY},
- {DatetimeExtractorType_SATURDAY, DateParseData::SATURDAY},
- {DatetimeExtractorType_SUNDAY, DateParseData::SUNDAY},
+ {DatetimeExtractorType_DAYS, DateParseData::TimeUnit::DAYS},
+ {DatetimeExtractorType_WEEKS, DateParseData::TimeUnit::WEEKS},
+ {DatetimeExtractorType_MONTHS, DateParseData::TimeUnit::MONTHS},
+ {DatetimeExtractorType_HOURS, DateParseData::TimeUnit::HOURS},
+ {DatetimeExtractorType_MINUTES, DateParseData::TimeUnit::MINUTES},
+ {DatetimeExtractorType_SECONDS, DateParseData::TimeUnit::SECONDS},
+ {DatetimeExtractorType_YEARS, DateParseData::TimeUnit::YEARS},
+ },
+ parsed_time_unit);
+}
+
+bool DatetimeExtractor::ParseWeekday(
+ const UnicodeText& input,
+ DateParseData::RelationType* parsed_weekday) const {
+ return MapInput(
+ input,
+ {
+ {DatetimeExtractorType_MONDAY, DateParseData::RelationType::MONDAY},
+ {DatetimeExtractorType_TUESDAY, DateParseData::RelationType::TUESDAY},
+ {DatetimeExtractorType_WEDNESDAY,
+ DateParseData::RelationType::WEDNESDAY},
+ {DatetimeExtractorType_THURSDAY,
+ DateParseData::RelationType::THURSDAY},
+ {DatetimeExtractorType_FRIDAY, DateParseData::RelationType::FRIDAY},
+ {DatetimeExtractorType_SATURDAY,
+ DateParseData::RelationType::SATURDAY},
+ {DatetimeExtractorType_SUNDAY, DateParseData::RelationType::SUNDAY},
},
parsed_weekday);
}
diff --git a/annotator/datetime/extractor.h b/annotator/datetime/extractor.h
index 4c17aa7..95e7f7c 100644
--- a/annotator/datetime/extractor.h
+++ b/annotator/datetime/extractor.h
@@ -86,16 +86,19 @@
bool ParseWrittenNumber(const UnicodeText& input, int* parsed_number) const;
bool ParseYear(const UnicodeText& input, int* parsed_year) const;
bool ParseMonth(const UnicodeText& input, int* parsed_month) const;
- bool ParseAMPM(const UnicodeText& input, int* parsed_ampm) const;
+ bool ParseAMPM(const UnicodeText& input,
+ DateParseData::AMPM* parsed_ampm) const;
bool ParseRelation(const UnicodeText& input,
DateParseData::Relation* parsed_relation) const;
bool ParseRelationDistance(const UnicodeText& input,
int* parsed_distance) const;
- bool ParseTimeUnit(const UnicodeText& input, int* parsed_time_unit) const;
+ bool ParseTimeUnit(const UnicodeText& input,
+ DateParseData::TimeUnit* parsed_time_unit) const;
bool ParseRelationType(
const UnicodeText& input,
DateParseData::RelationType* parsed_relation_type) const;
- bool ParseWeekday(const UnicodeText& input, int* parsed_weekday) const;
+ bool ParseWeekday(const UnicodeText& input,
+ DateParseData::RelationType* parsed_weekday) const;
const CompiledRule& rule_;
const UniLib::RegexMatcher& matcher_;
diff --git a/annotator/datetime/parser.cc b/annotator/datetime/parser.cc
index ac3a62d..e2a2266 100644
--- a/annotator/datetime/parser.cc
+++ b/annotator/datetime/parser.cc
@@ -103,6 +103,8 @@
}
use_extractors_for_locating_ = model->use_extractors_for_locating();
+ generate_alternative_interpretations_when_ambiguous_ =
+ model->generate_alternative_interpretations_when_ambiguous();
initialized_ = true;
}
@@ -168,10 +170,9 @@
}
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
- int counter = 0;
- for (const auto& found_span : found_spans) {
- indexed_found_spans.push_back({found_span, counter});
- counter++;
+ indexed_found_spans.reserve(found_spans.size());
+ for (int i = 0; i < found_spans.size(); i++) {
+ indexed_found_spans.push_back({found_spans[i], i});
}
// Resolve conflicts by always picking the longer span and breaking ties by
@@ -224,21 +225,28 @@
}
DatetimeParseResultSpan parse_result;
+ std::vector<DatetimeParseResult> alternatives;
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
- reference_locale, locale_id, &(parse_result.data),
+ reference_locale, locale_id, &alternatives,
&parse_result.span)) {
return false;
}
+
if (!use_extractors_for_locating_) {
parse_result.span = {start, end};
}
+
if (parse_result.span.first != kInvalidIndex &&
parse_result.span.second != kInvalidIndex) {
parse_result.target_classification_score =
rule.pattern->target_classification_score();
parse_result.priority_score = rule.pattern->priority_score();
- result->push_back(parse_result);
+
+ for (DatetimeParseResult& alternative : alternatives) {
+ parse_result.data.push_back(alternative);
+ }
}
+ result->push_back(parse_result);
return true;
}
@@ -375,6 +383,49 @@
return granularity;
}
+void FillInterpretations(const DateParseData& parse,
+ std::vector<DateParseData>* interpretations) {
+ DatetimeGranularity granularity = GetGranularity(parse);
+
+ // Multiple interpretations of ambiguous datetime expressions are generated
+ // here.
+ if (granularity > DatetimeGranularity::GRANULARITY_DAY && parse.hour <= 12 &&
+ (parse.field_set_mask & DateParseData::Fields::AMPM_FIELD) == 0) {
+ // If it's not clear if the time is AM or PM, generate all variants.
+ interpretations->push_back(parse);
+ interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
+ interpretations->back().ampm = DateParseData::AMPM::AM;
+
+ interpretations->push_back(parse);
+ interpretations->back().field_set_mask |= DateParseData::Fields::AMPM_FIELD;
+ interpretations->back().ampm = DateParseData::AMPM::PM;
+ } else if (((parse.field_set_mask &
+ DateParseData::Fields::RELATION_TYPE_FIELD) != 0) &&
+ ((parse.field_set_mask & DateParseData::Fields::RELATION_FIELD) ==
+ 0)) {
+ // If it's not clear if it's this monday next monday or last monday,
+ // generate
+ // all variants.
+ interpretations->push_back(parse);
+ interpretations->back().field_set_mask |=
+ DateParseData::Fields::RELATION_FIELD;
+ interpretations->back().relation = DateParseData::Relation::LAST;
+
+ interpretations->push_back(parse);
+ interpretations->back().field_set_mask |=
+ DateParseData::Fields::RELATION_FIELD;
+ interpretations->back().relation = DateParseData::Relation::NEXT;
+
+ interpretations->push_back(parse);
+ interpretations->back().field_set_mask |=
+ DateParseData::Fields::RELATION_FIELD;
+ interpretations->back().relation = DateParseData::Relation::NEXT_OR_SAME;
+ } else {
+ // Otherwise just generate 1 variant.
+ interpretations->push_back(parse);
+ }
+}
+
} // namespace
bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
@@ -382,7 +433,8 @@
const int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
- int locale_id, DatetimeParseResult* result,
+ int locale_id,
+ std::vector<DatetimeParseResult>* results,
CodepointSpan* result_span) const {
DateParseData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, unilib_,
@@ -392,14 +444,24 @@
return false;
}
- result->granularity = GetGranularity(parse);
-
- if (!calendarlib_.InterpretParseData(
- parse, reference_time_ms_utc, reference_timezone, reference_locale,
- result->granularity, &(result->time_ms_utc))) {
- return false;
+ std::vector<DateParseData> interpretations;
+ if (generate_alternative_interpretations_when_ambiguous_) {
+ FillInterpretations(parse, &interpretations);
+ } else {
+ interpretations.push_back(parse);
}
+ results->reserve(results->size() + interpretations.size());
+ for (const DateParseData& interpretation : interpretations) {
+ DatetimeParseResult result;
+ result.granularity = GetGranularity(interpretation);
+ if (!calendarlib_.InterpretParseData(
+ interpretation, reference_time_ms_utc, reference_timezone,
+ reference_locale, result.granularity, &(result.time_ms_utc))) {
+ return false;
+ }
+ results->push_back(result);
+ }
return true;
}
diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h
index c7eaf1f..133d674 100644
--- a/annotator/datetime/parser.h
+++ b/annotator/datetime/parser.h
@@ -56,6 +56,12 @@
ModeFlag mode, bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
+#ifdef TC3_TEST_ONLY
+ void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) {
+ generate_alternative_interpretations_when_ambiguous_ = value;
+ }
+#endif // TC3_TEST_ONLY
+
protected:
DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
const CalendarLib& calendarlib,
@@ -88,7 +94,7 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale, int locale_id,
- DatetimeParseResult* result,
+ std::vector<DatetimeParseResult>* results,
CodepointSpan* result_span) const;
// Parse and extract information from current match in 'matcher'.
@@ -111,6 +117,7 @@
std::unordered_map<std::string, int> locale_string_to_id_;
std::vector<int> default_locale_ids_;
bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
};
} // namespace libtextclassifier3
diff --git a/annotator/datetime/parser_test.cc b/annotator/datetime/parser_test.cc
index d46accf..997d780 100644
--- a/annotator/datetime/parser_test.cc
+++ b/annotator/datetime/parser_test.cc
@@ -27,6 +27,7 @@
#include "annotator/datetime/parser.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
+#include "utils/testing/annotator.h"
using testing::ElementsAreArray;
@@ -42,23 +43,23 @@
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
-std::string FormatMillis(int64 time_ms_utc) {
- long time_seconds = time_ms_utc / 1000; // NOLINT
- // Format time, "ddd yyyy-mm-dd hh:mm:ss zzz"
- char buffer[512];
- strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
- localtime(&time_seconds));
- return std::string(buffer);
-}
-
class ParserTest : public testing::Test {
public:
void SetUp() override {
- model_buffer_ = ReadFile(GetModelPath() + "test_model.fb");
+ // Loads default unmodified model. Individual tests can call LoadModel to
+ // make changes.
+ LoadModel([](ModelT* model) {});
+ }
+
+ template <typename Fn>
+ void LoadModel(Fn model_visitor_fn) {
+ std::string model_buffer = ReadFile(GetModelPath() + "test_model.fb");
+ model_buffer_ = ModifyAnnotatorModel(model_buffer, model_visitor_fn);
classifier_ = Annotator::FromUnownedBuffer(model_buffer_.data(),
model_buffer_.size(), &unilib_);
TC3_CHECK(classifier_);
parser_ = classifier_->DatetimeParserForTests();
+ TC3_CHECK(parser_);
}
bool HasNoResult(const std::string& text, bool anchor_start_end = false,
@@ -73,7 +74,7 @@
}
bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
+ const std::vector<int64>& expected_ms_utcs,
DatetimeGranularity expected_granularity,
bool anchor_start_end = false,
const std::string& timezone = "Europe/Zurich",
@@ -120,25 +121,48 @@
}
}
- const std::vector<DatetimeParseResultSpan> expected{
+ std::vector<DatetimeParseResultSpan> expected{
{{expected_start_index, expected_end_index},
- {expected_ms_utc, expected_granularity},
+ {},
/*target_classification_score=*/1.0,
/*priority_score=*/0.1}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity};
+ }
+
const bool matches =
testing::Matches(ElementsAreArray(expected))(filtered_results);
if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0] << " which corresponds to: "
- << FormatMillis(expected[0].data.time_ms_utc);
- for (int i = 0; i < filtered_results.size(); ++i) {
- TC3_LOG(ERROR) << "Actual[" << i << "]: " << filtered_results[i]
- << " which corresponds to: "
- << FormatMillis(filtered_results[i].data.time_ms_utc);
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
}
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
}
+
return matches;
}
+ bool ParsesCorrectly(const std::string& marked_text,
+ const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US") {
+ return ParsesCorrectly(marked_text, std::vector<int64>{expected_ms_utc},
+ expected_granularity, anchor_start_end, timezone,
+ locales);
+ }
+
+ bool ParsesCorrectlyGerman(const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+ }
+
bool ParsesCorrectlyGerman(const std::string& marked_text,
const int64 expected_ms_utc,
DatetimeGranularity expected_granularity) {
@@ -173,24 +197,32 @@
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{Jun 09 2011 15:28:14}", 1307626094000,
GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{Mar 16 08:12:04}", {6419524000, 6462724000},
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}",
+ {1277512289000, 1277555489000},
+ GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}",
+ {1137899465000, 1137942665000},
+ GRANULARITY_SECOND));
EXPECT_TRUE(
- ParsesCorrectly("{Mar 16 08:12:04}", 6419524000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2010-06-26 02:31:29}", 1277512289000,
+ ParsesCorrectly("{11:42:35}", {38555000, 81755000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr 11:42:35}", {9715355000, 9758555000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{2006/01/22 04:11:05}", 1137899465000,
+ EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{11:42:35}", 38555000, GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectly("{23/Apr 11:42:35}", 9715355000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23/Apr/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23-Apr-2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{23 Apr 2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/15 11:42:35}", 1429782155000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectly("{04/23/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectly("{9/28/2011 2:23:15 PM}", 1317212595000,
GRANULARITY_SECOND));
@@ -205,20 +237,22 @@
"think order event music. Incommode so intention defective at "
"convinced. Led income months itself and houses you. After nor "
"you leave might share court balls. ",
- 1271651775000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ {1271651775000, 1271694975000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
+ {1514777400000, 1514820600000},
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30 am}", 1514777400000,
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4pm}", 1514818800000,
GRANULARITY_HOUR));
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -3600000, GRANULARITY_MINUTE));
- EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", -57600000, GRANULARITY_MINUTE,
- /*anchor_start_end=*/false,
- "America/Los_Angeles"));
- EXPECT_TRUE(
- ParsesCorrectly("{tomorrow at 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly("{today at 0:00}", {-3600000, 39600000},
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectly(
+ "{today at 0:00}", {-57600000, -14400000}, GRANULARITY_MINUTE,
+ /*anchor_start_end=*/false, "America/Los_Angeles"));
+ EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4:00}", {97200000, 140400000},
+ GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectly("{tomorrow at 4am}", 97200000, GRANULARITY_HOUR));
EXPECT_TRUE(
ParsesCorrectly("{wednesday at 4am}", 529200000, GRANULARITY_HOUR));
@@ -244,39 +278,51 @@
ParsesCorrectlyGerman("{1 2 2018}", 1517439600000, GRANULARITY_DAY));
EXPECT_TRUE(ParsesCorrectlyGerman("lorem {1 Januar 2018} ipsum",
1514761200000, GRANULARITY_DAY));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}", 1271651775000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/Apr/2010:06:36:15}",
+ {1271651775000, 1271694975000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectlyGerman("{09/März/2004 22:02:40}", 1078866160000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}", 1291253998000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{Dez 2, 2010 2:39:58}",
+ {1291253998000, 1291297198000},
GRANULARITY_SECOND));
EXPECT_TRUE(ParsesCorrectlyGerman("{Juni 09 2011 15:28:14}", 1307626094000,
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{März 16 08:12:04}", 6419524000,
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{März 16 08:12:04}", {6419524000, 6462724000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}",
+ {1277512289000, 1277555489000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2010-06-26 02:31:29}", 1277512289000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}",
+ {1137899465000, 1137942665000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{2006/01/22 04:11:05}", 1137899465000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{11:42:35}", {38555000, 81755000},
GRANULARITY_SECOND));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{11:42:35}", 38555000, GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr 11:42:35}", 9715355000,
+ EXPECT_TRUE(ParsesCorrectlyGerman(
+ "{23/Apr 11:42:35}", {9715355000, 9758555000}, GRANULARITY_SECOND));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015:11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23/Apr/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23-Apr-2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{23 Apr 2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/15 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}",
+ {1429782155000, 1429825355000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{04/23/2015 11:42:35}", 1429782155000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}",
+ {1271651775000, 1271694975000},
GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{19/apr/2010:06:36:15}", 1271651775000,
- GRANULARITY_SECOND));
- EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}", 1514777400000,
+ EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30}",
+ {1514777400000, 1514820600000},
GRANULARITY_MINUTE));
EXPECT_TRUE(ParsesCorrectlyGerman("{januar 1 2018 um 4:30 nachm}",
1514820600000, GRANULARITY_MINUTE));
@@ -284,10 +330,10 @@
GRANULARITY_HOUR));
EXPECT_TRUE(
ParsesCorrectlyGerman("{14.03.2017}", 1489446000000, GRANULARITY_DAY));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen 0:00}", 82800000, GRANULARITY_MINUTE));
- EXPECT_TRUE(
- ParsesCorrectlyGerman("{morgen um 4:00}", 97200000, GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{morgen 0:00}", {82800000, 126000000},
+ GRANULARITY_MINUTE));
+ EXPECT_TRUE(ParsesCorrectlyGerman("{morgen um 4:00}", {97200000, 140400000},
+ GRANULARITY_MINUTE));
EXPECT_TRUE(
ParsesCorrectlyGerman("{morgen um 4 vorm}", 97200000, GRANULARITY_HOUR));
}
@@ -320,6 +366,27 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
+TEST_F(ParserTest, WhenEnabled_GeneratesAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ true;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}",
+ {1514777400000, 1514820600000},
+ GRANULARITY_MINUTE));
+}
+
+TEST_F(ParserTest, WhenDisabled_DoesNotGenerateAlternatives) {
+ LoadModel([](ModelT* model) {
+ model->datetime_model->generate_alternative_interpretations_when_ambiguous =
+ false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly("{january 1 2018 at 4:30}", 1514777400000,
+ GRANULARITY_MINUTE));
+}
+
class ParserLocaleTest : public testing::Test {
public:
void SetUp() override;
diff --git a/annotator/model.fbs b/annotator/model.fbs
index 3682994..fac4c22 100755
--- a/annotator/model.fbs
+++ b/annotator/model.fbs
@@ -173,6 +173,14 @@
verify_luhn_checksum:bool = false;
}
+// Behaviour of capturing groups.
+namespace libtextclassifier3.RegexModel_.Pattern_;
+table CapturingGroup {
+ // If true, the span of the capturing group will be used to
+ // extend the selection.
+ extend_selection:bool = true;
+}
+
// List of regular expression matchers to check.
namespace libtextclassifier3.RegexModel_;
table Pattern {
@@ -180,11 +188,10 @@
collection_name:string;
// The pattern to check.
- // Can specify a single capturing group used as match boundaries.
pattern:string;
// The modes for which to apply the patterns.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
// The final score to assign to the results of this pattern.
target_classification_score:float = 1;
@@ -197,15 +204,17 @@
// use the first Find() result and then check that it spans the whole input.
use_approximate_matching:bool = false;
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
// Verification to apply on a match.
- verification_options:libtextclassifier3.VerificationOptions;
+ verification_options:VerificationOptions;
+
+ capturing_group:[Pattern_.CapturingGroup];
}
namespace libtextclassifier3;
table RegexModel {
- patterns:[libtextclassifier3.RegexModel_.Pattern];
+ patterns:[RegexModel_.Pattern];
}
// List of regex patterns.
@@ -215,14 +224,14 @@
// The ith entry specifies the type of the ith capturing group.
// This is used to decide how the matched content has to be parsed.
- groups:[libtextclassifier3.DatetimeGroupType];
+ groups:[DatetimeGroupType];
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
}
namespace libtextclassifier3;
table DatetimeModelPattern {
- regexes:[libtextclassifier3.DatetimeModelPattern_.Regex];
+ regexes:[DatetimeModelPattern_.Regex];
// List of locale indices in DatetimeModel that represent the locales that
// these patterns should be used for. If empty, can be used for all locales.
@@ -235,15 +244,15 @@
priority_score:float = 0;
// The modes for which to apply the patterns.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
}
namespace libtextclassifier3;
table DatetimeModelExtractor {
- extractor:libtextclassifier3.DatetimeExtractorType;
+ extractor:DatetimeExtractorType;
pattern:string;
locales:[int];
- compressed_pattern:libtextclassifier3.CompressedBuffer;
+ compressed_pattern:CompressedBuffer;
}
namespace libtextclassifier3;
@@ -252,8 +261,8 @@
// model. The individual patterns refer back to them using an index.
locales:[string];
- patterns:[libtextclassifier3.DatetimeModelPattern];
- extractors:[libtextclassifier3.DatetimeModelExtractor];
+ patterns:[DatetimeModelPattern];
+ extractors:[DatetimeModelExtractor];
// If true, will use the extractors for determining the match location as
// opposed to using the location where the global pattern matched.
@@ -262,18 +271,22 @@
// List of locale ids, rules of whose are always run, after the requested
// ones.
default_locales:[int];
+
+ // If true, will generate the alternative interpretations for ambiguous
+ // datetime expressions.
+ generate_alternative_interpretations_when_ambiguous:bool = false;
}
namespace libtextclassifier3.DatetimeModelLibrary_;
table Item {
key:string;
- value:libtextclassifier3.DatetimeModel;
+ value:DatetimeModel;
}
// A set of named DateTime models.
namespace libtextclassifier3;
table DatetimeModelLibrary {
- models:[libtextclassifier3.DatetimeModelLibrary_.Item];
+ models:[DatetimeModelLibrary_.Item];
}
// Options controlling the output of the Tensorflow Lite models.
@@ -283,7 +296,7 @@
min_annotate_confidence:float = 0;
// The modes for which to enable the models.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
}
// Options controlling the output of the classifier.
@@ -310,8 +323,8 @@
// A name for the model that can be used for e.g. logging.
name:string;
- selection_feature_options:libtextclassifier3.FeatureProcessorOptions;
- classification_feature_options:libtextclassifier3.FeatureProcessorOptions;
+ selection_feature_options:FeatureProcessorOptions;
+ classification_feature_options:FeatureProcessorOptions;
// Tensorflow Lite models.
selection_model:[ubyte] (force_align: 16);
@@ -320,18 +333,18 @@
embedding_model:[ubyte] (force_align: 16);
// Options for the different models.
- selection_options:libtextclassifier3.SelectionModelOptions;
+ selection_options:SelectionModelOptions;
- classification_options:libtextclassifier3.ClassificationModelOptions;
- regex_model:libtextclassifier3.RegexModel;
- datetime_model:libtextclassifier3.DatetimeModel;
+ classification_options:ClassificationModelOptions;
+ regex_model:RegexModel;
+ datetime_model:DatetimeModel;
// Options controlling the output of the models.
- triggering_options:libtextclassifier3.ModelTriggeringOptions;
+ triggering_options:ModelTriggeringOptions;
// Global switch that controls if SuggestSelection(), ClassifyText() and
// Annotate() will run. If a mode is disabled it returns empty/no-op results.
- enabled_modes:libtextclassifier3.ModeFlag = ALL;
+ enabled_modes:ModeFlag = ALL;
// If true, will snap the selections that consist only of whitespaces to the
// containing suggested span. Otherwise, no suggestion is proposed, since the
@@ -340,13 +353,13 @@
// Global configuration for the output of SuggestSelection(), ClassifyText()
// and Annotate().
- output_options:libtextclassifier3.OutputOptions;
+ output_options:OutputOptions;
// Configures how Intents should be generated on Android.
// TODO(smillius): Remove deprecated factory options.
- android_intent_options:libtextclassifier3.AndroidIntentFactoryOptions;
+ android_intent_options:AndroidIntentFactoryOptions;
- intent_options:libtextclassifier3.IntentFactoryModel;
+ intent_options:IntentFactoryModel;
}
// Role of the codepoints in the range.
@@ -379,7 +392,7 @@
table TokenizationCodepointRange {
start:int;
end:int;
- role:libtextclassifier3.TokenizationCodepointRange_.Role;
+ role:TokenizationCodepointRange_.Role;
// Integer identifier of the script this range denotes. Negative values are
// reserved for Tokenizer's internal use.
@@ -530,20 +543,20 @@
// Codepoint ranges that determine how different codepoints are tokenized.
// The ranges must not overlap.
- tokenization_codepoint_config:[libtextclassifier3.TokenizationCodepointRange];
+ tokenization_codepoint_config:[TokenizationCodepointRange];
- center_token_selection_method:libtextclassifier3.FeatureProcessorOptions_.CenterTokenSelectionMethod;
+ center_token_selection_method:FeatureProcessorOptions_.CenterTokenSelectionMethod;
// If true, span boundaries will be snapped to containing tokens and not
// required to exactly match token boundaries.
snap_label_span_boundaries_to_containing_tokens:bool;
// A set of codepoint ranges supported by the model.
- supported_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+ supported_codepoint_ranges:[FeatureProcessorOptions_.CodepointRange];
// A set of codepoint ranges to use in the mixed tokenization mode to identify
// stretches of tokens to re-tokenize using the internal tokenizer.
- internal_tokenizer_codepoint_ranges:[libtextclassifier3.FeatureProcessorOptions_.CodepointRange];
+ internal_tokenizer_codepoint_ranges:[FeatureProcessorOptions_.CodepointRange];
// Minimum ratio of supported codepoints in the input context. If the ratio
// is lower than this, the feature computation will fail.
@@ -559,14 +572,14 @@
// to it. So the resulting feature vector has two regions.
feature_version:int = 0;
- tokenization_type:libtextclassifier3.FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER;
+ tokenization_type:FeatureProcessorOptions_.TokenizationType = INTERNAL_TOKENIZER;
icu_preserve_whitespace_tokens:bool = false;
// List of codepoints that will be stripped from beginning and end of
// predicted spans.
ignored_span_boundary_codepoints:[int];
- bounds_sensitive_features:libtextclassifier3.FeatureProcessorOptions_.BoundsSensitiveFeatures;
+ bounds_sensitive_features:FeatureProcessorOptions_.BoundsSensitiveFeatures;
// List of allowed charactergrams. The extracted charactergrams are filtered
// using this list, and charactergrams that are not present are interpreted as
diff --git a/annotator/test_data/test_model.fb b/annotator/test_data/test_model.fb
index fa9cec5..8744142 100644
--- a/annotator/test_data/test_model.fb
+++ b/annotator/test_data/test_model.fb
Binary files differ
diff --git a/annotator/test_data/test_model_cc.fb b/annotator/test_data/test_model_cc.fb
index b73d84f..1d1e473 100644
--- a/annotator/test_data/test_model_cc.fb
+++ b/annotator/test_data/test_model_cc.fb
Binary files differ
diff --git a/annotator/test_data/wrong_embeddings.fb b/annotator/test_data/wrong_embeddings.fb
index ba71cdd..7f39119 100644
--- a/annotator/test_data/wrong_embeddings.fb
+++ b/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/annotator/types.cc b/annotator/types.cc
new file mode 100644
index 0000000..78d72df
--- /dev/null
+++ b/annotator/types.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token) {
+ if (!token.is_padding) {
+ return stream << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ")";
+ } else {
+ return stream << "Token()";
+ }
+}
+
+namespace {
+std::string FormatMillis(int64 time_ms_utc) {
+ long time_seconds = time_ms_utc / 1000; // NOLINT
+ char buffer[512];
+ strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
+ localtime(&time_seconds));
+ return std::string(buffer);
+}
+} // namespace
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value) {
+ stream << "DatetimeParseResultSpan({" << value.span.first << ", "
+ << value.span.second << "}, {";
+ for (const DatetimeParseResult& data : value.data) {
+ stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
+ << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
+ << data.granularity << "}, ";
+ }
+ stream << "})";
+ return stream;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result) {
+ return stream << "ClassificationResult(" << result.collection << ", "
+ << result.score << ")";
+}
+
+logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream,
+ const std::vector<ClassificationResult>& results) {
+ stream = stream << "{\n";
+ for (const ClassificationResult& result : results) {
+ stream = stream << " " << result << "\n";
+ }
+ stream = stream << "}";
+ return stream;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span) {
+ std::string best_class;
+ float best_score = -1;
+ if (!span.classification.empty()) {
+ best_class = span.classification[0].collection;
+ best_score = span.classification[0].score;
+ }
+ return stream << "Span(" << span.span.first << ", " << span.span.second
+ << ", " << best_class << ", " << best_score << ")";
+}
+
+} // namespace libtextclassifier3
diff --git a/annotator/types.h b/annotator/types.h
index 38bce41..71acaf4 100644
--- a/annotator/types.h
+++ b/annotator/types.h
@@ -17,6 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
+#include <time.h>
#include <algorithm>
#include <cmath>
#include <functional>
@@ -147,15 +148,8 @@
};
// Pretty-printing function for Token.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const Token& token) {
- if (!token.is_padding) {
- return stream << "Token(\"" << token.value << "\", " << token.start << ", "
- << token.end << ")";
- } else {
- return stream << "Token()";
- }
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Token& token);
enum DatetimeGranularity {
GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
@@ -195,13 +189,12 @@
struct DatetimeParseResultSpan {
CodepointSpan span;
- DatetimeParseResult data;
+ std::vector<DatetimeParseResult> data;
float target_classification_score;
float priority_score;
bool operator==(const DatetimeParseResultSpan& other) const {
- return span == other.span && data.granularity == other.data.granularity &&
- data.time_ms_utc == other.data.time_ms_utc &&
+ return span == other.span && data == other.data &&
std::abs(target_classification_score -
other.target_classification_score) < kFloatCompareEpsilon &&
std::abs(priority_score - other.priority_score) <
@@ -210,20 +203,16 @@
};
// Pretty-printing function for DatetimeParseResultSpan.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream,
- const DatetimeParseResultSpan& value) {
- return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
- << value.span.second << "}, {/*time_ms_utc=*/ "
- << value.data.time_ms_utc << ", /*granularity=*/ "
- << value.data.granularity << "})";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const DatetimeParseResultSpan& value);
struct ClassificationResult {
std::string collection;
float score;
DatetimeParseResult datetime_parse_result;
std::string serialized_knowledge_result;
+ std::string contact_name, contact_given_name, contact_nickname,
+ contact_email_address, contact_phone_number;
// Internal score used for conflict resolution.
float priority_score;
@@ -246,23 +235,13 @@
};
// Pretty-printing function for ClassificationResult.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const ClassificationResult& result) {
- return stream << "ClassificationResult(" << result.collection << ", "
- << result.score << ")";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const ClassificationResult& result);
// Pretty-printing function for std::vector<ClassificationResult>.
-inline logging::LoggingStringStream& operator<<(
+logging::LoggingStringStream& operator<<(
logging::LoggingStringStream& stream,
- const std::vector<ClassificationResult>& results) {
- stream = stream << "{\n";
- for (const ClassificationResult& result : results) {
- stream = stream << " " << result << "\n";
- }
- stream = stream << "}";
- return stream;
-}
+ const std::vector<ClassificationResult>& results);
// Represents a result of Annotate call.
struct AnnotatedSpan {
@@ -274,17 +253,8 @@
};
// Pretty-printing function for AnnotatedSpan.
-inline logging::LoggingStringStream& operator<<(
- logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
- std::string best_class;
- float best_score = -1;
- if (!span.classification.empty()) {
- best_class = span.classification[0].collection;
- best_score = span.classification[0].score;
- }
- return stream << "Span(" << span.span.first << ", " << span.span.second
- << ", " << best_class << ", " << best_score << ")";
-}
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const AnnotatedSpan& span);
// StringPiece analogue for std::vector<T>.
template <class T>
@@ -312,7 +282,8 @@
};
struct DateParseData {
- enum Relation {
+ enum class Relation {
+ UNSPECIFIED = 0,
NEXT = 1,
NEXT_OR_SAME = 2,
LAST = 3,
@@ -323,7 +294,8 @@
FUTURE = 8
};
- enum RelationType {
+ enum class RelationType {
+ UNSPECIFIED = 0,
SUNDAY = 1,
MONDAY = 2,
TUESDAY = 3,
@@ -352,9 +324,9 @@
RELATION_DISTANCE_FIELD = 1 << 11
};
- enum AMPM { AM = 0, PM = 1 };
+ enum class AMPM { AM = 0, PM = 1 };
- enum TimeUnit {
+ enum class TimeUnit {
DAYS = 1,
WEEKS = 2,
MONTHS = 3,
@@ -365,36 +337,57 @@
};
// Bit mask of fields which have been set on the struct
- int field_set_mask;
+ int field_set_mask = 0;
// Fields describing absolute date fields.
// Year of the date seen in the text match.
- int year;
+ int year = 0;
// Month of the year starting with January = 1.
- int month;
+ int month = 0;
// Day of the month starting with 1.
- int day_of_month;
+ int day_of_month = 0;
// Hour of the day with a range of 0-23,
// values less than 12 need the AMPM field below or heuristics
// to definitively determine the time.
- int hour;
+ int hour = 0;
// Hour of the day with a range of 0-59.
- int minute;
+ int minute = 0;
// Hour of the day with a range of 0-59.
- int second;
+ int second = 0;
// 0 == AM, 1 == PM
- int ampm;
+ AMPM ampm = AMPM::AM;
// Number of hours offset from UTC this date time is in.
- int zone_offset;
+ int zone_offset = 0;
// Number of hours offest for DST
- int dst_offset;
+ int dst_offset = 0;
// The permutation from now that was made to find the date time.
- Relation relation;
+ Relation relation = Relation::UNSPECIFIED;
// The unit of measure of the change to the date time.
- RelationType relation_type;
+ RelationType relation_type = RelationType::UNSPECIFIED;
// The number of units of change that were made.
- int relation_distance;
+ int relation_distance = 0;
+
+ DateParseData() = default;
+
+ DateParseData(int field_set_mask, int year, int month, int day_of_month,
+ int hour, int minute, int second, AMPM ampm, int zone_offset,
+ int dst_offset, Relation relation, RelationType relation_type,
+ int relation_distance) {
+ this->field_set_mask = field_set_mask;
+ this->year = year;
+ this->month = month;
+ this->day_of_month = day_of_month;
+ this->hour = hour;
+ this->minute = minute;
+ this->second = second;
+ this->ampm = ampm;
+ this->zone_offset = zone_offset;
+ this->dst_offset = dst_offset;
+ this->relation = relation;
+ this->relation_type = relation_type;
+ this->relation_distance = relation_distance;
+ }
};
} // namespace libtextclassifier3
diff --git a/generate_flatbuffers.mk b/generate_flatbuffers.mk
index 3d36e41..8f59bca 100644
--- a/generate_flatbuffers.mk
+++ b/generate_flatbuffers.mk
@@ -18,6 +18,15 @@
intermediates := $(call local-generated-sources-dir)
+# Generate utils/named-extra_generated.h using FlatBuffer schema compiler.
+NAMED_EXTRA_FBS := $(LOCAL_PATH)/utils/named-extra.fbs
+NAMED_EXTRA_H := $(intermediates)/utils/named-extra_generated.h
+$(NAMED_EXTRA_H): PRIVATE_INPUT_FBS := $(NAMED_EXTRA_FBS)
+$(NAMED_EXTRA_H): INPUT_DIR := $(LOCAL_PATH)
+$(NAMED_EXTRA_H): $(FLATC) $(NAMED_EXTRA_FBS)
+ $(transform-fbs-to-cpp)
+LOCAL_GENERATED_SOURCES += $(NAMED_EXTRA_H)
+
# Generate utils/zlib/buffer_generated.h using FlatBuffer schema compiler.
UTILS_ZLIB_BUFFER_FBS := $(LOCAL_PATH)/utils/zlib/buffer.fbs
UTILS_ZLIB_BUFFER_H := $(intermediates)/utils/zlib/buffer_generated.h
diff --git a/java/com/google/android/textclassifier/AnnotatorModel.java b/java/com/google/android/textclassifier/AnnotatorModel.java
index 08a4455..dac9176 100644
--- a/java/com/google/android/textclassifier/AnnotatorModel.java
+++ b/java/com/google/android/textclassifier/AnnotatorModel.java
@@ -73,6 +73,13 @@
}
}
+ /** Initializes the contact engine, passing the given serialized config to it. */
+ public void initializeContactEngine(byte[] serializedConfig) {
+ if (!nativeInitializeContactEngine(annotatorPtr, serializedConfig)) {
+ throw new IllegalArgumentException("Couldn't initialize the KG engine");
+ }
+ }
+
/**
* Given a string context and current selection, computes the selection suggestion.
*
@@ -98,7 +105,20 @@
*/
public ClassificationResult[] classifyText(
String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
- return nativeClassifyText(annotatorPtr, context, selectionBegin, selectionEnd, options);
+ return classifyText(context, selectionBegin, selectionEnd, options, /*appContext=*/ null);
+ }
+
+ public ClassificationResult[] classifyText(
+ String context,
+ int selectionBegin,
+ int selectionEnd,
+ ClassificationOptions options,
+
+ // Pass through android.content.Context object as Object as we cannot directly depend on
+ // android here.
+ Object appContext) {
+ return nativeClassifyText(
+ annotatorPtr, context, selectionBegin, selectionEnd, options, appContext);
}
/**
@@ -176,16 +196,34 @@
private final float score;
private final DatetimeResult datetimeResult;
private final byte[] serializedKnowledgeResult;
+ private final String contactName;
+ private final String contactGivenName;
+ private final String contactNickname;
+ private final String contactEmailAddress;
+ private final String contactPhoneNumber;
+ private final RemoteActionTemplate[] remoteActionTemplates;
public ClassificationResult(
String collection,
float score,
DatetimeResult datetimeResult,
- byte[] serializedKnowledgeResult) {
+ byte[] serializedKnowledgeResult,
+ String contactName,
+ String contactGivenName,
+ String contactNickname,
+ String contactEmailAddress,
+ String contactPhoneNumber,
+ RemoteActionTemplate[] remoteActionTemplates) {
this.collection = collection;
this.score = score;
this.datetimeResult = datetimeResult;
this.serializedKnowledgeResult = serializedKnowledgeResult;
+ this.contactName = contactName;
+ this.contactGivenName = contactGivenName;
+ this.contactNickname = contactNickname;
+ this.contactEmailAddress = contactEmailAddress;
+ this.contactPhoneNumber = contactPhoneNumber;
+ this.remoteActionTemplates = remoteActionTemplates;
}
/** Returns the classified entity type. */
@@ -215,6 +253,30 @@
byte[] getSerializedKnowledgeResult() {
return serializedKnowledgeResult;
}
+
+ String getContactName() {
+ return contactName;
+ }
+
+ String getContactGivenName() {
+ return contactGivenName;
+ }
+
+ String getContactNickname() {
+ return contactNickname;
+ }
+
+ String getContactEmailAddress() {
+ return contactEmailAddress;
+ }
+
+ String getContactPhoneNumber() {
+ return contactPhoneNumber;
+ }
+
+ public RemoteActionTemplate[] getRemoteActionTemplates() {
+ return remoteActionTemplates;
+ }
}
/** Represents a result of Annotate call. */
@@ -325,6 +387,8 @@
private native boolean nativeInitializeKnowledgeEngine(long context, byte[] serializedConfig);
+ private native boolean nativeInitializeContactEngine(long context, byte[] serializedConfig);
+
private native int[] nativeSuggestSelection(
long context, String text, int selectionBegin, int selectionEnd, SelectionOptions options);
@@ -333,7 +397,8 @@
String text,
int selectionBegin,
int selectionEnd,
- ClassificationOptions options);
+ ClassificationOptions options,
+ Object appContext);
private native AnnotatedSpan[] nativeAnnotate(
long context, String text, AnnotationOptions options);
diff --git a/java/com/google/android/textclassifier/NamedVariant.java b/java/com/google/android/textclassifier/NamedVariant.java
new file mode 100644
index 0000000..d04bb11
--- /dev/null
+++ b/java/com/google/android/textclassifier/NamedVariant.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+/**
+ * Represents a union of different basic types.
+ *
+ * @hide
+ */
+public final class NamedVariant {
+ public static final int TYPE_EMPTY = 0;
+ public static final int TYPE_INT = 1;
+ public static final int TYPE_LONG = 2;
+ public static final int TYPE_FLOAT = 3;
+ public static final int TYPE_DOUBLE = 4;
+ public static final int TYPE_BOOL = 5;
+ public static final int TYPE_STRING = 6;
+
+ public NamedVariant(String name, int value) {
+ this.name = name;
+ this.intValue = value;
+ this.type = TYPE_INT;
+ }
+
+ public NamedVariant(String name, long value) {
+ this.name = name;
+ this.longValue = value;
+ this.type = TYPE_LONG;
+ }
+
+ public NamedVariant(String name, float value) {
+ this.name = name;
+ this.floatValue = value;
+ this.type = TYPE_FLOAT;
+ }
+
+ public NamedVariant(String name, double value) {
+ this.name = name;
+ this.doubleValue = value;
+ this.type = TYPE_DOUBLE;
+ }
+
+ public NamedVariant(String name, boolean value) {
+ this.name = name;
+ this.boolValue = value;
+ this.type = TYPE_BOOL;
+ }
+
+ public NamedVariant(String name, String value) {
+ this.name = name;
+ this.stringValue = value;
+ this.type = TYPE_STRING;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public int getType() {
+ return type;
+ }
+
+ public int getInt() {
+ assert (type == TYPE_INT);
+ return intValue;
+ }
+
+ public long getLong() {
+ assert (type == TYPE_LONG);
+ return longValue;
+ }
+
+ public float getFloat() {
+ assert (type == TYPE_FLOAT);
+ return floatValue;
+ }
+
+ public double getDouble() {
+ assert (type == TYPE_DOUBLE);
+ return doubleValue;
+ }
+
+ public boolean getBool() {
+ assert (type == TYPE_BOOL);
+ return boolValue;
+ }
+
+ public String getString() {
+ assert (type == TYPE_STRING);
+ return stringValue;
+ }
+
+ private final String name;
+ private final int type;
+ private int intValue;
+ private long longValue;
+ private float floatValue;
+ private double doubleValue;
+ private boolean boolValue;
+ private String stringValue;
+}
diff --git a/java/com/google/android/textclassifier/RemoteActionTemplate.java b/java/com/google/android/textclassifier/RemoteActionTemplate.java
new file mode 100644
index 0000000..11ad33a
--- /dev/null
+++ b/java/com/google/android/textclassifier/RemoteActionTemplate.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier;
+
+/**
+ * Represents a template for an Android RemoteAction.
+ *
+ * @hide
+ */
+public final class RemoteActionTemplate {
+ // Title shown for the action (see: RemoteAction.getTitle).
+ public final String title;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription).
+ public final String description;
+
+ // The action to set on the Intent (see: Intent.setAction).
+ public final String action;
+
+ // The data to set on the Intent (see: Intent.setData).
+ public final String data;
+
+ // The type to set on the Intent (see: Intent.setType).
+ public final String type;
+
+ // Flags for launching the Intent (see: Intent.setFlags).
+ public final Integer flags;
+
+ // Categories to set on the Intent (see: Intent.addCategory).
+ public final String[] category;
+
+ // Explicit application package to set on the Intent (see: Intent.setPackage).
+ public final String packageName;
+
+ // The list of all the extras to add to the Intent.
+ public final NamedVariant[] extras;
+
+ // Private request code to use for the Intent.
+ public final Integer requestCode;
+
+ public RemoteActionTemplate(
+ String title,
+ String description,
+ String action,
+ String data,
+ String type,
+ Integer flags,
+ String[] category,
+ String packageName,
+ NamedVariant[] extras,
+ Integer requestCode) {
+ this.title = title;
+ this.description = description;
+ this.action = action;
+ this.data = data;
+ this.type = type;
+ this.flags = flags;
+ this.category = category;
+ this.packageName = packageName;
+ this.extras = extras;
+ this.requestCode = requestCode;
+ }
+}
diff --git a/lang_id/common/lite_strings/str-cat.h b/lang_id/common/lite_strings/str-cat.h
index f24e6e6..f0c1682 100644
--- a/lang_id/common/lite_strings/str-cat.h
+++ b/lang_id/common/lite_strings/str-cat.h
@@ -92,12 +92,6 @@
dest->append(LiteStrCat(v4)); // NOLINT
}
-template <typename T1, typename T2, typename T3, typename T4, typename T5>
-inline void LiteStrAppend(string *dest, T1 v1, T2 v2, T3 v3, T4 v4, T5 v5) {
- LiteStrAppend(dest, v1, v2, v3, v4);
- dest->append(LiteStrCat(v5)); // NOLINT
-}
-
} // namespace mobile
} // namespace nlp_saft
diff --git a/lang_id/common/lite_strings/stringpiece.h b/lang_id/common/lite_strings/stringpiece.h
index d19ea41..59a2176 100644
--- a/lang_id/common/lite_strings/stringpiece.h
+++ b/lang_id/common/lite_strings/stringpiece.h
@@ -62,7 +62,6 @@
// Returns number of bytes of underlying data.
size_t size() const { return size_; }
- size_t length() const { return size_; }
// Returns true if this StringPiece does not refer to any characters.
bool empty() const { return size() == 0; }
diff --git a/models/actions_suggestions.model b/models/actions_suggestions.model
index ee60ce2..051a193 100644
--- a/models/actions_suggestions.model
+++ b/models/actions_suggestions.model
Binary files differ
diff --git a/models/textclassifier.en.model b/models/textclassifier.en.model
index 887d1df..602848e 100644
--- a/models/textclassifier.en.model
+++ b/models/textclassifier.en.model
Binary files differ
diff --git a/utils/calendar/CalendarJavaIcuLocalTest.java b/utils/calendar/CalendarJavaIcuLocalTest.java
new file mode 100644
index 0000000..3e2bc4b
--- /dev/null
+++ b/utils/calendar/CalendarJavaIcuLocalTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.calendar;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.thirdparty.robolectric.GoogleRobolectricTestRunner;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(GoogleRobolectricTestRunner.class)
+public class CalendarJavaIcuLocalTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("calendar-javaicu-test-lib");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/calendar/CalendarJavaIcuTest.java b/utils/calendar/CalendarJavaIcuTest.java
new file mode 100644
index 0000000..0295fbc
--- /dev/null
+++ b/utils/calendar/CalendarJavaIcuTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.calendar;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public class CalendarJavaIcuTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("calendar-javaicu-test-lib");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/calendar/calendar-common.h b/utils/calendar/calendar-common.h
index 7e606de..1c2c008 100644
--- a/utils/calendar/calendar-common.h
+++ b/utils/calendar/calendar-common.h
@@ -131,6 +131,9 @@
constexpr int relation_distance_mask =
DateParseData::Fields::RELATION_DISTANCE_FIELD;
switch (parse_data.relation) {
+ case DateParseData::Relation::UNSPECIFIED:
+ TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
+ return false;
case DateParseData::Relation::NEXT:
if (parse_data.field_set_mask & relation_type_mask) {
TC3_CALENDAR_CHECK(AdjustByRelation(parse_data.relation_type,
@@ -226,13 +229,13 @@
TCalendar* calendar) const {
const int distance_sign = distance < 0 ? -1 : 1;
switch (relation_type) {
- case DateParseData::MONDAY:
- case DateParseData::TUESDAY:
- case DateParseData::WEDNESDAY:
- case DateParseData::THURSDAY:
- case DateParseData::FRIDAY:
- case DateParseData::SATURDAY:
- case DateParseData::SUNDAY:
+ case DateParseData::RelationType::MONDAY:
+ case DateParseData::RelationType::TUESDAY:
+ case DateParseData::RelationType::WEDNESDAY:
+ case DateParseData::RelationType::THURSDAY:
+ case DateParseData::RelationType::FRIDAY:
+ case DateParseData::RelationType::SATURDAY:
+ case DateParseData::RelationType::SUNDAY:
if (!allow_today) {
// If we're not including the same day as the reference, skip it.
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
@@ -241,25 +244,25 @@
while (distance != 0) {
int day_of_week;
TC3_CALENDAR_CHECK(calendar->GetDayOfWeek(&day_of_week))
- if (day_of_week == relation_type) {
+ if (day_of_week == static_cast<int>(relation_type)) {
distance += -distance_sign;
if (distance == 0) break;
}
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance_sign))
}
return true;
- case DateParseData::DAY:
+ case DateParseData::RelationType::DAY:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(distance));
return true;
- case DateParseData::WEEK:
+ case DateParseData::RelationType::WEEK:
TC3_CALENDAR_CHECK(calendar->AddDayOfMonth(7 * distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfWeek(1))
return true;
- case DateParseData::MONTH:
+ case DateParseData::RelationType::MONTH:
TC3_CALENDAR_CHECK(calendar->AddMonth(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfMonth(1))
return true;
- case DateParseData::YEAR:
+ case DateParseData::RelationType::YEAR:
TC3_CALENDAR_CHECK(calendar->AddYear(distance))
TC3_CALENDAR_CHECK(calendar->SetDayOfYear(1))
return true;
diff --git a/utils/calendar/calendar-javaicu.cc b/utils/calendar/calendar-javaicu.cc
index 7b7f2fa..a5ae0ec 100644
--- a/utils/calendar/calendar-javaicu.cc
+++ b/utils/calendar/calendar-javaicu.cc
@@ -67,13 +67,20 @@
}
// We'll assume the day indices match later on, so verify it here.
- if (jni_cache_->calendar_sunday != DateParseData::SUNDAY ||
- jni_cache_->calendar_monday != DateParseData::MONDAY ||
- jni_cache_->calendar_tuesday != DateParseData::TUESDAY ||
- jni_cache_->calendar_wednesday != DateParseData::WEDNESDAY ||
- jni_cache_->calendar_thursday != DateParseData::THURSDAY ||
- jni_cache_->calendar_friday != DateParseData::FRIDAY ||
- jni_cache_->calendar_saturday != DateParseData::SATURDAY) {
+ if (jni_cache_->calendar_sunday !=
+ static_cast<int>(DateParseData::RelationType::SUNDAY) ||
+ jni_cache_->calendar_monday !=
+ static_cast<int>(DateParseData::RelationType::MONDAY) ||
+ jni_cache_->calendar_tuesday !=
+ static_cast<int>(DateParseData::RelationType::TUESDAY) ||
+ jni_cache_->calendar_wednesday !=
+ static_cast<int>(DateParseData::RelationType::WEDNESDAY) ||
+ jni_cache_->calendar_thursday !=
+ static_cast<int>(DateParseData::RelationType::THURSDAY) ||
+ jni_cache_->calendar_friday !=
+ static_cast<int>(DateParseData::RelationType::FRIDAY) ||
+ jni_cache_->calendar_saturday !=
+ static_cast<int>(DateParseData::RelationType::SATURDAY)) {
TC3_LOG(ERROR) << "day of the week indices mismatch";
return false;
}
diff --git a/utils/calendar/calendar_test.cc b/utils/calendar/calendar_test.cc
index a8c3af8..98a320d 100644
--- a/utils/calendar/calendar_test.cc
+++ b/utils/calendar/calendar_test.cc
@@ -37,7 +37,8 @@
bool result = calendarlib_.InterpretParseData(
DateParseData{/*field_set_mask=*/0, /*year=*/0, /*month=*/0,
/*day_of_month=*/0, /*hour=*/0, /*minute=*/0, /*second=*/0,
- /*ampm=*/0, /*zone_offset=*/0, /*dst_offset=*/0,
+ /*ampm=*/static_cast<DateParseData::AMPM>(0),
+ /*zone_offset=*/0, /*dst_offset=*/0,
static_cast<DateParseData::Relation>(0),
static_cast<DateParseData::RelationType>(0),
/*relation_distance=*/0},
@@ -146,7 +147,7 @@
/*hour=*/0,
/*minute=*/0,
/*second=*/0,
- /*ampm=*/0,
+ static_cast<DateParseData::AMPM>(0),
/*zone_offset=*/0,
/*dst_offset=*/0,
DateParseData::Relation::FUTURE,
@@ -166,7 +167,7 @@
/*hour=*/0,
/*minute=*/0,
/*second=*/0,
- /*ampm=*/0,
+ static_cast<DateParseData::AMPM>(0),
/*zone_offset=*/0,
/*dst_offset=*/0,
DateParseData::Relation::NEXT,
@@ -186,7 +187,7 @@
/*hour=*/0,
/*minute=*/0,
/*second=*/0,
- /*ampm=*/0,
+ static_cast<DateParseData::AMPM>(0),
/*zone_offset=*/0,
/*dst_offset=*/0,
DateParseData::Relation::NEXT_OR_SAME,
@@ -206,7 +207,7 @@
/*hour=*/0,
/*minute=*/0,
/*second=*/0,
- /*ampm=*/0,
+ static_cast<DateParseData::AMPM>(0),
/*zone_offset=*/0,
/*dst_offset=*/0,
DateParseData::Relation::LAST,
@@ -226,7 +227,7 @@
/*hour=*/0,
/*minute=*/0,
/*second=*/0,
- /*ampm=*/0,
+ static_cast<DateParseData::AMPM>(0),
/*zone_offset=*/0,
/*dst_offset=*/0,
DateParseData::Relation::PAST,
diff --git a/utils/intents/IntentGeneratorTest.java b/utils/intents/IntentGeneratorTest.java
new file mode 100644
index 0000000..f43ecc0
--- /dev/null
+++ b/utils/intents/IntentGeneratorTest.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.intents;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import androidx.test.InstrumentationRegistry;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public final class IntentGeneratorTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("intent-generator-test-lib");
+ }
+
+ private native boolean testsMain(Context context);
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain(InstrumentationRegistry.getContext())).isTrue();
+ }
+}
diff --git a/utils/intents/intent-config.fbs b/utils/intents/intent-config.fbs
index 93a6fc9..65e2a6b 100755
--- a/utils/intents/intent-config.fbs
+++ b/utils/intents/intent-config.fbs
@@ -73,7 +73,7 @@
// implements the Intent generation logic.
namespace libtextclassifier3;
table AndroidIntentFactoryOptions {
- entity:[libtextclassifier3.AndroidIntentFactoryEntityOptions];
+ entity:[AndroidIntentFactoryEntityOptions];
}
// Describes how intents should be generated for a particular entity type.
@@ -85,17 +85,17 @@
// List of generators for all the different types of intents that should
// be made available for the entity type.
- generator:[libtextclassifier3.AndroidIntentGeneratorOptions];
+ generator:[AndroidIntentGeneratorOptions];
}
// Configures a single Android Intent generator.
namespace libtextclassifier3;
table AndroidIntentGeneratorOptions {
// Strings for UI elements.
- strings:[libtextclassifier3.AndroidIntentGeneratorStrings];
+ strings:[AndroidIntentGeneratorStrings];
// Generator specific configuration.
- simple:libtextclassifier3.AndroidSimpleIntentGeneratorOptions;
+ simple:AndroidSimpleIntentGeneratorOptions;
}
// Language dependent configuration for an Android Intent generator.
@@ -122,7 +122,7 @@
name:string;
// The type of the extra to set.
- type:libtextclassifier3.AndroidSimpleIntentGeneratorExtraType;
+ type:AndroidSimpleIntentGeneratorExtraType;
string_:string;
@@ -133,7 +133,7 @@
// A condition that needs to be fulfilled for an Intent to get generated.
namespace libtextclassifier3;
table AndroidSimpleIntentGeneratorCondition {
- type:libtextclassifier3.AndroidSimpleIntentGeneratorConditionType;
+ type:AndroidSimpleIntentGeneratorConditionType;
string_:string;
@@ -161,15 +161,15 @@
type:string;
// The list of all the extras to add to the Intent.
- extra:[libtextclassifier3.AndroidSimpleIntentGeneratorExtra];
+ extra:[AndroidSimpleIntentGeneratorExtra];
// The list of all the variables that become available for substitution in
// the action, data, type and extra strings. To e.g. set a field to the value
// of the first variable, use "%0$s".
- variable:[libtextclassifier3.AndroidSimpleIntentGeneratorVariableType];
+ variable:[AndroidSimpleIntentGeneratorVariableType];
// The list of all conditions that need to be fulfilled for Intent generation.
- condition:[libtextclassifier3.AndroidSimpleIntentGeneratorCondition];
+ condition:[AndroidSimpleIntentGeneratorCondition];
}
// Describes how intents should be generated for a particular entity type.
@@ -187,6 +187,6 @@
// Describes how intents for the various entity types should be generated.
namespace libtextclassifier3;
table IntentFactoryModel {
- entities:[libtextclassifier3.IntentFactoryModel_.IntentGenerator];
+ entities:[IntentFactoryModel_.IntentGenerator];
}
diff --git a/utils/intents/intent-generator.cc b/utils/intents/intent-generator.cc
new file mode 100644
index 0000000..12ded7d
--- /dev/null
+++ b/utils/intents/intent-generator.cc
@@ -0,0 +1,520 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/intent-generator.h"
+
+#include <map>
+
+#include "utils/base/logging.h"
+#include "utils/java/string_utils.h"
+#include "utils/lua-utils.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+
+static constexpr const char* kEntityTextKey = "text";
+static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
+static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
+
+// An Android specific Lua environment with JNI backed callbacks.
+class JniLuaEnvironment : public LuaEnvironment {
+ public:
+ JniLuaEnvironment(const JniCache* jni_cache, const jobject context,
+ StringPiece entity_text, int64 event_time_ms_usec,
+ int64 reference_time_ms_utc,
+ const std::map<std::string, Variant>& extra);
+
+ // Runs an intent generator snippet.
+ std::vector<RemoteActionTemplate> RunIntentGenerator(
+ const std::string& generator_snippet);
+
+ protected:
+ int HandleCallback(int callback_id) override;
+
+ private:
+ // Callback handlers.
+ int HandleExternalCallback();
+ int HandleExtrasLookup();
+ int HandleAndroidCallback();
+ int HandleUserRestrictionsCallback();
+ int HandleUrlEncode();
+ int HandleUrlSchema();
+
+ // Reads and create a RemoteAction result from Lua.
+ RemoteActionTemplate ReadRemoteActionTemplateResult();
+
+ // Reads the extras from the Lua result.
+ void ReadExtras(std::map<std::string, Variant>* extra);
+
+ // Reads the intent categories array from a Lua result.
+ void ReadCategories(std::vector<std::string>* category);
+
+ // Retrieves user manager if not previously done.
+ bool RetrieveUserManager();
+
+ // Builtins.
+ enum CallbackId {
+ CALLBACK_ID_EXTERNAL = 0,
+ CALLBACK_ID_EXTRAS = 1,
+ CALLBACK_ID_ANDROID = 2,
+ CALLBACK_ID_USER_PERMISSIONS = 3,
+ CALLBACK_ID_URL_ENCODE = 4,
+ CALLBACK_ID_URL_SCHEMA = 5,
+ };
+
+ JNIEnv* jenv_;
+ const JniCache* jni_cache_;
+ jobject context_;
+ StringPiece entity_text_;
+ int64 event_time_ms_usec_;
+ int64 reference_time_ms_utc_;
+ const std::map<std::string, Variant>& extra_;
+
+ ScopedGlobalRef<jobject> usermanager_;
+ // Whether we previously attempted to retrieve the UserManager before.
+ bool usermanager_retrieved_;
+};
+
+JniLuaEnvironment::JniLuaEnvironment(
+ const JniCache* jni_cache, const jobject context, StringPiece entity_text,
+ int64 event_time_ms_usec, int64 reference_time_ms_utc,
+ const std::map<std::string, Variant>& extra)
+ : jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
+ jni_cache_(jni_cache),
+ context_(context),
+ entity_text_(entity_text),
+ event_time_ms_usec_(event_time_ms_usec),
+ reference_time_ms_utc_(reference_time_ms_utc),
+ extra_(extra),
+ usermanager_(/*object=*/nullptr,
+ /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
+ usermanager_retrieved_(false) {
+ LoadDefaultLibraries();
+
+ // Setup callbacks.
+ // This exposes an `external` object with the following fields:
+ // * extras: the bundle with all information about a classification.
+ // * android: callbacks into specific android provided methods.
+ // * android.user_restrictions: callbacks to check user permissions.
+ SetupTableLookupCallback("external", CALLBACK_ID_EXTERNAL);
+
+ // extras
+ lua_pushstring(state_, "extras");
+ SetupTableLookupCallback("extras", CALLBACK_ID_EXTRAS);
+ lua_settable(state_, -3);
+
+ // android
+ lua_pushstring(state_, "android");
+ SetupTableLookupCallback("android", CALLBACK_ID_ANDROID);
+
+ // android.user_restrictions
+ lua_pushstring(state_, "user_restrictions");
+ SetupTableLookupCallback("user_restrictions", CALLBACK_ID_USER_PERMISSIONS);
+ lua_settable(state_, -3);
+ lua_settable(state_, -3);
+
+ lua_setglobal(state_, "external");
+}
+
+int JniLuaEnvironment::HandleCallback(int callback_id) {
+ switch (callback_id) {
+ case CALLBACK_ID_EXTERNAL:
+ return HandleExternalCallback();
+ case CALLBACK_ID_EXTRAS:
+ return HandleExtrasLookup();
+ case CALLBACK_ID_ANDROID:
+ return HandleAndroidCallback();
+ case CALLBACK_ID_USER_PERMISSIONS:
+ return HandleUserRestrictionsCallback();
+ case CALLBACK_ID_URL_ENCODE:
+ return HandleUrlEncode();
+ case CALLBACK_ID_URL_SCHEMA:
+ return HandleUrlSchema();
+ default:
+ TC3_LOG(ERROR) << "Unhandled callback: " << callback_id;
+ return LUA_ERRRUN;
+ }
+}
+
+int JniLuaEnvironment::HandleExternalCallback() {
+ const char* key = luaL_checkstring(state_, 2);
+ if (strcmp(kReferenceTimeUsecKey, key) == 0) {
+ lua_pushinteger(state_, reference_time_ms_utc_);
+ return LUA_YIELD;
+ } else {
+ TC3_LOG(ERROR) << "Undefined external access " << key;
+ return LUA_ERRRUN;
+ }
+}
+
+int JniLuaEnvironment::HandleExtrasLookup() {
+ const char* key = luaL_checkstring(state_, 2);
+ if (strcmp(kEntityTextKey, key) == 0) {
+ lua_pushlstring(state_, entity_text_.data(), entity_text_.length());
+ } else if (strcmp(kTimeUsecKey, key) == 0) {
+ lua_pushinteger(state_, event_time_ms_usec_);
+ } else {
+ const auto it = extra_.find(std::string(key));
+ if (it == extra_.end()) {
+ TC3_LOG(ERROR) << "Undefined extra lookup " << key;
+ return LUA_ERRRUN;
+ }
+ PushValue(it->second);
+ }
+ return LUA_YIELD;
+}
+
+int JniLuaEnvironment::HandleAndroidCallback() {
+ const char* key = luaL_checkstring(state_, 2);
+ if (strcmp("package_name", key) == 0) {
+ ScopedLocalRef<jstring> package_name_str(
+ static_cast<jstring>(jenv_->CallObjectMethod(
+ context_, jni_cache_->context_get_package_name)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling Context.getPackageName";
+ return LUA_ERRRUN;
+ }
+ ScopedStringChars package_name =
+ GetScopedStringChars(jenv_, package_name_str.get());
+ lua_pushstring(state_, reinterpret_cast<const char*>(package_name.get()));
+ return LUA_YIELD;
+ } else if (strcmp("urlencode", key) == 0) {
+ PushCallback(CALLBACK_ID_URL_ENCODE);
+ return LUA_YIELD;
+ } else if (strcmp("url_schema", key) == 0) {
+ PushCallback(CALLBACK_ID_URL_SCHEMA);
+ return LUA_YIELD;
+ } else {
+ TC3_LOG(ERROR) << "Undefined android reference " << key;
+ return LUA_ERRRUN;
+ }
+}
+
+int JniLuaEnvironment::HandleUserRestrictionsCallback() {
+ if (jni_cache_->usermanager_class == nullptr ||
+ jni_cache_->usermanager_get_user_restrictions == nullptr) {
+ // UserManager is only available for API level >= 17 and
+ // getUserRestrictions only for API level >= 18, so we just return false
+ // normally here.
+ lua_pushboolean(state_, false);
+ return LUA_YIELD;
+ }
+
+ // Get user manager if not previously retrieved.
+ if (!RetrieveUserManager()) {
+ TC3_LOG(ERROR) << "Error retrieving user manager.";
+ return LUA_ERRRUN;
+ }
+
+ ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
+ usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
+ if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
+ TC3_LOG(ERROR) << "Error calling getUserRestrictions";
+ return LUA_ERRRUN;
+ }
+ ScopedLocalRef<jstring> key(jenv_->NewStringUTF(luaL_checkstring(state_, 2)));
+ if (key == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ return LUA_ERRRUN;
+ }
+ const bool permission = jenv_->CallBooleanMethod(
+ bundle.get(), jni_cache_->bundle_get_boolean, key.get());
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error getting bundle value";
+ lua_pushboolean(state_, false);
+ } else {
+ lua_pushboolean(state_, permission);
+ }
+ return LUA_YIELD;
+}
+
+int JniLuaEnvironment::HandleUrlEncode() {
+ // Call Java URL encoder.
+ ScopedLocalRef<jstring> input_str(
+ jenv_->NewStringUTF(luaL_checkstring(state_, 1)));
+ if (input_str == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null.";
+ return LUA_ERRRUN;
+ }
+ ScopedLocalRef<jstring> encoding_str(jenv_->NewStringUTF("UTF-8"));
+ ScopedLocalRef<jstring> encoded_str(
+ static_cast<jstring>(jenv_->CallStaticObjectMethod(
+ jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
+ input_str.get(), encoding_str.get())));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
+ return LUA_ERRRUN;
+ }
+ ScopedStringChars encoded = GetScopedStringChars(jenv_, encoded_str.get());
+ lua_pushstring(state_, encoded.get());
+ return LUA_YIELD;
+}
+
+int JniLuaEnvironment::HandleUrlSchema() {
+ // Call to Java URI parser.
+ ScopedLocalRef<jstring> url_str(
+ jenv_->NewStringUTF(luaL_checkstring(state_, 1)));
+ if (url_str == nullptr) {
+ TC3_LOG(ERROR) << "Expected string, got null";
+ return LUA_ERRRUN;
+ }
+ // Try to parse uri and get scheme.
+ ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
+ jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
+ if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
+ TC3_LOG(ERROR) << "Error calling Uri.parse";
+ return LUA_ERRRUN;
+ }
+ ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
+ jenv_->CallObjectMethod(uri.get(), jni_cache_->uri_get_scheme)));
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling Uri.getScheme";
+ return LUA_ERRRUN;
+ }
+ if (scheme_str == nullptr) {
+ lua_pushnil(state_);
+ } else {
+ ScopedStringChars scheme = GetScopedStringChars(jenv_, scheme_str.get());
+ lua_pushstring(state_, scheme.get());
+ }
+ return LUA_YIELD;
+}
+
+bool JniLuaEnvironment::RetrieveUserManager() {
+ if (context_ == nullptr) {
+ return false;
+ }
+ if (usermanager_retrieved_) {
+ return (usermanager_ != nullptr);
+ }
+ usermanager_retrieved_ = true;
+ ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
+ jobject usermanager_ref = jenv_->CallObjectMethod(
+ context_, jni_cache_->context_get_system_service, service.get());
+ if (jni_cache_->ExceptionCheckAndClear()) {
+ TC3_LOG(ERROR) << "Error calling getSystemService.";
+ return false;
+ }
+ usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
+ return (usermanager_ != nullptr);
+}
+
+RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
+ RemoteActionTemplate result;
+ // Read intent template.
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const char* key = lua_tostring(state_, /*idx=*/-2);
+ if (strcmp("title", key) == 0) {
+ result.title = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("description", key) == 0) {
+ result.description = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("action", key) == 0) {
+ result.action = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("data", key) == 0) {
+ result.data = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("type", key) == 0) {
+ result.type = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("flags", key) == 0) {
+ result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ } else if (strcmp("package_name", key) == 0) {
+ result.package_name = lua_tostring(state_, /*idx=*/-1);
+ } else if (strcmp("request_code", key) == 0) {
+ result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
+ } else if (strcmp("category", key) == 0) {
+ ReadCategories(&result.category);
+ } else if (strcmp("extra", key) == 0) {
+ ReadExtras(&result.extra);
+ } else {
+ TC3_LOG(INFO) << "Unknown entry: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ lua_pop(state_, 1);
+ return result;
+}
+
+void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
+ // Read category array.
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected categories table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ category->push_back(lua_tostring(state_, /*idx=*/-1));
+ lua_pop(state_, 1);
+ }
+}
+
+void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected extras table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ // Each entry is a table specifying name and value.
+ // The value is specified via a type specific field as Lua doesn't allow
+ // to easily distinguish between different number types.
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected a table for an extra, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ return;
+ }
+ std::string name;
+ Variant value;
+
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ const char* key = lua_tostring(state_, /*idx=*/-2);
+ if (strcmp("name", key) == 0) {
+ name = std::string(lua_tostring(state_, /*idx=*/-1));
+ } else if (strcmp("int_value", key) == 0) {
+ value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (strcmp("long_value", key) == 0) {
+ value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (strcmp("float_value", key) == 0) {
+ value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
+ } else if (strcmp("bool_value", key) == 0) {
+ value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
+ } else if (strcmp("string_value", key) == 0) {
+ value = Variant(lua_tostring(state_, /*idx=*/-1));
+ } else {
+ TC3_LOG(INFO) << "Unknown extra field: " << key;
+ }
+ lua_pop(state_, 1);
+ }
+ if (!name.empty()) {
+ (*extra)[name] = value;
+ } else {
+ TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
+ }
+ lua_pop(state_, 1);
+ }
+}
+} // namespace
+
+std::vector<RemoteActionTemplate> JniLuaEnvironment::RunIntentGenerator(
+ const std::string& generator_snippet) {
+ int status = luaL_loadstring(state_, generator_snippet.data());
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
+ return {};
+ }
+ status = lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0);
+ if (status != LUA_OK) {
+ TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
+ return {};
+ }
+ // Read result.
+ if (lua_gettop(state_) != 1 || lua_type(state_, 1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Unexpected result for snippet.";
+ return {};
+ }
+
+ // Read remote action templates array.
+ std::vector<RemoteActionTemplate> result;
+ lua_pushnil(state_);
+ while (lua_next(state_, /*idx=*/-2)) {
+ if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
+ TC3_LOG(ERROR) << "Expected intent table, got: "
+ << lua_type(state_, /*idx=*/-1);
+ lua_pop(state_, 1);
+ continue;
+ }
+ result.push_back(ReadRemoteActionTemplateResult());
+ }
+ lua_pop(state_, /*n=*/1);
+
+ // Check that we correctly cleaned-up the state.
+ const int stack_size = lua_gettop(state_);
+ if (stack_size > 0) {
+ TC3_LOG(ERROR) << "Unexpected stack size.";
+ lua_settop(state_, 0);
+ return {};
+ }
+
+ return result;
+}
+
+IntentGenerator::IntentGenerator(const IntentFactoryModel* options,
+ const std::shared_ptr<JniCache>& jni_cache,
+ const jobject context)
+ : options_(options), jni_cache_(jni_cache), context_(context) {
+ if (options_ == nullptr || options_->entities() == nullptr) {
+ return;
+ }
+
+ // Normally this check would be performed by the Java compiler and we wouldn't
+ // need to worry about it here. But we can't depend on Android's SDK in Java,
+ // so we check the instance type here.
+ if (context != nullptr && !jni_cache->GetEnv()->IsInstanceOf(
+ context, jni_cache->context_class.get())) {
+ TC3_LOG(ERROR) << "Provided context is not an android.content.Context";
+ return;
+ }
+
+ if (options_ != nullptr && options_->entities() != nullptr) {
+ for (const IntentFactoryModel_::IntentGenerator* generator :
+ *options_->entities()) {
+ generators_[generator->entity_type()->str()] =
+ std::string(reinterpret_cast<const char*>(
+ generator->lua_template_generator()->data()),
+ generator->lua_template_generator()->size());
+ }
+ }
+}
+
+std::vector<RemoteActionTemplate> IntentGenerator::GenerateIntents(
+ const ClassificationResult& classification, int64 reference_time_ms_usec,
+ StringPiece entity_text) const {
+ if (options_ == nullptr) {
+ return {};
+ }
+
+ // Retrieve generator for specified entity.
+ auto it = generators_.find(classification.collection);
+ if (it == generators_.end()) {
+ TC3_LOG(INFO) << "Unknown entity: " << classification.collection;
+ return {};
+ }
+
+ std::unique_ptr<JniLuaEnvironment> interpreter(
+ new JniLuaEnvironment(jni_cache_.get(), context_, entity_text,
+ classification.datetime_parse_result.time_ms_utc,
+ reference_time_ms_usec, classification.extra));
+
+ return interpreter->RunIntentGenerator(it->second);
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/intents/intent-generator.h b/utils/intents/intent-generator.h
new file mode 100644
index 0000000..e779e94
--- /dev/null
+++ b/utils/intents/intent-generator.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// TODO(smillius): Move intent generation code outside of utils.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
+
+#include <jni.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "annotator/types.h"
+#include "utils/intents/intent-config_generated.h"
+#include "utils/java/jni-cache.h"
+#include "utils/java/scoped_local_ref.h"
+#include "utils/optional.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// A template with parameters for an Android remote action.
+struct RemoteActionTemplate {
+ // Title shown for the action (see: RemoteAction.getTitle).
+ Optional<std::string> title;
+
+ // Description shown for the action (see: RemoteAction.getContentDescription).
+ Optional<std::string> description;
+
+ // The action to set on the Intent (see: Intent.setAction).
+ Optional<std::string> action;
+
+ // The data to set on the Intent (see: Intent.setData).
+ Optional<std::string> data;
+
+ // The type to set on the Intent (see: Intent.setType).
+ Optional<std::string> type;
+
+ // Flags for launching the Intent (see: Intent.setFlags).
+ Optional<int> flags;
+
+ // Categories to set on the Intent (see: Intent.addCategory).
+ std::vector<std::string> category;
+
+ // Explicit application package to set on the Intent (see: Intent.setPackage).
+ Optional<std::string> package_name;
+
+ // The list of all the extras to add to the Intent.
+ std::map<std::string, Variant> extra;
+
+ // Private request code ot use for the Intent.
+ Optional<int> request_code;
+};
+
+// Helper class to generate Android intents for text classifier results.
+class IntentGenerator {
+ public:
+ explicit IntentGenerator(const IntentFactoryModel* options,
+ const std::shared_ptr<JniCache>& jni_cache,
+ const jobject context);
+
+ // Generate intents for a classification result.
+ std::vector<RemoteActionTemplate> GenerateIntents(
+ const ClassificationResult& classification, int64 reference_time_ms_utc,
+ StringPiece entity_text) const;
+
+ private:
+ const IntentFactoryModel* options_;
+ std::shared_ptr<JniCache> jni_cache_;
+ const jobject context_;
+ std::map<std::string, std::string> generators_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
diff --git a/utils/intents/jni.cc b/utils/intents/jni.cc
new file mode 100644
index 0000000..4a8f07b
--- /dev/null
+++ b/utils/intents/jni.cc
@@ -0,0 +1,200 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/intents/jni.h"
+#include <memory>
+#include "utils/intents/intent-generator.h"
+#include "utils/java/scoped_local_ref.h"
+
+namespace libtextclassifier3 {
+
+// The macros below are intended to reduce the boilerplate and avoid
+// easily introduced copy/paste errors.
+#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
+#define TC3_GET_CLASS(FIELD, NAME) \
+ handler->FIELD.reset(env->FindClass(NAME)); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME;
+#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
+ handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
+ TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
+
+std::unique_ptr<RemoteActionTemplatesHandler>
+RemoteActionTemplatesHandler::Create(JNIEnv* env) {
+ if (env == nullptr) {
+ return nullptr;
+ }
+ std::unique_ptr<RemoteActionTemplatesHandler> handler(
+ new RemoteActionTemplatesHandler(env));
+
+ TC3_GET_CLASS(string_class_, "java/lang/String");
+ TC3_GET_CLASS(integer_class_, "java/lang/Integer");
+ TC3_GET_METHOD(integer_class_, integer_init_, "<init>", "(I)V");
+
+ TC3_GET_CLASS(remote_action_template_class_,
+ TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR);
+ TC3_GET_METHOD(
+ remote_action_template_class_, remote_action_template_init_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/Integer;[Ljava/"
+ "lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR ";Ljava/lang/Integer;)V");
+
+ TC3_GET_CLASS(named_variant_class_,
+ TC3_PACKAGE_PATH TC3_NAMED_VARIANT_CLASS_NAME_STR);
+
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_int_, "<init>",
+ "(Ljava/lang/String;I)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_long_, "<init>",
+ "(Ljava/lang/String;J)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_float_, "<init>",
+ "(Ljava/lang/String;F)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_double_, "<init>",
+ "(Ljava/lang/String;D)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_bool_, "<init>",
+ "(Ljava/lang/String;Z)V");
+ TC3_GET_METHOD(named_variant_class_, named_variant_from_string_, "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;)V");
+
+ return handler;
+}
+
+jstring RemoteActionTemplatesHandler::AsUTF8String(
+ const Optional<std::string>& optional) {
+ return (optional.has_value() ? env_->NewStringUTF(optional.value().c_str())
+ : nullptr);
+}
+
+jobject RemoteActionTemplatesHandler::AsInteger(const Optional<int>& optional) {
+ return (optional.has_value()
+ ? env_->NewObject(integer_class_.get(), integer_init_,
+ optional.value())
+ : nullptr);
+}
+
+jobjectArray RemoteActionTemplatesHandler::AsStringArray(
+ const std::vector<std::string>& values) {
+ if (values.empty()) {
+ return nullptr;
+ }
+ jobjectArray result =
+ env_->NewObjectArray(values.size(), string_class_.get(), nullptr);
+ if (result == nullptr) {
+ return nullptr;
+ }
+ for (int k = 0; k < values.size(); k++) {
+ env_->SetObjectArrayElement(result, k,
+ env_->NewStringUTF(values[k].c_str()));
+ }
+ return result;
+}
+
+jobject RemoteActionTemplatesHandler::AsNamedVariant(const std::string& name,
+ const Variant& value) {
+ jstring jname = env_->NewStringUTF(name.c_str());
+ if (jname == nullptr) {
+ return nullptr;
+ }
+ switch (value.GetType()) {
+ case VariantValue_::Type_INT_VALUE:
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_int_, jname, value.IntValue());
+ case VariantValue_::Type_INT64_VALUE:
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_long_, jname,
+ value.Int64Value());
+ case VariantValue_::Type_FLOAT_VALUE:
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_float_, jname,
+ value.FloatValue());
+ case VariantValue_::Type_DOUBLE_VALUE:
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_double_, jname,
+ value.DoubleValue());
+ case VariantValue_::Type_BOOL_VALUE:
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_bool_, jname,
+ value.BoolValue());
+ case VariantValue_::Type_STRING_VALUE: {
+ jstring jstring = env_->NewStringUTF(value.StringValue().c_str());
+ if (jstring == nullptr) {
+ return nullptr;
+ }
+ return env_->NewObject(named_variant_class_.get(),
+ named_variant_from_string_, jname, jstring);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+jobjectArray RemoteActionTemplatesHandler::AsNamedVariantArray(
+ const std::map<std::string, Variant>& values) {
+ if (values.empty()) {
+ return nullptr;
+ }
+ jobjectArray result =
+ env_->NewObjectArray(values.size(), named_variant_class_.get(), nullptr);
+ int element_index = 0;
+ for (auto key_value_pair : values) {
+ if (!key_value_pair.second.HasValue()) {
+ element_index++;
+ continue;
+ }
+ ScopedLocalRef<jobject> named_extra(
+ AsNamedVariant(key_value_pair.first, key_value_pair.second), env_);
+ if (named_extra == nullptr) {
+ return nullptr;
+ }
+ env_->SetObjectArrayElement(result, element_index, named_extra.get());
+ element_index++;
+ }
+ return result;
+}
+
+jobjectArray RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions) {
+ const jobjectArray results = env_->NewObjectArray(
+ remote_actions.size(), remote_action_template_class_.get(), nullptr);
+ if (results == nullptr) {
+ return nullptr;
+ }
+ for (int i = 0; i < remote_actions.size(); i++) {
+ const RemoteActionTemplate& remote_action = remote_actions[i];
+ const jstring title = AsUTF8String(remote_action.title);
+ const jstring description = AsUTF8String(remote_action.description);
+ const jstring action = AsUTF8String(remote_action.action);
+ const jstring data = AsUTF8String(remote_action.data);
+ const jstring type = AsUTF8String(remote_action.type);
+ const jobject flags = AsInteger(remote_action.flags);
+ const jobjectArray category = AsStringArray(remote_action.category);
+ const jstring package = AsUTF8String(remote_action.package_name);
+ const jobjectArray extra = AsNamedVariantArray(remote_action.extra);
+ const jobject request_code = AsInteger(remote_action.request_code);
+ ScopedLocalRef<jobject> result(
+ env_->NewObject(remote_action_template_class_.get(),
+ remote_action_template_init_, title, description,
+ action, data, type, flags, category, package, extra,
+ request_code),
+ env_);
+ if (result == nullptr) {
+ return nullptr;
+ }
+ env_->SetObjectArrayElement(results, i, result.get());
+ }
+ return results;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/intents/jni.h b/utils/intents/jni.h
new file mode 100644
index 0000000..d84a51a
--- /dev/null
+++ b/utils/intents/jni.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
+
+#include <jni.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/intents/intent-generator.h"
+#include "utils/java/jni-base.h"
+#include "utils/optional.h"
+#include "utils/variant.h"
+
+#ifndef TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME RemoteActionTemplate
+#endif
+
+#define TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME)
+
+#ifndef TC3_NAMED_VARIANT_CLASS_NAME
+#define TC3_NAMED_VARIANT_CLASS_NAME NamedVariant
+#endif
+
+#define TC3_NAMED_VARIANT_CLASS_NAME_STR \
+ TC3_ADD_QUOTES(TC3_NAMED_VARIANT_CLASS_NAME)
+
+namespace libtextclassifier3 {
+
+// A helper class to create RemoteActionTemplate object from model results.
+class RemoteActionTemplatesHandler {
+ public:
+ static std::unique_ptr<RemoteActionTemplatesHandler> Create(JNIEnv* env);
+
+ explicit RemoteActionTemplatesHandler(JNIEnv* env)
+ : env_(env),
+ string_class_(nullptr, env),
+ integer_class_(nullptr, env),
+ remote_action_template_class_(nullptr, env),
+ named_variant_class_(nullptr, env) {}
+
+ jstring AsUTF8String(const Optional<std::string>& optional);
+ jobject AsInteger(const Optional<int>& optional);
+ jobjectArray AsStringArray(const std::vector<std::string>& values);
+ jobject AsNamedVariant(const std::string& name, const Variant& value);
+ jobjectArray AsNamedVariantArray(
+ const std::map<std::string, Variant>& values);
+
+ jobjectArray RemoteActionTemplatesToJObjectArray(
+ const std::vector<RemoteActionTemplate>& remote_actions);
+
+ private:
+ JNIEnv* env_;
+
+ // java.lang.String
+ ScopedLocalRef<jclass> string_class_;
+
+ // java.lang.Integer
+ ScopedLocalRef<jclass> integer_class_;
+ jmethodID integer_init_ = nullptr;
+
+ // RemoteActionTemplate
+ ScopedLocalRef<jclass> remote_action_template_class_;
+ jmethodID remote_action_template_init_ = nullptr;
+
+ // NamedVariant
+ ScopedLocalRef<jclass> named_variant_class_;
+ jmethodID named_variant_from_int_ = nullptr;
+ jmethodID named_variant_from_long_ = nullptr;
+ jmethodID named_variant_from_float_ = nullptr;
+ jmethodID named_variant_from_double_ = nullptr;
+ jmethodID named_variant_from_bool_ = nullptr;
+ jmethodID named_variant_from_string_ = nullptr;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
diff --git a/utils/java/jni-cache.h b/utils/java/jni-cache.h
index 18675fc..8163817 100644
--- a/utils/java/jni-cache.h
+++ b/utils/java/jni-cache.h
@@ -109,7 +109,6 @@
ScopedGlobalRef<jclass> urlencoder_class;
jmethodID urlencoder_encode = nullptr;
-#ifdef __ANDROID__
// android.content.Context
ScopedGlobalRef<jclass> context_class;
jmethodID context_get_package_name = nullptr;
@@ -127,7 +126,6 @@
// android.os.Bundle
ScopedGlobalRef<jclass> bundle_class;
jmethodID bundle_get_boolean = nullptr;
-#endif
// Helper to convert lib3 UnicodeText to Java strings.
ScopedLocalRef<jstring> ConvertToJavaString(const UnicodeText& text) const;
diff --git a/utils/lua-utils.cc b/utils/lua-utils.cc
new file mode 100644
index 0000000..ce44192
--- /dev/null
+++ b/utils/lua-utils.cc
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/lua-utils.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lauxlib.h"
+#include "lualib.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+namespace {
+static const int kEnvIndex = 1;
+static const int kCallbackIdIndex = 2;
+
+static const luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
+ {LUA_TABLIBNAME, luaopen_table},
+ {LUA_STRLIBNAME, luaopen_string},
+ {LUA_BITLIBNAME, luaopen_bit32},
+ {LUA_MATHLIBNAME, luaopen_math},
+ {nullptr, nullptr}};
+
+} // namespace
+
+void LuaEnvironment::LoadDefaultLibraries() {
+ for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
+ luaL_requiref(state_, lib->name, lib->func, 1);
+ lua_pop(state_, 1); /* remove lib */
+ }
+}
+
+int LuaEnvironment::CallbackDispatch(lua_State *state) {
+ // Fetch reference to our environment.
+ LuaEnvironment *env = static_cast<LuaEnvironment *>(
+ lua_touserdata(state, lua_upvalueindex(kEnvIndex)));
+ TC3_CHECK_EQ(env->state_, state);
+ int callback_id = lua_tointeger(state, lua_upvalueindex(kCallbackIdIndex));
+ return env->HandleCallback(callback_id);
+}
+
+LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
+
+LuaEnvironment::~LuaEnvironment() {
+ if (state_ != nullptr) {
+ lua_close(state_);
+ }
+}
+
+void LuaEnvironment::PushCallback(int callback_id) {
+ lua_pushlightuserdata(state_, static_cast<void *>(this));
+ lua_pushnumber(state_, callback_id);
+ lua_pushcclosure(state_, CallbackDispatch, 2);
+}
+
+void LuaEnvironment::SetupTableLookupCallback(const char *name,
+ int callback_id) {
+ lua_newtable(state_);
+ luaL_newmetatable(state_, name);
+ PushCallback(callback_id);
+ lua_setfield(state_, -2, "__index");
+ lua_setmetatable(state_, -2);
+}
+
+void LuaEnvironment::PushValue(const Variant &value) {
+ if (value.HasInt()) {
+ lua_pushnumber(state_, value.IntValue());
+ } else if (value.HasInt64()) {
+ lua_pushnumber(state_, value.Int64Value());
+ } else if (value.HasBool()) {
+ lua_pushboolean(state_, value.BoolValue());
+ } else if (value.HasFloat()) {
+ lua_pushnumber(state_, value.FloatValue());
+ } else if (value.HasDouble()) {
+ lua_pushnumber(state_, value.DoubleValue());
+ } else if (value.HasString()) {
+ lua_pushstring(state_, value.StringValue().data());
+ } else {
+ TC3_LOG(FATAL) << "Unknown value type.";
+ }
+}
+
+int LuaEnvironment::HandleCallback(int callback_id) { return LUA_ERRRUN; }
+
+} // namespace libtextclassifier3
diff --git a/utils/lua-utils.h b/utils/lua-utils.h
new file mode 100644
index 0000000..d7ce65c
--- /dev/null
+++ b/utils/lua-utils.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
+
+#include "utils/strings/stringpiece.h"
+#include "utils/variant.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "lua.h"
+#ifdef __cplusplus
+}
+#endif
+
+namespace libtextclassifier3 {
+
+class LuaEnvironment {
+ public:
+ virtual ~LuaEnvironment();
+ LuaEnvironment();
+
+ protected:
+ // Loads default libraries.
+ void LoadDefaultLibraries();
+
+ // Provides a callback to Lua with given id, which will be dispatched to
+ // `HandleCallback(id)` when called. This is useful when we need to call
+ // native C++ code from within Lua code.
+ void PushCallback(int callback_id);
+
+ // Setup a named table that callsback whenever a member is accessed.
+ // This allows to lazily provide required information to the script.
+ // `HandleCallback` will be called upon callback invocation with the
+ // callback identifier provided.
+ void SetupTableLookupCallback(const char *name, int callback_id);
+
+ // Called from Lua when invoking a callback either by
+ // `PushCallback` or `SetupTableLookupCallback`.
+ virtual int HandleCallback(int callback_id);
+
+ void PushValue(const Variant &value);
+
+ lua_State *state_;
+
+ private:
+ static int CallbackDispatch(lua_State *state);
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
diff --git a/utils/named-extra.fbs b/utils/named-extra.fbs
new file mode 100755
index 0000000..f1f6f21
--- /dev/null
+++ b/utils/named-extra.fbs
@@ -0,0 +1,44 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+namespace libtextclassifier3.VariantValue_;
+enum Type : int {
+ NONE = 0,
+ INT_VALUE = 1,
+ INT64_VALUE = 2,
+ FLOAT_VALUE = 3,
+ DOUBLE_VALUE = 4,
+ BOOL_VALUE = 5,
+ STRING_VALUE = 6,
+}
+
+namespace libtextclassifier3;
+table VariantValue {
+ type:VariantValue_.Type;
+ int_value:int;
+ int64_value:long;
+ float_value:float;
+ double_value:double;
+ bool_value:bool;
+ string_value:string;
+}
+
+namespace libtextclassifier3;
+table NamedVariant {
+ name:string;
+ value:VariantValue;
+}
+
diff --git a/utils/test-utils.cc b/utils/test-utils.cc
new file mode 100644
index 0000000..e37105a
--- /dev/null
+++ b/utils/test-utils.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/test-utils.h"
+
+#include <iterator>
+
+#include "utils/strings/split.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+using libtextclassifier3::Token;
+
+// Returns a list of Tokens for given input string. Can't handle non-ASCII
+// input.
+std::vector<Token> TokenizeAsciiOnSpace(const std::string& text) {
+ std::vector<Token> result;
+ for (const StringPiece token : strings::Split(text, ' ')) {
+ const int start_offset = std::distance(text.data(), token.data());
+ const int token_length = token.length();
+ result.push_back(
+ Token{token.ToString(), start_offset, start_offset + token_length});
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/test-utils.h b/utils/test-utils.h
new file mode 100644
index 0000000..7e227dc
--- /dev/null
+++ b/utils/test-utils.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for tests.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+#define LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
+
+#include <string>
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Returns a list of Tokens for given input string. Can't handle non-ASCII
+// input.
+std::vector<Token> TokenizeAsciiOnSpace(const std::string& text);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TEST_UTILS_H_
diff --git a/utils/testing/annotator.h b/utils/testing/annotator.h
new file mode 100644
index 0000000..b988d0b
--- /dev/null
+++ b/utils/testing/annotator.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Helper utilities for testing Annotator.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
+
+#include <memory>
+#include <string>
+
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace libtextclassifier3 {
+
+// Loads FlatBuffer model, unpacks it and passes it to the visitor_fn so that it
+// can modify it. Afterwards the modified unpacked model is serialized back to a
+// flatbuffer.
+template <typename Fn>
+std::string ModifyAnnotatorModel(const std::string& model_flatbuffer,
+ Fn visitor_fn) {
+ std::unique_ptr<ModelT> unpacked_model =
+ UnPackModel(model_flatbuffer.c_str());
+
+ visitor_fn(unpacked_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+
+ return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TESTING_ANNOTATOR_H_
diff --git a/utils/utf8/UniLibJavaIcuTest.java b/utils/utf8/UniLibJavaIcuTest.java
new file mode 100644
index 0000000..245beee
--- /dev/null
+++ b/utils/utf8/UniLibJavaIcuTest.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.android.textclassifier.utils.utf8;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** This is just a launcher of the tests because we need a valid JNIEnv in the C++ test. */
+@RunWith(JUnit4.class)
+public class UniLibJavaIcuTest {
+
+ @Before
+ public void setUp() throws Exception {
+ System.loadLibrary("unilib-javaicu-test-lib");
+ }
+
+ private native boolean testsMain();
+
+ @Test
+ public void testNative() {
+ assertThat(testsMain()).isTrue();
+ }
+}
diff --git a/utils/variant.cc b/utils/variant.cc
new file mode 100644
index 0000000..30c268e
--- /dev/null
+++ b/utils/variant.cc
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/variant.h"
+
+namespace libtextclassifier3 {
+
+std::map<std::string, Variant> AsVariantMap(
+ const flatbuffers::Vector<flatbuffers::Offset<NamedVariant>>* extra) {
+ std::map<std::string, Variant> result;
+ if (extra != nullptr) {
+ for (const NamedVariant* entry : *extra) {
+ result[entry->name()->str()] = Variant(entry->value());
+ }
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/variant.h b/utils/variant.h
index ddb0d60..c529aa9 100644
--- a/utils/variant.h
+++ b/utils/variant.h
@@ -17,48 +17,123 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
#define LIBTEXTCLASSIFIER_UTILS_VARIANT_H_
+#include <map>
#include <string>
#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/named-extra_generated.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
// Represents a type-tagged union of different basic types.
-struct Variant {
- Variant() : type(TYPE_INVALID) {}
- explicit Variant(int value) : type(TYPE_INT_VALUE), int_value(value) {}
- explicit Variant(int64 value) : type(TYPE_LONG_VALUE), long_value(value) {}
- explicit Variant(float value) : type(TYPE_FLOAT_VALUE), float_value(value) {}
+class Variant {
+ public:
+ Variant() : type_(VariantValue_::Type_NONE) {}
+ explicit Variant(int value)
+ : type_(VariantValue_::Type_INT_VALUE), int_value_(value) {}
+ explicit Variant(int64 value)
+ : type_(VariantValue_::Type_INT64_VALUE), long_value_(value) {}
+ explicit Variant(float value)
+ : type_(VariantValue_::Type_FLOAT_VALUE), float_value_(value) {}
explicit Variant(double value)
- : type(TYPE_DOUBLE_VALUE), double_value(value) {}
+ : type_(VariantValue_::Type_DOUBLE_VALUE), double_value_(value) {}
explicit Variant(StringPiece value)
- : type(TYPE_STRING_VALUE), string_value(value.ToString()) {}
+ : type_(VariantValue_::Type_STRING_VALUE),
+ string_value_(value.ToString()) {}
explicit Variant(std::string value)
- : type(TYPE_STRING_VALUE), string_value(value) {}
+ : type_(VariantValue_::Type_STRING_VALUE), string_value_(value) {}
explicit Variant(const char* value)
- : type(TYPE_STRING_VALUE), string_value(value) {}
- explicit Variant(bool value) : type(TYPE_BOOL_VALUE), bool_value(value) {}
- enum Type {
- TYPE_INVALID = 0,
- TYPE_INT_VALUE = 1,
- TYPE_LONG_VALUE = 2,
- TYPE_FLOAT_VALUE = 3,
- TYPE_DOUBLE_VALUE = 4,
- TYPE_BOOL_VALUE = 5,
- TYPE_STRING_VALUE = 6,
- };
- Type type;
+ : type_(VariantValue_::Type_STRING_VALUE), string_value_(value) {}
+ explicit Variant(bool value)
+ : type_(VariantValue_::Type_BOOL_VALUE), bool_value_(value) {}
+ explicit Variant(const VariantValue* value) : type_(value->type()) {
+ switch (type_) {
+ case VariantValue_::Type_INT_VALUE:
+ int_value_ = value->int_value();
+ break;
+ case VariantValue_::Type_INT64_VALUE:
+ long_value_ = value->int64_value();
+ break;
+ case VariantValue_::Type_FLOAT_VALUE:
+ float_value_ = value->float_value();
+ break;
+ case VariantValue_::Type_DOUBLE_VALUE:
+ double_value_ = value->double_value();
+ break;
+ case VariantValue_::Type_BOOL_VALUE:
+ bool_value_ = value->bool_value();
+ break;
+ case VariantValue_::Type_STRING_VALUE:
+ string_value_ = value->string_value()->str();
+ break;
+ default:
+ TC3_LOG(ERROR) << "Unknown variant type: " << type_;
+ }
+ }
+
+ int IntValue() const {
+ TC3_CHECK(HasInt());
+ return int_value_;
+ }
+
+ int64 Int64Value() const {
+ TC3_CHECK(HasInt64());
+ return long_value_;
+ }
+
+ float FloatValue() const {
+ TC3_CHECK(HasFloat());
+ return float_value_;
+ }
+
+ double DoubleValue() const {
+ TC3_CHECK(HasDouble());
+ return double_value_;
+ }
+
+ bool BoolValue() const {
+ TC3_CHECK(HasBool());
+ return bool_value_;
+ }
+
+ const std::string& StringValue() const {
+ TC3_CHECK(HasString());
+ return string_value_;
+ }
+
+ bool HasInt() const { return type_ == VariantValue_::Type_INT_VALUE; }
+
+ bool HasInt64() const { return type_ == VariantValue_::Type_INT64_VALUE; }
+
+ bool HasFloat() const { return type_ == VariantValue_::Type_FLOAT_VALUE; }
+
+ bool HasDouble() const { return type_ == VariantValue_::Type_DOUBLE_VALUE; }
+
+ bool HasBool() const { return type_ == VariantValue_::Type_BOOL_VALUE; }
+
+ bool HasString() const { return type_ == VariantValue_::Type_STRING_VALUE; }
+
+ VariantValue_::Type GetType() const { return type_; }
+
+ bool HasValue() const { return type_ != VariantValue_::Type_NONE; }
+
+ private:
+ VariantValue_::Type type_;
union {
- int int_value;
- int64 long_value;
- float float_value;
- double double_value;
- bool bool_value;
+ int int_value_;
+ int64 long_value_;
+ float float_value_;
+ double double_value_;
+ bool bool_value_;
};
- std::string string_value;
+ std::string string_value_;
};
+std::map<std::string, Variant> AsVariantMap(
+ const flatbuffers::Vector<flatbuffers::Offset<NamedVariant>>* extra);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_VARIANT_H_