Import libtextclassifier changes from google3.

This includes an upgrade to the latest version of the model.

We now use a test data path constant in the unit tests. This is
consistent with another test project, minikin_tests.

Test: Tests pass on-device.
Bug: 34865247
Change-Id: Ia061888a0bba371c6b429ad7c5457af611ba12f2
diff --git a/Android.mk b/Android.mk
index 823815d..204a97a 100644
--- a/Android.mk
+++ b/Android.mk
@@ -70,6 +70,9 @@
 
 LOCAL_TEST_DATA := $(call find-test-data-in-subdirs, $(LOCAL_PATH), *, tests/testdata)
 
+LOCAL_CPPFLAGS_32 += -DTEST_DATA_DIR="\"/data/nativetest/libtextclassifier_tests/tests/testdata/\""
+LOCAL_CPPFLAGS_64 += -DTEST_DATA_DIR="\"/data/nativetest64/libtextclassifier_tests/tests/testdata/\""
+
 LOCAL_SRC_FILES := $(patsubst ./%,%, $(shell cd $(LOCAL_PATH); \
     find . -name "*.cc" -and -not -name ".*"))
 LOCAL_C_INCLUDES += .
diff --git a/models/textclassifier.smartselection.en.model b/models/textclassifier.smartselection.en.model
index 0530a6f..2b96f42 100644
--- a/models/textclassifier.smartselection.en.model
+++ b/models/textclassifier.smartselection.en.model
Binary files differ
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index c931021..9e357bd 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -125,27 +125,8 @@
   }
 }
 
-std::pair<SelectionWithContext, int> ExtractLineWithClick(
-    const SelectionWithContext& selection_with_context) {
-  std::string new_context;
-  int shift;
-  std::tie(new_context, shift) = ExtractLineWithSpan(
-      selection_with_context.context,
-      {selection_with_context.click_start, selection_with_context.click_end});
-
-  SelectionWithContext result(selection_with_context);
-  result.selection_start -= shift;
-  result.selection_end -= shift;
-  result.click_start -= shift;
-  result.click_end -= shift;
-  result.context = new_context;
-  return {result, shift};
-}
-
-}  // namespace internal
-
-std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
-                                                CodepointSpan span) {
+void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+                               std::vector<Token>* tokens) {
   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
                                                         /*do_copy=*/false);
   std::vector<UnicodeTextRange> lines;
@@ -164,19 +145,26 @@
   }
   for (const UnicodeTextRange& line : lines) {
     // Find the line that completely contains the span.
-    if (line.first <= span_start && line.second >= span_start &&
-        line.first <= span_end && line.second >= span_end) {
+    if (line.first <= span_start && line.second >= span_end) {
       const CodepointIndex last_line_begin_index =
           std::distance(context_unicode.begin(), line.first);
+      const CodepointIndex last_line_end_index =
+          last_line_begin_index + std::distance(line.first, line.second);
 
-      std::string result =
-          context_unicode.UTF8Substring(line.first, line.second);
-      return {result, last_line_begin_index};
+      for (auto token = tokens->begin(); token != tokens->end();) {
+        if (token->start >= last_line_begin_index &&
+            token->end <= last_line_end_index) {
+          ++token;
+        } else {
+          token = tokens->erase(token);
+        }
+      }
     }
   }
-  return {context, 0};
 }
 
+}  // namespace internal
+
 const char* const FeatureProcessor::kFeatureTypeName = "chargram_continuous";
 
 std::vector<Token> FeatureProcessor::Tokenize(
@@ -308,56 +296,81 @@
   return kInvalidIndex;
 }
 
-// Helper function to get the click position (in terms of index into a vector of
-// selectable tokens), given selectable tokens. If click position is given, it
-// will be used. If it is not given, but selection is given, the click will be
-// the middle token in the selection range.
-int GetClickPosition(const SelectionWithContext& selection_with_context,
-                     const std::vector<Token>& selectable_tokens) {
-  int click_pos = kInvalidIndex;
-  if (selection_with_context.GetClickSpan() !=
-      std::make_pair(kInvalidIndex, kInvalidIndex)) {
-    int range_begin;
-    int range_end;
-    std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
-        selectable_tokens, selection_with_context.GetClickSpan());
+}  // namespace
 
