Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 1 | /* |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 2 | * Copyright (C) 2018 The Android Open Source Project |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | // Feature processing for FFModel (feed-forward SmartSelection model). |
| 18 | |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 19 | #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ |
| 20 | #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 21 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 22 | #include <map> |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 23 | #include <memory> |
Lukas Zilka | e5ea2ab | 2017-10-11 10:50:05 +0200 | [diff] [blame] | 24 | #include <set> |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 25 | #include <string> |
| 26 | #include <vector> |
| 27 | |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 28 | #include "annotator/cached-features.h" |
| 29 | #include "annotator/model_generated.h" |
| 30 | #include "annotator/token-feature-extractor.h" |
| 31 | #include "annotator/tokenizer.h" |
| 32 | #include "annotator/types.h" |
| 33 | #include "utils/base/integral_types.h" |
| 34 | #include "utils/base/logging.h" |
| 35 | #include "utils/utf8/unicodetext.h" |
| 36 | #include "utils/utf8/unilib.h" |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 37 | |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 38 | namespace libtextclassifier3 { |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 39 | |
| 40 | constexpr int kInvalidLabel = -1; |
| 41 | |
| 42 | namespace internal { |
| 43 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 44 | TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 45 | const FeatureProcessorOptions* options); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 46 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 47 | // Splits tokens that contain the selection boundary inside them. |
| 48 | // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" |
| 49 | void SplitTokensOnSelectionBoundaries(CodepointSpan selection, |
| 50 | std::vector<Token>* tokens); |
| 51 | |
Matt Sharifi | be876dc | 2017-03-17 17:02:43 +0100 | [diff] [blame] | 52 | // Returns the index of token that corresponds to the codepoint span. |
| 53 | int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens); |
| 54 | |
| 55 | // Returns the index of token that corresponds to the middle of the codepoint |
| 56 | // span. |
| 57 | int CenterTokenFromMiddleOfSelection( |
| 58 | CodepointSpan span, const std::vector<Token>& selectable_tokens); |
| 59 | |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 60 | // Strips the tokens from the tokens vector that are not used for feature |
| 61 | // extraction because they are out of scope, or pads them so that there is |
| 62 | // enough tokens in the required context_size for all inferences with a click |
| 63 | // in relative_click_span. |
| 64 | void StripOrPadTokens(TokenSpan relative_click_span, int context_size, |
| 65 | std::vector<Token>* tokens, int* click_pos); |
| 66 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 67 | } // namespace internal |
| 68 | |
Lukas Zilka | 40c18de | 2017-04-10 17:22:22 +0200 | [diff] [blame] | 69 | // Converts a codepoint span to a token span in the given list of tokens. |
Lukas Zilka | 726b4d2 | 2017-12-13 16:37:03 +0100 | [diff] [blame] | 70 | // If snap_boundaries_to_containing_tokens is set to true, it is enough for a |
| 71 | // token to overlap with the codepoint range to be considered part of it. |
| 72 | // Otherwise it must be fully included in the range. |
| 73 | TokenSpan CodepointSpanToTokenSpan( |
| 74 | const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span, |
| 75 | bool snap_boundaries_to_containing_tokens = false); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 76 | |
Lukas Zilka | 40c18de | 2017-04-10 17:22:22 +0200 | [diff] [blame] | 77 | // Converts a token span to a codepoint span in the given list of tokens. |
| 78 | CodepointSpan TokenSpanToCodepointSpan( |
| 79 | const std::vector<Token>& selectable_tokens, TokenSpan token_span); |
| 80 | |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 81 | // Takes care of preparing features for the span prediction model. |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 82 | class FeatureProcessor { |
| 83 | public: |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 84 | // A cache mapping codepoint spans to embedded tokens features. An instance |
| 85 | // can be provided to multiple calls to ExtractFeatures() operating on the |
| 86 | // same context (the same codepoint spans corresponding to the same tokens), |
| 87 | // as an optimization. Note that the tokenizations do not have to be |
| 88 | // identical. |
| 89 | typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; |
| 90 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 91 | // If unilib is nullptr, will create and own an instance of a UniLib, |
| 92 | // otherwise will use what's passed in. |
| 93 | explicit FeatureProcessor(const FeatureProcessorOptions* options, |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 94 | const UniLib* unilib) |
| 95 | : unilib_(unilib), |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 96 | feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), |
| 97 | *unilib_), |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 98 | options_(options), |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 99 | tokenizer_( |
| 100 | options->tokenization_codepoint_config() != nullptr |
| 101 | ? Tokenizer({options->tokenization_codepoint_config()->begin(), |
| 102 | options->tokenization_codepoint_config()->end()}, |
| 103 | options->tokenize_on_script_change()) |
| 104 | : Tokenizer({}, /*split_on_script_change=*/false)) { |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 105 | MakeLabelMaps(); |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 106 | if (options->supported_codepoint_ranges() != nullptr) { |
| 107 | PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(), |
| 108 | options->supported_codepoint_ranges()->end()}, |
| 109 | &supported_codepoint_ranges_); |
| 110 | } |
| 111 | if (options->internal_tokenizer_codepoint_ranges() != nullptr) { |
| 112 | PrepareCodepointRanges( |
| 113 | {options->internal_tokenizer_codepoint_ranges()->begin(), |
| 114 | options->internal_tokenizer_codepoint_ranges()->end()}, |
| 115 | &internal_tokenizer_codepoint_ranges_); |
| 116 | } |
Lukas Zilka | e5ea2ab | 2017-10-11 10:50:05 +0200 | [diff] [blame] | 117 | PrepareIgnoredSpanBoundaryCodepoints(); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 118 | } |
| 119 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 120 | // Tokenizes the input string using the selected tokenization method. |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 121 | std::vector<Token> Tokenize(const std::string& text) const; |
| 122 | |
| 123 | // Same as above but takes UnicodeText. |
| 124 | std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 125 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 126 | // Converts a label into a token span. |
| 127 | bool LabelToTokenSpan(int label, TokenSpan* token_span) const; |
| 128 | |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 129 | // Gets the total number of selection labels. |
| 130 | int GetSelectionLabelCount() const { return label_to_selection_.size(); } |
| 131 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 132 | // Gets the string value for given collection label. |
| 133 | std::string LabelToCollection(int label) const; |
| 134 | |
| 135 | // Gets the total number of collections of the model. |
| 136 | int NumCollections() const { return collection_to_label_.size(); } |
| 137 | |
| 138 | // Gets the name of the default collection. |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 139 | std::string GetDefaultCollection() const; |
| 140 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 141 | const FeatureProcessorOptions* GetOptions() const { return options_; } |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 142 | |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 143 | // Retokenizes the context and input span, and finds the click position. |
| 144 | // Depending on the options, might modify tokens (split them or remove them). |
| 145 | void RetokenizeAndFindClick(const std::string& context, |
| 146 | CodepointSpan input_span, |
| 147 | bool only_use_line_with_click, |
| 148 | std::vector<Token>* tokens, int* click_pos) const; |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 149 | |
| 150 | // Same as above but takes UnicodeText. |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 151 | void RetokenizeAndFindClick(const UnicodeText& context_unicode, |
| 152 | CodepointSpan input_span, |
| 153 | bool only_use_line_with_click, |
| 154 | std::vector<Token>* tokens, int* click_pos) const; |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 155 | |
Lukas Zilka | 434442d | 2018-04-25 11:38:51 +0200 | [diff] [blame] | 156 | // Returns true if the token span has enough supported codepoints (as defined |
| 157 | // in the model config) or not and model should not run. |
| 158 | bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens, |
| 159 | TokenSpan token_span) const; |
| 160 | |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 161 | // Extracts features as a CachedFeatures object that can be used for repeated |
| 162 | // inference over token spans in the given context. |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 163 | bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span, |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 164 | CodepointSpan selection_span_for_feature, |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 165 | const EmbeddingExecutor* embedding_executor, |
| 166 | EmbeddingCache* embedding_cache, int feature_vector_size, |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 167 | std::unique_ptr<CachedFeatures>* cached_features) const; |
| 168 | |
| 169 | // Fills selection_label_spans with CodepointSpans that correspond to the |
| 170 | // selection labels. The CodepointSpans are based on the codepoint ranges of |
| 171 | // given tokens. |
| 172 | bool SelectionLabelSpans( |
| 173 | VectorSpan<Token> tokens, |
| 174 | std::vector<CodepointSpan>* selection_label_spans) const; |
| 175 | |
| 176 | int DenseFeaturesCount() const { |
| 177 | return feature_extractor_.DenseFeaturesCount(); |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 178 | } |
| 179 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 180 | int EmbeddingSize() const { return options_->embedding_size(); } |
| 181 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 182 | // Splits context to several segments. |
Lukas Zilka | 726b4d2 | 2017-12-13 16:37:03 +0100 | [diff] [blame] | 183 | std::vector<UnicodeTextRange> SplitContext( |
| 184 | const UnicodeText& context_unicode) const; |
| 185 | |
Lukas Zilka | e5ea2ab | 2017-10-11 10:50:05 +0200 | [diff] [blame] | 186 | // Strips boundary codepoints from the span in context and returns the new |
| 187 | // start and end indices. If the span comprises entirely of boundary |
| 188 | // codepoints, the first index of span is returned for both indices. |
| 189 | CodepointSpan StripBoundaryCodepoints(const std::string& context, |
| 190 | CodepointSpan span) const; |
| 191 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 192 | // Same as above but takes UnicodeText. |
| 193 | CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode, |
| 194 | CodepointSpan span) const; |
| 195 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 196 | protected: |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame] | 197 | // Represents a codepoint range [start, end). |
| 198 | struct CodepointRange { |
| 199 | int32 start; |
| 200 | int32 end; |
| 201 | |
| 202 | CodepointRange(int32 arg_start, int32 arg_end) |
| 203 | : start(arg_start), end(arg_end) {} |
| 204 | }; |
| 205 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 206 | // Returns the class id corresponding to the given string collection |
| 207 | // identifier. There is a catch-all class id that the function returns for |
| 208 | // unknown collections. |
| 209 | int CollectionToLabel(const std::string& collection) const; |
| 210 | |
| 211 | // Prepares mapping from collection names to labels. |
| 212 | void MakeLabelMaps(); |
| 213 | |
| 214 | // Gets the number of spannable tokens for the model. |
| 215 | // |
| 216 | // Spannable tokens are those tokens of context, which the model predicts |
| 217 | // selection spans over (i.e., there is 1:1 correspondence between the output |
| 218 | // classes of the model and each of the spannable tokens). |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 219 | int GetNumContextTokens() const { return options_->context_size() * 2 + 1; } |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 220 | |
| 221 | // Converts a label into a span of codepoint indices corresponding to it |
| 222 | // given output_tokens. |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 223 | bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens, |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 224 | CodepointSpan* span) const; |
| 225 | |
| 226 | // Converts a span to the corresponding label given output_tokens. |
| 227 | bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span, |
| 228 | const std::vector<Token>& output_tokens, int* label) const; |
| 229 | |
| 230 | // Converts a token span to the corresponding label. |
| 231 | int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; |
| 232 | |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 233 | void PrepareCodepointRanges( |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 234 | const std::vector<const FeatureProcessorOptions_::CodepointRange*>& |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 235 | codepoint_ranges, |
| 236 | std::vector<CodepointRange>* prepared_codepoint_ranges); |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame] | 237 | |
| 238 | // Returns the ratio of supported codepoints to total number of codepoints in |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 239 | // the given token span. |
| 240 | float SupportedCodepointsRatio(const TokenSpan& token_span, |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame] | 241 | const std::vector<Token>& tokens) const; |
| 242 | |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 243 | // Returns true if given codepoint is covered by the given sorted vector of |
| 244 | // codepoint ranges. |
| 245 | bool IsCodepointInRanges( |
| 246 | int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const; |
Lukas Zilka | 26e8c2e | 2017-04-06 15:54:24 +0200 | [diff] [blame] | 247 | |
Lukas Zilka | e5ea2ab | 2017-10-11 10:50:05 +0200 | [diff] [blame] | 248 | void PrepareIgnoredSpanBoundaryCodepoints(); |
| 249 | |
| 250 | // Counts the number of span boundary codepoints. If count_from_beginning is |
| 251 | // True, the counting will start at the span_start iterator (inclusive) and at |
| 252 | // maximum end at span_end (exclusive). If count_from_beginning is True, the |
| 253 | // counting will start from span_end (exclusive) and end at span_start |
| 254 | // (inclusive). |
| 255 | int CountIgnoredSpanBoundaryCodepoints( |
| 256 | const UnicodeText::const_iterator& span_start, |
| 257 | const UnicodeText::const_iterator& span_end, |
| 258 | bool count_from_beginning) const; |
| 259 | |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 260 | // Finds the center token index in tokens vector, using the method defined |
| 261 | // in options_. |
| 262 | int FindCenterToken(CodepointSpan span, |
| 263 | const std::vector<Token>& tokens) const; |
| 264 | |
Lukas Zilka | 40c18de | 2017-04-10 17:22:22 +0200 | [diff] [blame] | 265 | // Tokenizes the input text using ICU tokenizer. |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 266 | bool ICUTokenize(const UnicodeText& context_unicode, |
Lukas Zilka | 40c18de | 2017-04-10 17:22:22 +0200 | [diff] [blame] | 267 | std::vector<Token>* result) const; |
| 268 | |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 269 | // Takes the result of ICU tokenization and retokenizes stretches of tokens |
| 270 | // made of a specific subset of characters using the internal tokenizer. |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 271 | void InternalRetokenize(const UnicodeText& unicode_text, |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 272 | std::vector<Token>* tokens) const; |
| 273 | |
| 274 | // Tokenizes a substring of the unicode string, appending the resulting tokens |
| 275 | // to the output vector. The resulting tokens have bounds relative to the full |
| 276 | // string. Does nothing if the start of the span is negative. |
| 277 | void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span, |
| 278 | std::vector<Token>* result) const; |
| 279 | |
Lukas Zilka | 726b4d2 | 2017-12-13 16:37:03 +0100 | [diff] [blame] | 280 | // Removes all tokens from tokens that are not on a line (defined by calling |
| 281 | // SplitContext on the context) to which span points. |
| 282 | void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, |
| 283 | std::vector<Token>* tokens) const; |
| 284 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 285 | // Same as above but takes UnicodeText. |
| 286 | void StripTokensFromOtherLines(const UnicodeText& context_unicode, |
| 287 | CodepointSpan span, |
| 288 | std::vector<Token>* tokens) const; |
| 289 | |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 290 | // Extracts the features of a token and appends them to the output vector. |
| 291 | // Uses the embedding cache to to avoid re-extracting the re-embedding the |
| 292 | // sparse features for the same token. |
| 293 | bool AppendTokenFeaturesWithCache(const Token& token, |
| 294 | CodepointSpan selection_span_for_feature, |
| 295 | const EmbeddingExecutor* embedding_executor, |
| 296 | EmbeddingCache* embedding_cache, |
| 297 | std::vector<float>* output_features) const; |
| 298 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 299 | private: |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 300 | const UniLib* unilib_; |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 301 | |
| 302 | protected: |
Lukas Zilka | 6bb39a8 | 2017-04-07 19:55:11 +0200 | [diff] [blame] | 303 | const TokenFeatureExtractor feature_extractor_; |
| 304 | |
Matt Sharifi | f95c3bd | 2017-04-25 18:41:11 +0200 | [diff] [blame] | 305 | // Codepoint ranges that define what codepoints are supported by the model. |
| 306 | // NOTE: Must be sorted. |
| 307 | std::vector<CodepointRange> supported_codepoint_ranges_; |
| 308 | |
| 309 | // Codepoint ranges that define which tokens (consisting of which codepoints) |
| 310 | // should be re-tokenized with the internal tokenizer in the mixed |
| 311 | // tokenization mode. |
| 312 | // NOTE: Must be sorted. |
| 313 | std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_; |
| 314 | |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 315 | private: |
Lukas Zilka | e5ea2ab | 2017-10-11 10:50:05 +0200 | [diff] [blame] | 316 | // Set of codepoints that will be stripped from beginning and end of |
| 317 | // predicted spans. |
| 318 | std::set<int32> ignored_span_boundary_codepoints_; |
| 319 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 320 | const FeatureProcessorOptions* const options_; |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 321 | |
| 322 | // Mapping between token selection spans and labels ids. |
| 323 | std::map<TokenSpan, int> selection_to_label_; |
| 324 | std::vector<TokenSpan> label_to_selection_; |
| 325 | |
| 326 | // Mapping between collections and labels. |
| 327 | std::map<std::string, int> collection_to_label_; |
| 328 | |
| 329 | Tokenizer tokenizer_; |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 330 | }; |
| 331 | |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 332 | } // namespace libtextclassifier3 |
Matt Sharifi | bda09f1 | 2017-03-10 12:29:15 +0100 | [diff] [blame] | 333 | |
Tony Mak | 6c4cc67 | 2018-09-17 11:48:50 +0100 | [diff] [blame] | 334 | #endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ |