Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2017 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | // Feature processing for FFModel (feed-forward SmartSelection model). |
| 18 | |
| 19 | #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ |
| 20 | #define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ |
| 21 | |
| 22 | #include <memory> |
| 23 | #include <random> |
| 24 | #include <string> |
| 25 | #include <vector> |
| 26 | |
| 27 | #include "common/feature-extractor.h" |
| 28 | #include "smartselect/text-classification-model.pb.h" |
| 29 | #include "smartselect/token-feature-extractor.h" |
| 30 | #include "smartselect/tokenizer.h" |
| 31 | #include "smartselect/types.h" |
| 32 | |
| 33 | namespace libtextclassifier { |
| 34 | |
| 35 | constexpr int kInvalidLabel = -1; |
| 36 | |
| 37 | namespace internal { |
| 38 | |
| 39 | // Parses the serialized protocol buffer. |
| 40 | FeatureProcessorOptions ParseSerializedOptions( |
| 41 | const std::string& serialized_options); |
| 42 | |
| 43 | TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( |
| 44 | const FeatureProcessorOptions& options); |
| 45 | |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 46 | // Removes tokens that are not part of a line of the context which contains |
| 47 | // given span. |
| 48 | void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, |
| 49 | std::vector<Token>* tokens); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 50 | |
| 51 | // Splits tokens that contain the selection boundary inside them. |
| 52 | // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" |
| 53 | void SplitTokensOnSelectionBoundaries(CodepointSpan selection, |
| 54 | std::vector<Token>* tokens); |
| 55 | |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 56 | // Returns the index of token that corresponds to the codepoint span. |
| 57 | int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens); |
| 58 | |
| 59 | // Returns the index of token that corresponds to the middle of the codepoint |
| 60 | // span. |
| 61 | int CenterTokenFromMiddleOfSelection( |
| 62 | CodepointSpan span, const std::vector<Token>& selectable_tokens); |
| 63 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 64 | } // namespace internal |
| 65 | |
| 66 | TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens, |
| 67 | CodepointSpan codepoint_span); |
| 68 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 69 | // Takes care of preparing features for the FFModel. |
| 70 | class FeatureProcessor { |
| 71 | public: |
| 72 | explicit FeatureProcessor(const FeatureProcessorOptions& options) |
| 73 | : options_(options), |
| 74 | feature_extractor_( |
| 75 | internal::BuildTokenFeatureExtractorOptions(options)), |
| 76 | feature_type_(FeatureProcessor::kFeatureTypeName, |
| 77 | options.num_buckets()), |
| 78 | tokenizer_({options.tokenization_codepoint_config().begin(), |
| 79 | options.tokenization_codepoint_config().end()}), |
| 80 | random_(new std::mt19937(std::random_device()())) { |
| 81 | MakeLabelMaps(); |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame^] | 82 | PrepareSupportedCodepointRanges( |
| 83 | {options.supported_codepoint_ranges().begin(), |
| 84 | options.supported_codepoint_ranges().end()}); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 85 | } |
| 86 | |
| 87 | explicit FeatureProcessor(const std::string& serialized_options) |
| 88 | : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) { |
| 89 | } |
| 90 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 91 | // Tokenizes the input string using the selected tokenization method. |
| 92 | std::vector<Token> Tokenize(const std::string& utf8_text) const; |
| 93 | |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 94 | bool GetFeatures(const std::string& context, CodepointSpan input_span, |
| 95 | std::vector<nlp_core::FeatureVector>* features, |
| 96 | std::vector<float>* extra_features, |
| 97 | std::vector<CodepointSpan>* selection_label_spans) const; |
| 98 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 99 | // NOTE: If dropout is on, subsequent calls of this function with the same |
| 100 | // arguments might return different results. |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 101 | bool GetFeaturesAndLabels(const std::string& context, |
| 102 | CodepointSpan input_span, CodepointSpan label_span, |
| 103 | const std::string& label_collection, |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 104 | std::vector<nlp_core::FeatureVector>* features, |
| 105 | std::vector<float>* extra_features, |
| 106 | std::vector<CodepointSpan>* selection_label_spans, |
| 107 | int* selection_label, |
| 108 | CodepointSpan* selection_codepoint_label, |
| 109 | int* classification_label) const; |
| 110 | |
| 111 | // Same as above but uses std::vector instead of FeatureVector. |
| 112 | // NOTE: If dropout is on, subsequent calls of this function with the same |
| 113 | // arguments might return different results. |
| 114 | bool GetFeaturesAndLabels( |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 115 | const std::string& context, CodepointSpan input_span, |
| 116 | CodepointSpan label_span, const std::string& label_collection, |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 117 | std::vector<std::vector<std::pair<int, float>>>* features, |
| 118 | std::vector<float>* extra_features, |
| 119 | std::vector<CodepointSpan>* selection_label_spans, int* selection_label, |
| 120 | CodepointSpan* selection_codepoint_label, |
| 121 | int* classification_label) const; |
| 122 | |
| 123 | // Converts a label into a token span. |
| 124 | bool LabelToTokenSpan(int label, TokenSpan* token_span) const; |
| 125 | |
| 126 | // Gets the string value for given collection label. |
| 127 | std::string LabelToCollection(int label) const; |
| 128 | |
| 129 | // Gets the total number of collections of the model. |
| 130 | int NumCollections() const { return collection_to_label_.size(); } |
| 131 | |
| 132 | // Gets the name of the default collection. |
| 133 | std::string GetDefaultCollection() const { |
| 134 | return options_.collections(options_.default_collection()); |
| 135 | } |
| 136 | |
| 137 | FeatureProcessorOptions GetOptions() const { return options_; } |
| 138 | |
| 139 | int GetSelectionLabelCount() const { return label_to_selection_.size(); } |
| 140 | |
| 141 | // Sets the source of randomness. |
| 142 | void SetRandom(std::mt19937* new_random) { random_.reset(new_random); } |
| 143 | |
| 144 | protected: |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame^] | 145 | // Represents a codepoint range [start, end). |
| 146 | struct CodepointRange { |
| 147 | int32 start; |
| 148 | int32 end; |
| 149 | |
| 150 | CodepointRange(int32 arg_start, int32 arg_end) |
| 151 | : start(arg_start), end(arg_end) {} |
| 152 | }; |
| 153 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 154 | // Extracts features for given word. |
| 155 | std::vector<int> GetWordFeatures(const std::string& word) const; |
| 156 | |
| 157 | // NOTE: If dropout is on, subsequent calls of this function with the same |
| 158 | // arguments might return different results. |
| 159 | bool ComputeFeatures(int click_pos, |
| 160 | const std::vector<Token>& selectable_tokens, |
| 161 | CodepointSpan selected_span, |
| 162 | std::vector<nlp_core::FeatureVector>* features, |
| 163 | std::vector<float>* extra_features, |
| 164 | std::vector<Token>* output_tokens) const; |
| 165 | |
| 166 | // Helper function that computes how much left context and how much right |
| 167 | // context should be dropped. Uses a mutable random_ member as a source of |
| 168 | // randomness. |
| 169 | bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const; |
| 170 | |
| 171 | // Returns the class id corresponding to the given string collection |
| 172 | // identifier. There is a catch-all class id that the function returns for |
| 173 | // unknown collections. |
| 174 | int CollectionToLabel(const std::string& collection) const; |
| 175 | |
| 176 | // Prepares mapping from collection names to labels. |
| 177 | void MakeLabelMaps(); |
| 178 | |
| 179 | // Gets the number of spannable tokens for the model. |
| 180 | // |
| 181 | // Spannable tokens are those tokens of context, which the model predicts |
| 182 | // selection spans over (i.e., there is 1:1 correspondence between the output |
| 183 | // classes of the model and each of the spannable tokens). |
| 184 | int GetNumContextTokens() const { return options_.context_size() * 2 + 1; } |
| 185 | |
| 186 | // Converts a label into a span of codepoint indices corresponding to it |
| 187 | // given output_tokens. |
| 188 | bool LabelToSpan(int label, const std::vector<Token>& output_tokens, |
| 189 | CodepointSpan* span) const; |
| 190 | |
| 191 | // Converts a span to the corresponding label given output_tokens. |
| 192 | bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span, |
| 193 | const std::vector<Token>& output_tokens, int* label) const; |
| 194 | |
| 195 | // Converts a token span to the corresponding label. |
| 196 | int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; |
| 197 | |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 198 | // Finds the center token index in tokens vector, using the method defined |
| 199 | // in options_. |
| 200 | int FindCenterToken(CodepointSpan span, |
| 201 | const std::vector<Token>& tokens) const; |
| 202 | |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame^] | 203 | void PrepareSupportedCodepointRanges( |
| 204 | const std::vector<FeatureProcessorOptions::CodepointRange>& |
| 205 | codepoint_range_configs); |
| 206 | |
| 207 | // Returns the ratio of supported codepoints to total number of codepoints in |
| 208 | // the input context around given click position. |
| 209 | float SupportedCodepointsRatio(int click_pos, |
| 210 | const std::vector<Token>& tokens) const; |
| 211 | |
| 212 | // Returns true if given codepoint is supported. |
| 213 | bool IsCodepointSupported(int codepoint) const; |
| 214 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 215 | private: |
| 216 | FeatureProcessorOptions options_; |
| 217 | |
| 218 | TokenFeatureExtractor feature_extractor_; |
| 219 | |
| 220 | static const char* const kFeatureTypeName; |
| 221 | |
| 222 | nlp_core::NumericFeatureType feature_type_; |
| 223 | |
| 224 | // Mapping between token selection spans and labels ids. |
| 225 | std::map<TokenSpan, int> selection_to_label_; |
| 226 | std::vector<TokenSpan> label_to_selection_; |
| 227 | |
| 228 | // Mapping between collections and labels. |
| 229 | std::map<std::string, int> collection_to_label_; |
| 230 | |
| 231 | Tokenizer tokenizer_; |
| 232 | |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame^] | 233 | // Codepoint ranges that define what codepoints are supported by the model. |
| 234 | // NOTE: Must be sorted. |
| 235 | std::vector<CodepointRange> supported_codepoint_ranges_; |
| 236 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 237 | // Source of randomness. |
| 238 | mutable std::unique_ptr<std::mt19937> random_; |
| 239 | }; |
| 240 | |
| 241 | } // namespace libtextclassifier |
| 242 | |
| 243 | #endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ |