Export lib3 to AOSP (external/libtextclassifier part)
1. Include both annotator (existing one) and actions(new one for smart
reply and actions)
2. One more model file. actions_suggestions.model is dropped to
/etc/textclassifier./ It is around 7.5mb for now, we will slim down
it later.
3. The Java counterpart of the JNI is now moved from frameworks/base
to here.
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Icb2458967ef51efa2952b3eaddefbf1f7b359930
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
new file mode 100644
index 0000000..d1892c7
--- /dev/null
+++ b/utils/tflite/text_encoder_test.cc
@@ -0,0 +1,158 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "utils/tflite/text_encoder.h"
+#include "gtest/gtest.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+std::string GetTestConfigPath() {
+ return "";
+}
+
+class TextEncoderOpModel : public tflite::SingleOpModel {
+ public:
+ TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape);
+ void SetInputText(const std::initializer_list<string>& strings) {
+ PopulateStringTensor(input_string_, strings);
+ PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
+ }
+ void SetMaxOutputLength(int length) {
+ PopulateTensor(input_output_maxlength_, {length});
+ }
+ void SetInt32Attribute(const std::initializer_list<int>& attribute) {
+ PopulateTensor(input_attributes_int32_, attribute);
+ }
+ void SetFloatAttribute(const std::initializer_list<float>& attribute) {
+ PopulateTensor(input_attributes_float_, attribute);
+ }
+
+ std::vector<int> GetOutputEncoding() {
+ return ExtractVector<int>(output_encoding_);
+ }
+ std::vector<int> GetOutputAttributeInt32() {
+ return ExtractVector<int>(output_attributes_int32_);
+ }
+ std::vector<float> GetOutputAttributeFloat() {
+ return ExtractVector<float>(output_attributes_float_);
+ }
+ int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }
+
+ private:
+ int input_string_;
+ int input_length_;
+ int input_output_maxlength_;
+ int input_attributes_int32_;
+ int input_attributes_float_;
+
+ int output_encoding_;
+ int output_length_;
+ int output_attributes_int32_;
+ int output_attributes_float_;
+};
+
+TextEncoderOpModel::TextEncoderOpModel(
+ std::initializer_list<int> input_strings_shape,
+ std::initializer_list<int> attribute_shape) {
+ input_string_ = AddInput(tflite::TensorType_STRING);
+ input_length_ = AddInput(tflite::TensorType_INT32);
+ input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
+ input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
+
+ output_encoding_ = AddOutput(tflite::TensorType_INT32);
+ output_length_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
+ output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
+
+ std::ifstream test_config_stream(GetTestConfigPath());
+ std::string config((std::istreambuf_iterator<char>(test_config_stream)),
+ (std::istreambuf_iterator<char>()));
+ flexbuffers::Builder builder;
+ builder.Map([&]() { builder.String("text_encoder_config", config); });
+ builder.Finish();
+ SetCustomOp("TextEncoder", builder.GetBuffer(),
+ tflite::ops::custom::Register_TEXT_ENCODER);
+ BuildInterpreter(
+ {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
+}
+
+// Tests
+TEST(TextEncoderTest, SimpleEncoder) {
+ TextEncoderOpModel m({1, 1}, {1, 1});
+ m.SetInputText({"Hello"});
+ m.SetMaxOutputLength(10);
+ m.SetInt32Attribute({7});
+ m.SetFloatAttribute({3.f});
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 5);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, ManyStrings) {
+ TextEncoderOpModel m({1, 3}, {1, 3});
+ m.SetInt32Attribute({1, 2, 3});
+ m.SetFloatAttribute({5.f, 4.f, 3.f});
+ m.SetInputText({"Hello", "Hi", "Bye"});
+ m.SetMaxOutputLength(10);
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 10);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
+}
+
+TEST(TextEncoderTest, LongStrings) {
+ TextEncoderOpModel m({1, 4}, {1, 4});
+ m.SetInt32Attribute({1, 2, 3, 4});
+ m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
+ m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
+ m.SetMaxOutputLength(9);
+ m.Invoke();
+ EXPECT_EQ(m.GetEncodedLength(), 9);
+ EXPECT_THAT(m.GetOutputEncoding(),
+ testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+ EXPECT_THAT(m.GetOutputAttributeInt32(),
+ testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
+ EXPECT_THAT(
+ m.GetOutputAttributeFloat(),
+ testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
+}
+
+} // namespace
+} // namespace libtextclassifier3