Initial checkin of libtextclassifier.

Test: libtextclassifier_tests pass in google3 repository.

Change-Id: I2aac5b12b8810c2ec1c1aa27238cb8aa8430f403
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
new file mode 100644
index 0000000..2e04920
--- /dev/null
+++ b/smartselect/feature-processor.cc
@@ -0,0 +1,644 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/feature-processor.h"
+
+#include <iterator>
+#include <set>
+#include <vector>
+
+#include "smartselect/text-classification-model.pb.h"
+#include "util/base/logging.h"
+#include "util/strings/utf8.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+constexpr int kMaxWordLength = 20;  // All words will be trimmed to this length.
+
+namespace internal {
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+    const FeatureProcessorOptions& options) {
+  TokenFeatureExtractorOptions extractor_options;
+
+  extractor_options.num_buckets = options.num_buckets();
+  for (int order : options.chargram_orders()) {
+    extractor_options.chargram_orders.push_back(order);
+  }
+  extractor_options.extract_case_feature = options.extract_case_feature();
+  extractor_options.extract_selection_mask_feature =
+      options.extract_selection_mask_feature();
+
+  return extractor_options;
+}
+
+FeatureProcessorOptions ParseSerializedOptions(
+    const std::string& serialized_options) {
+  FeatureProcessorOptions options;
+  options.ParseFromString(serialized_options);
+  return options;
+}
+
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+                                      std::vector<Token>* tokens) {
+  for (auto it = tokens->begin(); it != tokens->end(); ++it) {
+    const UnicodeText token_word =
+        UTF8ToUnicodeText(it->value, /*do_copy=*/false);
+
+    auto last_start = token_word.begin();
+    int last_start_index = it->start;
+    std::vector<UnicodeText::const_iterator> split_points;
+
+    // Selection start split point.
+    if (selection.first > it->start && selection.first < it->end) {
+      std::advance(last_start, selection.first - last_start_index);
+      split_points.push_back(last_start);
+      last_start_index = selection.first;
+    }
+
+    // Selection end split point.
+    if (selection.second > it->start && selection.second < it->end) {
+      std::advance(last_start, selection.second - last_start_index);
+      split_points.push_back(last_start);
+    }
+
+    if (!split_points.empty()) {
+      // Add a final split for the rest of the token unless it's been all
+      // consumed already.
+      if (split_points.back() != token_word.end()) {
+        split_points.push_back(token_word.end());
+      }
+
+      std::vector<Token> replacement_tokens;
+      last_start = token_word.begin();
+      int current_pos = it->start;
+      for (const auto& split_point : split_points) {
+        Token new_token(token_word.UTF8Substring(last_start, split_point),
+                        current_pos,
+                        current_pos + std::distance(last_start, split_point),
+                        /*is_in_span=*/false);
+
+        last_start = split_point;
+        current_pos = new_token.end;
+
+        replacement_tokens.push_back(new_token);
+      }
+
+      it = tokens->erase(it);
+      it = tokens->insert(it, replacement_tokens.begin(),
+                          replacement_tokens.end());
+      std::advance(it, replacement_tokens.size() - 1);
+    }
+  }
+}
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+                    std::vector<UnicodeTextRange>* ranges) {
+  UnicodeText::const_iterator start = t.begin();
+  UnicodeText::const_iterator curr = start;
+  UnicodeText::const_iterator end = t.end();
+  for (; curr != end; ++curr) {
+    if (codepoints.find(*curr) != codepoints.end()) {
+      if (start != curr) {
+        ranges->push_back(std::make_pair(start, curr));
+      }
+      start = curr;
+      ++start;
+    }
+  }
+  if (start != end) {
+    ranges->push_back(std::make_pair(start, end));
+  }
+}
+
+std::pair<SelectionWithContext, int> ExtractLineWithClick(
+    const SelectionWithContext& selection_with_context) {
+  std::string new_context;
+  int shift;
+  std::tie(new_context, shift) = ExtractLineWithSpan(
+      selection_with_context.context,
+      {selection_with_context.click_start, selection_with_context.click_end});
+
+  SelectionWithContext result(selection_with_context);
+  result.selection_start -= shift;
+  result.selection_end -= shift;
+  result.click_start -= shift;
+  result.click_end -= shift;
+  result.context = new_context;
+  return {result, shift};
+}
+
+}  // namespace internal
+
+std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
+                                                CodepointSpan span) {
+  const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+                                                        /*do_copy=*/false);
+  std::vector<UnicodeTextRange> lines;
+  std::set<char32> codepoints;
+  codepoints.insert('\n');
+  codepoints.insert('|');
+  internal::FindSubstrings(context_unicode, codepoints, &lines);
+
+  auto span_start = context_unicode.begin();
+  if (span.first > 0) {
+    std::advance(span_start, span.first);
+  }
+  auto span_end = context_unicode.begin();
+  if (span.second > 0) {
+    std::advance(span_end, span.second);
+  }
+  for (const UnicodeTextRange& line : lines) {
+    // Find the line that completely contains the span.
+    if (line.first <= span_start && line.second >= span_start &&
+        line.first <= span_end && line.second >= span_end) {
+      const CodepointIndex last_line_begin_index =
+          std::distance(context_unicode.begin(), line.first);
+
+      std::string result =
+          context_unicode.UTF8Substring(line.first, line.second);
+      return {result, last_line_begin_index};
+    }
+  }
+  return {context, 0};
+}
+
+const char* const FeatureProcessor::kFeatureTypeName = "chargram_continuous";
+
+std::vector<Token> FeatureProcessor::Tokenize(
+    const std::string& utf8_text) const {
+  return tokenizer_.Tokenize(utf8_text);
+}
+
+bool FeatureProcessor::LabelToSpan(
+    const int label, const std::vector<Token>& tokens,
+    std::pair<CodepointIndex, CodepointIndex>* span) const {
+  if (tokens.size() != GetNumContextTokens()) {
+    return false;
+  }
+
+  TokenSpan token_span;
+  if (!LabelToTokenSpan(label, &token_span)) {
+    return false;
+  }
+
+  const int result_begin_token = token_span.first;
+  const int result_begin_codepoint =
+      tokens[options_.context_size() - result_begin_token].start;
+  const int result_end_token = token_span.second;
+  const int result_end_codepoint =
+      tokens[options_.context_size() + result_end_token].end;
+
+  if (result_begin_codepoint == kInvalidIndex ||
+      result_end_codepoint == kInvalidIndex) {
+    *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+  } else {
+    *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
+  }
+  return true;
+}
+
+bool FeatureProcessor::LabelToTokenSpan(const int label,
+                                        TokenSpan* token_span) const {
+  if (label >= 0 && label < label_to_selection_.size()) {
+    *token_span = label_to_selection_[label];
+    return true;
+  } else {
+    return false;
+  }
+}
+
+bool FeatureProcessor::SpanToLabel(
+    const std::pair<CodepointIndex, CodepointIndex>& span,
+    const std::vector<Token>& tokens, int* label) const {
+  if (tokens.size() != GetNumContextTokens()) {
+    return false;
+  }
+
+  const int click_position =
+      options_.context_size();  // Click is always in the middle.
+  const int padding = options_.context_size() - options_.max_selection_span();
+
+  int span_left = 0;
+  for (int i = click_position - 1; i >= padding; i--) {
+    if (tokens[i].start != kInvalidIndex && tokens[i].start >= span.first) {
+      ++span_left;
+    } else {
+      break;
+    }
+  }
+
+  int span_right = 0;
+  for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
+    if (tokens[i].end != kInvalidIndex && tokens[i].end <= span.second) {
+      ++span_right;
+    } else {
+      break;
+    }
+  }
+
+  // Check that the spanned tokens cover the whole span.
+  if (tokens[click_position - span_left].start == span.first &&
+      tokens[click_position + span_right].end == span.second) {
+    *label = TokenSpanToLabel({span_left, span_right});
+  } else {
+    *label = kInvalidLabel;
+  }
+
+  return true;
+}
+
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
+  auto it = selection_to_label_.find(span);
+  if (it != selection_to_label_.end()) {
+    return it->second;
+  } else {
+    return kInvalidLabel;
+  }
+}
+
+// Converts a codepoint span to a token span in the given list of tokens.
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+                                   CodepointSpan codepoint_span) {
+  const int codepoint_start = std::get<0>(codepoint_span);
+  const int codepoint_end = std::get<1>(codepoint_span);
+
+  TokenIndex start_token = kInvalidIndex;
+  TokenIndex end_token = kInvalidIndex;
+  for (int i = 0; i < selectable_tokens.size(); ++i) {
+    if (codepoint_start <= selectable_tokens[i].start &&
+        codepoint_end >= selectable_tokens[i].end) {
+      if (start_token == kInvalidIndex) {
+        start_token = i;
+      }
+      end_token = i + 1;
+    }
+  }
+  return {start_token, end_token};
+}
+
+namespace {
+
+// Finds a single token that completely contains the given span.
+int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
+                              CodepointSpan codepoint_span) {
+  const int codepoint_start = std::get<0>(codepoint_span);
+  const int codepoint_end = std::get<1>(codepoint_span);
+
+  for (int i = 0; i < selectable_tokens.size(); ++i) {
+    if (codepoint_start >= selectable_tokens[i].start &&
+        codepoint_end <= selectable_tokens[i].end) {
+      return i;
+    }
+  }
+  return kInvalidIndex;
+}
+
+// Find tokens that are part of the selection.
+// NOTE: Will select all tokens that somehow overlap with the selection.
+std::vector<Token> FindTokensInSelection(
+    const std::vector<Token>& selectable_tokens,
+    const SelectionWithContext& selection_with_context) {
+  std::vector<Token> tokens_in_selection;
+  for (const Token& token : selectable_tokens) {
+    const bool selection_start_in_token =
+        token.start <= selection_with_context.selection_start &&
+        token.end > selection_with_context.selection_start;
+
+    const bool selection_end_in_token =
+        token.start < selection_with_context.selection_end &&
+        token.end >= selection_with_context.selection_end;
+
+    if (selection_start_in_token || selection_end_in_token) {
+      tokens_in_selection.push_back(token);
+    }
+  }
+  return tokens_in_selection;
+}
+
+// Helper function to get the click position (in terms of index into a vector of
+// selectable tokens), given selectable tokens. If click position is given, it
+// will be used. If it is not given, but selection is given, the click will be
+// the middle token in the selection range.
+int GetClickPosition(const SelectionWithContext& selection_with_context,
+                     const std::vector<Token>& selectable_tokens) {
+  int click_pos = kInvalidIndex;
+  if (selection_with_context.GetClickSpan() !=
+      std::make_pair(kInvalidIndex, kInvalidIndex)) {
+    int range_begin;
+    int range_end;
+    std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
+        selectable_tokens, selection_with_context.GetClickSpan());
+
+    // If no exact match was found, try finding a token that completely contains
+    // the click span. This is useful e.g. when Android builds the selection
+    // using ICU tokenization, and ends up with only a portion of our space-
+    // separated token. E.g. for "(857)" Android would select "857".
+    if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
+      int token_index = FindTokenThatContainsSpan(
+          selectable_tokens, selection_with_context.GetClickSpan());
+      if (token_index != kInvalidIndex) {
+        range_begin = token_index;
+        range_end = token_index + 1;
+      }
+    }
+
+    // We only allow clicks that are exactly 1 selectable token.
+    if (range_end - range_begin == 1) {
+      click_pos = range_begin;
+    } else {
+      click_pos = kInvalidIndex;
+    }
+  } else if (selection_with_context.GetSelectionSpan() !=
+             std::make_pair(kInvalidIndex, kInvalidIndex)) {
+    int range_begin;
+    int range_end;
+    std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
+        selectable_tokens, selection_with_context.GetSelectionSpan());
+
+    // Center the clicked token in the selection range.
+    if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
+      click_pos = (range_begin + range_end - 1) / 2;
+    }
+  }
+
+  return click_pos;
+}
+
+}  // namespace
+
+CodepointSpan FeatureProcessor::ClickRandomTokenInSelection(
+    const SelectionWithContext& selection_with_context) const {
+  const std::vector<Token> tokens = Tokenize(selection_with_context.context);
+  const std::vector<Token> tokens_in_selection =
+      FindTokensInSelection(tokens, selection_with_context);
+
+  if (!tokens_in_selection.empty()) {
+    std::uniform_int_distribution<> selection_token_draw(
+        0, tokens_in_selection.size() - 1);
+    const int token_id = selection_token_draw(*random_);
+    return {tokens_in_selection[token_id].start,
+            tokens_in_selection[token_id].end};
+  } else {
+    return {kInvalidIndex, kInvalidIndex};
+  }
+}
+
+bool FeatureProcessor::GetFeaturesAndLabels(
+    const SelectionWithContext& selection_with_context,
+    std::vector<nlp_core::FeatureVector>* features,
+    std::vector<float>* extra_features,
+    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+    CodepointSpan* selection_codepoint_label, int* classification_label) const {
+  if (features == nullptr) {
+    return false;
+  }
+  *features =
+      std::vector<nlp_core::FeatureVector>(options_.context_size() * 2 + 1);
+
+  SelectionWithContext selection_normalized;
+  int normalization_shift;
+  if (options_.only_use_line_with_click()) {
+    std::tie(selection_normalized, normalization_shift) =
+        internal::ExtractLineWithClick(selection_with_context);
+  } else {
+    selection_normalized = selection_with_context;
+    normalization_shift = 0;
+  }
+
+  std::vector<Token> input_tokens = Tokenize(selection_normalized.context);
+  if (options_.split_tokens_on_selection_boundaries()) {
+    internal::SplitTokensOnSelectionBoundaries(
+        selection_with_context.GetSelectionSpan(), &input_tokens);
+  }
+  int click_pos = GetClickPosition(selection_normalized, input_tokens);
+  if (click_pos == kInvalidIndex) {
+    TC_LOG(ERROR) << "Could not extract click position.";
+    return false;
+  }
+
+  std::vector<Token> output_tokens;
+  bool status = ComputeFeatures(click_pos, input_tokens,
+                                selection_normalized.GetSelectionSpan(),
+                                features, extra_features, &output_tokens);
+  if (!status) {
+    TC_LOG(ERROR) << "Feature computation failed.";
+    return false;
+  }
+
+  if (selection_label != nullptr) {
+    status = SpanToLabel(selection_normalized.GetSelectionSpan(), output_tokens,
+                         selection_label);
+    if (!status) {
+      TC_LOG(ERROR) << "Could not convert selection span to label.";
+      return false;
+    }
+  }
+
+  if (selection_codepoint_label != nullptr) {
+    *selection_codepoint_label = selection_with_context.GetSelectionSpan();
+  }
+
+  if (selection_label_spans != nullptr) {
+    // If an input normalization was performed, we need to shift the tokens
+    // back to get the correct ranges in the original input.
+    if (normalization_shift) {
+      for (Token& token : output_tokens) {
+        if (token.start != kInvalidIndex && token.end != kInvalidIndex) {
+          token.start += normalization_shift;
+          token.end += normalization_shift;
+        }
+      }
+    }
+    for (int i = 0; i < label_to_selection_.size(); ++i) {
+      CodepointSpan span;
+      status = LabelToSpan(i, output_tokens, &span);
+      if (!status) {
+        TC_LOG(ERROR) << "Could not convert label to span: " << i;
+        return false;
+      }
+      selection_label_spans->push_back(span);
+    }
+  }
+
+  if (classification_label != nullptr) {
+    *classification_label = CollectionToLabel(selection_normalized.collection);
+  }
+
+  return true;
+}
+
+bool FeatureProcessor::GetFeaturesAndLabels(
+    const SelectionWithContext& selection_with_context,
+    std::vector<std::vector<std::pair<int, float>>>* features,
+    std::vector<float>* extra_features,
+    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+    CodepointSpan* selection_codepoint_label, int* classification_label) const {
+  if (features == nullptr) {
+    return false;
+  }
+  if (extra_features == nullptr) {
+    return false;
+  }
+
+  std::vector<nlp_core::FeatureVector> feature_vectors;
+  bool result = GetFeaturesAndLabels(selection_with_context, &feature_vectors,
+                                     extra_features, selection_label_spans,
+                                     selection_label, selection_codepoint_label,
+                                     classification_label);
+
+  if (!result) {
+    return false;
+  }
+
+  features->clear();
+  for (int i = 0; i < feature_vectors.size(); ++i) {
+    features->emplace_back();
+    for (int j = 0; j < feature_vectors[i].size(); ++j) {
+      nlp_core::FloatFeatureValue feature_value(feature_vectors[i].value(j));
+      (*features)[i].push_back({feature_value.id, feature_value.weight});
+    }
+  }
+
+  return true;
+}
+
+bool FeatureProcessor::ComputeFeatures(
+    int click_pos, const std::vector<Token>& tokens,
+    CodepointSpan selected_span, std::vector<nlp_core::FeatureVector>* features,
+    std::vector<float>* extra_features,
+    std::vector<Token>* output_tokens) const {
+  int dropout_left = 0;
+  int dropout_right = 0;
+  if (options_.context_dropout_probability() > 0.0) {
+    // Determine how much context to drop.
+    bool status = GetContextDropoutRange(&dropout_left, &dropout_right);
+    if (!status) {
+      return false;
+    }
+  }
+
+  int feature_index = 0;
+  output_tokens->reserve(options_.context_size() * 2 + 1);
+  const int num_extra_features =
+      static_cast<int>(options_.extract_case_feature()) +
+      static_cast<int>(options_.extract_selection_mask_feature());
+  extra_features->reserve((options_.context_size() * 2 + 1) *
+                          num_extra_features);
+  for (int i = click_pos - options_.context_size();
+       i <= click_pos + options_.context_size(); ++i, ++feature_index) {
+    std::vector<int> sparse_features;
+    std::vector<float> dense_features;
+
+    const bool is_valid_token = i >= 0 && i < tokens.size();
+
+    bool is_dropped = false;
+    if (options_.context_dropout_probability() > 0.0) {
+      if (i < click_pos - options_.context_size() + dropout_left) {
+        is_dropped = true;
+      } else if (i > click_pos + options_.context_size() - dropout_right) {
+        is_dropped = true;
+      }
+    }
+
+    if (is_valid_token && !is_dropped) {
+      Token token(tokens[i]);
+      token.is_in_span = token.start >= selected_span.first &&
+                         token.end <= selected_span.second;
+      feature_extractor_.Extract(token, &sparse_features, &dense_features);
+      output_tokens->push_back(tokens[i]);
+    } else {
+      feature_extractor_.Extract(Token(), &sparse_features, &dense_features);
+      // This adds an empty string for each missing context token to exactly
+      // match the input tokens to the network.
+      output_tokens->emplace_back();
+    }
+
+    for (int feature_id : sparse_features) {
+      const int64 feature_value =
+          nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
+              .discrete_value;
+      (*features)[feature_index].add(
+          const_cast<nlp_core::NumericFeatureType*>(&feature_type_),
+          feature_value);
+    }
+
+    for (float value : dense_features) {
+      extra_features->push_back(value);
+    }
+  }
+  return true;
+}
+
+int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
+  const auto it = collection_to_label_.find(collection);
+  if (it == collection_to_label_.end()) {
+    return options_.default_collection();
+  } else {
+    return it->second;
+  }
+}
+
+std::string FeatureProcessor::LabelToCollection(int label) const {
+  if (label >= 0 && label < collection_to_label_.size()) {
+    return options_.collections(label);
+  } else {
+    return GetDefaultCollection();
+  }
+}
+
+void FeatureProcessor::MakeLabelMaps() {
+  for (int i = 0; i < options_.collections().size(); ++i) {
+    collection_to_label_[options_.collections(i)] = i;
+  }
+
+  int selection_label_id = 0;
+  for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
+    for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
+      if (!options_.selection_reduced_output_space() ||
+          r + l <= options_.max_selection_span()) {
+        TokenSpan token_span{l, r};
+        selection_to_label_[token_span] = selection_label_id;
+        label_to_selection_.push_back(token_span);
+        ++selection_label_id;
+      }
+    }
+  }
+}
+
+bool FeatureProcessor::GetContextDropoutRange(int* dropout_left,
+                                              int* dropout_right) const {
+  std::uniform_real_distribution<> uniform01_draw(0, 1);
+  if (uniform01_draw(*random_) < options_.context_dropout_probability()) {
+    if (options_.use_variable_context_dropout()) {
+      std::uniform_int_distribution<> uniform_context_draw(
+          0, options_.context_size());
+      // Select how much to drop in the range: [zero; context size]
+      *dropout_left = uniform_context_draw(*random_);
+      *dropout_right = uniform_context_draw(*random_);
+    } else {
+      // Drop all context.
+      return false;
+    }
+  } else {
+    *dropout_left = 0;
+    *dropout_right = 0;
+  }
+  return true;
+}
+
+}  // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
new file mode 100644
index 0000000..5684cd7
--- /dev/null
+++ b/smartselect/feature-processor.h
@@ -0,0 +1,206 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature processing for FFModel (feed-forward SmartSelection model).
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+
+#include <memory>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "common/feature-extractor.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "smartselect/token-feature-extractor.h"
+#include "smartselect/tokenizer.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+constexpr int kInvalidLabel = -1;
+
+namespace internal {
+
+// Parses the serialized protocol buffer.
+FeatureProcessorOptions ParseSerializedOptions(
+    const std::string& serialized_options);
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+    const FeatureProcessorOptions& options);
+
+// Returns a modified version of the selection_with_context, such that only the
+// line that contains the clicked span is kept, the number of codepoints
+// the selection was moved by.
+std::pair<SelectionWithContext, int> ExtractLineWithClick(
+    const SelectionWithContext& selection_with_context);
+
+// Splits tokens that contain the selection boundary inside them.
+// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+                                      std::vector<Token>* tokens);
+
+}  // namespace internal
+
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+                                   CodepointSpan codepoint_span);
+
+// Returns a modified version of the context string, such that only the
+// line that contains the span is kept. Also returns a codepoint shift
+// size that happend. If the span spans multiple lines, returns the original
+// input with zero shift.
+// The following characters are considered to be line separators: '\n', '|'
+std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
+                                                CodepointSpan span);
+
+// Takes care of preparing features for the FFModel.
+class FeatureProcessor {
+ public:
+  explicit FeatureProcessor(const FeatureProcessorOptions& options)
+      : options_(options),
+        feature_extractor_(
+            internal::BuildTokenFeatureExtractorOptions(options)),
+        feature_type_(FeatureProcessor::kFeatureTypeName,
+                      options.num_buckets()),
+        tokenizer_({options.tokenization_codepoint_config().begin(),
+                    options.tokenization_codepoint_config().end()}),
+        random_(new std::mt19937(std::random_device()())) {
+    MakeLabelMaps();
+  }
+
+  explicit FeatureProcessor(const std::string& serialized_options)
+      : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
+  }
+
+  CodepointSpan ClickRandomTokenInSelection(
+      const SelectionWithContext& selection_with_context) const;
+
+  // Tokenizes the input string using the selected tokenization method.
+  std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+  // NOTE: If dropout is on, subsequent calls of this function with the same
+  // arguments might return different results.
+  bool GetFeaturesAndLabels(const SelectionWithContext& selection_with_context,
+                            std::vector<nlp_core::FeatureVector>* features,
+                            std::vector<float>* extra_features,
+                            std::vector<CodepointSpan>* selection_label_spans,
+                            int* selection_label,
+                            CodepointSpan* selection_codepoint_label,
+                            int* classification_label) const;
+
+  // Same as above but uses std::vector instead of FeatureVector.
+  // NOTE: If dropout is on, subsequent calls of this function with the same
+  // arguments might return different results.
+  bool GetFeaturesAndLabels(
+      const SelectionWithContext& selection_with_context,
+      std::vector<std::vector<std::pair<int, float>>>* features,
+      std::vector<float>* extra_features,
+      std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+      CodepointSpan* selection_codepoint_label,
+      int* classification_label) const;
+
+  // Converts a label into a token span.
+  bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+
+  // Gets the string value for given collection label.
+  std::string LabelToCollection(int label) const;
+
+  // Gets the total number of collections of the model.
+  int NumCollections() const { return collection_to_label_.size(); }
+
+  // Gets the name of the default collection.
+  std::string GetDefaultCollection() const {
+    return options_.collections(options_.default_collection());
+  }
+
+  FeatureProcessorOptions GetOptions() const { return options_; }
+
+  int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
+  // Sets the source of randomness.
+  void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
+
+ protected:
+  // Extracts features for given word.
+  std::vector<int> GetWordFeatures(const std::string& word) const;
+
+  // NOTE: If dropout is on, subsequent calls of this function with the same
+  // arguments might return different results.
+  bool ComputeFeatures(int click_pos,
+                       const std::vector<Token>& selectable_tokens,
+                       CodepointSpan selected_span,
+                       std::vector<nlp_core::FeatureVector>* features,
+                       std::vector<float>* extra_features,
+                       std::vector<Token>* output_tokens) const;
+
+  // Helper function that computes how much left context and how much right
+  // context should be dropped. Uses a mutable random_ member as a source of
+  // randomness.
+  bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
+
+  // Returns the class id corresponding to the given string collection
+  // identifier. There is a catch-all class id that the function returns for
+  // unknown collections.
+  int CollectionToLabel(const std::string& collection) const;
+
+  // Prepares mapping from collection names to labels.
+  void MakeLabelMaps();
+
+  // Gets the number of spannable tokens for the model.
+  //
+  // Spannable tokens are those tokens of context, which the model predicts
+  // selection spans over (i.e., there is 1:1 correspondence between the output
+  // classes of the model and each of the spannable tokens).
+  int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
+
+  // Converts a label into a span of codepoint indices corresponding to it
+  // given output_tokens.
+  bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
+                   CodepointSpan* span) const;
+
+  // Converts a span to the corresponding label given output_tokens.
+  bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+                   const std::vector<Token>& output_tokens, int* label) const;
+
+  // Converts a token span to the corresponding label.
+  int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+
+ private:
+  FeatureProcessorOptions options_;
+
+  TokenFeatureExtractor feature_extractor_;
+
+  static const char* const kFeatureTypeName;
+
+  nlp_core::NumericFeatureType feature_type_;
+
+  // Mapping between token selection spans and labels ids.
+  std::map<TokenSpan, int> selection_to_label_;
+  std::vector<TokenSpan> label_to_selection_;
+
+  // Mapping between collections and labels.
+  std::map<std::string, int> collection_to_label_;
+
+  Tokenizer tokenizer_;
+
+  // Source of randomness.
+  mutable std::unique_ptr<std::mt19937> random_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
new file mode 100644
index 0000000..fc5a0ab
--- /dev/null
+++ b/smartselect/text-classification-model.cc
@@ -0,0 +1,441 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/text-classification-model.h"
+
+#include <cmath>
+#include <iterator>
+#include <numeric>
+
+#include "common/embedding-network.h"
+#include "common/feature-extractor.h"
+#include "common/memory_image/embedding-network-params-from-image.h"
+#include "common/memory_image/memory-image-reader.h"
+#include "common/mmap.h"
+#include "common/softmax.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "util/base/logging.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+using nlp_core::EmbeddingNetwork;
+using nlp_core::EmbeddingNetworkProto;
+using nlp_core::FeatureVector;
+using nlp_core::MemoryImageReader;
+using nlp_core::MmapFile;
+using nlp_core::MmapHandle;
+
+ModelParams* ModelParams::Build(const void* start, uint64 num_bytes) {
+  MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
+
+  FeatureProcessorOptions feature_processor_options;
+  auto feature_processor_extension_id =
+      feature_processor_options_in_embedding_network_proto;
+  if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
+    feature_processor_options =
+        reader.trimmed_proto().GetExtension(feature_processor_extension_id);
+
+    // If no tokenization codepoint config is present, tokenize on space.
+    if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
+      TokenizationCodepointRange* config =
+          feature_processor_options.add_tokenization_codepoint_config();
+      config->set_start(32);
+      config->set_end(33);
+      config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+    }
+  } else {
+    return nullptr;
+  }
+
+  SelectionModelOptions selection_options;
+  auto selection_options_extension_id =
+      selection_model_options_in_embedding_network_proto;
+  if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
+    selection_options =
+        reader.trimmed_proto().GetExtension(selection_options_extension_id);
+  } else {
+    // TODO(zilka): Remove this once we added the model options to the exported
+    // models.
+    for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
+             {{33, 35},       {37, 39},       {42, 42},       {44, 47},
+              {58, 59},       {63, 64},       {91, 93},       {95, 95},
+              {123, 123},     {125, 125},     {161, 161},     {171, 171},
+              {183, 183},     {187, 187},     {191, 191},     {894, 894},
+              {903, 903},     {1370, 1375},   {1417, 1418},   {1470, 1470},
+              {1472, 1472},   {1475, 1475},   {1478, 1478},   {1523, 1524},
+              {1548, 1549},   {1563, 1563},   {1566, 1567},   {1642, 1645},
+              {1748, 1748},   {1792, 1805},   {2404, 2405},   {2416, 2416},
+              {3572, 3572},   {3663, 3663},   {3674, 3675},   {3844, 3858},
+              {3898, 3901},   {3973, 3973},   {4048, 4049},   {4170, 4175},
+              {4347, 4347},   {4961, 4968},   {5741, 5742},   {5787, 5788},
+              {5867, 5869},   {5941, 5942},   {6100, 6102},   {6104, 6106},
+              {6144, 6154},   {6468, 6469},   {6622, 6623},   {6686, 6687},
+              {8208, 8231},   {8240, 8259},   {8261, 8273},   {8275, 8286},
+              {8317, 8318},   {8333, 8334},   {9001, 9002},   {9140, 9142},
+              {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
+              {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
+              {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
+              {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
+              {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
+              {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
+              {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
+              {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
+              {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
+              {68176, 68184}})) {
+      for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
+        selection_options.add_punctuation_to_strip(i);
+      }
+      selection_options.set_strip_punctuation(true);
+      selection_options.set_enforce_symmetry(true);
+      selection_options.set_symmetry_context_size(
+          feature_processor_options.context_size() * 2);
+    }
+  }
+
+  return new ModelParams(start, num_bytes, selection_options,
+                         feature_processor_options);
+}
+
+CodepointSpan TextClassificationModel::StripPunctuation(
+    CodepointSpan selection, const std::string& context) const {
+  UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
+  int context_length =
+      std::distance(context_unicode.begin(), context_unicode.end());
+
+  // Check that the indices are valid.
+  if (selection.first < 0 || selection.first > context_length ||
+      selection.second < 0 || selection.second > context_length) {
+    return selection;
+  }
+
+  UnicodeText::const_iterator it;
+  for (it = context_unicode.begin(), std::advance(it, selection.first);
+       punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
+       ++it, ++selection.first) {
+  }
+
+  for (it = context_unicode.begin(), std::advance(it, selection.second - 1);
+       punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
+       --it, --selection.second) {
+  }
+
+  return selection;
+}
+
+TextClassificationModel::TextClassificationModel(int fd) {
+  initialized_ = LoadModels(fd);
+  if (!initialized_) {
+    TC_LOG(ERROR) << "Failed to load models";
+    return;
+  }
+
+  selection_options_ = selection_params_->GetSelectionModelOptions();
+  for (const int codepoint : selection_options_.punctuation_to_strip()) {
+    punctuation_to_strip_.insert(codepoint);
+  }
+}
+
+bool TextClassificationModel::LoadModels(int fd) {
+  MmapHandle mmap_handle = MmapFile(fd);
+  if (!mmap_handle.ok()) {
+    return false;
+  }
+
+  // Read the length of the selection model.
+  const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
+  uint32 selection_model_length =
+      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+  model_data += sizeof(selection_model_length);
+
+  selection_params_.reset(
+      ModelParams::Build(model_data, selection_model_length));
+  if (!selection_params_.get()) {
+    return false;
+  }
+  selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
+  selection_feature_processor_.reset(
+      new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
+
+  model_data += selection_model_length;
+  uint32 sharing_model_length =
+      LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+  model_data += sizeof(sharing_model_length);
+  sharing_params_.reset(ModelParams::Build(model_data, sharing_model_length));
+  if (!sharing_params_.get()) {
+    return false;
+  }
+  sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
+  sharing_feature_processor_.reset(
+      new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
+
+  return true;
+}
+
+EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
+    const std::string& context, CodepointSpan click_indices,
+    CodepointSpan selection_indices, const FeatureProcessor& feature_processor,
+    const EmbeddingNetwork* network,
+    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+    CodepointSpan* selection_codepoint_label, int* classification_label) const {
+  SelectionWithContext selection_with_context;
+  selection_with_context.context = context;
+  selection_with_context.click_start = std::get<0>(click_indices);
+  selection_with_context.click_end = std::get<1>(click_indices);
+  selection_with_context.selection_start = std::get<0>(selection_indices);
+  selection_with_context.selection_end = std::get<1>(selection_indices);
+
+  std::vector<FeatureVector> features;
+  std::vector<float> extra_features;
+  const bool features_computed = feature_processor.GetFeaturesAndLabels(
+      selection_with_context, &features, &extra_features, selection_label_spans,
+      selection_label, selection_codepoint_label, classification_label);
+
+  EmbeddingNetwork::Vector scores;
+  if (!features_computed) {
+    TC_LOG(ERROR) << "Features not computed";
+    return scores;
+  }
+  network->ComputeFinalScores(features, extra_features, &scores);
+  return scores;
+}
+
+CodepointSpan TextClassificationModel::SuggestSelection(
+    const std::string& context, CodepointSpan click_indices) const {
+  if (!initialized_) {
+    TC_LOG(ERROR) << "Not initialized";
+    return click_indices;
+  }
+
+  if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
+    TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
+                  << std::get<0>(click_indices) << " "
+                  << std::get<1>(click_indices);
+    return click_indices;
+  }
+
+  CodepointSpan result;
+  if (selection_options_.enforce_symmetry()) {
+    result = SuggestSelectionSymmetrical(context, click_indices);
+  } else {
+    float score;
+    std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
+  }
+
+  if (selection_options_.strip_punctuation()) {
+    result = StripPunctuation(result, context);
+  }
+
+  return result;
+}
+
+std::pair<CodepointSpan, float>
+TextClassificationModel::SuggestSelectionInternal(
+    const std::string& context, CodepointSpan click_indices) const {
+  if (!initialized_) {
+    TC_LOG(ERROR) << "Not initialized";
+    return {click_indices, -1.0};
+  }
+
+  // Invalid selection indices make the feature extraction use the provided
+  // click indices.
+  const CodepointSpan selection_indices({kInvalidIndex, kInvalidIndex});
+
+  std::vector<CodepointSpan> selection_label_spans;
+  EmbeddingNetwork::Vector scores = InferInternal(
+      context, click_indices, selection_indices, *selection_feature_processor_,
+      selection_network_.get(), &selection_label_spans,
+      /*selection_label=*/nullptr,
+      /*selection_codepoint_label=*/nullptr,
+      /*classification_label=*/nullptr);
+
+  if (!scores.empty()) {
+    scores = nlp_core::ComputeSoftmax(scores);
+    const int prediction =
+        std::max_element(scores.begin(), scores.end()) - scores.begin();
+    std::pair<CodepointIndex, CodepointIndex> selection =
+        selection_label_spans[prediction];
+
+    if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
+      TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
+                    << prediction << " " << selection.first << " "
+                    << selection.second;
+      return {click_indices, -1.0};
+    }
+
+    return {{selection.first, selection.second}, scores[prediction]};
+  } else {
+    TC_LOG(ERROR) << "Returning default selection: scores.size() = "
+                  << scores.size();
+    return {click_indices, -1.0};
+  }
+}
+
+namespace {
+
+int GetClickTokenIndex(const std::vector<Token>& tokens,
+                       CodepointSpan click_indices) {
+  TokenSpan span = CodepointSpanToTokenSpan(tokens, click_indices);
+  if (span.second - span.first == 1) {
+    return span.first;
+  } else {
+    for (int i = 0; i < tokens.size(); i++) {
+      if (tokens[i].start <= click_indices.first &&
+          tokens[i].end >= click_indices.second) {
+        return i;
+      }
+    }
+    return kInvalidIndex;
+  }
+}
+
+}  // namespace
+
+// Implements a greedy-search-like algorithm for making selections symmetric.
+//
+// Steps:
+// 1. Get a set of selection proposals from places around the clicked word.
+// 2. For each proposal (going from highest-scoring), check if the tokens that
+//    the proposal selects are still free, otherwise claims them, if a proposal
+//    that contains the clicked token is found, it is returned as the
+//    suggestion.
+//
+// This algorithm should ensure that if a selection is proposed, it does not
+// matter which word of it was tapped - all of them will lead to the same
+// selection.
+CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
+    const std::string& full_context, CodepointSpan click_indices) const {
+  // Extract context from the current line only.
+  std::string context;
+  int context_shift;
+  std::tie(context, context_shift) =
+      ExtractLineWithSpan(full_context, click_indices);
+  click_indices.first -= context_shift;
+  click_indices.second -= context_shift;
+
+  std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
+
+  const int click_index = GetClickTokenIndex(tokens, click_indices);
+  if (click_index == kInvalidIndex) {
+    return click_indices;
+  }
+
+  const int symmetry_context_size = selection_options_.symmetry_context_size();
+
+  // Scan in the symmetry context for selection span proposals.
+  std::vector<std::pair<CodepointSpan, float>> proposals;
+  for (int i = -symmetry_context_size; i < symmetry_context_size + 1; i++) {
+    const int token_index = click_index + i;
+    if (token_index >= 0 && token_index < tokens.size()) {
+      float score;
+      CodepointSpan span;
+      std::tie(span, score) = SuggestSelectionInternal(
+          context, {tokens[token_index].start, tokens[token_index].end});
+      proposals.push_back({span, score});
+    }
+  }
+
+  // Sort selection span proposals by their respective probabilities.
+  std::sort(
+      proposals.begin(), proposals.end(),
+      [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
+        return a.second > b.second;
+      });
+
+  // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
+  // claimed by the higher-scoring selection proposals, so that the
+  // lower-scoring ones cannot use them. Returns the selection proposal if it
+  // contains the clicked token.
+  std::vector<int> used_tokens(tokens.size(), 0);
+  for (auto span_result : proposals) {
+    TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
+    if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
+      bool feasible = true;
+      for (int i = span.first; i < span.second; i++) {
+        if (used_tokens[i] != 0) {
+          feasible = false;
+          break;
+        }
+      }
+
+      if (feasible) {
+        if (span.first <= click_index && span.second > click_index) {
+          return {span_result.first.first + context_shift,
+                  span_result.first.second + context_shift};
+        }
+        for (int i = span.first; i < span.second; i++) {
+          used_tokens[i] = 1;
+        }
+      }
+    }
+  }
+
+  return {click_indices.first + context_shift,
+          click_indices.second + context_shift};
+}
+
+CodepointSpan TextClassificationModel::SuggestSelection(
+    const SelectionWithContext& selection_with_context) const {
+  CodepointSpan click_indices = {selection_with_context.click_start,
+                                 selection_with_context.click_end};
+
+  // If click_indices are unspecified, select the first token.
+  if (click_indices == CodepointSpan({kInvalidIndex, kInvalidIndex})) {
+    click_indices = selection_feature_processor_->ClickRandomTokenInSelection(
+        selection_with_context);
+  }
+
+  return SuggestSelection(selection_with_context.context, click_indices);
+}
+
+std::string TextClassificationModel::ClassifyText(
+    const std::string& context, CodepointSpan selection_indices) const {
+  if (!initialized_) {
+    TC_LOG(ERROR) << "Not initialized";
+    return sharing_feature_processor_->GetDefaultCollection();
+  }
+
+  // Invalid click indices make the feature extraction select the middle word in
+  // the selection span.
+  const CodepointSpan click_indices({kInvalidIndex, kInvalidIndex});
+
+  EmbeddingNetwork::Vector scores =
+      InferInternal(context, click_indices, selection_indices,
+                    *sharing_feature_processor_, sharing_network_.get(),
+                    /*selection_label_spans=*/nullptr,
+                    /*selection_label=*/nullptr,
+                    /*selection_codepoint_label=*/nullptr,
+                    /*classification_label=*/nullptr);
+  if (scores.empty()) {
+    TC_LOG(ERROR) << "Using default class";
+    return sharing_feature_processor_->GetDefaultCollection();
+  }
+  if (!scores.empty() &&
+      scores.size() == sharing_feature_processor_->NumCollections()) {
+    const int prediction =
+        std::max_element(scores.begin(), scores.end()) - scores.begin();
+
+    // Convert to a class name.
+    const std::string class_name =
+        sharing_feature_processor_->LabelToCollection(prediction);
+    return class_name;
+  } else {
+    TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
+    return sharing_feature_processor_->GetDefaultCollection();
+  }
+}
+
+}  // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
new file mode 100644
index 0000000..4e0a3da
--- /dev/null
+++ b/smartselect/text-classification-model.h
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Inference code for the feed-forward text classification models.
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
+
+#include <memory>
+#include <set>
+#include <string>
+
+#include "base.h"
+#include "common/embedding-network.h"
+#include "common/feature-extractor.h"
+#include "common/memory_image/embedding-network-params-from-image.h"
+#include "smartselect/feature-processor.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+// Loads and holds the parameters of the inference network.
+//
+// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
+// because we only have one embedding matrix for all positions of context,
+// whereas the original class would have a separate one for each.
+class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
+ public:
+  static ModelParams* Build(const void* start, uint64 num_bytes);
+
+  const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
+    return feature_processor_options_;
+  }
+
+  const SelectionModelOptions& GetSelectionModelOptions() const {
+    return selection_options_;
+  }
+
+ protected:
+  int embeddings_size() const override { return context_size_ * 2 + 1; }
+
+  int embedding_num_features_size() const override {
+    return context_size_ * 2 + 1;
+  }
+
+  int embedding_num_features(int i) const override { return 1; }
+
+  int embeddings_num_rows(int i) const override {
+    return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
+  };
+
+  int embeddings_num_cols(int i) const override {
+    return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
+  };
+
+  const void* embeddings_weights(int i) const override {
+    return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
+  };
+
+  nlp_core::QuantizationType embeddings_quant_type(int i) const override {
+    return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
+  }
+
+  const nlp_core::float16* embeddings_quant_scales(int i) const override {
+    return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
+  }
+
+ private:
+  ModelParams(const void* start, uint64 num_bytes,
+              const SelectionModelOptions& selection_options,
+              const FeatureProcessorOptions& feature_processor_options)
+      : EmbeddingNetworkParamsFromImage(start, num_bytes),
+        selection_options_(selection_options),
+        feature_processor_options_(feature_processor_options),
+        context_size_(feature_processor_options_.context_size()) {}
+
+  SelectionModelOptions selection_options_;
+  FeatureProcessorOptions feature_processor_options_;
+  int context_size_;
+};
+
+// SmartSelection/Sharing feed-forward model.
+class TextClassificationModel {
+ public:
+  // Loads TextClassificationModel from given file given by an int
+  // file descriptor.
+  explicit TextClassificationModel(int fd);
+
+  // Runs inference for given a context and current selection (i.e. index
+  // of the first and one past last selected characters (utf8 codepoint
+  // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
+  // beginning character and one past selection end character.
+  // Returns the original click_indices if an error occurs.
+  // NOTE: The selection indices are passed in and returned in terms of
+  // UTF8 codepoints (not bytes).
+  // Requires that the model is a smart selection model.
+  CodepointSpan SuggestSelection(const std::string& context,
+                                 CodepointSpan click_indices) const;
+
+  // Same as above but accepts a selection_with_context. Only used for
+  // evaluation.
+  CodepointSpan SuggestSelection(
+      const SelectionWithContext& selection_with_context) const;
+
+  // Classifies the selected text given the context string.
+  // Requires that the model is a smart sharing model.
+  // Returns a default collection name if an error occurs.
+  std::string ClassifyText(const std::string& context,
+                           CodepointSpan click_indices) const;
+
+ protected:
+  // Removes punctuation from the beginning and end of the selection and returns
+  // the new selection span.
+  CodepointSpan StripPunctuation(CodepointSpan selection,
+                                 const std::string& context) const;
+
+  // During evaluation we need access to the feature processor.
+  FeatureProcessor* SelectionFeatureProcessor() const {
+    return selection_feature_processor_.get();
+  }
+
+ private:
+  bool LoadModels(int fd);
+
+  nlp_core::EmbeddingNetwork::Vector InferInternal(
+      const std::string& context, CodepointSpan click_indices,
+      CodepointSpan selection_indices,
+      const FeatureProcessor& feature_processor,
+      const nlp_core::EmbeddingNetwork* network,
+      std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+      CodepointSpan* selection_codepoint_label,
+      int* classification_label) const;
+
+  // Returns a selection suggestion with a score.
+  std::pair<CodepointSpan, float> SuggestSelectionInternal(
+      const std::string& context, CodepointSpan click_indices) const;
+
+  // Returns a selection suggestion and makes sure it's symmetric. Internally
+  // runs several times SuggestSelectionInternal.
+  CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
+                                            CodepointSpan click_indices) const;
+
+  bool initialized_;
+  std::unique_ptr<ModelParams> selection_params_;
+  std::unique_ptr<FeatureProcessor> selection_feature_processor_;
+  std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
+  std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
+  std::unique_ptr<ModelParams> sharing_params_;
+  std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
+
+  SelectionModelOptions selection_options_;
+  std::set<int> punctuation_to_strip_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
new file mode 100644
index 0000000..e098e81
--- /dev/null
+++ b/smartselect/text-classification-model.proto
@@ -0,0 +1,108 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Text classification model configuration.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+
+import "external/libtextclassifier/common/embedding-network.proto";
+import "external/libtextclassifier/smartselect/tokenizer.proto";
+
+package libtextclassifier;
+
+message SelectionModelOptions {
+  // A list of Unicode codepoints to strip from predicted selections.
+  repeated int32 punctuation_to_strip = 1;
+
+  // Whether to strip punctuation after the selection is made.
+  optional bool strip_punctuation = 2;
+
+  // Enforce symmetrical selections.
+  optional bool enforce_symmetry = 3;
+
+  // Number of inferences made around the click position (to one side), for
+  // enforcing symmetry.
+  optional int32 symmetry_context_size = 4;
+}
+
+message FeatureProcessorOptions {
+  // Number of buckets used for hashing charactergrams.
+  optional int32 num_buckets = 1 [default = -1];
+
+  // Context size defines the number of words to the left and to the right of
+  // the selected word to be used as context. For example, if context size is
+  // N, then we take N words to the left and N words to the right of the
+  // selected word as its context.
+  optional int32 context_size = 2 [default = -1];
+
+  // Maximum number of words of the context to select in total.
+  optional int32 max_selection_span = 3 [default = -1];
+
+  // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+  // character trigrams etc.
+  repeated int32 chargram_orders = 4;
+
+  // Whether to extract the token case feature.
+  optional bool extract_case_feature = 5 [default = false];
+
+  // Whether to extract the selection mask feature.
+  optional bool extract_selection_mask_feature = 6 [default = false];
+
+  // If true, tokenize on space, otherwise tokenize using ICU.
+  optional bool tokenize_on_space = 7 [default = true];
+
+  // If true, the selection classifier output will contain only the selections
+  // that are feasible (e.g., those that are shorter than max_selection_span),
+  // if false, the output will be a complete cross-product of possible
+  // selections to the left and posible selections to the right, including the
+  // infeasible ones.
+  // NOTE: Exists mainly for compatibility with older models that were trained
+  // with the non-reduced output space.
+  optional bool selection_reduced_output_space = 8 [default = true];
+
+  // Collection names.
+  repeated string collections = 9;
+
+  // An index of collection in collections to be used if a collection name can't
+  // be mapped to an id.
+  optional int32 default_collection = 10 [default = -1];
+
+  // Probability with which to drop context of examples.
+  optional float context_dropout_probability = 11 [default = 0.0];
+
+  // If true, drop variable amounts of context, if false all context, with
+  // probability given by context_dropout_ratio.
+  optional bool use_variable_context_dropout = 12 [default = false];
+
+  // If true, will split the input by lines, and only use the line that contains
+  // the clicked token.
+  optional bool only_use_line_with_click = 13 [default = false];
+
+  // If true, will split tokens that contain the selection boundary, at the
+  // position of the boundary.
+  // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+  optional bool split_tokens_on_selection_boundaries = 14 [default = false];
+
+  // Codepoint ranges that determine how different codepoints are tokenized.
+  // The ranges must not overlap.
+  repeated TokenizationCodepointRange tokenization_codepoint_config = 15;
+};
+
+extend nlp_core.EmbeddingNetworkProto {
+  optional FeatureProcessorOptions
+      feature_processor_options_in_embedding_network_proto = 146230910;
+  optional SelectionModelOptions
+      selection_model_options_in_embedding_network_proto = 148190899;
+}
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
new file mode 100644
index 0000000..a6bbf4f
--- /dev/null
+++ b/smartselect/token-feature-extractor.cc
@@ -0,0 +1,119 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/token-feature-extractor.h"
+
+#include "util/hash/farmhash.h"
+
+namespace libtextclassifier {
+
+constexpr int kMaxWordLength = 20;  // All words will be trimmed to this length.
+
+int TokenFeatureExtractor::HashToken(const std::string& token) const {
+  return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
+    const Token& token) const {
+  std::vector<int> result;
+  if (token.is_padding) {
+    result.push_back(HashToken("<PAD>"));
+  } else {
+    const std::string& word = token.value;
+    std::string feature_word;
+
+    // Trim words that are over kMaxWordLength characters.
+    if (word.size() > kMaxWordLength) {
+      feature_word =
+          "^" + word.substr(0, kMaxWordLength / 2) + "\1" +
+          word.substr(word.size() - kMaxWordLength / 2, kMaxWordLength / 2) +
+          "$";
+    } else {
+      // Add a prefix and suffix to the word.
+      feature_word = "^" + word + "$";
+    }
+
+    // Upper-bound the number of charactergram extracted to avoid resizing.
+    result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+    // Generate the character-grams.
+    for (int chargram_order : options_.chargram_orders) {
+      if (chargram_order == 1) {
+        for (int i = 1; i < feature_word.size() - 1; ++i) {
+          result.push_back(HashToken(feature_word.substr(i, 1)));
+        }
+      } else {
+        for (int i = 0;
+             i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+             ++i) {
+          result.push_back(HashToken(feature_word.substr(i, chargram_order)));
+        }
+      }
+    }
+  }
+  return result;
+}
+
+bool TokenFeatureExtractor::Extract(const Token& token,
+                                    std::vector<int>* sparse_features,
+                                    std::vector<float>* dense_features) const {
+  if (sparse_features == nullptr || dense_features == nullptr) {
+    return false;
+  }
+
+  *sparse_features = ExtractCharactergramFeatures(token);
+
+  if (options_.extract_case_feature) {
+    // TODO(zilka): Make isupper Unicode-aware.
+    if (!token.value.empty() && isupper(*token.value.begin())) {
+      dense_features->push_back(1.0);
+    } else {
+      dense_features->push_back(-1.0);
+    }
+  }
+
+  if (options_.extract_selection_mask_feature) {
+    if (token.is_in_span) {
+      dense_features->push_back(1.0);
+    } else {
+      // TODO(zilka): Switch to -1. Doing this now would break current models.
+      dense_features->push_back(0.0);
+    }
+  }
+
+  return true;
+}
+
+bool TokenFeatureExtractor::Extract(
+    const std::vector<Token>& tokens,
+    std::vector<std::vector<int>>* sparse_features,
+    std::vector<std::vector<float>>* dense_features) const {
+  if (sparse_features == nullptr || dense_features == nullptr) {
+    return false;
+  }
+
+  sparse_features->resize(tokens.size());
+  dense_features->resize(tokens.size());
+  for (size_t i = 0; i < tokens.size(); i++) {
+    if (!Extract(tokens[i], &((*sparse_features)[i]),
+                 &((*dense_features)[i]))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+}  // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
new file mode 100644
index 0000000..9ba695e
--- /dev/null
+++ b/smartselect/token-feature-extractor.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+
+#include <vector>
+
+#include "base.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+struct TokenFeatureExtractorOptions {
+  // Number of buckets used for hashing charactergrams.
+  int num_buckets;
+
+  // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+  // character trigrams etc.
+  std::vector<int> chargram_orders;
+
+  // Whether to extract the token case feature.
+  bool extract_case_feature;
+
+  // Whether to extract the selection mask feature.
+  bool extract_selection_mask_feature;
+};
+
+class TokenFeatureExtractor {
+ public:
+  explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options)
+      : options_(options) {}
+
+  // Extracts features from a token.
+  //  - sparse_features are indices into a sparse feature vector of size
+  //    options.num_buckets which are set to 1.0 (others are implicitly 0.0).
+  //  - dense_features are values of a dense feature vector of size 0-2
+  //    (depending on the options) for the token
+  bool Extract(const Token& token, std::vector<int>* sparse_features,
+               std::vector<float>* dense_features) const;
+
+  // Convenience method that sequentially applies Extract to each Token.
+  bool Extract(const std::vector<Token>& tokens,
+               std::vector<std::vector<int>>* sparse_features,
+               std::vector<std::vector<float>>* dense_features) const;
+
+ protected:
+  // Hashes given token to given number of buckets.
+  int HashToken(const std::string& token) const;
+
+  // Extracts the charactergram features from the token.
+  std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
+
+ private:
+  TokenFeatureExtractorOptions options_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
new file mode 100644
index 0000000..72bb668
--- /dev/null
+++ b/smartselect/tokenizer.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/tokenizer.h"
+
+#include "util/strings/utf8.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+void Tokenizer::PrepareTokenizationCodepointRanges(
+    const std::vector<TokenizationCodepointRange> codepoint_range_configs) {
+  codepoint_ranges_.clear();
+  codepoint_ranges_.reserve(codepoint_range_configs.size());
+  for (const TokenizationCodepointRange& range : codepoint_range_configs) {
+    codepoint_ranges_.push_back(
+        CodepointRange(range.start(), range.end(), range.role()));
+  }
+
+  std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+            [](const CodepointRange& a, const CodepointRange& b) {
+              return a.start < b.start;
+            });
+}
+
+TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
+    int codepoint) const {
+  auto it = std::lower_bound(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+                             codepoint,
+                             [](const CodepointRange& range, int codepoint) {
+                               // This function compares range with the
+                               // codepoint for the purpose of finding the first
+                               // greater or equal range. Because of the use of
+                               // std::lower_bound it needs to return true when
+                               // range < codepoint; the first time it will
+                               // return false the lower bound is found and
+                               // returned.
+                               //
+                               // It might seem weird that the condition is
+                               // range < codepoint here but when codepoint ==
+                               // range.end it means it's actually just outside
+                               // of the range, thus the range is less than the
+                               // codepoint.
+                               return range.end <= codepoint;
+                             });
+  if (it != codepoint_ranges_.end() && it->start <= codepoint &&
+      it->end > codepoint) {
+    return it->role;
+  } else {
+    return TokenizationCodepointRange::DEFAULT_ROLE;
+  }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const {
+  UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false);
+
+  std::vector<Token> result;
+  Token new_token("", 0, 0);
+  int codepoint_index = 0;
+  for (auto it = context_unicode.begin(); it != context_unicode.end();
+       ++it, ++codepoint_index) {
+    TokenizationCodepointRange::Role role = FindTokenizationRole(*it);
+    if (role & TokenizationCodepointRange::SPLIT_BEFORE) {
+      if (!new_token.value.empty()) {
+        result.push_back(new_token);
+      }
+      new_token = Token("", codepoint_index, codepoint_index);
+    }
+    if (!(role & TokenizationCodepointRange::DISCARD_CODEPOINT)) {
+      new_token.value += std::string(
+          it.utf8_data(),
+          it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
+      ++new_token.end;
+    }
+    if (role & TokenizationCodepointRange::SPLIT_AFTER) {
+      if (!new_token.value.empty()) {
+        result.push_back(new_token);
+      }
+      new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+    }
+  }
+  if (!new_token.value.empty()) {
+    result.push_back(new_token);
+  }
+
+  return result;
+}
+
+}  // namespace libtextclassifier
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
new file mode 100644
index 0000000..9ed152f
--- /dev/null
+++ b/smartselect/tokenizer.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Tokenizer.
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "smartselect/tokenizer.pb.h"
+#include "smartselect/types.h"
+#include "util/base/integral_types.h"
+
+namespace libtextclassifier {
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+struct CodepointRange {
+  int32 start;
+  int32 end;
+  TokenizationCodepointRange::Role role;
+
+  CodepointRange(int32 arg_start, int32 arg_end,
+                 TokenizationCodepointRange::Role arg_role)
+      : start(arg_start), end(arg_end), role(arg_role) {}
+};
+
+// Tokenizer splits the input string into a sequence of tokens, according to the
+// configuration.
+class Tokenizer {
+ public:
+  explicit Tokenizer(
+      const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
+    PrepareTokenizationCodepointRanges(codepoint_range_configs);
+  }
+
+  // Tokenizes the input string using the selected tokenization method.
+  std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+ protected:
+  // Prepares tokenization codepoint ranges for use in tokenization.
+  void PrepareTokenizationCodepointRanges(
+      const std::vector<TokenizationCodepointRange> codepoint_range_configs);
+
+  // Finds the tokenization role for given codepoint.
+  // If the character is not found returns DEFAULT_ROLE.
+  // Internally uses binary search so should be O(log2(# of codepoint_ranges)).
+  TokenizationCodepointRange::Role FindTokenizationRole(int codepoint) const;
+
+ private:
+  // Codepoint ranges that determine how different codepoints are tokenized.
+  // The ranges must not overlap.
+  std::vector<CodepointRange> codepoint_ranges_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
diff --git a/smartselect/tokenizer.proto b/smartselect/tokenizer.proto
new file mode 100644
index 0000000..8e78970
--- /dev/null
+++ b/smartselect/tokenizer.proto
@@ -0,0 +1,48 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+
+package libtextclassifier;
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+message TokenizationCodepointRange {
+  optional int32 start = 1;
+  optional int32 end = 2;
+
+  // Role of the codepoints in the range.
+  enum Role {
+    // Concatenates the codepoint to the current run of codepoints.
+    DEFAULT_ROLE = 0;
+
+    // Splits a run of codepoints before the current codepoint.
+    SPLIT_BEFORE = 0x1;
+
+    // Splits a run of codepoints after the current codepoint.
+    SPLIT_AFTER = 0x2;
+
+    // Discards the codepoint.
+    DISCARD_CODEPOINT = 0x4;
+
+    // Common values:
+    // Splits on the characters and discards them. Good e.g. for the space
+    // character.
+    WHITESPACE_SEPARATOR = 0x7;
+    // Each codepoint will be a separate token. Good e.g. for Chinese
+    // characters.
+    TOKEN_SEPARATOR = 0x3;
+  }
+  optional Role role = 3;
+}
diff --git a/smartselect/types.h b/smartselect/types.h
new file mode 100644
index 0000000..12e93d9
--- /dev/null
+++ b/smartselect/types.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
+
+#include <ostream>
+#include <string>
+#include <utility>
+
+namespace libtextclassifier {
+
+constexpr int kInvalidIndex = -1;
+
+// Index for a 0-based array of tokens.
+using TokenIndex = int;
+
+// Index for a 0-based array of codepoints.
+using CodepointIndex = int;
+
+// Marks a span in a sequence of codepoints. The first element is the index of
+// the first codepoint of the span, and the second element is the index of the
+// codepoint one past the end of the span.
+using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+
+// Marks a span in a sequence of tokens. The first element is the index of the
+// first token in the span, and the second element is the index of the token one
+// past the end of the span.
+using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+
+// Token holds a token, its position in the original string and whether it was
+// part of the input span.
+struct Token {
+  std::string value;
+  CodepointIndex start;
+  CodepointIndex end;
+
+  // Whether the token was in the input span.
+  bool is_in_span;
+
+  // Whether the token is a padding token.
+  bool is_padding;
+
+  // Default constructor constructs the padding-token.
+  Token()
+      : value(""),
+        start(kInvalidIndex),
+        end(kInvalidIndex),
+        is_in_span(false),
+        is_padding(true) {}
+
+  Token(const std::string& arg_value, CodepointIndex arg_start,
+        CodepointIndex arg_end)
+      : Token(arg_value, arg_start, arg_end, false) {}
+
+  Token(const std::string& arg_value, CodepointIndex arg_start,
+        CodepointIndex arg_end, bool is_in_span)
+      : value(arg_value),
+        start(arg_start),
+        end(arg_end),
+        is_in_span(is_in_span),
+        is_padding(false) {}
+
+  bool operator==(const Token& other) const {
+    return value == other.value && start == other.start && end == other.end &&
+           is_in_span == other.is_in_span && is_padding == other.is_padding;
+  }
+};
+
+// Pretty-printing function for Token.
+inline std::ostream& operator<<(std::ostream& os, const Token& token) {
+  return os << "Token(\"" << token.value << "\", " << token.start << ", "
+            << token.end << ", is_in_span=" << token.is_in_span
+            << ", is_padding=" << token.is_padding << ")";
+}
+
+// Represents a selection.
+struct SelectionWithContext {
+  SelectionWithContext()
+      : context(""),
+        selection_start(-1),
+        selection_end(-1),
+        click_start(-1),
+        click_end(-1) {}
+
+  // UTF8 encoded context.
+  std::string context;
+
+  // Codepoint index to the context where selection starts.
+  CodepointIndex selection_start;
+
+  // Codepoint index to the context one past where selection ends.
+  CodepointIndex selection_end;
+
+  // Codepoint index to the context where click starts.
+  CodepointIndex click_start;
+
+  // Codepoint index to the context one past where click ends.
+  CodepointIndex click_end;
+
+  // Type of the selection.
+  std::string collection;
+
+  CodepointSpan GetClickSpan() const { return {click_start, click_end}; }
+
+  CodepointSpan GetSelectionSpan() const {
+    return {selection_start, selection_end};
+  }
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_TYPES_H_