Export libtextclassifier to Android (generated by the export script)
Test: Compile and boot
Change-Id: I0433e6fb549ba0b32bc55933b3c11562e61a0b4d
diff --git a/utils/hash/farmhash.h b/utils/hash/farmhash.h
index 4c9d2fe..f374c0b 100644
--- a/utils/hash/farmhash.h
+++ b/utils/hash/farmhash.h
@@ -24,7 +24,7 @@
#include <utility>
#ifndef NAMESPACE_FOR_HASH_FUNCTIONS
-#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash
+#define NAMESPACE_FOR_HASH_FUNCTIONS tc3farmhash
#endif
namespace NAMESPACE_FOR_HASH_FUNCTIONS {
diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc
index 8073c5a..330732c 100644
--- a/utils/java/jni-base.cc
+++ b/utils/java/jni-base.cc
@@ -36,21 +36,7 @@
return result;
}
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
- // Get system-level file descriptor from AssetFileDescriptor.
- ScopedLocalRef<jclass> afd_class(
- env->FindClass("android/content/res/AssetFileDescriptor"), env);
- if (afd_class == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
- jmethodID afd_class_getFileDescriptor = env->GetMethodID(
- afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
- if (afd_class_getFileDescriptor == nullptr) {
- TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
- return reinterpret_cast<jlong>(nullptr);
- }
-
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) {
ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
env);
if (fd_class == nullptr) {
@@ -63,9 +49,24 @@
TC3_LOG(ERROR) << "Couldn't find descriptor.";
return reinterpret_cast<jlong>(nullptr);
}
+ return env->GetIntField(fd, fd_class_descriptor);
+}
+jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
+ ScopedLocalRef<jclass> afd_class(
+ env->FindClass("android/content/res/AssetFileDescriptor"), env);
+ if (afd_class == nullptr) {
+ TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
+ jmethodID afd_class_getFileDescriptor = env->GetMethodID(
+ afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
+ if (afd_class_getFileDescriptor == nullptr) {
+ TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
+ return reinterpret_cast<jlong>(nullptr);
+ }
jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
- return env->GetIntField(bundle_jfd, fd_class_descriptor);
+ return GetFdFromFileDescriptor(env, bundle_jfd);
}
} // namespace libtextclassifier3
diff --git a/utils/java/jni-base.h b/utils/java/jni-base.h
index 147ae08..23658a3 100644
--- a/utils/java/jni-base.h
+++ b/utils/java/jni-base.h
@@ -72,7 +72,12 @@
}
std::string ToStlString(JNIEnv* env, const jstring& str);
+
+// Get system-level file descriptor from AssetFileDescriptor.
jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd);
+
+// Get system-level file descriptor from FileDescriptor.
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd);
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/utils/java/string_utils.cc b/utils/java/string_utils.cc
index 6d4a5d7..ef865a8 100644
--- a/utils/java/string_utils.cc
+++ b/utils/java/string_utils.cc
@@ -20,6 +20,21 @@
namespace libtextclassifier3 {
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+ std::string* result) {
+ jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+ if (array_bytes == nullptr) {
+ return false;
+ }
+
+ const int array_length = env->GetArrayLength(array);
+ *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
+
+ env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+
+ return true;
+}
+
bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
std::string* result) {
if (jstr == nullptr) {
@@ -37,16 +52,13 @@
env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
jstring encoding = env->NewStringUTF("UTF-8");
+
jbyteArray array = reinterpret_cast<jbyteArray>(
env->CallObjectMethod(jstr, get_bytes_id, encoding));
- jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
- int length = env->GetArrayLength(array);
-
- *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+ JByteArrayToString(env, array, result);
// Release the array.
- env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
env->DeleteLocalRef(array);
env->DeleteLocalRef(string_class);
env->DeleteLocalRef(encoding);
diff --git a/utils/java/string_utils.h b/utils/java/string_utils.h
index c4fd97a..e4f2bd8 100644
--- a/utils/java/string_utils.h
+++ b/utils/java/string_utils.h
@@ -22,6 +22,8 @@
namespace libtextclassifier3 {
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+ std::string* result);
bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
index a0ad4a4..050c466 100644
--- a/utils/sentencepiece/double_array_trie.h
+++ b/utils/sentencepiece/double_array_trie.h
@@ -20,7 +20,7 @@
#include <functional>
#include <vector>
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -35,17 +35,17 @@
typedef unsigned int TrieNode;
// A memory mappable trie, compatible with Darts::DoubleArray.
-class DoubleArrayTrie {
+class DoubleArrayTrie : public SentencePieceMatcher {
public:
// nodes and nodes_length specify the array of the nodes of the trie.
DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
: nodes_(nodes), nodes_length_(nodes_length) {}
// Find matches that are prefixes of a string.
- std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
// Find the longest prefix match of a string.
- TrieMatch LongestPrefixMatch(StringPiece input) const;
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
private:
// Returns whether a node as a leaf as a child.
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..96fb868
--- /dev/null
+++ b/utils/sentencepiece/encoder.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+std::vector<int> Encoder::Encode(StringPiece normalized_text) const {
+ const int len = normalized_text.size();
+ if (len <= 0) {
+ return {start_code_, end_code_};
+ }
+ // We use `previous_pos` to indicate whether a dynamic programming state was
+ // reachable.
+ std::vector<SegmentationEntry> segmentation(
+ len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+ /*num_pieces=*/0});
+ for (int i = 0; i < len; i++) {
+ // State couldn't be reached.
+ if (i > 0 && segmentation[i].previous_pos < 0) {
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ continue;
+ }
+ for (const auto& match : matcher_->FindAllPrefixMatches(normalized_text)) {
+ TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+ const int pos = i + match.match_length;
+ const float candidate_score = segmentation[i].score + scores_[match.id];
+ if (segmentation[pos].previous_pos < 0 ||
+ segmentation[pos].score < candidate_score) {
+ segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+ /*piece_id=*/match.id,
+ /*num_pieces=*/segmentation[i].num_pieces + 1};
+ }
+ }
+ // Advance position.
+ normalized_text.RemovePrefix(1);
+ }
+ if (segmentation[len].num_pieces <= 0) {
+ return {start_code_, end_code_};
+ }
+ const int num_pieces = segmentation[len].num_pieces;
+ std::vector<int> result(num_pieces + 2);
+ result[num_pieces + 1] = end_code_;
+ int pos = len;
+ for (int i = num_pieces; i > 0; i--) {
+ result[i] = segmentation[pos].piece_id + encoding_offset_;
+ pos = segmentation[pos].previous_pos;
+ }
+ result[0] = start_code_;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
index 4aa7582..fffd86f 100644
--- a/utils/sentencepiece/encoder.h
+++ b/utils/sentencepiece/encoder.h
@@ -17,19 +17,16 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
-#include <memory>
-#include <string>
#include <vector>
#include "utils/base/logging.h"
-#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
// Encoder to segment/tokenize strings into pieces such that the sum of the
// scores of the pieces used is maximized.
-template <typename TMatcher = DoubleArrayTrie>
class Encoder {
public:
// matcher: the list of valid sentence pieces represented as a matcher, e.g.
@@ -40,7 +37,7 @@
// end_code: Code that is used as encoding of the end of input.
// encoding_offset: Value added to the sentence piece ids to make them
// not interesecting with start_code and end_code.
- Encoder(const TMatcher& matcher, const int num_pieces,
+ Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
const float* pieces_scores, int start_code = 0, int end_code = 1,
int encoding_offset = 2)
: num_pieces_(num_pieces),
@@ -49,6 +46,10 @@
start_code_(start_code),
end_code_(end_code),
encoding_offset_(encoding_offset) {}
+
+ // Segment the input so that the total score of the pieces used is maximized.
+ // This is a simplified implementation of the general Viterbi algorithm,
+ // assuming independence between individual pieces.
std::vector<int> Encode(StringPiece normalized_text) const;
private:
@@ -69,62 +70,12 @@
const int num_pieces_;
const float* scores_;
- TMatcher matcher_;
+ const SentencePieceMatcher* matcher_;
const int start_code_;
const int end_code_;
const int encoding_offset_;
};
-// Segment the input such that the total score of the pieces used is maximized.
-// This is a simplified implementation of the general Viterbi algorithm,
-// assuming independence between individual pieces.
-template <typename TMatcher>
-std::vector<int> Encoder<TMatcher>::Encode(StringPiece normalized_text) const {
- const int len = normalized_text.size();
- if (len <= 0) {
- return {start_code_, end_code_};
- }
- // We use `previous_pos` to indicate whether a dynamic programming state was
- // reachable.
- std::vector<SegmentationEntry> segmentation(
- len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
- /*num_pieces=*/0});
- for (int i = 0; i < len; i++) {
- // State couldn't be reached.
- if (i > 0 && segmentation[i].previous_pos < 0) {
- // Advance position.
- normalized_text.RemovePrefix(1);
- continue;
- }
- for (const auto& match : matcher_.FindAllPrefixMatches(normalized_text)) {
- TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
- const int pos = i + match.match_length;
- const float candidate_score = segmentation[i].score + scores_[match.id];
- if (segmentation[pos].previous_pos < 0 ||
- segmentation[pos].score < candidate_score) {
- segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
- /*piece_id=*/match.id,
- /*num_pieces=*/segmentation[i].num_pieces + 1};
- }
- }
- // Advance position.
- normalized_text.RemovePrefix(1);
- }
- if (segmentation[len].num_pieces <= 0) {
- return {start_code_, end_code_};
- }
- const int num_pieces = segmentation[len].num_pieces;
- std::vector<int> result(num_pieces + 2);
- result[num_pieces + 1] = end_code_;
- int pos = len;
- for (int i = num_pieces; i > 0; i--) {
- result[i] = segmentation[pos].piece_id + encoding_offset_;
- pos = segmentation[pos].previous_pos;
- }
- result[0] = start_code_;
- return result;
-}
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
index 5697758..59c12ad 100644
--- a/utils/sentencepiece/encoder_test.cc
+++ b/utils/sentencepiece/encoder_test.cc
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include <memory>
#include <vector>
#include "gmock/gmock.h"
@@ -32,9 +33,10 @@
const char pieces[] = "hell\0hello\0o\0there\0";
const int offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
- const Encoder<SortedStringsTable> encoder(
- SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
- /*num_pieces=*/4, scores);
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
EXPECT_THAT(encoder.Encode("hellothere"), ElementsAreArray({0, 3, 5, 1}));
@@ -48,9 +50,10 @@
const char pieces[] = "hell\0hello\0o\0there\0";
const int offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
- const Encoder<SortedStringsTable> encoder(
- SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
- /*num_pieces=*/4, scores);
+ std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+ /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+ const Encoder encoder(matcher.get(),
+ /*num_pieces=*/4, scores);
EXPECT_THAT(encoder.Encode("hellhello"), ElementsAreArray({0, 2, 3, 1}));
EXPECT_THAT(encoder.Encode("hellohell"), ElementsAreArray({0, 3, 2, 1}));
EXPECT_THAT(encoder.Encode(""), ElementsAreArray({0, 1}));
diff --git a/utils/sentencepiece/match.h b/utils/sentencepiece/matcher.h
similarity index 63%
rename from utils/sentencepiece/match.h
rename to utils/sentencepiece/matcher.h
index c1dc475..b538d69 100644
--- a/utils/sentencepiece/match.h
+++ b/utils/sentencepiece/matcher.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
#include <vector>
#include "utils/strings/stringpiece.h"
@@ -29,6 +29,18 @@
int match_length = -1;
};
+class SentencePieceMatcher {
+ public:
+ virtual ~SentencePieceMatcher() {}
+
+ // Find matches that are prefixes of a string.
+ virtual std::vector<TrieMatch> FindAllPrefixMatches(
+ StringPiece input) const = 0;
+
+ // Find the longest prefix match of a string.
+ virtual TrieMatch LongestPrefixMatch(StringPiece input) const = 0;
+};
+
} // namespace libtextclassifier3
-#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
index da4209d..82cda5c 100644
--- a/utils/sentencepiece/sorted_strings_table.h
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -20,7 +20,7 @@
#include <functional>
#include <vector>
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -34,7 +34,7 @@
// pieces: String pieces, concatenated in sorted order and zero byte separated.
// use_linear_scan_threshold: Minimum size of binary search range before
// switching to a linear sweep for prefix match testing.
-class SortedStringsTable {
+class SortedStringsTable : public SentencePieceMatcher {
public:
SortedStringsTable(const int num_pieces, const int* offsets,
StringPiece pieces,
@@ -45,10 +45,10 @@
use_linear_scan_threshold_(use_linear_scan_threshold) {}
// Find matches that are prefixes of a string.
- std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+ std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
// Find the longest prefix match of a string.
- TrieMatch LongestPrefixMatch(StringPiece input) const;
+ TrieMatch LongestPrefixMatch(StringPiece input) const override;
private:
void GatherPrefixMatches(
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
index 727991f..9554283 100644
--- a/utils/tflite/text_encoder.cc
+++ b/utils/tflite/text_encoder.cc
@@ -21,6 +21,7 @@
#include "utils/sentencepiece/double_array_trie.h"
#include "utils/sentencepiece/encoder.h"
#include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
#include "utils/strings/stringpiece.h"
#include "utils/tflite/text_encoder.h"
#include "utils/tflite/text_encoder_config_generated.h"
@@ -35,7 +36,8 @@
struct TextEncoderOp {
std::unique_ptr<Normalizer> normalizer;
- std::unique_ptr<Encoder<DoubleArrayTrie>> encoder;
+ std::unique_ptr<Encoder> encoder;
+ std::unique_ptr<SentencePieceMatcher> matcher;
};
// Input parameters for the op.
@@ -49,12 +51,17 @@
// Output parameters for the op.
enum SmartReplyModelOutputs {
TEXT_ENCODER_OUTPUT_ENCODED = 0,
- TEXT_ENCODER_OUTPUT_LENGTHS = 1,
- TEXT_ENCODER_OUTPUT_ATTR = 2,
+ TEXT_ENCODER_OUTPUT_POSITION = 1,
+ TEXT_ENCODER_OUTPUT_LENGTHS = 2,
+ TEXT_ENCODER_OUTPUT_ATTR = 3,
};
const char kTextEncoderConfigAttr[] = "text_encoder_config";
+// Input rank is 2 since there is a dummy batch dimension of 1.
+const int kInputRank = 2;
+const int kBatchSize = 1;
+
// Initializes text encoder object from serialized options:
// The options are a flexbuffers attribute map that contain the op config
// with the key `text_encoder_config` as `TextEncoderConfig`.
@@ -81,15 +88,32 @@
config->add_dummy_prefix(), config->remove_extra_whitespaces(),
config->escape_whitespaces()));
- const TrieNode* pieces_trie_nodes =
- reinterpret_cast<const TrieNode*>(config->pieces()->Data());
- const int pieces_trie_nodes_length =
- config->pieces()->Length() / sizeof(TrieNode);
const int num_pieces = config->pieces_scores()->Length();
- encoder_op->encoder.reset(new Encoder<DoubleArrayTrie>(
- DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length), num_pieces,
- config->pieces_scores()->data(), config->start_code(), config->end_code(),
- config->encoding_offset()));
+
+ switch (config->matcher_type()) {
+ case SentencePieceMatcherType_MAPPED_TRIE: {
+ const TrieNode* pieces_trie_nodes =
+ reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+ const int pieces_trie_nodes_length =
+ config->pieces()->Length() / sizeof(TrieNode);
+ encoder_op->matcher.reset(
+ new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+ break;
+ }
+ case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+ encoder_op->matcher.reset(new SortedStringsTable(
+ num_pieces, config->pieces_offsets()->data(),
+ StringPiece(config->pieces()->data(), config->pieces()->Length())));
+ break;
+ }
+ default: {
+ TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+ return nullptr;
+ }
+ }
+ encoder_op->encoder.reset(new Encoder(
+ encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+ config->start_code(), config->end_code(), config->encoding_offset()));
return encoder_op.release();
}
@@ -112,8 +136,8 @@
const std::vector<int>& encoding_end_offsets,
int start_offset, TfLiteContext* context,
TfLiteTensor* out) {
- TF_LITE_ENSURE_EQ(context, in.dims->size, 2);
- TF_LITE_ENSURE_EQ(context, in.dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, in.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, in.dims->data[0], kBatchSize);
const int output_size = out->dims->data[1];
int output_offset = 0;
for (int value_index = 0;
@@ -162,7 +186,8 @@
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
std::fill(out->data.f + output_offset, out->data.f + output_size, value);
} break;
- default: {}
+ default:
+ break;
}
return kTfLiteOk;
}
@@ -173,16 +198,26 @@
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, &output_encoded,
- CreateSizeArray({1, max_output_length})));
+ context,
+ context->ResizeTensor(context, &output_encoded,
+ CreateSizeArray({kBatchSize, max_output_length})));
+
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+ TF_LITE_ENSURE_OK(
+ context,
+ context->ResizeTensor(context, &output_positions,
+ CreateSizeArray({kBatchSize, max_output_length})));
const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
for (int i = 0; i < num_output_attrs; ++i) {
TfLiteTensor& output =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(
- context, &output,
- CreateSizeArray({1, max_output_length})));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(
+ context, &output,
+ CreateSizeArray({kBatchSize, max_output_length})));
}
return kTfLiteOk;
}
@@ -190,19 +225,22 @@
} // namespace
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- // Check that the batch dimension is 1.
+ // Check that the batch dimension is kBatchSize.
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_TEXTS]];
- TF_LITE_ENSURE_EQ(context, input_text.dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kBatchSize);
TfLiteTensor& output_lengths =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_LENGTHS]];
TfLiteTensor& output_encoded =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_lengths,
- CreateSizeArray({1})));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateSizeArray({kBatchSize})));
// Check that there are enough outputs for attributes.
const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
@@ -225,6 +263,7 @@
return ResizeOutputTensors(context, node, output_length.data.i64[0]);
} else {
tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
for (int i = 0; i < num_output_attrs; ++i) {
TfLiteTensor& output_attr =
context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
@@ -253,10 +292,15 @@
TF_LITE_ENSURE_OK(
context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
}
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
std::vector<int> encoded_total;
std::vector<int> encoded_offsets;
+ std::vector<int> encoded_positions;
encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
for (int i = 0; i < num_strings; ++i) {
const auto& strref = tflite::GetString(&input_text, i);
@@ -264,16 +308,20 @@
encoder_op->normalizer->Normalize(StringPiece(strref.str, strref.len)));
encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); i++) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
}
- const int max_output_length = output_encoded.dims->data[1];
// Copy encoding to output tensor.
const int start_offset =
std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
int output_offset = 0;
int32_t* output_buffer = output_encoded.data.i32;
+ int32_t* output_positions_buffer = output_positions.data.i32;
for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
output_buffer[output_offset] = encoded_total[i];
+ output_positions_buffer[output_offset] = encoded_positions[i];
}
// Save output encoded length.
@@ -284,6 +332,7 @@
// Do padding.
for (; output_offset < max_output_length; ++output_offset) {
output_buffer[output_offset] = encoded_total.back();
+ output_positions_buffer[output_offset] = max_encoded_position;
}
// Process attributes, all checks of sizes and types are done in Prepare.
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 9017116..462da21 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -17,6 +17,12 @@
// Configuration for the text encoder op.
namespace libtextclassifier3;
+
+enum SentencePieceMatcherType : byte {
+ MAPPED_TRIE = 0,
+ SORTED_STRING_TABLE = 1,
+}
+
table TextEncoderConfig {
// Code that is used as encoding of the start code.
start_code:int32 = 0;
@@ -46,6 +52,8 @@
// Sentence pieces scores.
pieces_scores:[float];
- // Serialized sentence pieces trie.
+ // Serialized sentence pieces.
pieces:string;
+ pieces_offsets:[int32];
+ matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
}
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
index d1892c7..0b6ff71 100644
--- a/utils/tflite/text_encoder_test.cc
+++ b/utils/tflite/text_encoder_test.cc
@@ -55,6 +55,9 @@
std::vector<int> GetOutputEncoding() {
return ExtractVector<int>(output_encoding_);
}
+ std::vector<int> GetOutputPositions() {
+ return ExtractVector<int>(output_positions_);
+ }
std::vector<int> GetOutputAttributeInt32() {
return ExtractVector<int>(output_attributes_int32_);
}
@@ -71,6 +74,7 @@
int input_attributes_float_;
int output_encoding_;
+ int output_positions_;
int output_length_;
int output_attributes_int32_;
int output_attributes_float_;
@@ -86,6 +90,7 @@
input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
output_encoding_ = AddOutput(tflite::TensorType_INT32);
+ output_positions_ = AddOutput(tflite::TensorType_INT32);
output_length_ = AddOutput(tflite::TensorType_INT32);
output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
@@ -113,6 +118,8 @@
EXPECT_EQ(m.GetEncodedLength(), 5);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
EXPECT_THAT(
@@ -130,6 +137,8 @@
EXPECT_EQ(m.GetEncodedLength(), 10);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
EXPECT_THAT(
@@ -147,6 +156,8 @@
EXPECT_EQ(m.GetEncodedLength(), 9);
EXPECT_THAT(m.GetOutputEncoding(),
testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+ EXPECT_THAT(m.GetOutputPositions(),
+ testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
EXPECT_THAT(m.GetOutputAttributeInt32(),
testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
EXPECT_THAT(
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
index 057703a..81492d8 100644
--- a/utils/utf8/unicodetext.cc
+++ b/utils/utf8/unicodetext.cc
@@ -176,7 +176,7 @@
} // namespace
-UnicodeText& UnicodeText::AppendCodepoint(char32 ch) {
+UnicodeText& UnicodeText::push_back(char32 ch) {
char str[4];
int char_len = runetochar(ch, str);
repr_.append(str, char_len);
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
index f7790f5..eb206b8 100644
--- a/utils/utf8/unicodetext.h
+++ b/utils/utf8/unicodetext.h
@@ -168,7 +168,7 @@
// Calling this may invalidate pointers to underlying data.
UnicodeText& AppendUTF8(const char* utf8, int len);
- UnicodeText& AppendCodepoint(char32 ch);
+ UnicodeText& push_back(char32 ch);
void clear();
std::string ToUTF8String() const;
diff --git a/utils/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc
index 9cdc850..7ebb415 100644
--- a/utils/utf8/unicodetext_test.cc
+++ b/utils/utf8/unicodetext_test.cc
@@ -24,11 +24,11 @@
class UnicodeTextTest : public testing::Test {
protected:
UnicodeTextTest() : empty_text_() {
- text_.AppendCodepoint(0x1C0);
- text_.AppendCodepoint(0x4E8C);
- text_.AppendCodepoint(0xD7DB);
- text_.AppendCodepoint(0x34);
- text_.AppendCodepoint(0x1D11E);
+ text_.push_back(0x1C0);
+ text_.push_back(0x4E8C);
+ text_.push_back(0xD7DB);
+ text_.push_back(0x34);
+ text_.push_back(0x1D11E);
}
UnicodeText empty_text_;