blob: ce44372344dc868cdc9cec784d38a955ea8478b8 [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Matt Sharifibda09f12017-03-10 12:29:15 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
Tony Mak6c4cc672018-09-17 11:48:50 +010019#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
20#define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
Matt Sharifibda09f12017-03-10 12:29:15 +010021
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <map>
Matt Sharifibda09f12017-03-10 12:29:15 +010023#include <memory>
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020024#include <set>
Matt Sharifibda09f12017-03-10 12:29:15 +010025#include <string>
26#include <vector>
27
Tony Mak6c4cc672018-09-17 11:48:50 +010028#include "annotator/cached-features.h"
29#include "annotator/model_generated.h"
30#include "annotator/token-feature-extractor.h"
31#include "annotator/tokenizer.h"
32#include "annotator/types.h"
33#include "utils/base/integral_types.h"
34#include "utils/base/logging.h"
35#include "utils/utf8/unicodetext.h"
36#include "utils/utf8/unilib.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010037
Tony Mak6c4cc672018-09-17 11:48:50 +010038namespace libtextclassifier3 {
Matt Sharifibda09f12017-03-10 12:29:15 +010039
40constexpr int kInvalidLabel = -1;
41
42namespace internal {
43
Matt Sharifibda09f12017-03-10 12:29:15 +010044TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
Lukas Zilka21d8c982018-01-24 11:11:20 +010045 const FeatureProcessorOptions* options);
Matt Sharifibda09f12017-03-10 12:29:15 +010046
Matt Sharifibda09f12017-03-10 12:29:15 +010047// Splits tokens that contain the selection boundary inside them.
48// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
49void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
50 std::vector<Token>* tokens);
51
Matt Sharifibe876dc2017-03-17 17:02:43 +010052// Returns the index of token that corresponds to the codepoint span.
53int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
54
55// Returns the index of token that corresponds to the middle of the codepoint
56// span.
57int CenterTokenFromMiddleOfSelection(
58 CodepointSpan span, const std::vector<Token>& selectable_tokens);
59
Lukas Zilka6bb39a82017-04-07 19:55:11 +020060// Strips the tokens from the tokens vector that are not used for feature
61// extraction because they are out of scope, or pads them so that there is
62// enough tokens in the required context_size for all inferences with a click
63// in relative_click_span.
64void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
65 std::vector<Token>* tokens, int* click_pos);
66
Matt Sharifibda09f12017-03-10 12:29:15 +010067} // namespace internal
68
Lukas Zilka40c18de2017-04-10 17:22:22 +020069// Converts a codepoint span to a token span in the given list of tokens.
Lukas Zilka726b4d22017-12-13 16:37:03 +010070// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
71// token to overlap with the codepoint range to be considered part of it.
72// Otherwise it must be fully included in the range.
73TokenSpan CodepointSpanToTokenSpan(
74 const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
75 bool snap_boundaries_to_containing_tokens = false);
Matt Sharifibda09f12017-03-10 12:29:15 +010076
Lukas Zilka40c18de2017-04-10 17:22:22 +020077// Converts a token span to a codepoint span in the given list of tokens.
78CodepointSpan TokenSpanToCodepointSpan(
79 const std::vector<Token>& selectable_tokens, TokenSpan token_span);
80
Lukas Zilka6bb39a82017-04-07 19:55:11 +020081// Takes care of preparing features for the span prediction model.
Matt Sharifibda09f12017-03-10 12:29:15 +010082class FeatureProcessor {
83 public:
Lukas Zilkaba849e72018-03-08 14:48:21 +010084 // A cache mapping codepoint spans to embedded tokens features. An instance
85 // can be provided to multiple calls to ExtractFeatures() operating on the
86 // same context (the same codepoint spans corresponding to the same tokens),
87 // as an optimization. Note that the tokenizations do not have to be
88 // identical.
89 typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
90
Lukas Zilka21d8c982018-01-24 11:11:20 +010091 // If unilib is nullptr, will create and own an instance of a UniLib,
92 // otherwise will use what's passed in.
93 explicit FeatureProcessor(const FeatureProcessorOptions* options,
Tony Mak6c4cc672018-09-17 11:48:50 +010094 const UniLib* unilib)
95 : unilib_(unilib),
Lukas Zilka21d8c982018-01-24 11:11:20 +010096 feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
97 *unilib_),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020098 options_(options),
Lukas Zilka21d8c982018-01-24 11:11:20 +010099 tokenizer_(
100 options->tokenization_codepoint_config() != nullptr
101 ? Tokenizer({options->tokenization_codepoint_config()->begin(),
102 options->tokenization_codepoint_config()->end()},
103 options->tokenize_on_script_change())
104 : Tokenizer({}, /*split_on_script_change=*/false)) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100105 MakeLabelMaps();
Lukas Zilka21d8c982018-01-24 11:11:20 +0100106 if (options->supported_codepoint_ranges() != nullptr) {
107 PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
108 options->supported_codepoint_ranges()->end()},
109 &supported_codepoint_ranges_);
110 }
111 if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
112 PrepareCodepointRanges(
113 {options->internal_tokenizer_codepoint_ranges()->begin(),
114 options->internal_tokenizer_codepoint_ranges()->end()},
115 &internal_tokenizer_codepoint_ranges_);
116 }
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200117 PrepareIgnoredSpanBoundaryCodepoints();
Matt Sharifibda09f12017-03-10 12:29:15 +0100118 }
119
Matt Sharifibda09f12017-03-10 12:29:15 +0100120 // Tokenizes the input string using the selected tokenization method.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100121 std::vector<Token> Tokenize(const std::string& text) const;
122
123 // Same as above but takes UnicodeText.
124 std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
Matt Sharifibda09f12017-03-10 12:29:15 +0100125
Matt Sharifibda09f12017-03-10 12:29:15 +0100126 // Converts a label into a token span.
127 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
128
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200129 // Gets the total number of selection labels.
130 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
131
Matt Sharifibda09f12017-03-10 12:29:15 +0100132 // Gets the string value for given collection label.
133 std::string LabelToCollection(int label) const;
134
135 // Gets the total number of collections of the model.
136 int NumCollections() const { return collection_to_label_.size(); }
137
138 // Gets the name of the default collection.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200139 std::string GetDefaultCollection() const;
140
Lukas Zilka21d8c982018-01-24 11:11:20 +0100141 const FeatureProcessorOptions* GetOptions() const { return options_; }
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200142
Lukas Zilkaba849e72018-03-08 14:48:21 +0100143 // Retokenizes the context and input span, and finds the click position.
144 // Depending on the options, might modify tokens (split them or remove them).
145 void RetokenizeAndFindClick(const std::string& context,
146 CodepointSpan input_span,
147 bool only_use_line_with_click,
148 std::vector<Token>* tokens, int* click_pos) const;
Lukas Zilkab23e2122018-02-09 10:25:19 +0100149
150 // Same as above but takes UnicodeText.
Lukas Zilkaba849e72018-03-08 14:48:21 +0100151 void RetokenizeAndFindClick(const UnicodeText& context_unicode,
152 CodepointSpan input_span,
153 bool only_use_line_with_click,
154 std::vector<Token>* tokens, int* click_pos) const;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200155
Lukas Zilka434442d2018-04-25 11:38:51 +0200156 // Returns true if the token span has enough supported codepoints (as defined
157 // in the model config) or not and model should not run.
158 bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
159 TokenSpan token_span) const;
160
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200161 // Extracts features as a CachedFeatures object that can be used for repeated
162 // inference over token spans in the given context.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100163 bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
Lukas Zilkab23e2122018-02-09 10:25:19 +0100164 CodepointSpan selection_span_for_feature,
Lukas Zilkaba849e72018-03-08 14:48:21 +0100165 const EmbeddingExecutor* embedding_executor,
166 EmbeddingCache* embedding_cache, int feature_vector_size,
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200167 std::unique_ptr<CachedFeatures>* cached_features) const;
168
169 // Fills selection_label_spans with CodepointSpans that correspond to the
170 // selection labels. The CodepointSpans are based on the codepoint ranges of
171 // given tokens.
172 bool SelectionLabelSpans(
173 VectorSpan<Token> tokens,
174 std::vector<CodepointSpan>* selection_label_spans) const;
175
176 int DenseFeaturesCount() const {
177 return feature_extractor_.DenseFeaturesCount();
Matt Sharifibda09f12017-03-10 12:29:15 +0100178 }
179
Lukas Zilka21d8c982018-01-24 11:11:20 +0100180 int EmbeddingSize() const { return options_->embedding_size(); }
181
Lukas Zilkab23e2122018-02-09 10:25:19 +0100182 // Splits context to several segments.
Lukas Zilka726b4d22017-12-13 16:37:03 +0100183 std::vector<UnicodeTextRange> SplitContext(
184 const UnicodeText& context_unicode) const;
185
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200186 // Strips boundary codepoints from the span in context and returns the new
187 // start and end indices. If the span comprises entirely of boundary
188 // codepoints, the first index of span is returned for both indices.
189 CodepointSpan StripBoundaryCodepoints(const std::string& context,
190 CodepointSpan span) const;
191
Lukas Zilkab23e2122018-02-09 10:25:19 +0100192 // Same as above but takes UnicodeText.
193 CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
194 CodepointSpan span) const;
195
Matt Sharifibda09f12017-03-10 12:29:15 +0100196 protected:
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200197 // Represents a codepoint range [start, end).
198 struct CodepointRange {
199 int32 start;
200 int32 end;
201
202 CodepointRange(int32 arg_start, int32 arg_end)
203 : start(arg_start), end(arg_end) {}
204 };
205
Matt Sharifibda09f12017-03-10 12:29:15 +0100206 // Returns the class id corresponding to the given string collection
207 // identifier. There is a catch-all class id that the function returns for
208 // unknown collections.
209 int CollectionToLabel(const std::string& collection) const;
210
211 // Prepares mapping from collection names to labels.
212 void MakeLabelMaps();
213
214 // Gets the number of spannable tokens for the model.
215 //
216 // Spannable tokens are those tokens of context, which the model predicts
217 // selection spans over (i.e., there is 1:1 correspondence between the output
218 // classes of the model and each of the spannable tokens).
Lukas Zilka21d8c982018-01-24 11:11:20 +0100219 int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
Matt Sharifibda09f12017-03-10 12:29:15 +0100220
221 // Converts a label into a span of codepoint indices corresponding to it
222 // given output_tokens.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200223 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
Matt Sharifibda09f12017-03-10 12:29:15 +0100224 CodepointSpan* span) const;
225
226 // Converts a span to the corresponding label given output_tokens.
227 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
228 const std::vector<Token>& output_tokens, int* label) const;
229
230 // Converts a token span to the corresponding label.
231 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
232
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200233 void PrepareCodepointRanges(
Lukas Zilka21d8c982018-01-24 11:11:20 +0100234 const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200235 codepoint_ranges,
236 std::vector<CodepointRange>* prepared_codepoint_ranges);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200237
238 // Returns the ratio of supported codepoints to total number of codepoints in
Lukas Zilka21d8c982018-01-24 11:11:20 +0100239 // the given token span.
240 float SupportedCodepointsRatio(const TokenSpan& token_span,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200241 const std::vector<Token>& tokens) const;
242
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200243 // Returns true if given codepoint is covered by the given sorted vector of
244 // codepoint ranges.
245 bool IsCodepointInRanges(
246 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200247
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200248 void PrepareIgnoredSpanBoundaryCodepoints();
249
250 // Counts the number of span boundary codepoints. If count_from_beginning is
251 // True, the counting will start at the span_start iterator (inclusive) and at
252 // maximum end at span_end (exclusive). If count_from_beginning is True, the
253 // counting will start from span_end (exclusive) and end at span_start
254 // (inclusive).
255 int CountIgnoredSpanBoundaryCodepoints(
256 const UnicodeText::const_iterator& span_start,
257 const UnicodeText::const_iterator& span_end,
258 bool count_from_beginning) const;
259
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200260 // Finds the center token index in tokens vector, using the method defined
261 // in options_.
262 int FindCenterToken(CodepointSpan span,
263 const std::vector<Token>& tokens) const;
264
Lukas Zilka40c18de2017-04-10 17:22:22 +0200265 // Tokenizes the input text using ICU tokenizer.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100266 bool ICUTokenize(const UnicodeText& context_unicode,
Lukas Zilka40c18de2017-04-10 17:22:22 +0200267 std::vector<Token>* result) const;
268
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200269 // Takes the result of ICU tokenization and retokenizes stretches of tokens
270 // made of a specific subset of characters using the internal tokenizer.
Lukas Zilkab23e2122018-02-09 10:25:19 +0100271 void InternalRetokenize(const UnicodeText& unicode_text,
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200272 std::vector<Token>* tokens) const;
273
274 // Tokenizes a substring of the unicode string, appending the resulting tokens
275 // to the output vector. The resulting tokens have bounds relative to the full
276 // string. Does nothing if the start of the span is negative.
277 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
278 std::vector<Token>* result) const;
279
Lukas Zilka726b4d22017-12-13 16:37:03 +0100280 // Removes all tokens from tokens that are not on a line (defined by calling
281 // SplitContext on the context) to which span points.
282 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
283 std::vector<Token>* tokens) const;
284
Lukas Zilkab23e2122018-02-09 10:25:19 +0100285 // Same as above but takes UnicodeText.
286 void StripTokensFromOtherLines(const UnicodeText& context_unicode,
287 CodepointSpan span,
288 std::vector<Token>* tokens) const;
289
Lukas Zilkaba849e72018-03-08 14:48:21 +0100290 // Extracts the features of a token and appends them to the output vector.
291 // Uses the embedding cache to to avoid re-extracting the re-embedding the
292 // sparse features for the same token.
293 bool AppendTokenFeaturesWithCache(const Token& token,
294 CodepointSpan selection_span_for_feature,
295 const EmbeddingExecutor* embedding_executor,
296 EmbeddingCache* embedding_cache,
297 std::vector<float>* output_features) const;
298
Lukas Zilka21d8c982018-01-24 11:11:20 +0100299 private:
Lukas Zilkab23e2122018-02-09 10:25:19 +0100300 const UniLib* unilib_;
Lukas Zilka21d8c982018-01-24 11:11:20 +0100301
302 protected:
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200303 const TokenFeatureExtractor feature_extractor_;
304
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200305 // Codepoint ranges that define what codepoints are supported by the model.
306 // NOTE: Must be sorted.
307 std::vector<CodepointRange> supported_codepoint_ranges_;
308
309 // Codepoint ranges that define which tokens (consisting of which codepoints)
310 // should be re-tokenized with the internal tokenizer in the mixed
311 // tokenization mode.
312 // NOTE: Must be sorted.
313 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
314
Matt Sharifibda09f12017-03-10 12:29:15 +0100315 private:
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200316 // Set of codepoints that will be stripped from beginning and end of
317 // predicted spans.
318 std::set<int32> ignored_span_boundary_codepoints_;
319
Lukas Zilka21d8c982018-01-24 11:11:20 +0100320 const FeatureProcessorOptions* const options_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100321
322 // Mapping between token selection spans and labels ids.
323 std::map<TokenSpan, int> selection_to_label_;
324 std::vector<TokenSpan> label_to_selection_;
325
326 // Mapping between collections and labels.
327 std::map<std::string, int> collection_to_label_;
328
329 Tokenizer tokenizer_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100330};
331
Tony Mak6c4cc672018-09-17 11:48:50 +0100332} // namespace libtextclassifier3
Matt Sharifibda09f12017-03-10 12:29:15 +0100333
Tony Mak6c4cc672018-09-17 11:48:50 +0100334#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_