-    // If no exact match was found, try finding a token that completely contains
-    // the click span. This is useful e.g. when Android builds the selection
-    // using ICU tokenization, and ends up with only a portion of our space-
-    // separated token. E.g. for "(857)" Android would select "857".
-    if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
-      int token_index = FindTokenThatContainsSpan(
-          selectable_tokens, selection_with_context.GetClickSpan());
-      if (token_index != kInvalidIndex) {
-        range_begin = token_index;
-        range_end = token_index + 1;
-      }
-    }
+namespace internal {
 
-    // We only allow clicks that are exactly 1 selectable token.
-    if (range_end - range_begin == 1) {
-      click_pos = range_begin;
-    } else {
-      click_pos = kInvalidIndex;
-    }
-  } else if (selection_with_context.GetSelectionSpan() !=
-             std::make_pair(kInvalidIndex, kInvalidIndex)) {
-    int range_begin;
-    int range_end;
-    std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
-        selectable_tokens, selection_with_context.GetSelectionSpan());
+int CenterTokenFromClick(CodepointSpan span,
+                         const std::vector<Token>& selectable_tokens) {
+  int range_begin;
+  int range_end;
+  std::tie(range_begin, range_end) =
+      CodepointSpanToTokenSpan(selectable_tokens, span);
 
-    // Center the clicked token in the selection range.
-    if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
-      click_pos = (range_begin + range_end - 1) / 2;
+  // If no exact match was found, try finding a token that completely contains
+  // the click span. This is useful e.g. when Android builds the selection
+  // using ICU tokenization, and ends up with only a portion of our space-
+  // separated token. E.g. for "(857)" Android would select "857".
+  if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
+    int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
+    if (token_index != kInvalidIndex) {
+      range_begin = token_index;
+      range_end = token_index + 1;
     }
   }
 
-  return click_pos;
+  // We only allow clicks that are exactly 1 selectable token.
+  if (range_end - range_begin == 1) {
+    return range_begin;
+  } else {
+    return kInvalidIndex;
+  }
 }
 
-}  // namespace
+int CenterTokenFromMiddleOfSelection(
+    CodepointSpan span, const std::vector<Token>& selectable_tokens) {
+  int range_begin;
+  int range_end;
+  std::tie(range_begin, range_end) =
+      CodepointSpanToTokenSpan(selectable_tokens, span);
+
+  // Center the clicked token in the selection range.
+  if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
+    return (range_begin + range_end - 1) / 2;
+  } else {
+    return kInvalidIndex;
+  }
+}
+
+}  // namespace internal
+
+int FeatureProcessor::FindCenterToken(CodepointSpan span,
+                                      const std::vector<Token>& tokens) const {
+  if (options_.center_token_selection_method() ==
+      FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) {
+    return internal::CenterTokenFromClick(span, tokens);
+  } else if (options_.center_token_selection_method() ==
+             FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) {
+    return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+  } else if (options_.center_token_selection_method() ==
+             FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) {
+    // TODO(zilka): This is a HACK not to break the current models. Remove once
+    // we have new models on the device.
+    // It uses the fact that sharing model use
+    // split_tokens_on_selection_boundaries and selection not. So depending on
+    // this we select the right way of finding the click location.
+    if (!options_.split_tokens_on_selection_boundaries()) {
+      // SmartSelection model.
+      return internal::CenterTokenFromClick(span, tokens);
+    } else {
+      // SmartSharing model.
+      return internal::CenterTokenFromMiddleOfSelection(span, tokens);
+    }
+  } else {
+    TC_LOG(ERROR) << "Invalid center token selection method.";
+    return kInvalidIndex;
+  }
+}
 
 std::vector<Token> FeatureProcessor::FindTokensInSelection(
     const std::vector<Token>& selectable_tokens,
@@ -401,8 +414,20 @@
   }
 }
 
