Export libtextclassifier to Android (generated by the export script)

Test: Compile and boot

Change-Id: I0433e6fb549ba0b32bc55933b3c11562e61a0b4d
diff --git a/utils/hash/farmhash.h b/utils/hash/farmhash.h
index 4c9d2fe..f374c0b 100644
--- a/utils/hash/farmhash.h
+++ b/utils/hash/farmhash.h
@@ -24,7 +24,7 @@
 #include <utility>
 
 #ifndef NAMESPACE_FOR_HASH_FUNCTIONS
-#define NAMESPACE_FOR_HASH_FUNCTIONS tc2farmhash
+#define NAMESPACE_FOR_HASH_FUNCTIONS tc3farmhash
 #endif
 
 namespace NAMESPACE_FOR_HASH_FUNCTIONS {
diff --git a/utils/java/jni-base.cc b/utils/java/jni-base.cc
index 8073c5a..330732c 100644
--- a/utils/java/jni-base.cc
+++ b/utils/java/jni-base.cc
@@ -36,21 +36,7 @@
   return result;
 }
 
-jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
-  // Get system-level file descriptor from AssetFileDescriptor.
-  ScopedLocalRef<jclass> afd_class(
-      env->FindClass("android/content/res/AssetFileDescriptor"), env);
-  if (afd_class == nullptr) {
-    TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
-    return reinterpret_cast<jlong>(nullptr);
-  }
-  jmethodID afd_class_getFileDescriptor = env->GetMethodID(
-      afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
-  if (afd_class_getFileDescriptor == nullptr) {
-    TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
-    return reinterpret_cast<jlong>(nullptr);
-  }
-
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd) {
   ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
                                   env);
   if (fd_class == nullptr) {
@@ -63,9 +49,24 @@
     TC3_LOG(ERROR) << "Couldn't find descriptor.";
     return reinterpret_cast<jlong>(nullptr);
   }
+  return env->GetIntField(fd, fd_class_descriptor);
+}
 
+jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
+  ScopedLocalRef<jclass> afd_class(
+      env->FindClass("android/content/res/AssetFileDescriptor"), env);
+  if (afd_class == nullptr) {
+    TC3_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
+    return reinterpret_cast<jlong>(nullptr);
+  }
+  jmethodID afd_class_getFileDescriptor = env->GetMethodID(
+      afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
+  if (afd_class_getFileDescriptor == nullptr) {
+    TC3_LOG(ERROR) << "Couldn't find getFileDescriptor.";
+    return reinterpret_cast<jlong>(nullptr);
+  }
   jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
-  return env->GetIntField(bundle_jfd, fd_class_descriptor);
+  return GetFdFromFileDescriptor(env, bundle_jfd);
 }
 
 }  // namespace libtextclassifier3
diff --git a/utils/java/jni-base.h b/utils/java/jni-base.h
index 147ae08..23658a3 100644
--- a/utils/java/jni-base.h
+++ b/utils/java/jni-base.h
@@ -72,7 +72,12 @@
 }
 
 std::string ToStlString(JNIEnv* env, const jstring& str);
+
+// Get system-level file descriptor from AssetFileDescriptor.
 jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd);
+
+// Get system-level file descriptor from FileDescriptor.
+jint GetFdFromFileDescriptor(JNIEnv* env, jobject fd);
 }  // namespace libtextclassifier3
 
 #endif  // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
