Export lib3 to AOSP (external/libtextclassifier part)
1. Include both annotator (existing one) and actions(new one for smart
reply and actions)
2. One more model file. actions_suggestions.model is dropped to
/etc/textclassifier./ It is around 7.5mb for now, we will slim down
it later.
3. The Java counterpart of the JNI is now moved from frameworks/base
to here.
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Icb2458967ef51efa2952b3eaddefbf1f7b359930
diff --git a/annotator/quantization_test.cc b/annotator/quantization_test.cc
new file mode 100644
index 0000000..b995096
--- /dev/null
+++ b/annotator/quantization_test.cc
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/quantization.h"
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using testing::ElementsAreArray;
+using testing::FloatEq;
+using testing::Matcher;
+
+namespace libtextclassifier3 {
+namespace {
+
+Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
+ std::vector<Matcher<float>> matchers;
+ for (const float value : values) {
+ matchers.push_back(FloatEq(value));
+ }
+ return ElementsAreArray(matchers);
+}
+
+TEST(QuantizationTest, DequantizeAdd8bit) {
+ std::vector<float> scales{{0.1, 9.0, -7.0}};
+ std::vector<uint8> embeddings{{/*0: */ 0x00, 0xFF, 0x09, 0x00,
+ /*1: */ 0xFF, 0x09, 0x00, 0xFF,
+ /*2: */ 0x09, 0x00, 0xFF, 0x09}};
+
+ const int quantization_bits = 8;
+ const int bytes_per_embedding = 4;
+ const int num_sparse_features = 7;
+ {
+ const int bucket_id = 0;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 0.1 * (0x00 - 128),
+ 1.0 / 7 * 0.1 * (0xFF - 128),
+ 1.0 / 7 * 0.1 * (0x09 - 128),
+ 1.0 / 7 * 0.1 * (0x00 - 128)}
+ // clang-format on
+ }));
+ }
+
+ {
+ const int bucket_id = 1;
+ std::vector<float> dest(4, 0.0);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id,
+ dest.data(), dest.size());
+
+ EXPECT_THAT(dest,
+ ElementsAreFloat(std::vector<float>{
+ // clang-format off
+ {1.0 / 7 * 9.0 * (0xFF - 128),
+ 1.0 / 7 * 9.0 * (0x09 - 128),
+ 1.0 / 7 * 9.0 * (0x00 - 128),
+ 1.0 / 7 * 9.0 * (0xFF - 128)}
+ // clang-format on
+ }));
+ }
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitZeros) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets);
+ std::fill(scales.begin(), scales.end(), 1);
+ std::fill(embeddings.begin(), embeddings.end(), 0);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (0 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd1bitOnes) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 1;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0xFF);
+
+ std::vector<float> dest(32);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+ std::vector<float> expected(32);
+ std::fill(expected.begin(), expected.end(),
+ 1.0 / num_sparse_features * (1 - 1));
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+TEST(QuantizationTest, DequantizeAdd3bit) {
+ const int bytes_per_embedding = 4;
+ const int num_buckets = 3;
+ const int num_sparse_features = 7;
+ const int quantization_bits = 3;
+ const int bucket_id = 1;
+
+ std::vector<float> scales(num_buckets, 1.0);
+ scales[1] = 9.0;
+ std::vector<uint8> embeddings(bytes_per_embedding * num_buckets, 0);
+ // For bucket_id=1, the embedding has values 0..9 for indices 0..9:
+ embeddings[4] = (1 << 7) | (1 << 6) | (1 << 4) | 1;
+ embeddings[5] = (1 << 6) | (1 << 4) | (1 << 3);
+ embeddings[6] = (1 << 4) | (1 << 3) | (1 << 2) | (1 << 1) | 1;
+
+ std::vector<float> dest(10);
+ DequantizeAdd(scales.data(), embeddings.data(), bytes_per_embedding,
+ num_sparse_features, quantization_bits, bucket_id, dest.data(),
+ dest.size());
+
+ std::vector<float> expected;
+ expected.push_back(1.0 / num_sparse_features * (1 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (2 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (3 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (4 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (5 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (6 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (7 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ expected.push_back(1.0 / num_sparse_features * (0 - 4) * scales[bucket_id]);
+ EXPECT_THAT(dest, ElementsAreFloat(expected));
+}
+
+} // namespace
+} // namespace libtextclassifier3