Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2018 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #include <fstream> |
| 18 | #include <string> |
| 19 | #include <vector> |
| 20 | |
| 21 | #include "utils/tflite/text_encoder.h" |
| 22 | #include "gtest/gtest.h" |
Tony Mak | a0f598b | 2018-11-20 20:39:04 +0000 | [diff] [blame^] | 23 | #include "third_party/absl/flags/flag.h" |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 24 | #include "flatbuffers/flexbuffers.h" |
| 25 | #include "tensorflow/contrib/lite/interpreter.h" |
| 26 | #include "tensorflow/contrib/lite/kernels/register.h" |
| 27 | #include "tensorflow/contrib/lite/kernels/test_util.h" |
| 28 | #include "tensorflow/contrib/lite/model.h" |
| 29 | #include "tensorflow/contrib/lite/string_util.h" |
| 30 | |
| 31 | namespace libtextclassifier3 { |
| 32 | namespace { |
| 33 | |
| 34 | std::string GetTestConfigPath() { |
| 35 | return ""; |
| 36 | } |
| 37 | |
| 38 | class TextEncoderOpModel : public tflite::SingleOpModel { |
| 39 | public: |
| 40 | TextEncoderOpModel(std::initializer_list<int> input_strings_shape, |
| 41 | std::initializer_list<int> attribute_shape); |
| 42 | void SetInputText(const std::initializer_list<string>& strings) { |
| 43 | PopulateStringTensor(input_string_, strings); |
| 44 | PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())}); |
| 45 | } |
| 46 | void SetMaxOutputLength(int length) { |
| 47 | PopulateTensor(input_output_maxlength_, {length}); |
| 48 | } |
| 49 | void SetInt32Attribute(const std::initializer_list<int>& attribute) { |
| 50 | PopulateTensor(input_attributes_int32_, attribute); |
| 51 | } |
| 52 | void SetFloatAttribute(const std::initializer_list<float>& attribute) { |
| 53 | PopulateTensor(input_attributes_float_, attribute); |
| 54 | } |
| 55 | |
| 56 | std::vector<int> GetOutputEncoding() { |
| 57 | return ExtractVector<int>(output_encoding_); |
| 58 | } |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 59 | std::vector<int> GetOutputPositions() { |
| 60 | return ExtractVector<int>(output_positions_); |
| 61 | } |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 62 | std::vector<int> GetOutputAttributeInt32() { |
| 63 | return ExtractVector<int>(output_attributes_int32_); |
| 64 | } |
| 65 | std::vector<float> GetOutputAttributeFloat() { |
| 66 | return ExtractVector<float>(output_attributes_float_); |
| 67 | } |
| 68 | int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; } |
| 69 | |
| 70 | private: |
| 71 | int input_string_; |
| 72 | int input_length_; |
| 73 | int input_output_maxlength_; |
| 74 | int input_attributes_int32_; |
| 75 | int input_attributes_float_; |
| 76 | |
| 77 | int output_encoding_; |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 78 | int output_positions_; |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 79 | int output_length_; |
| 80 | int output_attributes_int32_; |
| 81 | int output_attributes_float_; |
| 82 | }; |
| 83 | |
| 84 | TextEncoderOpModel::TextEncoderOpModel( |
| 85 | std::initializer_list<int> input_strings_shape, |
| 86 | std::initializer_list<int> attribute_shape) { |
| 87 | input_string_ = AddInput(tflite::TensorType_STRING); |
| 88 | input_length_ = AddInput(tflite::TensorType_INT32); |
| 89 | input_output_maxlength_ = AddInput(tflite::TensorType_INT32); |
| 90 | input_attributes_int32_ = AddInput(tflite::TensorType_INT32); |
| 91 | input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32); |
| 92 | |
| 93 | output_encoding_ = AddOutput(tflite::TensorType_INT32); |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 94 | output_positions_ = AddOutput(tflite::TensorType_INT32); |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 95 | output_length_ = AddOutput(tflite::TensorType_INT32); |
| 96 | output_attributes_int32_ = AddOutput(tflite::TensorType_INT32); |
| 97 | output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32); |
| 98 | |
| 99 | std::ifstream test_config_stream(GetTestConfigPath()); |
| 100 | std::string config((std::istreambuf_iterator<char>(test_config_stream)), |
| 101 | (std::istreambuf_iterator<char>())); |
| 102 | flexbuffers::Builder builder; |
| 103 | builder.Map([&]() { builder.String("text_encoder_config", config); }); |
| 104 | builder.Finish(); |
| 105 | SetCustomOp("TextEncoder", builder.GetBuffer(), |
| 106 | tflite::ops::custom::Register_TEXT_ENCODER); |
| 107 | BuildInterpreter( |
| 108 | {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape}); |
| 109 | } |
| 110 | |
| 111 | // Tests |
| 112 | TEST(TextEncoderTest, SimpleEncoder) { |
| 113 | TextEncoderOpModel m({1, 1}, {1, 1}); |
| 114 | m.SetInputText({"Hello"}); |
| 115 | m.SetMaxOutputLength(10); |
| 116 | m.SetInt32Attribute({7}); |
| 117 | m.SetFloatAttribute({3.f}); |
| 118 | m.Invoke(); |
| 119 | EXPECT_EQ(m.GetEncodedLength(), 5); |
| 120 | EXPECT_THAT(m.GetOutputEncoding(), |
| 121 | testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2)); |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 122 | EXPECT_THAT(m.GetOutputPositions(), |
| 123 | testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10)); |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 124 | EXPECT_THAT(m.GetOutputAttributeInt32(), |
| 125 | testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7)); |
| 126 | EXPECT_THAT( |
| 127 | m.GetOutputAttributeFloat(), |
| 128 | testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f)); |
| 129 | } |
| 130 | |
| 131 | TEST(TextEncoderTest, ManyStrings) { |
| 132 | TextEncoderOpModel m({1, 3}, {1, 3}); |
| 133 | m.SetInt32Attribute({1, 2, 3}); |
| 134 | m.SetFloatAttribute({5.f, 4.f, 3.f}); |
| 135 | m.SetInputText({"Hello", "Hi", "Bye"}); |
| 136 | m.SetMaxOutputLength(10); |
| 137 | m.Invoke(); |
| 138 | EXPECT_EQ(m.GetEncodedLength(), 10); |
| 139 | EXPECT_THAT(m.GetOutputEncoding(), |
| 140 | testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2)); |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 141 | EXPECT_THAT(m.GetOutputPositions(), |
| 142 | testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3)); |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 143 | EXPECT_THAT(m.GetOutputAttributeInt32(), |
| 144 | testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3)); |
| 145 | EXPECT_THAT( |
| 146 | m.GetOutputAttributeFloat(), |
| 147 | testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f)); |
| 148 | } |
| 149 | |
| 150 | TEST(TextEncoderTest, LongStrings) { |
| 151 | TextEncoderOpModel m({1, 4}, {1, 4}); |
| 152 | m.SetInt32Attribute({1, 2, 3, 4}); |
| 153 | m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f}); |
| 154 | m.SetInputText({"Hello", "Hi", "Bye", "Hi"}); |
| 155 | m.SetMaxOutputLength(9); |
| 156 | m.Invoke(); |
| 157 | EXPECT_EQ(m.GetEncodedLength(), 9); |
| 158 | EXPECT_THAT(m.GetOutputEncoding(), |
| 159 | testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2)); |
Tony Mak | 51a9e54 | 2018-11-02 13:36:22 +0000 | [diff] [blame] | 160 | EXPECT_THAT(m.GetOutputPositions(), |
| 161 | testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2)); |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 162 | EXPECT_THAT(m.GetOutputAttributeInt32(), |
| 163 | testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4)); |
| 164 | EXPECT_THAT( |
| 165 | m.GetOutputAttributeFloat(), |
| 166 | testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f)); |
| 167 | } |
| 168 | |
| 169 | } // namespace |
| 170 | } // namespace libtextclassifier3 |