Export libtextclassifier to Android
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Id7a31dc60c8f6625ff8f2a9c85689e13b121a5a4
diff --git a/utils/tflite/text_encoder.cc b/utils/tflite/text_encoder.cc
index 9554283..734b5b0 100644
--- a/utils/tflite/text_encoder.cc
+++ b/utils/tflite/text_encoder.cc
@@ -35,7 +35,7 @@
namespace {
struct TextEncoderOp {
- std::unique_ptr<Normalizer> normalizer;
+ std::unique_ptr<SentencePieceNormalizer> normalizer;
std::unique_ptr<Encoder> encoder;
std::unique_ptr<SentencePieceMatcher> matcher;
};
@@ -81,7 +81,7 @@
config->normalization_charsmap()->Data());
const int charsmap_trie_nodes_length =
config->normalization_charsmap()->Length() / sizeof(TrieNode);
- encoder_op->normalizer.reset(new Normalizer(
+ encoder_op->normalizer.reset(new SentencePieceNormalizer(
DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
StringPiece(config->normalization_charsmap_values()->data(),
config->normalization_charsmap_values()->size()),
@@ -113,7 +113,8 @@
}
encoder_op->encoder.reset(new Encoder(
encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
- config->start_code(), config->end_code(), config->encoding_offset()));
+ config->start_code(), config->end_code(), config->encoding_offset(),
+ config->unknown_code(), config->unknown_score()));
return encoder_op.release();
}
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 462da21..8ae8fc5 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -34,6 +34,12 @@
// `start_code` and `end_code`.
encoding_offset:int32 = 2;
+ // Code that is used for out-of-dictionary characters.
+ unknown_code:int32 = -1;
+
+ // Penalty associated with the unknown code.
+ unknown_score:float;
+
// Normalization options.
// Serialized normalization charsmap.
normalization_charsmap:string;
diff --git a/utils/tflite/text_encoder_test.cc b/utils/tflite/text_encoder_test.cc
index 0b6ff71..0cd67ce 100644
--- a/utils/tflite/text_encoder_test.cc
+++ b/utils/tflite/text_encoder_test.cc
@@ -20,6 +20,7 @@
#include "utils/tflite/text_encoder.h"
#include "gtest/gtest.h"
+#include "third_party/absl/flags/flag.h"
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"