+bool FeatureProcessor::GetFeatures(
+    const std::string& context, CodepointSpan input_span,
+    std::vector<nlp_core::FeatureVector>* features,
+    std::vector<float>* extra_features,
+    std::vector<CodepointSpan>* selection_label_spans) const {
+  return FeatureProcessor::GetFeaturesAndLabels(
+      context, input_span, {kInvalidIndex, kInvalidIndex}, "", features,
+      extra_features, selection_label_spans, /*selection_label=*/nullptr,
+      /*selection_codepoint_label=*/nullptr, /*classification_label=*/nullptr);
+}
+
 bool FeatureProcessor::GetFeaturesAndLabels(
-    const SelectionWithContext& selection_with_context,
+    const std::string& context, CodepointSpan input_span,
+    CodepointSpan label_span, const std::string& label_collection,
     std::vector<nlp_core::FeatureVector>* features,
     std::vector<float>* extra_features,
     std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
@@ -413,39 +438,32 @@
   *features =
       std::vector<nlp_core::FeatureVector>(options_.context_size() * 2 + 1);
 
-  SelectionWithContext selection_normalized;
-  int normalization_shift;
-  if (options_.only_use_line_with_click()) {
-    std::tie(selection_normalized, normalization_shift) =
-        internal::ExtractLineWithClick(selection_with_context);
-  } else {
-    selection_normalized = selection_with_context;
-    normalization_shift = 0;
+  std::vector<Token> input_tokens = Tokenize(context);
+
+  if (options_.split_tokens_on_selection_boundaries()) {
+    internal::SplitTokensOnSelectionBoundaries(input_span, &input_tokens);
   }
 
-  std::vector<Token> input_tokens = Tokenize(selection_normalized.context);
-  if (options_.split_tokens_on_selection_boundaries()) {
-    internal::SplitTokensOnSelectionBoundaries(
-        selection_with_context.GetSelectionSpan(), &input_tokens);
+  if (options_.only_use_line_with_click()) {
+    internal::StripTokensFromOtherLines(context, input_span, &input_tokens);
   }
-  int click_pos = GetClickPosition(selection_normalized, input_tokens);
+
+  const int click_pos = FindCenterToken(input_span, input_tokens);
   if (click_pos == kInvalidIndex) {
     TC_LOG(ERROR) << "Could not extract click position.";
     return false;
   }
 
   std::vector<Token> output_tokens;
-  bool status = ComputeFeatures(click_pos, input_tokens,
-                                selection_normalized.GetSelectionSpan(),
-                                features, extra_features, &output_tokens);
+  bool status = ComputeFeatures(click_pos, input_tokens, input_span, features,
+                                extra_features, &output_tokens);
   if (!status) {
     TC_LOG(ERROR) << "Feature computation failed.";
     return false;
   }
 
   if (selection_label != nullptr) {
-    status = SpanToLabel(selection_normalized.GetSelectionSpan(), output_tokens,
-                         selection_label);
+    status = SpanToLabel(label_span, output_tokens, selection_label);
     if (!status) {
       TC_LOG(ERROR) << "Could not convert selection span to label.";
       return false;
@@ -453,20 +471,10 @@
   }
 
   if (selection_codepoint_label != nullptr) {
-    *selection_codepoint_label = selection_with_context.GetSelectionSpan();
+    *selection_codepoint_label = label_span;
   }
 
   if (selection_label_spans != nullptr) {
-    // If an input normalization was performed, we need to shift the tokens
-    // back to get the correct ranges in the original input.
-    if (normalization_shift) {
-      for (Token& token : output_tokens) {
-        if (token.start != kInvalidIndex && token.end != kInvalidIndex) {
-          token.start += normalization_shift;
-          token.end += normalization_shift;
-        }
-      }
-    }
     for (int i = 0; i < label_to_selection_.size(); ++i) {
       CodepointSpan span;
       status = LabelToSpan(i, output_tokens, &span);
@@ -479,14 +487,15 @@
   }
 
   if (classification_label != nullptr) {
-    *classification_label = CollectionToLabel(selection_normalized.collection);
+    *classification_label = CollectionToLabel(label_collection);
   }
 
   return true;
 }
 
 bool FeatureProcessor::GetFeaturesAndLabels(
-    const SelectionWithContext& selection_with_context,
+    const std::string& context, CodepointSpan input_span,
+    CodepointSpan label_span, const std::string& label_collection,
     std::vector<std::vector<std::pair<int, float>>>* features,
     std::vector<float>* extra_features,
     std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
@@ -499,10 +508,10 @@
   }
 
   std::vector<nlp_core::FeatureVector> feature_vectors;
-  bool result = GetFeaturesAndLabels(selection_with_context, &feature_vectors,
-                                     extra_features, selection_label_spans,
-                                     selection_label, selection_codepoint_label,
-                                     classification_label);
+  bool result = GetFeaturesAndLabels(
+      context, input_span, label_span, label_collection, &feature_vectors,
+      extra_features, selection_label_spans, selection_label,
+      selection_codepoint_label, classification_label);
 
   if (!result) {
     return false;
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 712534d..311be3e 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -43,30 +43,29 @@
 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
     const FeatureProcessorOptions& options);
 
-// Returns a modified version of the selection_with_context, such that only the
-// line that contains the clicked span is kept, the number of codepoints
-// the selection was moved by.
-std::pair<SelectionWithContext, int> ExtractLineWithClick(
-    const SelectionWithContext& selection_with_context);
+// Removes tokens that are not part of a line of the context which contains
+// given span.
+void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+                               std::vector<Token>* tokens);
 
 // Splits tokens that contain the selection boundary inside them.
 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
                                       std::vector<Token>* tokens);
 
+// Returns the index of token that corresponds to the codepoint span.
+int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+
+// Returns the index of token that corresponds to the middle of the  codepoint
+// span.
+int CenterTokenFromMiddleOfSelection(
+    CodepointSpan span, const std::vector<Token>& selectable_tokens);
+
 }  // namespace internal
 
 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
                                    CodepointSpan codepoint_span);
 
