Sync of lib2 to AOSP.
Model comes from experiment: 2524_BoundsEnglishv5_R1
Bug: 68239358
Test: Builds & tested on device.
Change-Id: I65cb7f0b067b68e3e1c22ee87232555887446089
diff --git a/feature-processor.h b/feature-processor.h
new file mode 100644
index 0000000..834c260
--- /dev/null
+++ b/feature-processor.h
@@ -0,0 +1,299 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Feature processing for FFModel (feed-forward SmartSelection model).
+
+#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
+
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "cached-features.h"
+#include "model_generated.h"
+#include "token-feature-extractor.h"
+#include "tokenizer.h"
+#include "types.h"
+#include "util/base/integral_types.h"
+#include "util/base/logging.h"
+#include "util/utf8/unicodetext.h"
+#include "util/utf8/unilib.h"
+
+namespace libtextclassifier2 {
+
+constexpr int kInvalidLabel = -1;
+
+namespace internal {
+
+TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
+ const FeatureProcessorOptions* options);
+
+// Splits tokens that contain the selection boundary inside them.
+// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
+void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
+ std::vector<Token>* tokens);
+
+// Returns the index of token that corresponds to the codepoint span.
+int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
+
+// Returns the index of token that corresponds to the middle of the codepoint
+// span.
+int CenterTokenFromMiddleOfSelection(
+ CodepointSpan span, const std::vector<Token>& selectable_tokens);
+
+// Strips the tokens from the tokens vector that are not used for feature
+// extraction because they are out of scope, or pads them so that there is
+// enough tokens in the required context_size for all inferences with a click
+// in relative_click_span.
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos);
+
+// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is
+// nullptr, will create UniLib, assign ownership to owned_unilib, and return it.
+UniLib* MaybeCreateUnilib(UniLib* unilib,
+ std::unique_ptr<UniLib>* owned_unilib);
+
+} // namespace internal
+
+// Converts a codepoint span to a token span in the given list of tokens.
+// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
+// token to overlap with the codepoint range to be considered part of it.
+// Otherwise it must be fully included in the range.
+TokenSpan CodepointSpanToTokenSpan(
+ const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
+ bool snap_boundaries_to_containing_tokens = false);
+
+// Converts a token span to a codepoint span in the given list of tokens.
+CodepointSpan TokenSpanToCodepointSpan(
+ const std::vector<Token>& selectable_tokens, TokenSpan token_span);
+
+// Takes care of preparing features for the span prediction model.
+class FeatureProcessor {
+ public:
+ // If unilib is nullptr, will create and own an instance of a UniLib,
+ // otherwise will use what's passed in.
+ explicit FeatureProcessor(const FeatureProcessorOptions* options,
+ UniLib* unilib = nullptr)
+ : owned_unilib_(nullptr),
+ unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)),
+ feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
+ *unilib_),
+ options_(options),
+ tokenizer_(
+ options->tokenization_codepoint_config() != nullptr
+ ? Tokenizer({options->tokenization_codepoint_config()->begin(),
+ options->tokenization_codepoint_config()->end()},
+ options->tokenize_on_script_change())
+ : Tokenizer({}, /*split_on_script_change=*/false)) {
+ MakeLabelMaps();
+ if (options->supported_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
+ options->supported_codepoint_ranges()->end()},
+ &supported_codepoint_ranges_);
+ }
+ if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
+ PrepareCodepointRanges(
+ {options->internal_tokenizer_codepoint_ranges()->begin(),
+ options->internal_tokenizer_codepoint_ranges()->end()},
+ &internal_tokenizer_codepoint_ranges_);
+ }
+ PrepareIgnoredSpanBoundaryCodepoints();
+ }
+
+ // Tokenizes the input string using the selected tokenization method.
+ std::vector<Token> Tokenize(const std::string& utf8_text) const;
+
+ // Converts a label into a token span.
+ bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+
+ // Gets the total number of selection labels.
+ int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
+ // Gets the string value for given collection label.
+ std::string LabelToCollection(int label) const;
+
+ // Gets the total number of collections of the model.
+ int NumCollections() const { return collection_to_label_.size(); }
+
+ // Gets the name of the default collection.
+ std::string GetDefaultCollection() const;
+
+ const FeatureProcessorOptions* GetOptions() const { return options_; }
+
+ // Tokenizes the context and input span, and finds the click position.
+ void TokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Extracts features as a CachedFeatures object that can be used for repeated
+ // inference over token spans in the given context.
+ bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
+ EmbeddingExecutor* embedding_executor,
+ int feature_vector_size,
+ std::unique_ptr<CachedFeatures>* cached_features) const;
+
+ // Fills selection_label_spans with CodepointSpans that correspond to the
+ // selection labels. The CodepointSpans are based on the codepoint ranges of
+ // given tokens.
+ bool SelectionLabelSpans(
+ VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const;
+
+ int DenseFeaturesCount() const {
+ return feature_extractor_.DenseFeaturesCount();
+ }
+
+ int EmbeddingSize() const { return options_->embedding_size(); }
+
+ // Splits context to several segments according to configuration.
+ std::vector<UnicodeTextRange> SplitContext(
+ const UnicodeText& context_unicode) const;
+
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
+ protected:
+ // Represents a codepoint range [start, end).
+ struct CodepointRange {
+ int32 start;
+ int32 end;
+
+ CodepointRange(int32 arg_start, int32 arg_end)
+ : start(arg_start), end(arg_end) {}
+ };
+
+ // Returns the class id corresponding to the given string collection
+ // identifier. There is a catch-all class id that the function returns for
+ // unknown collections.
+ int CollectionToLabel(const std::string& collection) const;
+
+ // Prepares mapping from collection names to labels.
+ void MakeLabelMaps();
+
+ // Gets the number of spannable tokens for the model.
+ //
+ // Spannable tokens are those tokens of context, which the model predicts
+ // selection spans over (i.e., there is 1:1 correspondence between the output
+ // classes of the model and each of the spannable tokens).
+ int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
+
+ // Converts a label into a span of codepoint indices corresponding to it
+ // given output_tokens.
+ bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
+ CodepointSpan* span) const;
+
+ // Converts a span to the corresponding label given output_tokens.
+ bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
+ const std::vector<Token>& output_tokens, int* label) const;
+
+ // Converts a token span to the corresponding label.
+ int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
+
+ void PrepareCodepointRanges(
+ const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
+ codepoint_ranges,
+ std::vector<CodepointRange>* prepared_codepoint_ranges);
+
+ // Returns the ratio of supported codepoints to total number of codepoints in
+ // the given token span.
+ float SupportedCodepointsRatio(const TokenSpan& token_span,
+ const std::vector<Token>& tokens) const;
+
+ // Returns true if given codepoint is covered by the given sorted vector of
+ // codepoint ranges.
+ bool IsCodepointInRanges(
+ int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
+ // Finds the center token index in tokens vector, using the method defined
+ // in options_.
+ int FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const;
+
+ // Tokenizes the input text using ICU tokenizer.
+ bool ICUTokenize(const std::string& context,
+ std::vector<Token>* result) const;
+
+ // Takes the result of ICU tokenization and retokenizes stretches of tokens
+ // made of a specific subset of characters using the internal tokenizer.
+ void InternalRetokenize(const std::string& context,
+ std::vector<Token>* tokens) const;
+
+ // Tokenizes a substring of the unicode string, appending the resulting tokens
+ // to the output vector. The resulting tokens have bounds relative to the full
+ // string. Does nothing if the start of the span is negative.
+ void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
+ std::vector<Token>* result) const;
+
+ // Removes all tokens from tokens that are not on a line (defined by calling
+ // SplitContext on the context) to which span points.
+ void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
+ std::vector<Token>* tokens) const;
+
+ private:
+ std::unique_ptr<UniLib> owned_unilib_;
+ UniLib* unilib_;
+
+ protected:
+ const TokenFeatureExtractor feature_extractor_;
+
+ // Codepoint ranges that define what codepoints are supported by the model.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> supported_codepoint_ranges_;
+
+ // Codepoint ranges that define which tokens (consisting of which codepoints)
+ // should be re-tokenized with the internal tokenizer in the mixed
+ // tokenization mode.
+ // NOTE: Must be sorted.
+ std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
+
+ private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
+ const FeatureProcessorOptions* const options_;
+
+ // Mapping between token selection spans and labels ids.
+ std::map<TokenSpan, int> selection_to_label_;
+ std::vector<TokenSpan> label_to_selection_;
+
+ // Mapping between collections and labels.
+ std::map<std::string, int> collection_to_label_;
+
+ Tokenizer tokenizer_;
+};
+
+} // namespace libtextclassifier2
+
+#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_