diff --git a/utils/java/string_utils.cc b/utils/java/string_utils.cc
index 6d4a5d7..ef865a8 100644
--- a/utils/java/string_utils.cc
+++ b/utils/java/string_utils.cc
@@ -20,6 +20,21 @@
 
 namespace libtextclassifier3 {
 
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+                        std::string* result) {
+  jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
+  if (array_bytes == nullptr) {
+    return false;
+  }
+
+  const int array_length = env->GetArrayLength(array);
+  *result = std::string(reinterpret_cast<char*>(array_bytes), array_length);
+
+  env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
+
+  return true;
+}
+
 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
                          std::string* result) {
   if (jstr == nullptr) {
@@ -37,16 +52,13 @@
       env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
 
   jstring encoding = env->NewStringUTF("UTF-8");
+
   jbyteArray array = reinterpret_cast<jbyteArray>(
       env->CallObjectMethod(jstr, get_bytes_id, encoding));
 
-  jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
-  int length = env->GetArrayLength(array);
-
-  *result = std::string(reinterpret_cast<char*>(array_bytes), length);
+  JByteArrayToString(env, array, result);
 
   // Release the array.
-  env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
   env->DeleteLocalRef(array);
   env->DeleteLocalRef(string_class);
   env->DeleteLocalRef(encoding);
diff --git a/utils/java/string_utils.h b/utils/java/string_utils.h
index c4fd97a..e4f2bd8 100644
--- a/utils/java/string_utils.h
+++ b/utils/java/string_utils.h
@@ -22,6 +22,8 @@
 
 namespace libtextclassifier3 {
 
+bool JByteArrayToString(JNIEnv* env, const jbyteArray& array,
+                        std::string* result);
 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, std::string* result);
 
 }  // namespace libtextclassifier3
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
index a0ad4a4..050c466 100644
--- a/utils/sentencepiece/double_array_trie.h
+++ b/utils/sentencepiece/double_array_trie.h
@@ -20,7 +20,7 @@
 #include <functional>
 #include <vector>
 
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
 #include "utils/strings/stringpiece.h"
 
 namespace libtextclassifier3 {
@@ -35,17 +35,17 @@
 typedef unsigned int TrieNode;
 
 // A memory mappable trie, compatible with Darts::DoubleArray.
-class DoubleArrayTrie {
+class DoubleArrayTrie : public SentencePieceMatcher {
  public:
   // nodes and nodes_length specify the array of the nodes of the trie.
   DoubleArrayTrie(const TrieNode* nodes, const int nodes_length)
       : nodes_(nodes), nodes_length_(nodes_length) {}
 
   // Find matches that are prefixes of a string.
-  std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+  std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
 
   // Find the longest prefix match of a string.
-  TrieMatch LongestPrefixMatch(StringPiece input) const;
+  TrieMatch LongestPrefixMatch(StringPiece input) const override;
 
  private:
   // Returns whether a node as a leaf as a child.
diff --git a/utils/sentencepiece/encoder.cc b/utils/sentencepiece/encoder.cc
new file mode 100644
index 0000000..96fb868
--- /dev/null
+++ b/utils/sentencepiece/encoder.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/sentencepiece/encoder.h"
+
+namespace libtextclassifier3 {
+
+std::vector<int> Encoder::Encode(StringPiece normalized_text) const {
+  const int len = normalized_text.size();
+  if (len <= 0) {
+    return {start_code_, end_code_};
+  }
+  // We use `previous_pos` to indicate whether a dynamic programming state was
+  // reachable.
+  std::vector<SegmentationEntry> segmentation(
+      len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
+                /*num_pieces=*/0});
+  for (int i = 0; i < len; i++) {
+    // State couldn't be reached.
+    if (i > 0 && segmentation[i].previous_pos < 0) {
+      // Advance position.
+      normalized_text.RemovePrefix(1);
+      continue;
+    }
+    for (const auto& match : matcher_->FindAllPrefixMatches(normalized_text)) {
+      TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
+      const int pos = i + match.match_length;
+      const float candidate_score = segmentation[i].score + scores_[match.id];
+      if (segmentation[pos].previous_pos < 0 ||
+          segmentation[pos].score < candidate_score) {
+        segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
+                             /*piece_id=*/match.id,
+                             /*num_pieces=*/segmentation[i].num_pieces + 1};
+      }
+    }
+    // Advance position.
+    normalized_text.RemovePrefix(1);
+  }
+  if (segmentation[len].num_pieces <= 0) {
+    return {start_code_, end_code_};
+  }
+  const int num_pieces = segmentation[len].num_pieces;
+  std::vector<int> result(num_pieces + 2);
+  result[num_pieces + 1] = end_code_;
+  int pos = len;
+  for (int i = num_pieces; i > 0; i--) {
+    result[i] = segmentation[pos].piece_id + encoding_offset_;
+    pos = segmentation[pos].previous_pos;
+  }
+  result[0] = start_code_;
+  return result;
+}
+
+}  // namespace libtextclassifier3
diff --git a/utils/sentencepiece/encoder.h b/utils/sentencepiece/encoder.h
index 4aa7582..fffd86f 100644
--- a/utils/sentencepiece/encoder.h
+++ b/utils/sentencepiece/encoder.h
@@ -17,19 +17,16 @@
 #ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
 #define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
 
-#include <memory>
-#include <string>
 #include <vector>
 
 #include "utils/base/logging.h"
-#include "utils/sentencepiece/double_array_trie.h"
+#include "utils/sentencepiece/matcher.h"
 #include "utils/strings/stringpiece.h"
 
 namespace libtextclassifier3 {
 
 // Encoder to segment/tokenize strings into pieces such that the sum of the
 // scores of the pieces used is maximized.
-template <typename TMatcher = DoubleArrayTrie>
 class Encoder {
  public:
   // matcher: the list of valid sentence pieces represented as a matcher, e.g.
@@ -40,7 +37,7 @@
   // end_code: Code that is used as encoding of the end of input.
   // encoding_offset: Value added to the sentence piece ids to make them
   //     not interesecting with start_code and end_code.
-  Encoder(const TMatcher& matcher, const int num_pieces,
+  Encoder(const SentencePieceMatcher* matcher, const int num_pieces,
           const float* pieces_scores, int start_code = 0, int end_code = 1,
           int encoding_offset = 2)
       : num_pieces_(num_pieces),
@@ -49,6 +46,10 @@
         start_code_(start_code),
         end_code_(end_code),
         encoding_offset_(encoding_offset) {}
+
+  // Segment the input so that the total score of the pieces used is maximized.
+  // This is a simplified implementation of the general Viterbi algorithm,
+  // assuming independence between individual pieces.
   std::vector<int> Encode(StringPiece normalized_text) const;
 
  private:
@@ -69,62 +70,12 @@
 
   const int num_pieces_;
   const float* scores_;
-  TMatcher matcher_;
+  const SentencePieceMatcher* matcher_;
   const int start_code_;
   const int end_code_;
   const int encoding_offset_;
 };
 
-// Segment the input such that the total score of the pieces used is maximized.
-// This is a simplified implementation of the general Viterbi algorithm,
-// assuming independence between individual pieces.
-template <typename TMatcher>
-std::vector<int> Encoder<TMatcher>::Encode(StringPiece normalized_text) const {
-  const int len = normalized_text.size();
-  if (len <= 0) {
-    return {start_code_, end_code_};
-  }
-  // We use `previous_pos` to indicate whether a dynamic programming state was
-  // reachable.
-  std::vector<SegmentationEntry> segmentation(
-      len + 1, {/*score=*/0, /*previous_pos=*/-1, /*piece_id=*/-1,
-                /*num_pieces=*/0});
-  for (int i = 0; i < len; i++) {
-    // State couldn't be reached.
-    if (i > 0 && segmentation[i].previous_pos < 0) {
-      // Advance position.
-      normalized_text.RemovePrefix(1);
-      continue;
-    }
-    for (const auto& match : matcher_.FindAllPrefixMatches(normalized_text)) {
-      TC3_CHECK(match.id >= 0 && match.id < num_pieces_);
-      const int pos = i + match.match_length;
-      const float candidate_score = segmentation[i].score + scores_[match.id];
-      if (segmentation[pos].previous_pos < 0 ||
-          segmentation[pos].score < candidate_score) {
-        segmentation[pos] = {/*score=*/candidate_score, /*previous_pos=*/i,
-                             /*piece_id=*/match.id,
-                             /*num_pieces=*/segmentation[i].num_pieces + 1};
-      }
-    }
-    // Advance position.
-    normalized_text.RemovePrefix(1);
-  }
-  if (segmentation[len].num_pieces <= 0) {
-    return {start_code_, end_code_};
-  }
-  const int num_pieces = segmentation[len].num_pieces;
-  std::vector<int> result(num_pieces + 2);
-  result[num_pieces + 1] = end_code_;
-  int pos = len;
-  for (int i = num_pieces; i > 0; i--) {
-    result[i] = segmentation[pos].piece_id + encoding_offset_;
-    pos = segmentation[pos].previous_pos;
-  }
-  result[0] = start_code_;
-  return result;
-}
-
 }  // namespace libtextclassifier3
 
 #endif  // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_ENCODER_H_
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
index 5697758..59c12ad 100644
--- a/utils/sentencepiece/encoder_test.cc
+++ b/utils/sentencepiece/encoder_test.cc
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <memory>
 #include <vector>
 
 #include "gmock/gmock.h"
@@ -32,9 +33,10 @@
   const char pieces[] = "hell\0hello\0o\0there\0";
   const int offsets[] = {0, 5, 11, 13};
   float scores[] = {-0.5, -1.0, -10.0, -1.0};
-  const Encoder<SortedStringsTable> encoder(
-      SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
-      /*num_pieces=*/4, scores);
+  std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+      /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+  const Encoder encoder(matcher.get(),
+                        /*num_pieces=*/4, scores);
 
   EXPECT_THAT(encoder.Encode("hellothere"), ElementsAreArray({0, 3, 5, 1}));
 
@@ -48,9 +50,10 @@
   const char pieces[] = "hell\0hello\0o\0there\0";
   const int offsets[] = {0, 5, 11, 13};
   float scores[] = {-0.5, -1.0, -10.0, -1.0};
-  const Encoder<SortedStringsTable> encoder(
-      SortedStringsTable(/*num_pieces=*/4, offsets, StringPiece(pieces, 18)),
-      /*num_pieces=*/4, scores);
+  std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
+      /*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
+  const Encoder encoder(matcher.get(),
+                        /*num_pieces=*/4, scores);
   EXPECT_THAT(encoder.Encode("hellhello"), ElementsAreArray({0, 2, 3, 1}));
   EXPECT_THAT(encoder.Encode("hellohell"), ElementsAreArray({0, 3, 2, 1}));
   EXPECT_THAT(encoder.Encode(""), ElementsAreArray({0, 1}));
diff --git a/utils/sentencepiece/match.h b/utils/sentencepiece/matcher.h
similarity index 63%
rename from utils/sentencepiece/match.h
rename to utils/sentencepiece/matcher.h
index c1dc475..b538d69 100644
--- a/utils/sentencepiece/match.h
+++ b/utils/sentencepiece/matcher.h
@@ -14,8 +14,8 @@
  * limitations under the License.
  */
 
-#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
-#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
+#define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
 
 #include <vector>
 #include "utils/strings/stringpiece.h"
@@ -29,6 +29,18 @@
   int match_length = -1;
 };
 
+class SentencePieceMatcher {
+ public:
+  virtual ~SentencePieceMatcher() {}
+
+  // Find matches that are prefixes of a string.
+  virtual std::vector<TrieMatch> FindAllPrefixMatches(
+      StringPiece input) const = 0;
+
+  // Find the longest prefix match of a string.
+  virtual TrieMatch LongestPrefixMatch(StringPiece input) const = 0;
+};
+
 }  // namespace libtextclassifier3
 