-// Returns a modified version of the context string, such that only the
-// line that contains the span is kept. Also returns a codepoint shift
-// size that happend. If the span spans multiple lines, returns the original
-// input with zero shift.
-// The following characters are considered to be line separators: '\n', '|'
-std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
-                                                CodepointSpan span);
-
 // Takes care of preparing features for the FFModel.
 class FeatureProcessor {
  public:
@@ -92,9 +91,16 @@
   // Tokenizes the input string using the selected tokenization method.
   std::vector<Token> Tokenize(const std::string& utf8_text) const;
 
+  bool GetFeatures(const std::string& context, CodepointSpan input_span,
+                   std::vector<nlp_core::FeatureVector>* features,
+                   std::vector<float>* extra_features,
+                   std::vector<CodepointSpan>* selection_label_spans) const;
+
   // NOTE: If dropout is on, subsequent calls of this function with the same
   // arguments might return different results.
-  bool GetFeaturesAndLabels(const SelectionWithContext& selection_with_context,
+  bool GetFeaturesAndLabels(const std::string& context,
+                            CodepointSpan input_span, CodepointSpan label_span,
+                            const std::string& label_collection,
                             std::vector<nlp_core::FeatureVector>* features,
                             std::vector<float>* extra_features,
                             std::vector<CodepointSpan>* selection_label_spans,
@@ -106,7 +112,8 @@
   // NOTE: If dropout is on, subsequent calls of this function with the same
   // arguments might return different results.
   bool GetFeaturesAndLabels(
-      const SelectionWithContext& selection_with_context,
+      const std::string& context, CodepointSpan input_span,
+      CodepointSpan label_span, const std::string& label_collection,
       std::vector<std::vector<std::pair<int, float>>>* features,
       std::vector<float>* extra_features,
       std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
@@ -179,12 +186,17 @@
   // Converts a token span to the corresponding label.
   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
 
-  // Find tokens that are part of the selection.
+  // Finds tokens that are part of the selection.
   // NOTE: Will select all tokens that somehow overlap with the selection.
   std::vector<Token> FindTokensInSelection(
       const std::vector<Token>& selectable_tokens,
       const SelectionWithContext& selection_with_context) const;
 
+  // Finds the center token index in tokens vector, using the method defined
+  // in options_.
+  int FindCenterToken(CodepointSpan span,
+                      const std::vector<Token>& tokens) const;
+
  private:
   FeatureProcessorOptions options_;
 
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index 5fece36..557497e 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -51,8 +51,15 @@
 
     // If no tokenization codepoint config is present, tokenize on space.
     if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
-      TokenizationCodepointRange* config =
-          feature_processor_options.add_tokenization_codepoint_config();
+      TokenizationCodepointRange* config;
+      // New line character.
+      config = feature_processor_options.add_tokenization_codepoint_config();
+      config->set_start(10);
+      config->set_end(11);
+      config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+
+      // Space character.
+      config = feature_processor_options.add_tokenization_codepoint_config();
       config->set_start(32);
       config->set_end(33);
       config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
@@ -186,23 +193,13 @@
 }
 
 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
-    const std::string& context, CodepointSpan click_indices,
-    CodepointSpan selection_indices, const FeatureProcessor& feature_processor,
-    const EmbeddingNetwork* network,
-    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
-    CodepointSpan* selection_codepoint_label, int* classification_label) const {
-  SelectionWithContext selection_with_context;
-  selection_with_context.context = context;
-  selection_with_context.click_start = std::get<0>(click_indices);
-  selection_with_context.click_end = std::get<1>(click_indices);
-  selection_with_context.selection_start = std::get<0>(selection_indices);
-  selection_with_context.selection_end = std::get<1>(selection_indices);
-
+    const std::string& context, CodepointSpan span,
+    const FeatureProcessor& feature_processor, const EmbeddingNetwork* network,
+    std::vector<CodepointSpan>* selection_label_spans) const {
   std::vector<FeatureVector> features;
   std::vector<float> extra_features;
-  const bool features_computed = feature_processor.GetFeaturesAndLabels(
-      selection_with_context, &features, &extra_features, selection_label_spans,
-      selection_label, selection_codepoint_label, classification_label);
+  const bool features_computed = feature_processor.GetFeatures(
+      context, span, &features, &extra_features, selection_label_spans);
 
   EmbeddingNetwork::Vector scores;
   if (!features_computed) {
@@ -250,17 +247,10 @@
     return {click_indices, -1.0};
   }
 
-  // Invalid selection indices make the feature extraction use the provided
-  // click indices.
-  const CodepointSpan selection_indices({kInvalidIndex, kInvalidIndex});
-
   std::vector<CodepointSpan> selection_label_spans;
-  EmbeddingNetwork::Vector scores = InferInternal(
-      context, click_indices, selection_indices, *selection_feature_processor_,
-      selection_network_.get(), &selection_label_spans,
-      /*selection_label=*/nullptr,
-      /*selection_codepoint_label=*/nullptr,
-      /*classification_label=*/nullptr);
+  EmbeddingNetwork::Vector scores =
+      InferInternal(context, click_indices, *selection_feature_processor_,
+                    selection_network_.get(), &selection_label_spans);
 
   if (!scores.empty()) {
     scores = nlp_core::ComputeSoftmax(scores);
@@ -317,18 +307,12 @@
 // matter which word of it was tapped - all of them will lead to the same
 // selection.
 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
-    const std::string& full_context, CodepointSpan click_indices) const {
-  // Extract context from the current line only.
-  std::string context;
-  int context_shift;
-  std::tie(context, context_shift) =
-      ExtractLineWithSpan(full_context, click_indices);
-  click_indices.first -= context_shift;
-  click_indices.second -= context_shift;
-
+    const std::string& context, CodepointSpan click_indices) const {
   std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
+  internal::StripTokensFromOtherLines(context, click_indices, &tokens);
 
-  const int click_index = GetClickTokenIndex(tokens, click_indices);
+  // const int click_index = GetClickTokenIndex(tokens, click_indices);
+  const int click_index = internal::CenterTokenFromClick(click_indices, tokens);
   if (click_index == kInvalidIndex) {
     return click_indices;
   }
@@ -337,7 +321,7 @@
 
   // Scan in the symmetry context for selection span proposals.
   std::vector<std::pair<CodepointSpan, float>> proposals;
-  for (int i = -symmetry_context_size; i < symmetry_context_size + 1; i++) {
+  for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
     const int token_index = click_index + i;
     if (token_index >= 0 && token_index < tokens.size()) {
       float score;
@@ -373,8 +357,7 @@
 
       if (feasible) {
         if (span.first <= click_index && span.second > click_index) {
-          return {span_result.first.first + context_shift,
-                  span_result.first.second + context_shift};
+          return {span_result.first.first, span_result.first.second};
         }
         for (int i = span.first; i < span.second; i++) {
           used_tokens[i] = 1;
@@ -383,8 +366,7 @@
     }
   }
 
-  return {click_indices.first + context_shift,
-          click_indices.second + context_shift};
+  return {click_indices.first, click_indices.second};
 }
 
 CodepointSpan TextClassificationModel::SuggestSelection(
@@ -409,13 +391,9 @@
     return {};
   }
 
-  // Invalid click indices make the feature extraction select the middle word in
-  // the selection span.
-  const CodepointSpan click_indices({kInvalidIndex, kInvalidIndex});
-
-  EmbeddingNetwork::Vector scores = InferInternal(
-      context, click_indices, selection_indices, *sharing_feature_processor_,
-      sharing_network_.get(), nullptr, nullptr, nullptr, nullptr);
+  EmbeddingNetwork::Vector scores =
+      InferInternal(context, selection_indices, *sharing_feature_processor_,
+                    sharing_network_.get(), nullptr);
   if (scores.empty()) {
     TC_LOG(ERROR) << "Using default class";
     return {};
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index c30a4d3..458249a 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -137,13 +137,10 @@
   bool LoadModels(int fd);
 
   nlp_core::EmbeddingNetwork::Vector InferInternal(
-      const std::string& context, CodepointSpan click_indices,
-      CodepointSpan selection_indices,
+      const std::string& context, CodepointSpan span,
       const FeatureProcessor& feature_processor,
       const nlp_core::EmbeddingNetwork* network,
-      std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
-      CodepointSpan* selection_codepoint_label,
-      int* classification_label) const;
+      std::vector<CodepointSpan>* selection_label_spans) const;
 
   // Returns a selection suggestion with a score.
   std::pair<CodepointSpan, float> SuggestSelectionInternal(
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index e098e81..b96ee2f 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -98,6 +98,22 @@
   // Codepoint ranges that determine how different codepoints are tokenized.
   // The ranges must not overlap.
   repeated TokenizationCodepointRange tokenization_codepoint_config = 15;
+
+  // Method for selecting the center token.
+  enum CenterTokenSelectionMethod {
+    DEFAULT_CENTER_TOKEN_METHOD = 0;  // Invalid option.
+
+    // Use click indices to determine the center token.
+    CENTER_TOKEN_FROM_CLICK = 1;
+
+    // Use selection indices to get a token range, and select the middle of it
+    // as the center token.
+    CENTER_TOKEN_MIDDLE_OF_SELECTION = 2;
+  }
+  optional CenterTokenSelectionMethod center_token_selection_method = 16;
+
+  // If true, during training will click random token in the example selection.
+  optional bool click_random_token_in_selection = 17 [default = true];
 };
 
 extend nlp_core.EmbeddingNetworkProto {
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index ecdad16..652db84 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -106,113 +106,95 @@
 }
 
 TEST(FeatureProcessorTest, KeepLineWithClickFirst) {
-  SelectionWithContext selection;
-  selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const CodepointSpan span = {0, 5};
+  // clang-format off
+  std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+                               Token("Lině", 6, 10),
+                               Token("Sěcond", 11, 17),
+                               Token("Lině", 18, 22),
+                               Token("Thiřd", 23, 28),
+                               Token("Lině", 29, 33)};
+  // clang-format on
 
   // Keeps the first line.
-  selection.click_start = 0;
-  selection.click_end = 5;
-  selection.selection_start = 6;
-  selection.selection_end = 10;
-
-  SelectionWithContext line_selection;
-  int shift;
-  std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
-  EXPECT_EQ(line_selection.context, "Fiřst Lině");
-  EXPECT_EQ(line_selection.click_start, 0);
-  EXPECT_EQ(line_selection.click_end, 5);
-  EXPECT_EQ(line_selection.selection_start, 6);
-  EXPECT_EQ(line_selection.selection_end, 10);
-  EXPECT_EQ(shift, 0);
+  internal::StripTokensFromOtherLines(context, span, &tokens);
+  EXPECT_THAT(tokens,
+              ElementsAreArray({Token("Fiřst", 0, 5), Token("Lině", 6, 10)}));
 }
 
 TEST(FeatureProcessorTest, KeepLineWithClickSecond) {
-  SelectionWithContext selection;
-  selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const CodepointSpan span = {18, 22};
+  // clang-format off
+  std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+                               Token("Lině", 6, 10),
+                               Token("Sěcond", 11, 17),
+                               Token("Lině", 18, 22),
+                               Token("Thiřd", 23, 28),
+                               Token("Lině", 29, 33)};
+  // clang-format on
 
-  // Keeps the second line.
-  selection.click_start = 11;
-  selection.click_end = 17;
-  selection.selection_start = 18;
-  selection.selection_end = 22;
-
-  SelectionWithContext line_selection;
-  int shift;
-  std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
-  EXPECT_EQ(line_selection.context, "Sěcond Lině");
-  EXPECT_EQ(line_selection.click_start, 0);
-  EXPECT_EQ(line_selection.click_end, 6);
-  EXPECT_EQ(line_selection.selection_start, 7);
-  EXPECT_EQ(line_selection.selection_end, 11);
-  EXPECT_EQ(shift, 11);
+  // Keeps the first line.
+  internal::StripTokensFromOtherLines(context, span, &tokens);
+  EXPECT_THAT(tokens, ElementsAreArray(
+                          {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
 }
 
 TEST(FeatureProcessorTest, KeepLineWithClickThird) {
-  SelectionWithContext selection;
-  selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const CodepointSpan span = {24, 33};
+  // clang-format off
+  std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+                               Token("Lině", 6, 10),
+                               Token("Sěcond", 11, 17),
+                               Token("Lině", 18, 22),
+                               Token("Thiřd", 23, 28),
+                               Token("Lině", 29, 33)};
+  // clang-format on
 
-  // Keeps the third line.
-  selection.click_start = 29;
-  selection.click_end = 33;
-  selection.selection_start = 23;
-  selection.selection_end = 28;
-
-  SelectionWithContext line_selection;
-  int shift;
-  std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
-  EXPECT_EQ(line_selection.context, "Thiřd Lině");
-  EXPECT_EQ(line_selection.click_start, 6);
-  EXPECT_EQ(line_selection.click_end, 10);
-  EXPECT_EQ(line_selection.selection_start, 0);
-  EXPECT_EQ(line_selection.selection_end, 5);
-  EXPECT_EQ(shift, 23);
+  // Keeps the first line.
+  internal::StripTokensFromOtherLines(context, span, &tokens);
+  EXPECT_THAT(tokens, ElementsAreArray(
+                          {Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
 }
 
 TEST(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) {
-  SelectionWithContext selection;
-  selection.context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+  const std::string context = "Fiřst Lině|Sěcond Lině\nThiřd Lině";
+  const CodepointSpan span = {18, 22};
+  // clang-format off
+  std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+                               Token("Lině", 6, 10),
+                               Token("Sěcond", 11, 17),
+                               Token("Lině", 18, 22),
+                               Token("Thiřd", 23, 28),
+                               Token("Lině", 29, 33)};
+  // clang-format on
 
-  // Keeps the second line.
-  selection.click_start = 11;
-  selection.click_end = 17;
-  selection.selection_start = 18;
-  selection.selection_end = 22;
-
-  SelectionWithContext line_selection;
-  int shift;
-  std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
-  EXPECT_EQ(line_selection.context, "Sěcond Lině");
-  EXPECT_EQ(line_selection.click_start, 0);
-  EXPECT_EQ(line_selection.click_end, 6);
-  EXPECT_EQ(line_selection.selection_start, 7);
-  EXPECT_EQ(line_selection.selection_end, 11);
-  EXPECT_EQ(shift, 11);
+  // Keeps the first line.
+  internal::StripTokensFromOtherLines(context, span, &tokens);
+  EXPECT_THAT(tokens, ElementsAreArray(
+                          {Token("Sěcond", 11, 17), Token("Lině", 18, 22)}));
 }
 
 TEST(FeatureProcessorTest, KeepLineWithCrosslineClick) {
-  SelectionWithContext selection;
-  selection.context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const std::string context = "Fiřst Lině\nSěcond Lině\nThiřd Lině";
+  const CodepointSpan span = {5, 23};
+  // clang-format off
+  std::vector<Token> tokens = {Token("Fiřst", 0, 5),
+                               Token("Lině", 6, 10),
+                               Token("Sěcond", 18, 23),
+                               Token("Lině", 19, 23),
+                               Token("Thiřd", 23, 28),
+                               Token("Lině", 29, 33)};
+  // clang-format on
 
-  // Selects across lines, so KeepLine should not do any changes.
-  selection.click_start = 6;
-  selection.click_end = 17;
-  selection.selection_start = 0;
-  selection.selection_end = 22;
-
-  SelectionWithContext line_selection;
-  int shift;
-  std::tie(line_selection, shift) = internal::ExtractLineWithClick(selection);
-
-  EXPECT_EQ(line_selection.context, "Fiřst Lině\nSěcond Lině\nThiřd Lině");
-  EXPECT_EQ(line_selection.click_start, 6);
-  EXPECT_EQ(line_selection.click_end, 17);
-  EXPECT_EQ(line_selection.selection_start, 0);
-  EXPECT_EQ(line_selection.selection_end, 22);
-  EXPECT_EQ(shift, 0);
+  // Keeps the first line.
+  internal::StripTokensFromOtherLines(context, span, &tokens);
+  EXPECT_THAT(tokens, ElementsAreArray(
+                          {Token("Fiřst", 0, 5), Token("Lině", 6, 10),
+                           Token("Sěcond", 18, 23), Token("Lině", 19, 23),
+                           Token("Thiřd", 23, 28), Token("Lině", 29, 33)}));
 }
 
 TEST(FeatureProcessorTest, GetFeaturesWithContextDropout) {
@@ -231,14 +213,6 @@
   config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
   FeatureProcessor feature_processor(options);
 
-  SelectionWithContext selection_with_context;
-  selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
-  // Selection and click indices of the X in the middle:
-  selection_with_context.selection_start = 20;
-  selection_with_context.selection_end = 21;
-  selection_with_context.click_start = 20;
-  selection_with_context.click_end = 21;
-
   // Test that two subsequent runs with random context dropout produce
   // different features.
   feature_processor.SetRandom(new std::mt19937);
@@ -251,13 +225,13 @@
   CodepointSpan selection_codepoint_label;
   int classification_label;
   EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      selection_with_context, &features, &extra_features,
-      &selection_label_spans, &selection_label, &selection_codepoint_label,
-      &classification_label));
+      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+      &features, &extra_features, &selection_label_spans, &selection_label,
+      &selection_codepoint_label, &classification_label));
   EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      selection_with_context, &features2, &extra_features,
-      &selection_label_spans, &selection_label, &selection_codepoint_label,
-      &classification_label));
+      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+      &features2, &extra_features, &selection_label_spans, &selection_label,
+      &selection_codepoint_label, &classification_label));
 
   EXPECT_NE(features, features2);
 }
