Initial checkin of libtextclassifier.
Test: libtextclassifier_tests pass in google3 repository.
Change-Id: I2aac5b12b8810c2ec1c1aa27238cb8aa8430f403
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
new file mode 100644
index 0000000..2e04920
--- /dev/null
+++ b/smartselect/feature-processor.cc
@@ -0,0 +1,644 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/feature-processor.h"
+
+#include <iterator>
+#include <set>
+#include <vector>
+
+#include "smartselect/text-classification-model.pb.h"
+#include "util/base/logging.h"
+#include "util/strings/utf8.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+constexpr int kMaxWordLength = 20; // All words will be trimmed to this length.
+
+namespace internal {
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions& options) {
+ TokenFeatureExtractorOptions extractor_options;
+
+ extractor_options.num_buckets = options.num_buckets();
+ for (int order : options.chargram_orders()) {
+ extractor_options.chargram_orders.push_back(order);
+ }
+ extractor_options.extract_case_feature = options.extract_case_feature();
+ extractor_options.extract_selection_mask_feature =
+ options.extract_selection_mask_feature();
+
+ return extractor_options;
+}
+
+FeatureProcessorOptions ParseSerializedOptions(
+ const std::string& serialized_options) {
+ FeatureProcessorOptions options;
+ options.ParseFromString(serialized_options);
+ return options;
+}
+
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens) {
+ for (auto it = tokens->begin(); it != tokens->end(); ++it) {
+ const UnicodeText token_word =
+ UTF8ToUnicodeText(it->value, /*do_copy=*/false);
+
+ auto last_start = token_word.begin();
+ int last_start_index = it->start;
+ std::vector<UnicodeText::const_iterator> split_points;
+
+ // Selection start split point.
+ if (selection.first > it->start && selection.first < it->end) {
+ std::advance(last_start, selection.first - last_start_index);
+ split_points.push_back(last_start);
+ last_start_index = selection.first;
+ }
+
+ // Selection end split point.
+ if (selection.second > it->start && selection.second < it->end) {
+ std::advance(last_start, selection.second - last_start_index);
+ split_points.push_back(last_start);
+ }
+
+ if (!split_points.empty()) {
+ // Add a final split for the rest of the token unless it's been all
+ // consumed already.
+ if (split_points.back() != token_word.end()) {
+ split_points.push_back(token_word.end());
+ }
+
+ std::vector<Token> replacement_tokens;
+ last_start = token_word.begin();
+ int current_pos = it->start;
+ for (const auto& split_point : split_points) {
+ Token new_token(token_word.UTF8Substring(last_start, split_point),
+ current_pos,
+ current_pos + std::distance(last_start, split_point),
+ /*is_in_span=*/false);
+
+ last_start = split_point;
+ current_pos = new_token.end;
+
+ replacement_tokens.push_back(new_token);
+ }
+
+ it = tokens->erase(it);
+ it = tokens->insert(it, replacement_tokens.begin(),
+ replacement_tokens.end());
+ std::advance(it, replacement_tokens.size() - 1);
+ }
+ }
+}
+
+void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
+ std::vector<UnicodeTextRange>* ranges) {
+ UnicodeText::const_iterator start = t.begin();
+ UnicodeText::const_iterator curr = start;
+ UnicodeText::const_iterator end = t.end();
+ for (; curr != end; ++curr) {
+ if (codepoints.find(*curr) != codepoints.end()) {
+ if (start != curr) {
+ ranges->push_back(std::make_pair(start, curr));
+ }
+ start = curr;
+ ++start;
+ }
+ }
+ if (start != end) {
+ ranges->push_back(std::make_pair(start, end));
+ }
+}
+
+std::pair<SelectionWithContext, int> ExtractLineWithClick(
+ const SelectionWithContext& selection_with_context) {
+ std::string new_context;
+ int shift;
+ std::tie(new_context, shift) = ExtractLineWithSpan(
+ selection_with_context.context,
+ {selection_with_context.click_start, selection_with_context.click_end});
+
+ SelectionWithContext result(selection_with_context);
+ result.selection_start -= shift;
+ result.selection_end -= shift;
+ result.click_start -= shift;
+ result.click_end -= shift;
+ result.context = new_context;
+ return {result, shift};
+}
+
+} // namespace internal
+
+std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
+ CodepointSpan span) {
+ const UnicodeText context_unicode = UTF8ToUnicodeText(context,
+ /*do_copy=*/false);
+ std::vector<UnicodeTextRange> lines;
+ std::set<char32> codepoints;
+ codepoints.insert('\n');
+ codepoints.insert('|');
+ internal::FindSubstrings(context_unicode, codepoints, &lines);
+
+ auto span_start = context_unicode.begin();
+ if (span.first > 0) {
+ std::advance(span_start, span.first);
+ }
+ auto span_end = context_unicode.begin();
+ if (span.second > 0) {
+ std::advance(span_end, span.second);
+ }
+ for (const UnicodeTextRange& line : lines) {
+ // Find the line that completely contains the span.
+ if (line.first <= span_start && line.second >= span_start &&
+ line.first <= span_end && line.second >= span_end) {
+ const CodepointIndex last_line_begin_index =
+ std::distance(context_unicode.begin(), line.first);
+
+ std::string result =
+ context_unicode.UTF8Substring(line.first, line.second);
+ return {result, last_line_begin_index};
+ }
+ }
+ return {context, 0};
+}
+
+const char* const FeatureProcessor::kFeatureTypeName = "chargram_continuous";
+
+std::vector<Token> FeatureProcessor::Tokenize(
+ const std::string& utf8_text) const {
+ return tokenizer_.Tokenize(utf8_text);
+}
+
+bool FeatureProcessor::LabelToSpan(
+ const int label, const std::vector<Token>& tokens,
+ std::pair<CodepointIndex, CodepointIndex>* span) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ TokenSpan token_span;
+ if (!LabelToTokenSpan(label, &token_span)) {
+ return false;
+ }
+
+ const int result_begin_token = token_span.first;
+ const int result_begin_codepoint =
+ tokens[options_.context_size() - result_begin_token].start;
+ const int result_end_token = token_span.second;
+ const int result_end_codepoint =
+ tokens[options_.context_size() + result_end_token].end;
+
+ if (result_begin_codepoint == kInvalidIndex ||
+ result_end_codepoint == kInvalidIndex) {
+ *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
+ } else {
+ *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
+ }
+ return true;
+}
+
+bool FeatureProcessor::LabelToTokenSpan(const int label,
+ TokenSpan* token_span) const {
+ if (label >= 0 && label < label_to_selection_.size()) {
+ *token_span = label_to_selection_[label];
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool FeatureProcessor::SpanToLabel(
+ const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& tokens, int* label) const {
+ if (tokens.size() != GetNumContextTokens()) {
+ return false;
+ }
+
+ const int click_position =
+ options_.context_size(); // Click is always in the middle.
+ const int padding = options_.context_size() - options_.max_selection_span();
+
+ int span_left = 0;
+ for (int i = click_position - 1; i >= padding; i--) {
+ if (tokens[i].start != kInvalidIndex && tokens[i].start >= span.first) {
+ ++span_left;
+ } else {
+ break;
+ }
+ }
+
+ int span_right = 0;
+ for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
+ if (tokens[i].end != kInvalidIndex && tokens[i].end <= span.second) {
+ ++span_right;
+ } else {
+ break;
+ }
+ }
+
+ // Check that the spanned tokens cover the whole span.
+ if (tokens[click_position - span_left].start == span.first &&
+ tokens[click_position + span_right].end == span.second) {
+ *label = TokenSpanToLabel({span_left, span_right});
+ } else {
+ *label = kInvalidLabel;
+ }
+
+ return true;
+}
+
+int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
+ auto it = selection_to_label_.find(span);
+ if (it != selection_to_label_.end()) {
+ return it->second;
+ } else {
+ return kInvalidLabel;
+ }
+}
+
+// Converts a codepoint span to a token span in the given list of tokens.
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ TokenIndex start_token = kInvalidIndex;
+ TokenIndex end_token = kInvalidIndex;
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ if (codepoint_start <= selectable_tokens[i].start &&
+ codepoint_end >= selectable_tokens[i].end) {
+ if (start_token == kInvalidIndex) {
+ start_token = i;
+ }
+ end_token = i + 1;
+ }
+ }
+ return {start_token, end_token};
+}
+
+namespace {
+
+// Finds a single token that completely contains the given span.
+int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span) {
+ const int codepoint_start = std::get<0>(codepoint_span);
+ const int codepoint_end = std::get<1>(codepoint_span);
+
+ for (int i = 0; i < selectable_tokens.size(); ++i) {
+ if (codepoint_start >= selectable_tokens[i].start &&
+ codepoint_end <= selectable_tokens[i].end) {
+ return i;
+ }
+ }
+ return kInvalidIndex;
+}
+
+// Find tokens that are part of the selection.
+// NOTE: Will select all tokens that somehow overlap with the selection.
+std::vector<Token> FindTokensInSelection(
+ const std::vector<Token>& selectable_tokens,
+ const SelectionWithContext& selection_with_context) {
+ std::vector<Token> tokens_in_selection;
+ for (const Token& token : selectable_tokens) {
+ const bool selection_start_in_token =
+ token.start <= selection_with_context.selection_start &&
+ token.end > selection_with_context.selection_start;
+
+ const bool selection_end_in_token =
+ token.start < selection_with_context.selection_end &&
+ token.end >= selection_with_context.selection_end;
+
+ if (selection_start_in_token || selection_end_in_token) {
+ tokens_in_selection.push_back(token);
+ }
+ }
+ return tokens_in_selection;
+}
+
+// Helper function to get the click position (in terms of index into a vector of
+// selectable tokens), given selectable tokens. If click position is given, it
+// will be used. If it is not given, but selection is given, the click will be
+// the middle token in the selection range.
+int GetClickPosition(const SelectionWithContext& selection_with_context,
+ const std::vector<Token>& selectable_tokens) {
+ int click_pos = kInvalidIndex;
+ if (selection_with_context.GetClickSpan() !=
+ std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
+ selectable_tokens, selection_with_context.GetClickSpan());
+
+ // If no exact match was found, try finding a token that completely contains
+ // the click span. This is useful e.g. when Android builds the selection
+ // using ICU tokenization, and ends up with only a portion of our space-
+ // separated token. E.g. for "(857)" Android would select "857".
+ if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
+ int token_index = FindTokenThatContainsSpan(
+ selectable_tokens, selection_with_context.GetClickSpan());
+ if (token_index != kInvalidIndex) {
+ range_begin = token_index;
+ range_end = token_index + 1;
+ }
+ }
+
+ // We only allow clicks that are exactly 1 selectable token.
+ if (range_end - range_begin == 1) {
+ click_pos = range_begin;
+ } else {
+ click_pos = kInvalidIndex;
+ }
+ } else if (selection_with_context.GetSelectionSpan() !=
+ std::make_pair(kInvalidIndex, kInvalidIndex)) {
+ int range_begin;
+ int range_end;
+ std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(
+ selectable_tokens, selection_with_context.GetSelectionSpan());
+
+ // Center the clicked token in the selection range.
+ if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
+ click_pos = (range_begin + range_end - 1) / 2;
+ }
+ }
+
+ return click_pos;
+}
+
+} // namespace
+
+CodepointSpan FeatureProcessor::ClickRandomTokenInSelection(
+ const SelectionWithContext& selection_with_context) const {
+ const std::vector<Token> tokens = Tokenize(selection_with_context.context);
+ const std::vector<Token> tokens_in_selection =
+ FindTokensInSelection(tokens, selection_with_context);
+
+ if (!tokens_in_selection.empty()) {
+ std::uniform_int_distribution<> selection_token_draw(
+ 0, tokens_in_selection.size() - 1);
+ const int token_id = selection_token_draw(*random_);
+ return {tokens_in_selection[token_id].start,
+ tokens_in_selection[token_id].end};
+ } else {
+ return {kInvalidIndex, kInvalidIndex};
+ }
+}
+
+bool FeatureProcessor::GetFeaturesAndLabels(
+ const SelectionWithContext& selection_with_context,
+ std::vector<nlp_core::FeatureVector>* features,
+ std::vector<float>* extra_features,
+ std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+ CodepointSpan* selection_codepoint_label, int* classification_label) const {
+ if (features == nullptr) {
+ return false;
+ }
+ *features =
+ std::vector<nlp_core::FeatureVector>(options_.context_size() * 2 + 1);
+
+ SelectionWithContext selection_normalized;
+ int normalization_shift;
+ if (options_.only_use_line_with_click()) {
+ std::tie(selection_normalized, normalization_shift) =
+ internal::ExtractLineWithClick(selection_with_context);
+ } else {
+ selection_normalized = selection_with_context;
+ normalization_shift = 0;
+ }
+
+ std::vector<Token> input_tokens = Tokenize(selection_normalized.context);
+ if (options_.split_tokens_on_selection_boundaries()) {
+ internal::SplitTokensOnSelectionBoundaries(
+ selection_with_context.GetSelectionSpan(), &input_tokens);
+ }
+ int click_pos = GetClickPosition(selection_normalized, input_tokens);
+ if (click_pos == kInvalidIndex) {
+ TC_LOG(ERROR) << "Could not extract click position.";
+ return false;
+ }
+
+ std::vector<Token> output_tokens;
+ bool status = ComputeFeatures(click_pos, input_tokens,
+ selection_normalized.GetSelectionSpan(),
+ features, extra_features, &output_tokens);
+ if (!status) {
+ TC_LOG(ERROR) << "Feature computation failed.";
+ return false;
+ }
+
+ if (selection_label != nullptr) {
+ status = SpanToLabel(selection_normalized.GetSelectionSpan(), output_tokens,
+ selection_label);
+ if (!status) {
+ TC_LOG(ERROR) << "Could not convert selection span to label.";
+ return false;
+ }
+ }
+
+ if (selection_codepoint_label != nullptr) {
+ *selection_codepoint_label = selection_with_context.GetSelectionSpan();
+ }
+
+ if (selection_label_spans != nullptr) {
+ // If an input normalization was performed, we need to shift the tokens
+ // back to get the correct ranges in the original input.
+ if (normalization_shift) {
+ for (Token& token : output_tokens) {
+ if (token.start != kInvalidIndex && token.end != kInvalidIndex) {
+ token.start += normalization_shift;
+ token.end += normalization_shift;
+ }
+ }
+ }
+ for (int i = 0; i < label_to_selection_.size(); ++i) {
+ CodepointSpan span;
+ status = LabelToSpan(i, output_tokens, &span);
+ if (!status) {
+ TC_LOG(ERROR) << "Could not convert label to span: " << i;
+ return false;
+ }
+ selection_label_spans->push_back(span);
+ }
+ }
+
+ if (classification_label != nullptr) {
+ *classification_label = CollectionToLabel(selection_normalized.collection);
+ }
+
+ return true;
+}
+
+bool FeatureProcessor::GetFeaturesAndLabels(
+ const SelectionWithContext& selection_with_context,
+ std::vector<std::vector<std::pair<int, float>>>* features,
+ std::vector<float>* extra_features,
+ std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+ CodepointSpan* selection_codepoint_label, int* classification_label) const {
+ if (features == nullptr) {
+ return false;
+ }
+ if (extra_features == nullptr) {
+ return false;
+ }
+
+ std::vector<nlp_core::FeatureVector> feature_vectors;
+ bool result = GetFeaturesAndLabels(selection_with_context, &feature_vectors,
+ extra_features, selection_label_spans,
+ selection_label, selection_codepoint_label,
+ classification_label);
+
+ if (!result) {
+ return false;
+ }
+
+ features->clear();
+ for (int i = 0; i < feature_vectors.size(); ++i) {
+ features->emplace_back();
+ for (int j = 0; j < feature_vectors[i].size(); ++j) {
+ nlp_core::FloatFeatureValue feature_value(feature_vectors[i].value(j));
+ (*features)[i].push_back({feature_value.id, feature_value.weight});
+ }
+ }
+
+ return true;
+}
+
+bool FeatureProcessor::ComputeFeatures(
+ int click_pos, const std::vector<Token>& tokens,
+ CodepointSpan selected_span, std::vector<nlp_core::FeatureVector>* features,
+ std::vector<float>* extra_features,
+ std::vector<Token>* output_tokens) const {
+ int dropout_left = 0;
+ int dropout_right = 0;
+ if (options_.context_dropout_probability() > 0.0) {
+ // Determine how much context to drop.
+ bool status = GetContextDropoutRange(&dropout_left, &dropout_right);
+ if (!status) {
+ return false;
+ }
+ }
+
+ int feature_index = 0;
+ output_tokens->reserve(options_.context_size() * 2 + 1);
+ const int num_extra_features =
+ static_cast<int>(options_.extract_case_feature()) +
+ static_cast<int>(options_.extract_selection_mask_feature());
+ extra_features->reserve((options_.context_size() * 2 + 1) *
+ num_extra_features);
+ for (int i = click_pos - options_.context_size();
+ i <= click_pos + options_.context_size(); ++i, ++feature_index) {
+ std::vector<int> sparse_features;
+ std::vector<float> dense_features;
+
+ const bool is_valid_token = i >= 0 && i < tokens.size();
+
+ bool is_dropped = false;
+ if (options_.context_dropout_probability() > 0.0) {
+ if (i < click_pos - options_.context_size() + dropout_left) {
+ is_dropped = true;
+ } else if (i > click_pos + options_.context_size() - dropout_right) {
+ is_dropped = true;
+ }
+ }
+
+ if (is_valid_token && !is_dropped) {
+ Token token(tokens[i]);
+ token.is_in_span = token.start >= selected_span.first &&
+ token.end <= selected_span.second;
+ feature_extractor_.Extract(token, &sparse_features, &dense_features);
+ output_tokens->push_back(tokens[i]);
+ } else {
+ feature_extractor_.Extract(Token(), &sparse_features, &dense_features);
+ // This adds an empty string for each missing context token to exactly
+ // match the input tokens to the network.
+ output_tokens->emplace_back();
+ }
+
+ for (int feature_id : sparse_features) {
+ const int64 feature_value =
+ nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
+ .discrete_value;
+ (*features)[feature_index].add(
+ const_cast<nlp_core::NumericFeatureType*>(&feature_type_),
+ feature_value);
+ }
+
+ for (float value : dense_features) {
+ extra_features->push_back(value);
+ }
+ }
+ return true;
+}
+
+int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
+ const auto it = collection_to_label_.find(collection);
+ if (it == collection_to_label_.end()) {
+ return options_.default_collection();
+ } else {
+ return it->second;
+ }
+}
+
+std::string FeatureProcessor::LabelToCollection(int label) const {
+ if (label >= 0 && label < collection_to_label_.size()) {
+ return options_.collections(label);
+ } else {
+ return GetDefaultCollection();
+ }
+}
+
+void FeatureProcessor::MakeLabelMaps() {
+ for (int i = 0; i < options_.collections().size(); ++i) {
+ collection_to_label_[options_.collections(i)] = i;
+ }
+
+ int selection_label_id = 0;
+ for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
+ for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
+ if (!options_.selection_reduced_output_space() ||
+ r + l <= options_.max_selection_span()) {
+ TokenSpan token_span{l, r};
+ selection_to_label_[token_span] = selection_label_id;
+ label_to_selection_.push_back(token_span);
+ ++selection_label_id;
+ }
+ }
+ }
+}
+
+bool FeatureProcessor::GetContextDropoutRange(int* dropout_left,
+ int* dropout_right) const {
+ std::uniform_real_distribution<> uniform01_draw(0, 1);
+ if (uniform01_draw(*random_) < options_.context_dropout_probability()) {
+ if (options_.use_variable_context_dropout()) {
+ std::uniform_int_distribution<> uniform_context_draw(
+ 0, options_.context_size());
+ // Select how much to drop in the range: [zero; context size]
+ *dropout_left = uniform_context_draw(*random_);
+ *dropout_right = uniform_context_draw(*random_);
+ } else {
+ // Drop all context.
+ return false;
+ }
+ } else {
+ *dropout_left = 0;
+ *dropout_right = 0;
+ }
+ return true;
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
new file mode 100644
index 0000000..5684cd7
--- /dev/null
+++ b/smartselect/feature-processor.h
@@ -0,0 +1,206 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature processing for FFModel (feed-forward SmartSelection model).
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
+
+#include <memory>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "common/feature-extractor.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "smartselect/token-feature-extractor.h"
+#include "smartselect/tokenizer.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+constexpr int kInvalidLabel = -1;
+
+namespace internal {
+
+// Parses the serialized protocol buffer.
+FeatureProcessorOptions ParseSerializedOptions(
+ const std::string& serialized_options);
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions& options);
+
+// Returns a modified version of the selection_with_context, such that only the
+// line that contains the clicked span is kept, the number of codepoints
+// the selection was moved by.
+std::pair<SelectionWithContext, int> ExtractLineWithClick(
+ const SelectionWithContext& selection_with_context);
+
+// Splits tokens that contain the selection boundary inside them.
+// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens);
+
+} // namespace internal
+
+TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
+ CodepointSpan codepoint_span);
+
+// Returns a modified version of the context string, such that only the
+// line that contains the span is kept. Also returns a codepoint shift
+// size that happend. If the span spans multiple lines, returns the original
+// input with zero shift.
+// The following characters are considered to be line separators: '\n', '|'
+std::pair<std::string, int> ExtractLineWithSpan(const std::string& context,
+ CodepointSpan span);
+
+// Takes care of preparing features for the FFModel.
+class FeatureProcessor {
+ public:
+ explicit FeatureProcessor(const FeatureProcessorOptions& options)
+ : options_(options),
+ feature_extractor_(
+ internal::BuildTokenFeatureExtractorOptions(options)),
+ feature_type_(FeatureProcessor::kFeatureTypeName,
+ options.num_buckets()),
+ tokenizer_({options.tokenization_codepoint_config().begin(),
+ options.tokenization_codepoint_config().end()}),
+ random_(new std::mt19937(std::random_device()())) {
+ MakeLabelMaps();
+ }
+
+ explicit FeatureProcessor(const std::string& serialized_options)
+ : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
+ }
+
+ CodepointSpan ClickRandomTokenInSelection(
+ const SelectionWithContext& selection_with_context) const;
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+ // NOTE: If dropout is on, subsequent calls of this function with the same
+ // arguments might return different results.
+ bool GetFeaturesAndLabels(const SelectionWithContext& selection_with_context,
+ std::vector<nlp_core::FeatureVector>* features,
+ std::vector<float>* extra_features,
+ std::vector<CodepointSpan>* selection_label_spans,
+ int* selection_label,
+ CodepointSpan* selection_codepoint_label,
+ int* classification_label) const;
+
+ // Same as above but uses std::vector instead of FeatureVector.
+ // NOTE: If dropout is on, subsequent calls of this function with the same
+ // arguments might return different results.
+ bool GetFeaturesAndLabels(
+ const SelectionWithContext& selection_with_context,
+ std::vector<std::vector<std::pair<int, float>>>* features,
+ std::vector<float>* extra_features,
+ std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+ CodepointSpan* selection_codepoint_label,
+ int* classification_label) const;
+
+ // Converts a label into a token span.
+ bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+
+ // Gets the string value for given collection label.
+ std::string LabelToCollection(int label) const;
+
+ // Gets the total number of collections of the model.
+ int NumCollections() const { return collection_to_label_.size(); }
+
+ // Gets the name of the default collection.
+ std::string GetDefaultCollection() const {
+ return options_.collections(options_.default_collection());
+ }
+
+ FeatureProcessorOptions GetOptions() const { return options_; }
+
+ int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
+ // Sets the source of randomness.
+ void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
+
+ protected:
+ // Extracts features for given word.
+ std::vector<int> GetWordFeatures(const std::string& word) const;
+
+ // NOTE: If dropout is on, subsequent calls of this function with the same
+ // arguments might return different results.
+ bool ComputeFeatures(int click_pos,
+ const std::vector<Token>& selectable_tokens,
+ CodepointSpan selected_span,
+ std::vector<nlp_core::FeatureVector>* features,
+ std::vector<float>* extra_features,
+ std::vector<Token>* output_tokens) const;
+
+ // Helper function that computes how much left context and how much right
+ // context should be dropped. Uses a mutable random_ member as a source of
+ // randomness.
+ bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
+
+ // Returns the class id corresponding to the given string collection
+ // identifier. There is a catch-all class id that the function returns for
+ // unknown collections.
+ int CollectionToLabel(const std::string& collection) const;
+
+ // Prepares mapping from collection names to labels.
+ void MakeLabelMaps();
+
+ // Gets the number of spannable tokens for the model.
+ //
+ // Spannable tokens are those tokens of context, which the model predicts
+ // selection spans over (i.e., there is 1:1 correspondence between the output
+ // classes of the model and each of the spannable tokens).
+ int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
+
+ // Converts a label into a span of codepoint indices corresponding to it
+ // given output_tokens.
+ bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
+ CodepointSpan* span) const;
+
+ // Converts a span to the corresponding label given output_tokens.
+ bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& output_tokens, int* label) const;
+
+ // Converts a token span to the corresponding label.
+ int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+
+ private:
+ FeatureProcessorOptions options_;
+
+ TokenFeatureExtractor feature_extractor_;
+
+ static const char* const kFeatureTypeName;
+
+ nlp_core::NumericFeatureType feature_type_;
+
+ // Mapping between token selection spans and labels ids.
+ std::map<TokenSpan, int> selection_to_label_;
+ std::vector<TokenSpan> label_to_selection_;
+
+ // Mapping between collections and labels.
+ std::map<std::string, int> collection_to_label_;
+
+ Tokenizer tokenizer_;
+
+ // Source of randomness.
+ mutable std::unique_ptr<std::mt19937> random_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
new file mode 100644
index 0000000..fc5a0ab
--- /dev/null
+++ b/smartselect/text-classification-model.cc
@@ -0,0 +1,441 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/text-classification-model.h"
+
+#include <cmath>
+#include <iterator>
+#include <numeric>
+
+#include "common/embedding-network.h"
+#include "common/feature-extractor.h"
+#include "common/memory_image/embedding-network-params-from-image.h"
+#include "common/memory_image/memory-image-reader.h"
+#include "common/mmap.h"
+#include "common/softmax.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "util/base/logging.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+using nlp_core::EmbeddingNetwork;
+using nlp_core::EmbeddingNetworkProto;
+using nlp_core::FeatureVector;
+using nlp_core::MemoryImageReader;
+using nlp_core::MmapFile;
+using nlp_core::MmapHandle;
+
+ModelParams* ModelParams::Build(const void* start, uint64 num_bytes) {
+ MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
+
+ FeatureProcessorOptions feature_processor_options;
+ auto feature_processor_extension_id =
+ feature_processor_options_in_embedding_network_proto;
+ if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
+ feature_processor_options =
+ reader.trimmed_proto().GetExtension(feature_processor_extension_id);
+
+ // If no tokenization codepoint config is present, tokenize on space.
+ if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
+ TokenizationCodepointRange* config =
+ feature_processor_options.add_tokenization_codepoint_config();
+ config->set_start(32);
+ config->set_end(33);
+ config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
+ }
+ } else {
+ return nullptr;
+ }
+
+ SelectionModelOptions selection_options;
+ auto selection_options_extension_id =
+ selection_model_options_in_embedding_network_proto;
+ if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
+ selection_options =
+ reader.trimmed_proto().GetExtension(selection_options_extension_id);
+ } else {
+ // TODO(zilka): Remove this once we added the model options to the exported
+ // models.
+ for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
+ {{33, 35}, {37, 39}, {42, 42}, {44, 47},
+ {58, 59}, {63, 64}, {91, 93}, {95, 95},
+ {123, 123}, {125, 125}, {161, 161}, {171, 171},
+ {183, 183}, {187, 187}, {191, 191}, {894, 894},
+ {903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470},
+ {1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524},
+ {1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645},
+ {1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416},
+ {3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858},
+ {3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175},
+ {4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788},
+ {5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106},
+ {6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687},
+ {8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286},
+ {8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142},
+ {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
+ {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
+ {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
+ {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
+ {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
+ {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
+ {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
+ {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
+ {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
+ {68176, 68184}})) {
+ for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
+ selection_options.add_punctuation_to_strip(i);
+ }
+ selection_options.set_strip_punctuation(true);
+ selection_options.set_enforce_symmetry(true);
+ selection_options.set_symmetry_context_size(
+ feature_processor_options.context_size() * 2);
+ }
+ }
+
+ return new ModelParams(start, num_bytes, selection_options,
+ feature_processor_options);
+}
+
+CodepointSpan TextClassificationModel::StripPunctuation(
+ CodepointSpan selection, const std::string& context) const {
+ UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
+ int context_length =
+ std::distance(context_unicode.begin(), context_unicode.end());
+
+ // Check that the indices are valid.
+ if (selection.first < 0 || selection.first > context_length ||
+ selection.second < 0 || selection.second > context_length) {
+ return selection;
+ }
+
+ UnicodeText::const_iterator it;
+ for (it = context_unicode.begin(), std::advance(it, selection.first);
+ punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
+ ++it, ++selection.first) {
+ }
+
+ for (it = context_unicode.begin(), std::advance(it, selection.second - 1);
+ punctuation_to_strip_.find(*it) != punctuation_to_strip_.end();
+ --it, --selection.second) {
+ }
+
+ return selection;
+}
+
+TextClassificationModel::TextClassificationModel(int fd) {
+ initialized_ = LoadModels(fd);
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Failed to load models";
+ return;
+ }
+
+ selection_options_ = selection_params_->GetSelectionModelOptions();
+ for (const int codepoint : selection_options_.punctuation_to_strip()) {
+ punctuation_to_strip_.insert(codepoint);
+ }
+}
+
+bool TextClassificationModel::LoadModels(int fd) {
+ MmapHandle mmap_handle = MmapFile(fd);
+ if (!mmap_handle.ok()) {
+ return false;
+ }
+
+ // Read the length of the selection model.
+ const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
+ uint32 selection_model_length =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+ model_data += sizeof(selection_model_length);
+
+ selection_params_.reset(
+ ModelParams::Build(model_data, selection_model_length));
+ if (!selection_params_.get()) {
+ return false;
+ }
+ selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
+ selection_feature_processor_.reset(
+ new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
+
+ model_data += selection_model_length;
+ uint32 sharing_model_length =
+ LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
+ model_data += sizeof(sharing_model_length);
+ sharing_params_.reset(ModelParams::Build(model_data, sharing_model_length));
+ if (!sharing_params_.get()) {
+ return false;
+ }
+ sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
+ sharing_feature_processor_.reset(
+ new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
+
+ return true;
+}
+
+EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
+ const std::string& context, CodepointSpan click_indices,
+ CodepointSpan selection_indices, const FeatureProcessor& feature_processor,
+ const EmbeddingNetwork* network,
+ std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+ CodepointSpan* selection_codepoint_label, int* classification_label) const {
+ SelectionWithContext selection_with_context;
+ selection_with_context.context = context;
+ selection_with_context.click_start = std::get<0>(click_indices);
+ selection_with_context.click_end = std::get<1>(click_indices);
+ selection_with_context.selection_start = std::get<0>(selection_indices);
+ selection_with_context.selection_end = std::get<1>(selection_indices);
+
+ std::vector<FeatureVector> features;
+ std::vector<float> extra_features;
+ const bool features_computed = feature_processor.GetFeaturesAndLabels(
+ selection_with_context, &features, &extra_features, selection_label_spans,
+ selection_label, selection_codepoint_label, classification_label);
+
+ EmbeddingNetwork::Vector scores;
+ if (!features_computed) {
+ TC_LOG(ERROR) << "Features not computed";
+ return scores;
+ }
+ network->ComputeFinalScores(features, extra_features, &scores);
+ return scores;
+}
+
+CodepointSpan TextClassificationModel::SuggestSelection(
+ const std::string& context, CodepointSpan click_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
+ return click_indices;
+ }
+
+ if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
+ TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
+ << std::get<0>(click_indices) << " "
+ << std::get<1>(click_indices);
+ return click_indices;
+ }
+
+ CodepointSpan result;
+ if (selection_options_.enforce_symmetry()) {
+ result = SuggestSelectionSymmetrical(context, click_indices);
+ } else {
+ float score;
+ std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
+ }
+
+ if (selection_options_.strip_punctuation()) {
+ result = StripPunctuation(result, context);
+ }
+
+ return result;
+}
+
+std::pair<CodepointSpan, float>
+TextClassificationModel::SuggestSelectionInternal(
+ const std::string& context, CodepointSpan click_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
+ return {click_indices, -1.0};
+ }
+
+ // Invalid selection indices make the feature extraction use the provided
+ // click indices.
+ const CodepointSpan selection_indices({kInvalidIndex, kInvalidIndex});
+
+ std::vector<CodepointSpan> selection_label_spans;
+ EmbeddingNetwork::Vector scores = InferInternal(
+ context, click_indices, selection_indices, *selection_feature_processor_,
+ selection_network_.get(), &selection_label_spans,
+ /*selection_label=*/nullptr,
+ /*selection_codepoint_label=*/nullptr,
+ /*classification_label=*/nullptr);
+
+ if (!scores.empty()) {
+ scores = nlp_core::ComputeSoftmax(scores);
+ const int prediction =
+ std::max_element(scores.begin(), scores.end()) - scores.begin();
+ std::pair<CodepointIndex, CodepointIndex> selection =
+ selection_label_spans[prediction];
+
+ if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
+ TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
+ << prediction << " " << selection.first << " "
+ << selection.second;
+ return {click_indices, -1.0};
+ }
+
+ return {{selection.first, selection.second}, scores[prediction]};
+ } else {
+ TC_LOG(ERROR) << "Returning default selection: scores.size() = "
+ << scores.size();
+ return {click_indices, -1.0};
+ }
+}
+
+namespace {
+
+int GetClickTokenIndex(const std::vector<Token>& tokens,
+ CodepointSpan click_indices) {
+ TokenSpan span = CodepointSpanToTokenSpan(tokens, click_indices);
+ if (span.second - span.first == 1) {
+ return span.first;
+ } else {
+ for (int i = 0; i < tokens.size(); i++) {
+ if (tokens[i].start <= click_indices.first &&
+ tokens[i].end >= click_indices.second) {
+ return i;
+ }
+ }
+ return kInvalidIndex;
+ }
+}
+
+} // namespace
+
+// Implements a greedy-search-like algorithm for making selections symmetric.
+//
+// Steps:
+// 1. Get a set of selection proposals from places around the clicked word.
+// 2. For each proposal (going from highest-scoring), check if the tokens that
+// the proposal selects are still free, otherwise claims them, if a proposal
+// that contains the clicked token is found, it is returned as the
+// suggestion.
+//
+// This algorithm should ensure that if a selection is proposed, it does not
+// matter which word of it was tapped - all of them will lead to the same
+// selection.
+CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
+ const std::string& full_context, CodepointSpan click_indices) const {
+ // Extract context from the current line only.
+ std::string context;
+ int context_shift;
+ std::tie(context, context_shift) =
+ ExtractLineWithSpan(full_context, click_indices);
+ click_indices.first -= context_shift;
+ click_indices.second -= context_shift;
+
+ std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
+
+ const int click_index = GetClickTokenIndex(tokens, click_indices);
+ if (click_index == kInvalidIndex) {
+ return click_indices;
+ }
+
+ const int symmetry_context_size = selection_options_.symmetry_context_size();
+
+ // Scan in the symmetry context for selection span proposals.
+ std::vector<std::pair<CodepointSpan, float>> proposals;
+ for (int i = -symmetry_context_size; i < symmetry_context_size + 1; i++) {
+ const int token_index = click_index + i;
+ if (token_index >= 0 && token_index < tokens.size()) {
+ float score;
+ CodepointSpan span;
+ std::tie(span, score) = SuggestSelectionInternal(
+ context, {tokens[token_index].start, tokens[token_index].end});
+ proposals.push_back({span, score});
+ }
+ }
+
+ // Sort selection span proposals by their respective probabilities.
+ std::sort(
+ proposals.begin(), proposals.end(),
+ [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
+ return a.second > b.second;
+ });
+
+ // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
+ // claimed by the higher-scoring selection proposals, so that the
+ // lower-scoring ones cannot use them. Returns the selection proposal if it
+ // contains the clicked token.
+ std::vector<int> used_tokens(tokens.size(), 0);
+ for (auto span_result : proposals) {
+ TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
+ bool feasible = true;
+ for (int i = span.first; i < span.second; i++) {
+ if (used_tokens[i] != 0) {
+ feasible = false;
+ break;
+ }
+ }
+
+ if (feasible) {
+ if (span.first <= click_index && span.second > click_index) {
+ return {span_result.first.first + context_shift,
+ span_result.first.second + context_shift};
+ }
+ for (int i = span.first; i < span.second; i++) {
+ used_tokens[i] = 1;
+ }
+ }
+ }
+ }
+
+ return {click_indices.first + context_shift,
+ click_indices.second + context_shift};
+}
+
+CodepointSpan TextClassificationModel::SuggestSelection(
+ const SelectionWithContext& selection_with_context) const {
+ CodepointSpan click_indices = {selection_with_context.click_start,
+ selection_with_context.click_end};
+
+ // If click_indices are unspecified, select the first token.
+ if (click_indices == CodepointSpan({kInvalidIndex, kInvalidIndex})) {
+ click_indices = selection_feature_processor_->ClickRandomTokenInSelection(
+ selection_with_context);
+ }
+
+ return SuggestSelection(selection_with_context.context, click_indices);
+}
+
+std::string TextClassificationModel::ClassifyText(
+ const std::string& context, CodepointSpan selection_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
+ return sharing_feature_processor_->GetDefaultCollection();
+ }
+
+ // Invalid click indices make the feature extraction select the middle word in
+ // the selection span.
+ const CodepointSpan click_indices({kInvalidIndex, kInvalidIndex});
+
+ EmbeddingNetwork::Vector scores =
+ InferInternal(context, click_indices, selection_indices,
+ *sharing_feature_processor_, sharing_network_.get(),
+ /*selection_label_spans=*/nullptr,
+ /*selection_label=*/nullptr,
+ /*selection_codepoint_label=*/nullptr,
+ /*classification_label=*/nullptr);
+ if (scores.empty()) {
+ TC_LOG(ERROR) << "Using default class";
+ return sharing_feature_processor_->GetDefaultCollection();
+ }
+ if (!scores.empty() &&
+ scores.size() == sharing_feature_processor_->NumCollections()) {
+ const int prediction =
+ std::max_element(scores.begin(), scores.end()) - scores.begin();
+
+ // Convert to a class name.
+ const std::string class_name =
+ sharing_feature_processor_->LabelToCollection(prediction);
+ return class_name;
+ } else {
+ TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
+ return sharing_feature_processor_->GetDefaultCollection();
+ }
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
new file mode 100644
index 0000000..4e0a3da
--- /dev/null
+++ b/smartselect/text-classification-model.h
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Inference code for the feed-forward text classification models.
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
+
+#include <memory>
+#include <set>
+#include <string>
+
+#include "base.h"
+#include "common/embedding-network.h"
+#include "common/feature-extractor.h"
+#include "common/memory_image/embedding-network-params-from-image.h"
+#include "smartselect/feature-processor.h"
+#include "smartselect/text-classification-model.pb.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+// Loads and holds the parameters of the inference network.
+//
+// This class overrides a couple of methods of EmbeddingNetworkParamsFromImage
+// because we only have one embedding matrix for all positions of context,
+// whereas the original class would have a separate one for each.
+class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage {
+ public:
+ static ModelParams* Build(const void* start, uint64 num_bytes);
+
+ const FeatureProcessorOptions& GetFeatureProcessorOptions() const {
+ return feature_processor_options_;
+ }
+
+ const SelectionModelOptions& GetSelectionModelOptions() const {
+ return selection_options_;
+ }
+
+ protected:
+ int embeddings_size() const override { return context_size_ * 2 + 1; }
+
+ int embedding_num_features_size() const override {
+ return context_size_ * 2 + 1;
+ }
+
+ int embedding_num_features(int i) const override { return 1; }
+
+ int embeddings_num_rows(int i) const override {
+ return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0);
+ };
+
+ int embeddings_num_cols(int i) const override {
+ return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0);
+ };
+
+ const void* embeddings_weights(int i) const override {
+ return EmbeddingNetworkParamsFromImage::embeddings_weights(0);
+ };
+
+ nlp_core::QuantizationType embeddings_quant_type(int i) const override {
+ return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0);
+ }
+
+ const nlp_core::float16* embeddings_quant_scales(int i) const override {
+ return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0);
+ }
+
+ private:
+ ModelParams(const void* start, uint64 num_bytes,
+ const SelectionModelOptions& selection_options,
+ const FeatureProcessorOptions& feature_processor_options)
+ : EmbeddingNetworkParamsFromImage(start, num_bytes),
+ selection_options_(selection_options),
+ feature_processor_options_(feature_processor_options),
+ context_size_(feature_processor_options_.context_size()) {}
+
+ SelectionModelOptions selection_options_;
+ FeatureProcessorOptions feature_processor_options_;
+ int context_size_;
+};
+
+// SmartSelection/Sharing feed-forward model.
+class TextClassificationModel {
+ public:
+ // Loads TextClassificationModel from given file given by an int
+ // file descriptor.
+ explicit TextClassificationModel(int fd);
+
+ // Runs inference for given a context and current selection (i.e. index
+ // of the first and one past last selected characters (utf8 codepoint
+ // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
+ // beginning character and one past selection end character.
+ // Returns the original click_indices if an error occurs.
+ // NOTE: The selection indices are passed in and returned in terms of
+ // UTF8 codepoints (not bytes).
+ // Requires that the model is a smart selection model.
+ CodepointSpan SuggestSelection(const std::string& context,
+ CodepointSpan click_indices) const;
+
+ // Same as above but accepts a selection_with_context. Only used for
+ // evaluation.
+ CodepointSpan SuggestSelection(
+ const SelectionWithContext& selection_with_context) const;
+
+ // Classifies the selected text given the context string.
+ // Requires that the model is a smart sharing model.
+ // Returns a default collection name if an error occurs.
+ std::string ClassifyText(const std::string& context,
+ CodepointSpan click_indices) const;
+
+ protected:
+ // Removes punctuation from the beginning and end of the selection and returns
+ // the new selection span.
+ CodepointSpan StripPunctuation(CodepointSpan selection,
+ const std::string& context) const;
+
+ // During evaluation we need access to the feature processor.
+ FeatureProcessor* SelectionFeatureProcessor() const {
+ return selection_feature_processor_.get();
+ }
+
+ private:
+ bool LoadModels(int fd);
+
+ nlp_core::EmbeddingNetwork::Vector InferInternal(
+ const std::string& context, CodepointSpan click_indices,
+ CodepointSpan selection_indices,
+ const FeatureProcessor& feature_processor,
+ const nlp_core::EmbeddingNetwork* network,
+ std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
+ CodepointSpan* selection_codepoint_label,
+ int* classification_label) const;
+
+ // Returns a selection suggestion with a score.
+ std::pair<CodepointSpan, float> SuggestSelectionInternal(
+ const std::string& context, CodepointSpan click_indices) const;
+
+ // Returns a selection suggestion and makes sure it's symmetric. Internally
+ // runs several times SuggestSelectionInternal.
+ CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
+ CodepointSpan click_indices) const;
+
+ bool initialized_;
+ std::unique_ptr<ModelParams> selection_params_;
+ std::unique_ptr<FeatureProcessor> selection_feature_processor_;
+ std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
+ std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
+ std::unique_ptr<ModelParams> sharing_params_;
+ std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
+
+ SelectionModelOptions selection_options_;
+ std::set<int> punctuation_to_strip_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFICATION_MODEL_H_
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
new file mode 100644
index 0000000..e098e81
--- /dev/null
+++ b/smartselect/text-classification-model.proto
@@ -0,0 +1,108 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Text classification model configuration.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+
+import "external/libtextclassifier/common/embedding-network.proto";
+import "external/libtextclassifier/smartselect/tokenizer.proto";
+
+package libtextclassifier;
+
+message SelectionModelOptions {
+ // A list of Unicode codepoints to strip from predicted selections.
+ repeated int32 punctuation_to_strip = 1;
+
+ // Whether to strip punctuation after the selection is made.
+ optional bool strip_punctuation = 2;
+
+ // Enforce symmetrical selections.
+ optional bool enforce_symmetry = 3;
+
+ // Number of inferences made around the click position (to one side), for
+ // enforcing symmetry.
+ optional int32 symmetry_context_size = 4;
+}
+
+message FeatureProcessorOptions {
+ // Number of buckets used for hashing charactergrams.
+ optional int32 num_buckets = 1 [default = -1];
+
+ // Context size defines the number of words to the left and to the right of
+ // the selected word to be used as context. For example, if context size is
+ // N, then we take N words to the left and N words to the right of the
+ // selected word as its context.
+ optional int32 context_size = 2 [default = -1];
+
+ // Maximum number of words of the context to select in total.
+ optional int32 max_selection_span = 3 [default = -1];
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ repeated int32 chargram_orders = 4;
+
+ // Whether to extract the token case feature.
+ optional bool extract_case_feature = 5 [default = false];
+
+ // Whether to extract the selection mask feature.
+ optional bool extract_selection_mask_feature = 6 [default = false];
+
+ // If true, tokenize on space, otherwise tokenize using ICU.
+ optional bool tokenize_on_space = 7 [default = true];
+
+ // If true, the selection classifier output will contain only the selections
+ // that are feasible (e.g., those that are shorter than max_selection_span),
+ // if false, the output will be a complete cross-product of possible
+ // selections to the left and posible selections to the right, including the
+ // infeasible ones.
+ // NOTE: Exists mainly for compatibility with older models that were trained
+ // with the non-reduced output space.
+ optional bool selection_reduced_output_space = 8 [default = true];
+
+ // Collection names.
+ repeated string collections = 9;
+
+ // An index of collection in collections to be used if a collection name can't
+ // be mapped to an id.
+ optional int32 default_collection = 10 [default = -1];
+
+ // Probability with which to drop context of examples.
+ optional float context_dropout_probability = 11 [default = 0.0];
+
+ // If true, drop variable amounts of context, if false all context, with
+ // probability given by context_dropout_ratio.
+ optional bool use_variable_context_dropout = 12 [default = false];
+
+ // If true, will split the input by lines, and only use the line that contains
+ // the clicked token.
+ optional bool only_use_line_with_click = 13 [default = false];
+
+ // If true, will split tokens that contain the selection boundary, at the
+ // position of the boundary.
+ // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+ optional bool split_tokens_on_selection_boundaries = 14 [default = false];
+
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ repeated TokenizationCodepointRange tokenization_codepoint_config = 15;
+};
+
+extend nlp_core.EmbeddingNetworkProto {
+ optional FeatureProcessorOptions
+ feature_processor_options_in_embedding_network_proto = 146230910;
+ optional SelectionModelOptions
+ selection_model_options_in_embedding_network_proto = 148190899;
+}
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
new file mode 100644
index 0000000..a6bbf4f
--- /dev/null
+++ b/smartselect/token-feature-extractor.cc
@@ -0,0 +1,119 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/token-feature-extractor.h"
+
+#include "util/hash/farmhash.h"
+
+namespace libtextclassifier {
+
+constexpr int kMaxWordLength = 20; // All words will be trimmed to this length.
+
+int TokenFeatureExtractor::HashToken(const std::string& token) const {
+ return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
+}
+
+std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
+ const Token& token) const {
+ std::vector<int> result;
+ if (token.is_padding) {
+ result.push_back(HashToken("<PAD>"));
+ } else {
+ const std::string& word = token.value;
+ std::string feature_word;
+
+ // Trim words that are over kMaxWordLength characters.
+ if (word.size() > kMaxWordLength) {
+ feature_word =
+ "^" + word.substr(0, kMaxWordLength / 2) + "\1" +
+ word.substr(word.size() - kMaxWordLength / 2, kMaxWordLength / 2) +
+ "$";
+ } else {
+ // Add a prefix and suffix to the word.
+ feature_word = "^" + word + "$";
+ }
+
+ // Upper-bound the number of charactergram extracted to avoid resizing.
+ result.reserve(options_.chargram_orders.size() * feature_word.size());
+
+ // Generate the character-grams.
+ for (int chargram_order : options_.chargram_orders) {
+ if (chargram_order == 1) {
+ for (int i = 1; i < feature_word.size() - 1; ++i) {
+ result.push_back(HashToken(feature_word.substr(i, 1)));
+ }
+ } else {
+ for (int i = 0;
+ i < static_cast<int>(feature_word.size()) - chargram_order + 1;
+ ++i) {
+ result.push_back(HashToken(feature_word.substr(i, chargram_order)));
+ }
+ }
+ }
+ }
+ return result;
+}
+
+bool TokenFeatureExtractor::Extract(const Token& token,
+ std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const {
+ if (sparse_features == nullptr || dense_features == nullptr) {
+ return false;
+ }
+
+ *sparse_features = ExtractCharactergramFeatures(token);
+
+ if (options_.extract_case_feature) {
+ // TODO(zilka): Make isupper Unicode-aware.
+ if (!token.value.empty() && isupper(*token.value.begin())) {
+ dense_features->push_back(1.0);
+ } else {
+ dense_features->push_back(-1.0);
+ }
+ }
+
+ if (options_.extract_selection_mask_feature) {
+ if (token.is_in_span) {
+ dense_features->push_back(1.0);
+ } else {
+ // TODO(zilka): Switch to -1. Doing this now would break current models.
+ dense_features->push_back(0.0);
+ }
+ }
+
+ return true;
+}
+
+bool TokenFeatureExtractor::Extract(
+ const std::vector<Token>& tokens,
+ std::vector<std::vector<int>>* sparse_features,
+ std::vector<std::vector<float>>* dense_features) const {
+ if (sparse_features == nullptr || dense_features == nullptr) {
+ return false;
+ }
+
+ sparse_features->resize(tokens.size());
+ dense_features->resize(tokens.size());
+ for (size_t i = 0; i < tokens.size(); i++) {
+ if (!Extract(tokens[i], &((*sparse_features)[i]),
+ &((*dense_features)[i]))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
new file mode 100644
index 0000000..9ba695e
--- /dev/null
+++ b/smartselect/token-feature-extractor.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
+
+#include <vector>
+
+#include "base.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+struct TokenFeatureExtractorOptions {
+ // Number of buckets used for hashing charactergrams.
+ int num_buckets;
+
+ // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
+ // character trigrams etc.
+ std::vector<int> chargram_orders;
+
+ // Whether to extract the token case feature.
+ bool extract_case_feature;
+
+ // Whether to extract the selection mask feature.
+ bool extract_selection_mask_feature;
+};
+
+class TokenFeatureExtractor {
+ public:
+ explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options)
+ : options_(options) {}
+
+ // Extracts features from a token.
+ // - sparse_features are indices into a sparse feature vector of size
+ // options.num_buckets which are set to 1.0 (others are implicitly 0.0).
+ // - dense_features are values of a dense feature vector of size 0-2
+ // (depending on the options) for the token
+ bool Extract(const Token& token, std::vector<int>* sparse_features,
+ std::vector<float>* dense_features) const;
+
+ // Convenience method that sequentially applies Extract to each Token.
+ bool Extract(const std::vector<Token>& tokens,
+ std::vector<std::vector<int>>* sparse_features,
+ std::vector<std::vector<float>>* dense_features) const;
+
+ protected:
+ // Hashes given token to given number of buckets.
+ int HashToken(const std::string& token) const;
+
+ // Extracts the charactergram features from the token.
+ std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
+
+ private:
+ TokenFeatureExtractorOptions options_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKEN_FEATURE_EXTRACTOR_H_
diff --git a/smartselect/tokenizer.cc b/smartselect/tokenizer.cc
new file mode 100644
index 0000000..72bb668
--- /dev/null
+++ b/smartselect/tokenizer.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/tokenizer.h"
+
+#include "util/strings/utf8.h"
+#include "util/utf8/unicodetext.h"
+
+namespace libtextclassifier {
+
+void Tokenizer::PrepareTokenizationCodepointRanges(
+ const std::vector<TokenizationCodepointRange> codepoint_range_configs) {
+ codepoint_ranges_.clear();
+ codepoint_ranges_.reserve(codepoint_range_configs.size());
+ for (const TokenizationCodepointRange& range : codepoint_range_configs) {
+ codepoint_ranges_.push_back(
+ CodepointRange(range.start(), range.end(), range.role()));
+ }
+
+ std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const CodepointRange& a, const CodepointRange& b) {
+ return a.start < b.start;
+ });
+}
+
+TokenizationCodepointRange::Role Tokenizer::FindTokenizationRole(
+ int codepoint) const {
+ auto it = std::lower_bound(codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ codepoint,
+ [](const CodepointRange& range, int codepoint) {
+ // This function compares range with the
+ // codepoint for the purpose of finding the first
+ // greater or equal range. Because of the use of
+ // std::lower_bound it needs to return true when
+ // range < codepoint; the first time it will
+ // return false the lower bound is found and
+ // returned.
+ //
+ // It might seem weird that the condition is
+ // range < codepoint here but when codepoint ==
+ // range.end it means it's actually just outside
+ // of the range, thus the range is less than the
+ // codepoint.
+ return range.end <= codepoint;
+ });
+ if (it != codepoint_ranges_.end() && it->start <= codepoint &&
+ it->end > codepoint) {
+ return it->role;
+ } else {
+ return TokenizationCodepointRange::DEFAULT_ROLE;
+ }
+}
+
+std::vector<Token> Tokenizer::Tokenize(const std::string& utf8_text) const {
+ UnicodeText context_unicode = UTF8ToUnicodeText(utf8_text, /*do_copy=*/false);
+
+ std::vector<Token> result;
+ Token new_token("", 0, 0);
+ int codepoint_index = 0;
+ for (auto it = context_unicode.begin(); it != context_unicode.end();
+ ++it, ++codepoint_index) {
+ TokenizationCodepointRange::Role role = FindTokenizationRole(*it);
+ if (role & TokenizationCodepointRange::SPLIT_BEFORE) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index, codepoint_index);
+ }
+ if (!(role & TokenizationCodepointRange::DISCARD_CODEPOINT)) {
+ new_token.value += std::string(
+ it.utf8_data(),
+ it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
+ ++new_token.end;
+ }
+ if (role & TokenizationCodepointRange::SPLIT_AFTER) {
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+ new_token = Token("", codepoint_index + 1, codepoint_index + 1);
+ }
+ }
+ if (!new_token.value.empty()) {
+ result.push_back(new_token);
+ }
+
+ return result;
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/tokenizer.h b/smartselect/tokenizer.h
new file mode 100644
index 0000000..9ed152f
--- /dev/null
+++ b/smartselect/tokenizer.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Tokenizer.
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
+
+#include <string>
+#include <vector>
+
+#include "smartselect/tokenizer.pb.h"
+#include "smartselect/types.h"
+#include "util/base/integral_types.h"
+
+namespace libtextclassifier {
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+struct CodepointRange {
+ int32 start;
+ int32 end;
+ TokenizationCodepointRange::Role role;
+
+ CodepointRange(int32 arg_start, int32 arg_end,
+ TokenizationCodepointRange::Role arg_role)
+ : start(arg_start), end(arg_end), role(arg_role) {}
+};
+
+// Tokenizer splits the input string into a sequence of tokens, according to the
+// configuration.
+class Tokenizer {
+ public:
+ explicit Tokenizer(
+ const std::vector<TokenizationCodepointRange>& codepoint_range_configs) {
+ PrepareTokenizationCodepointRanges(codepoint_range_configs);
+ }
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+ protected:
+ // Prepares tokenization codepoint ranges for use in tokenization.
+ void PrepareTokenizationCodepointRanges(
+ const std::vector<TokenizationCodepointRange> codepoint_range_configs);
+
+ // Finds the tokenization role for given codepoint.
+ // If the character is not found returns DEFAULT_ROLE.
+ // Internally uses binary search so should be O(log2(# of codepoint_ranges)).
+ TokenizationCodepointRange::Role FindTokenizationRole(int codepoint) const;
+
+ private:
+ // Codepoint ranges that determine how different codepoints are tokenized.
+ // The ranges must not overlap.
+ std::vector<CodepointRange> codepoint_ranges_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TOKENIZER_H_
diff --git a/smartselect/tokenizer.proto b/smartselect/tokenizer.proto
new file mode 100644
index 0000000..8e78970
--- /dev/null
+++ b/smartselect/tokenizer.proto
@@ -0,0 +1,48 @@
+// Copyright (C) 2017 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+
+package libtextclassifier;
+
+// Represents a codepoint range [start, end) with its role for tokenization.
+message TokenizationCodepointRange {
+ optional int32 start = 1;
+ optional int32 end = 2;
+
+ // Role of the codepoints in the range.
+ enum Role {
+ // Concatenates the codepoint to the current run of codepoints.
+ DEFAULT_ROLE = 0;
+
+ // Splits a run of codepoints before the current codepoint.
+ SPLIT_BEFORE = 0x1;
+
+ // Splits a run of codepoints after the current codepoint.
+ SPLIT_AFTER = 0x2;
+
+ // Discards the codepoint.
+ DISCARD_CODEPOINT = 0x4;
+
+ // Common values:
+ // Splits on the characters and discards them. Good e.g. for the space
+ // character.
+ WHITESPACE_SEPARATOR = 0x7;
+ // Each codepoint will be a separate token. Good e.g. for Chinese
+ // characters.
+ TOKEN_SEPARATOR = 0x3;
+ }
+ optional Role role = 3;
+}
diff --git a/smartselect/types.h b/smartselect/types.h
new file mode 100644
index 0000000..12e93d9
--- /dev/null
+++ b/smartselect/types.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_TYPES_H_
+
+#include <ostream>
+#include <string>
+#include <utility>
+
+namespace libtextclassifier {
+
+constexpr int kInvalidIndex = -1;
+
+// Index for a 0-based array of tokens.
+using TokenIndex = int;
+
+// Index for a 0-based array of codepoints.
+using CodepointIndex = int;
+
+// Marks a span in a sequence of codepoints. The first element is the index of
+// the first codepoint of the span, and the second element is the index of the
+// codepoint one past the end of the span.
+using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
+
+// Marks a span in a sequence of tokens. The first element is the index of the
+// first token in the span, and the second element is the index of the token one
+// past the end of the span.
+using TokenSpan = std::pair<TokenIndex, TokenIndex>;
+
+// Token holds a token, its position in the original string and whether it was
+// part of the input span.
+struct Token {
+ std::string value;
+ CodepointIndex start;
+ CodepointIndex end;
+
+ // Whether the token was in the input span.
+ bool is_in_span;
+
+ // Whether the token is a padding token.
+ bool is_padding;
+
+ // Default constructor constructs the padding-token.
+ Token()
+ : value(""),
+ start(kInvalidIndex),
+ end(kInvalidIndex),
+ is_in_span(false),
+ is_padding(true) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end)
+ : Token(arg_value, arg_start, arg_end, false) {}
+
+ Token(const std::string& arg_value, CodepointIndex arg_start,
+ CodepointIndex arg_end, bool is_in_span)
+ : value(arg_value),
+ start(arg_start),
+ end(arg_end),
+ is_in_span(is_in_span),
+ is_padding(false) {}
+
+ bool operator==(const Token& other) const {
+ return value == other.value && start == other.start && end == other.end &&
+ is_in_span == other.is_in_span && is_padding == other.is_padding;
+ }
+};
+
+// Pretty-printing function for Token.
+inline std::ostream& operator<<(std::ostream& os, const Token& token) {
+ return os << "Token(\"" << token.value << "\", " << token.start << ", "
+ << token.end << ", is_in_span=" << token.is_in_span
+ << ", is_padding=" << token.is_padding << ")";
+}
+
+// Represents a selection.
+struct SelectionWithContext {
+ SelectionWithContext()
+ : context(""),
+ selection_start(-1),
+ selection_end(-1),
+ click_start(-1),
+ click_end(-1) {}
+
+ // UTF8 encoded context.
+ std::string context;
+
+ // Codepoint index to the context where selection starts.
+ CodepointIndex selection_start;
+
+ // Codepoint index to the context one past where selection ends.
+ CodepointIndex selection_end;
+
+ // Codepoint index to the context where click starts.
+ CodepointIndex click_start;
+
+ // Codepoint index to the context one past where click ends.
+ CodepointIndex click_end;
+
+ // Type of the selection.
+ std::string collection;
+
+ CodepointSpan GetClickSpan() const { return {click_start, click_end}; }
+
+ CodepointSpan GetSelectionSpan() const {
+ return {selection_start, selection_end};
+ }
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_TYPES_H_