-#endif  // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCH_H_
+#endif  // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_MATCHER_H_
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
index da4209d..82cda5c 100644
--- a/utils/sentencepiece/sorted_strings_table.h
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -20,7 +20,7 @@
 #include <functional>
 #include <vector>
 
-#include "utils/sentencepiece/match.h"
+#include "utils/sentencepiece/matcher.h"
 #include "utils/strings/stringpiece.h"
 
 namespace libtextclassifier3 {
@@ -34,7 +34,7 @@
 // pieces: String pieces, concatenated in sorted order and zero byte separated.
 // use_linear_scan_threshold: Minimum size of binary search range before
 //     switching to a linear sweep for prefix match testing.
-class SortedStringsTable {
+class SortedStringsTable : public SentencePieceMatcher {
  public:
   SortedStringsTable(const int num_pieces, const int* offsets,
                      StringPiece pieces,
@@ -45,10 +45,10 @@
         use_linear_scan_threshold_(use_linear_scan_threshold) {}
 
   // Find matches that are prefixes of a string.
-  std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const;
+  std::vector<TrieMatch> FindAllPrefixMatches(StringPiece input) const override;
 
   // Find the longest prefix match of a string.
-  TrieMatch LongestPrefixMatch(StringPiece input) const;
+  TrieMatch LongestPrefixMatch(StringPiece input) const override;
 
  private:
   void GatherPrefixMatches(
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
index 727991f..9554283 100644
--- a/utils/tflite/text_encoder.cc
+++ b/utils/tflite/text_encoder.cc
@@ -21,6 +21,7 @@
 #include "utils/sentencepiece/double_array_trie.h"
 #include "utils/sentencepiece/encoder.h"
 #include "utils/sentencepiece/normalizer.h"
+#include "utils/sentencepiece/sorted_strings_table.h"
 #include "utils/strings/stringpiece.h"
 #include "utils/tflite/text_encoder.h"
 #include "utils/tflite/text_encoder_config_generated.h"
@@ -35,7 +36,8 @@
 
 struct TextEncoderOp {
   std::unique_ptr<Normalizer> normalizer;
-  std::unique_ptr<Encoder<DoubleArrayTrie>> encoder;
+  std::unique_ptr<Encoder> encoder;
+  std::unique_ptr<SentencePieceMatcher> matcher;
 };
 
 // Input parameters for the op.
@@ -49,12 +51,17 @@
 // Output parameters for the op.
 enum SmartReplyModelOutputs {
   TEXT_ENCODER_OUTPUT_ENCODED = 0,
-  TEXT_ENCODER_OUTPUT_LENGTHS = 1,
-  TEXT_ENCODER_OUTPUT_ATTR = 2,
+  TEXT_ENCODER_OUTPUT_POSITION = 1,
+  TEXT_ENCODER_OUTPUT_LENGTHS = 2,
+  TEXT_ENCODER_OUTPUT_ATTR = 3,
 };
 
 const char kTextEncoderConfigAttr[] = "text_encoder_config";
 
+// Input rank is 2 since there is a dummy batch dimension of 1.
+const int kInputRank = 2;
+const int kBatchSize = 1;
+
 // Initializes text encoder object from serialized options:
 //   The options are a flexbuffers attribute map that contain the op config
 //   with the key `text_encoder_config` as `TextEncoderConfig`.
@@ -81,15 +88,32 @@
       config->add_dummy_prefix(), config->remove_extra_whitespaces(),
       config->escape_whitespaces()));
 
-  const TrieNode* pieces_trie_nodes =
-      reinterpret_cast<const TrieNode*>(config->pieces()->Data());
-  const int pieces_trie_nodes_length =
-      config->pieces()->Length() / sizeof(TrieNode);
   const int num_pieces = config->pieces_scores()->Length();
-  encoder_op->encoder.reset(new Encoder<DoubleArrayTrie>(
-      DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length), num_pieces,
-      config->pieces_scores()->data(), config->start_code(), config->end_code(),
-      config->encoding_offset()));
+
+  switch (config->matcher_type()) {
+    case SentencePieceMatcherType_MAPPED_TRIE: {
+      const TrieNode* pieces_trie_nodes =
+          reinterpret_cast<const TrieNode*>(config->pieces()->Data());
+      const int pieces_trie_nodes_length =
+          config->pieces()->Length() / sizeof(TrieNode);
+      encoder_op->matcher.reset(
+          new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
+      break;
+    }
+    case SentencePieceMatcherType_SORTED_STRING_TABLE: {
+      encoder_op->matcher.reset(new SortedStringsTable(
+          num_pieces, config->pieces_offsets()->data(),
+          StringPiece(config->pieces()->data(), config->pieces()->Length())));
+      break;
+    }
+    default: {
+      TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
+      return nullptr;
+    }
+  }
+  encoder_op->encoder.reset(new Encoder(
+      encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
+      config->start_code(), config->end_code(), config->encoding_offset()));
   return encoder_op.release();
 }
 
@@ -112,8 +136,8 @@
                            const std::vector<int>& encoding_end_offsets,
                            int start_offset, TfLiteContext* context,
                            TfLiteTensor* out) {
-  TF_LITE_ENSURE_EQ(context, in.dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, in.dims->data[0], 1);
+  TF_LITE_ENSURE_EQ(context, in.dims->size, kInputRank);
+  TF_LITE_ENSURE_EQ(context, in.dims->data[0], kBatchSize);
   const int output_size = out->dims->data[1];
   int output_offset = 0;
   for (int value_index = 0;
@@ -162,7 +186,8 @@
           (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
       std::fill(out->data.f + output_offset, out->data.f + output_size, value);
     } break;
-    default: {}
+    default:
+      break;
   }
   return kTfLiteOk;
 }
@@ -173,16 +198,26 @@
       context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
 
   TF_LITE_ENSURE_OK(
-      context, context->ResizeTensor(context, &output_encoded,
-                                     CreateSizeArray({1, max_output_length})));
+      context,
+      context->ResizeTensor(context, &output_encoded,
+                            CreateSizeArray({kBatchSize, max_output_length})));
+
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
+
+  TF_LITE_ENSURE_OK(
+      context,
+      context->ResizeTensor(context, &output_positions,
+                            CreateSizeArray({kBatchSize, max_output_length})));
 
   const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
   for (int i = 0; i < num_output_attrs; ++i) {
     TfLiteTensor& output =
         context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
-    TF_LITE_ENSURE_OK(context, context->ResizeTensor(
-                                   context, &output,
-                                   CreateSizeArray({1, max_output_length})));
+    TF_LITE_ENSURE_OK(context,
+                      context->ResizeTensor(
+                          context, &output,
+                          CreateSizeArray({kBatchSize, max_output_length})));
   }
   return kTfLiteOk;
 }