@@ -276,14 +250,6 @@
   config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
   FeatureProcessor feature_processor(options);
 
-  SelectionWithContext selection_with_context;
-  selection_with_context.context = "1 2 3 c o n t e x t X c o n t e x t 1 2 3";
-  // Selection and click indices of the X in the middle:
-  selection_with_context.selection_start = 20;
-  selection_with_context.selection_end = 21;
-  selection_with_context.click_start = 20;
-  selection_with_context.click_end = 21;
-
   std::vector<std::vector<std::pair<int, float>>> features;
   std::vector<float> extra_features;
   std::vector<CodepointSpan> selection_label_spans;
@@ -291,19 +257,14 @@
   CodepointSpan selection_codepoint_label;
   int classification_label;
   EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      selection_with_context, &features, &extra_features,
-      &selection_label_spans, &selection_label, &selection_codepoint_label,
-      &classification_label));
+      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
+      &features, &extra_features, &selection_label_spans, &selection_label,
+      &selection_codepoint_label, &classification_label));
   EXPECT_EQ(19, features.size());
 
   // Should pad the string.
-  selection_with_context.context = "X";
-  selection_with_context.selection_start = 0;
-  selection_with_context.selection_end = 1;
-  selection_with_context.click_start = 0;
-  selection_with_context.click_end = 1;
   EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      selection_with_context, &features, &extra_features,
+      "X", {0, 1}, {0, 1}, "", &features, &extra_features,
       &selection_label_spans, &selection_label, &selection_codepoint_label,
       &classification_label));
   EXPECT_EQ(19, features.size());
@@ -432,5 +393,67 @@
       }));
 }
 
+TEST(FeatureProcessorTest, CenterTokenFromClick) {
+  int token_index;
+
+  // Exactly aligned indices.
+  token_index = internal::CenterTokenFromClick(
+      {6, 11}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+                Token("heře!", 12, 17, false)});
+  EXPECT_EQ(token_index, 1);
+
+  // Click is contained in a token.
+  token_index = internal::CenterTokenFromClick(
+      {13, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+                 Token("heře!", 12, 17, false)});
+  EXPECT_EQ(token_index, 2);
+
+  // Click spans two tokens.
+  token_index = internal::CenterTokenFromClick(
+      {6, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
+                Token("heře!", 12, 17, false)});
+  EXPECT_EQ(token_index, kInvalidIndex);
+}
+
+TEST(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) {
+  SelectionWithContext selection;
+  int token_index;
+
+  // Selection of length 3. Exactly aligned indices.
+  token_index = internal::CenterTokenFromMiddleOfSelection(
+      {7, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+                Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+                Token("Token5", 28, 34, false)});
+  EXPECT_EQ(token_index, 2);
+
+  // Selection of length 1 token. Exactly aligned indices.
+  token_index = internal::CenterTokenFromMiddleOfSelection(
+      {21, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+                 Token("Token5", 28, 34, false)});
+  EXPECT_EQ(token_index, 3);
+
+  // Selection marks sub-token range, with no tokens in it.
+  token_index = internal::CenterTokenFromMiddleOfSelection(
+      {29, 33}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+                 Token("Token5", 28, 34, false)});
+  EXPECT_EQ(token_index, kInvalidIndex);
+
+  // Selection of length 2. Sub-token indices.
+  token_index = internal::CenterTokenFromMiddleOfSelection(
+      {3, 25}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+                Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+                Token("Token5", 28, 34, false)});
+  EXPECT_EQ(token_index, 1);
+
+  // Selection of length 1. Sub-token indices.
+  token_index = internal::CenterTokenFromMiddleOfSelection(
+      {22, 34}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
+                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
+                 Token("Token5", 28, 34, false)});
+  EXPECT_EQ(token_index, 4);
+}
+
 }  // namespace
 }  // namespace libtextclassifier