@@ -190,19 +225,22 @@
 }  // namespace
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  // Check that the batch dimension is 1.
+  // Check that the batch dimension is kBatchSize.
   const TfLiteTensor& input_text =
       context->tensors[node->inputs->data[TEXT_ENCODER_INPUT_TEXTS]];
-  TF_LITE_ENSURE_EQ(context, input_text.dims->size, 2);
-  TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], 1);
+  TF_LITE_ENSURE_EQ(context, input_text.dims->size, kInputRank);
+  TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kBatchSize);
 
   TfLiteTensor& output_lengths =
       context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_LENGTHS]];
   TfLiteTensor& output_encoded =
       context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ENCODED]];
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
 
-  TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_lengths,
-                                                   CreateSizeArray({1})));
+  TF_LITE_ENSURE_OK(context,
+                    context->ResizeTensor(context, &output_lengths,
+                                          CreateSizeArray({kBatchSize})));
 
   // Check that there are enough outputs for attributes.
   const int num_output_attrs = node->outputs->size - TEXT_ENCODER_OUTPUT_ATTR;
@@ -225,6 +263,7 @@
     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
   } else {
     tflite::SetTensorToDynamic(&output_encoded);
+    tflite::SetTensorToDynamic(&output_positions);
     for (int i = 0; i < num_output_attrs; ++i) {
       TfLiteTensor& output_attr =
           context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_ATTR + i]];