diff --git a/tests/lang-id_test.cc b/tests/lang-id_test.cc
index 39aed63..faf78c6 100644
--- a/tests/lang-id_test.cc
+++ b/tests/lang-id_test.cc
@@ -30,7 +30,7 @@
 namespace {
 
 std::string GetModelPath() {
-  return "tests/testdata/langid.model";
+  return TEST_DATA_DIR "langid.model";
 }
 
 // Creates a LangId with default model.  Passes ownership to
diff --git a/tests/testdata/smartselection.model b/tests/testdata/smartselection.model
index 8972a2e..2b96f42 100644
--- a/tests/testdata/smartselection.model
+++ b/tests/testdata/smartselection.model
Binary files differ
diff --git a/tests/text-classification-model_test.cc b/tests/text-classification-model_test.cc
index a4545cd..d588fc7 100644
--- a/tests/text-classification-model_test.cc
+++ b/tests/text-classification-model_test.cc
@@ -28,28 +28,30 @@
 namespace {
 
 std::string GetModelPath() {
-  return "tests/testdata/smartselection.model";
+  return TEST_DATA_DIR "smartselection.model";
 }
 
 TEST(TextClassificationModelTest, SuggestSelection) {
   const std::string model_path = GetModelPath();
   int fd = open(model_path.c_str(), O_RDONLY);
-  std::unique_ptr<TextClassificationModel> ff_model(
+  std::unique_ptr<TextClassificationModel> model(
       new TextClassificationModel(fd));
   close(fd);
 
-  std::tuple<int, int> selection;
-  selection = ff_model->SuggestSelection(
-      "this afternoon Barack Obama gave a speech at", {15, 21});
-  EXPECT_EQ(15, std::get<0>(selection));
-  EXPECT_EQ(27, std::get<1>(selection));
+  EXPECT_EQ(model->SuggestSelection(
+                "this afternoon Barack Obama gave a speech at", {15, 21}),
+            std::make_pair(15, 27));
 
   // Try passing whole string.
-  selection =
-      ff_model->SuggestSelection("350 Third Street, Cambridge", {0, 27});
   // If more than 1 token is specified, we should return back what entered.
-  EXPECT_EQ(0, std::get<0>(selection));
-  EXPECT_EQ(27, std::get<1>(selection));
+  EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
+            std::make_pair(0, 27));
+
+  // Single letter.
+  EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1}));
+
+  // Single word.
+  EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
 }
 
 TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