@@ -253,10 +292,15 @@
     TF_LITE_ENSURE_OK(
         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
   }
+  TfLiteTensor& output_positions =
+      context->tensors[node->outputs->data[TEXT_ENCODER_OUTPUT_POSITION]];
 
   std::vector<int> encoded_total;
   std::vector<int> encoded_offsets;
+  std::vector<int> encoded_positions;
   encoded_offsets.reserve(num_strings);
+  const int max_output_length = output_encoded.dims->data[1];
+  const int max_encoded_position = max_output_length;
 
   for (int i = 0; i < num_strings; ++i) {
     const auto& strref = tflite::GetString(&input_text, i);
@@ -264,16 +308,20 @@
         encoder_op->normalizer->Normalize(StringPiece(strref.str, strref.len)));
     encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
     encoded_offsets.push_back(encoded_total.size());
+    for (int i = 0; i < encoded.size(); i++) {
+      encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+    }
   }
-  const int max_output_length = output_encoded.dims->data[1];
 
   // Copy encoding to output tensor.
   const int start_offset =
       std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
   int output_offset = 0;
   int32_t* output_buffer = output_encoded.data.i32;
+  int32_t* output_positions_buffer = output_positions.data.i32;
   for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
     output_buffer[output_offset] = encoded_total[i];
+    output_positions_buffer[output_offset] = encoded_positions[i];
   }
 
   // Save output encoded length.
@@ -284,6 +332,7 @@
   // Do padding.
   for (; output_offset < max_output_length; ++output_offset) {
     output_buffer[output_offset] = encoded_total.back();
+    output_positions_buffer[output_offset] = max_encoded_position;
   }
 
   // Process attributes, all checks of sizes and types are done in Prepare.
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 9017116..462da21 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -17,6 +17,12 @@
 // Configuration for the text encoder op.
 
 namespace libtextclassifier3;
+
+enum SentencePieceMatcherType : byte {
+  MAPPED_TRIE = 0,
+  SORTED_STRING_TABLE = 1,
+}
+
 table TextEncoderConfig {
   // Code that is used as encoding of the start code.
   start_code:int32 = 0;
@@ -46,6 +52,8 @@
   // Sentence pieces scores.
   pieces_scores:[float];
 
-  // Serialized sentence pieces trie.
+  // Serialized sentence pieces.
   pieces:string;
+  pieces_offsets:[int32];
+  matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
 }
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
index d1892c7..0b6ff71 100644
--- a/utils/tflite/text_encoder_test.cc
+++ b/utils/tflite/text_encoder_test.cc
@@ -55,6 +55,9 @@
   std::vector<int> GetOutputEncoding() {
     return ExtractVector<int>(output_encoding_);
   }