@@ -183,6 +185,10 @@
 namespace {
 
 std::string FindBestResult(std::vector<std::pair<std::string, float>> results) {
+  if (results.empty()) {
+    return "<INVALID RESULTS>";
+  }
+
   std::sort(results.begin(), results.end(),
             [](const std::pair<std::string, float> a,
                const std::pair<std::string, float> b) {
@@ -211,6 +217,29 @@
                          "Call me at (800) 123-456 today", {11, 24})));
   EXPECT_EQ("url", FindBestResult(model->ClassifyText(
                        "Visit www.google.com every today!", {6, 20})));
+
+  // More lines.
+  EXPECT_EQ("other",
+            FindBestResult(model->ClassifyText(
+                "this afternoon Barack Obama gave a speech at|Visit "
+                "www.google.com every today!|Call me at (800) 123-456 today.",
+                {15, 27})));
+  EXPECT_EQ("url",
+            FindBestResult(model->ClassifyText(
+                "this afternoon Barack Obama gave a speech at|Visit "
+                "www.google.com every today!|Call me at (800) 123-456 today.",
+                {51, 65})));
+  EXPECT_EQ("phone",
+            FindBestResult(model->ClassifyText(
+                "this afternoon Barack Obama gave a speech at|Visit "
+                "www.google.com every today!|Call me at (800) 123-456 today.",
+                {90, 103})));
+
+  // Single word.
+  EXPECT_EQ("other", FindBestResult(model->ClassifyText("Obama", {0, 5})));
+  EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
+  EXPECT_EQ("<INVALID RESULTS>",
+            FindBestResult(model->ClassifyText("asdf", {0, 0})));
 }
 
 }  // namespace