+  std::vector<int> GetOutputPositions() {
+    return ExtractVector<int>(output_positions_);
+  }
   std::vector<int> GetOutputAttributeInt32() {
     return ExtractVector<int>(output_attributes_int32_);
   }
@@ -71,6 +74,7 @@
   int input_attributes_float_;
 
   int output_encoding_;
+  int output_positions_;
   int output_length_;
   int output_attributes_int32_;
   int output_attributes_float_;
@@ -86,6 +90,7 @@
   input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);
 
   output_encoding_ = AddOutput(tflite::TensorType_INT32);
+  output_positions_ = AddOutput(tflite::TensorType_INT32);
   output_length_ = AddOutput(tflite::TensorType_INT32);
   output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
   output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);
@@ -113,6 +118,8 @@
   EXPECT_EQ(m.GetEncodedLength(), 5);
   EXPECT_THAT(m.GetOutputEncoding(),
               testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
+  EXPECT_THAT(m.GetOutputPositions(),
+              testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
   EXPECT_THAT(m.GetOutputAttributeInt32(),
               testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
   EXPECT_THAT(
@@ -130,6 +137,8 @@
   EXPECT_EQ(m.GetEncodedLength(), 10);
   EXPECT_THAT(m.GetOutputEncoding(),
               testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
+  EXPECT_THAT(m.GetOutputPositions(),
+              testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
   EXPECT_THAT(m.GetOutputAttributeInt32(),
               testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
   EXPECT_THAT(
@@ -147,6 +156,8 @@
   EXPECT_EQ(m.GetEncodedLength(), 9);
   EXPECT_THAT(m.GetOutputEncoding(),
               testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
+  EXPECT_THAT(m.GetOutputPositions(),
+              testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
   EXPECT_THAT(m.GetOutputAttributeInt32(),
               testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
   EXPECT_THAT(
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
index 057703a..81492d8 100644
--- a/utils/utf8/unicodetext.cc
+++ b/utils/utf8/unicodetext.cc
@@ -176,7 +176,7 @@
 
 }  // namespace
 
-UnicodeText& UnicodeText::AppendCodepoint(char32 ch) {
+UnicodeText& UnicodeText::push_back(char32 ch) {
   char str[4];
   int char_len = runetochar(ch, str);
   repr_.append(str, char_len);
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
index f7790f5..eb206b8 100644
--- a/utils/utf8/unicodetext.h
+++ b/utils/utf8/unicodetext.h
@@ -168,7 +168,7 @@
 
   // Calling this may invalidate pointers to underlying data.
   UnicodeText& AppendUTF8(const char* utf8, int len);
-  UnicodeText& AppendCodepoint(char32 ch);
+  UnicodeText& push_back(char32 ch);
   void clear();
 
   std::string ToUTF8String() const;
diff --git a/utils/utf8/unicodetext_test.cc b/utils/utf8/unicodetext_test.cc
index 9cdc850..7ebb415 100644
--- a/utils/utf8/unicodetext_test.cc
+++ b/utils/utf8/unicodetext_test.cc
@@ -24,11 +24,11 @@
 class UnicodeTextTest : public testing::Test {
  protected:
   UnicodeTextTest() : empty_text_() {
-    text_.AppendCodepoint(0x1C0);
-    text_.AppendCodepoint(0x4E8C);
-    text_.AppendCodepoint(0xD7DB);
-    text_.AppendCodepoint(0x34);
-    text_.AppendCodepoint(0x1D11E);
+    text_.push_back(0x1C0);
+    text_.push_back(0x4E8C);
+    text_.push_back(0xD7DB);
+    text_.push_back(0x34);
+    text_.push_back(0x1D11E);
   }
 
   UnicodeText empty_text_;