blob: 834c260615ce034e864e580084857d7d1f83bc5f [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
Lukas Zilka21d8c982018-01-24 11:11:20 +010019#ifndef KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
20#define KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_
Matt Sharifibda09f12017-03-10 12:29:15 +010021
Lukas Zilka21d8c982018-01-24 11:11:20 +010022#include <map>
Matt Sharifibda09f12017-03-10 12:29:15 +010023#include <memory>
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020024#include <set>
Matt Sharifibda09f12017-03-10 12:29:15 +010025#include <string>
26#include <vector>
27
Lukas Zilka21d8c982018-01-24 11:11:20 +010028#include "cached-features.h"
29#include "model_generated.h"
30#include "token-feature-extractor.h"
31#include "tokenizer.h"
32#include "types.h"
33#include "util/base/integral_types.h"
Lukas Zilka6bb39a82017-04-07 19:55:11 +020034#include "util/base/logging.h"
Matt Sharifif95c3bd2017-04-25 18:41:11 +020035#include "util/utf8/unicodetext.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010036#include "util/utf8/unilib.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010037
Lukas Zilka21d8c982018-01-24 11:11:20 +010038namespace libtextclassifier2 {
Matt Sharifibda09f12017-03-10 12:29:15 +010039
40constexpr int kInvalidLabel = -1;
41
42namespace internal {
43
Matt Sharifibda09f12017-03-10 12:29:15 +010044TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
Lukas Zilka21d8c982018-01-24 11:11:20 +010045 const FeatureProcessorOptions* options);
Matt Sharifibda09f12017-03-10 12:29:15 +010046
Matt Sharifibda09f12017-03-10 12:29:15 +010047// Splits tokens that contain the selection boundary inside them.
48// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
49void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
50 std::vector<Token>* tokens);
51
Matt Sharifibe876dc2017-03-17 17:02:43 +010052// Returns the index of token that corresponds to the codepoint span.
53int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
54
55// Returns the index of token that corresponds to the middle of the codepoint
56// span.
57int CenterTokenFromMiddleOfSelection(
58 CodepointSpan span, const std::vector<Token>& selectable_tokens);
59
Lukas Zilka6bb39a82017-04-07 19:55:11 +020060// Strips the tokens from the tokens vector that are not used for feature
61// extraction because they are out of scope, or pads them so that there is
62// enough tokens in the required context_size for all inferences with a click
63// in relative_click_span.
64void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
65 std::vector<Token>* tokens, int* click_pos);
66
Lukas Zilka21d8c982018-01-24 11:11:20 +010067// If unilib is not nullptr, just returns unilib. Otherwise, if unilib is
68// nullptr, will create UniLib, assign ownership to owned_unilib, and return it.
69UniLib* MaybeCreateUnilib(UniLib* unilib,
70 std::unique_ptr<UniLib>* owned_unilib);
71
Matt Sharifibda09f12017-03-10 12:29:15 +010072} // namespace internal
73
Lukas Zilka40c18de2017-04-10 17:22:22 +020074// Converts a codepoint span to a token span in the given list of tokens.
Lukas Zilka726b4d22017-12-13 16:37:03 +010075// If snap_boundaries_to_containing_tokens is set to true, it is enough for a
76// token to overlap with the codepoint range to be considered part of it.
77// Otherwise it must be fully included in the range.
78TokenSpan CodepointSpanToTokenSpan(
79 const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
80 bool snap_boundaries_to_containing_tokens = false);
Matt Sharifibda09f12017-03-10 12:29:15 +010081
Lukas Zilka40c18de2017-04-10 17:22:22 +020082// Converts a token span to a codepoint span in the given list of tokens.
83CodepointSpan TokenSpanToCodepointSpan(
84 const std::vector<Token>& selectable_tokens, TokenSpan token_span);
85
Lukas Zilka6bb39a82017-04-07 19:55:11 +020086// Takes care of preparing features for the span prediction model.
Matt Sharifibda09f12017-03-10 12:29:15 +010087class FeatureProcessor {
88 public:
Lukas Zilka21d8c982018-01-24 11:11:20 +010089 // If unilib is nullptr, will create and own an instance of a UniLib,
90 // otherwise will use what's passed in.
91 explicit FeatureProcessor(const FeatureProcessorOptions* options,
92 UniLib* unilib = nullptr)
93 : owned_unilib_(nullptr),
94 unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)),
95 feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
96 *unilib_),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020097 options_(options),
Lukas Zilka21d8c982018-01-24 11:11:20 +010098 tokenizer_(
99 options->tokenization_codepoint_config() != nullptr
100 ? Tokenizer({options->tokenization_codepoint_config()->begin(),
101 options->tokenization_codepoint_config()->end()},
102 options->tokenize_on_script_change())
103 : Tokenizer({}, /*split_on_script_change=*/false)) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100104 MakeLabelMaps();
Lukas Zilka21d8c982018-01-24 11:11:20 +0100105 if (options->supported_codepoint_ranges() != nullptr) {
106 PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
107 options->supported_codepoint_ranges()->end()},
108 &supported_codepoint_ranges_);
109 }
110 if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
111 PrepareCodepointRanges(
112 {options->internal_tokenizer_codepoint_ranges()->begin(),
113 options->internal_tokenizer_codepoint_ranges()->end()},
114 &internal_tokenizer_codepoint_ranges_);
115 }
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200116 PrepareIgnoredSpanBoundaryCodepoints();
Matt Sharifibda09f12017-03-10 12:29:15 +0100117 }
118
Matt Sharifibda09f12017-03-10 12:29:15 +0100119 // Tokenizes the input string using the selected tokenization method.
120 std::vector<Token> Tokenize(const std::string& utf8_text) const;
121
Matt Sharifibda09f12017-03-10 12:29:15 +0100122 // Converts a label into a token span.
123 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
124
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200125 // Gets the total number of selection labels.
126 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
127
Matt Sharifibda09f12017-03-10 12:29:15 +0100128 // Gets the string value for given collection label.
129 std::string LabelToCollection(int label) const;
130
131 // Gets the total number of collections of the model.
132 int NumCollections() const { return collection_to_label_.size(); }
133
134 // Gets the name of the default collection.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200135 std::string GetDefaultCollection() const;
136
Lukas Zilka21d8c982018-01-24 11:11:20 +0100137 const FeatureProcessorOptions* GetOptions() const { return options_; }
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200138
139 // Tokenizes the context and input span, and finds the click position.
140 void TokenizeAndFindClick(const std::string& context,
141 CodepointSpan input_span,
142 std::vector<Token>* tokens, int* click_pos) const;
143
144 // Extracts features as a CachedFeatures object that can be used for repeated
145 // inference over token spans in the given context.
Lukas Zilka21d8c982018-01-24 11:11:20 +0100146 bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
147 EmbeddingExecutor* embedding_executor,
148 int feature_vector_size,
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200149 std::unique_ptr<CachedFeatures>* cached_features) const;
150
151 // Fills selection_label_spans with CodepointSpans that correspond to the
152 // selection labels. The CodepointSpans are based on the codepoint ranges of
153 // given tokens.
154 bool SelectionLabelSpans(
155 VectorSpan<Token> tokens,
156 std::vector<CodepointSpan>* selection_label_spans) const;
157
158 int DenseFeaturesCount() const {
159 return feature_extractor_.DenseFeaturesCount();
Matt Sharifibda09f12017-03-10 12:29:15 +0100160 }
161
Lukas Zilka21d8c982018-01-24 11:11:20 +0100162 int EmbeddingSize() const { return options_->embedding_size(); }
163
Lukas Zilka726b4d22017-12-13 16:37:03 +0100164 // Splits context to several segments according to configuration.
165 std::vector<UnicodeTextRange> SplitContext(
166 const UnicodeText& context_unicode) const;
167
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200168 // Strips boundary codepoints from the span in context and returns the new
169 // start and end indices. If the span comprises entirely of boundary
170 // codepoints, the first index of span is returned for both indices.
171 CodepointSpan StripBoundaryCodepoints(const std::string& context,
172 CodepointSpan span) const;
173
Matt Sharifibda09f12017-03-10 12:29:15 +0100174 protected:
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200175 // Represents a codepoint range [start, end).
176 struct CodepointRange {
177 int32 start;
178 int32 end;
179
180 CodepointRange(int32 arg_start, int32 arg_end)
181 : start(arg_start), end(arg_end) {}
182 };
183
Matt Sharifibda09f12017-03-10 12:29:15 +0100184 // Returns the class id corresponding to the given string collection
185 // identifier. There is a catch-all class id that the function returns for
186 // unknown collections.
187 int CollectionToLabel(const std::string& collection) const;
188
189 // Prepares mapping from collection names to labels.
190 void MakeLabelMaps();
191
192 // Gets the number of spannable tokens for the model.
193 //
194 // Spannable tokens are those tokens of context, which the model predicts
195 // selection spans over (i.e., there is 1:1 correspondence between the output
196 // classes of the model and each of the spannable tokens).
Lukas Zilka21d8c982018-01-24 11:11:20 +0100197 int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
Matt Sharifibda09f12017-03-10 12:29:15 +0100198
199 // Converts a label into a span of codepoint indices corresponding to it
200 // given output_tokens.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200201 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
Matt Sharifibda09f12017-03-10 12:29:15 +0100202 CodepointSpan* span) const;
203
204 // Converts a span to the corresponding label given output_tokens.
205 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
206 const std::vector<Token>& output_tokens, int* label) const;
207
208 // Converts a token span to the corresponding label.
209 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
210
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200211 void PrepareCodepointRanges(
Lukas Zilka21d8c982018-01-24 11:11:20 +0100212 const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200213 codepoint_ranges,
214 std::vector<CodepointRange>* prepared_codepoint_ranges);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200215
216 // Returns the ratio of supported codepoints to total number of codepoints in
Lukas Zilka21d8c982018-01-24 11:11:20 +0100217 // the given token span.
218 float SupportedCodepointsRatio(const TokenSpan& token_span,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200219 const std::vector<Token>& tokens) const;
220
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200221 // Returns true if given codepoint is covered by the given sorted vector of
222 // codepoint ranges.
223 bool IsCodepointInRanges(
224 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200225
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200226 void PrepareIgnoredSpanBoundaryCodepoints();
227
228 // Counts the number of span boundary codepoints. If count_from_beginning is
229 // True, the counting will start at the span_start iterator (inclusive) and at
230 // maximum end at span_end (exclusive). If count_from_beginning is True, the
231 // counting will start from span_end (exclusive) and end at span_start
232 // (inclusive).
233 int CountIgnoredSpanBoundaryCodepoints(
234 const UnicodeText::const_iterator& span_start,
235 const UnicodeText::const_iterator& span_end,
236 bool count_from_beginning) const;
237
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200238 // Finds the center token index in tokens vector, using the method defined
239 // in options_.
240 int FindCenterToken(CodepointSpan span,
241 const std::vector<Token>& tokens) const;
242
Lukas Zilka40c18de2017-04-10 17:22:22 +0200243 // Tokenizes the input text using ICU tokenizer.
244 bool ICUTokenize(const std::string& context,
245 std::vector<Token>* result) const;
246
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200247 // Takes the result of ICU tokenization and retokenizes stretches of tokens
248 // made of a specific subset of characters using the internal tokenizer.
249 void InternalRetokenize(const std::string& context,
250 std::vector<Token>* tokens) const;
251
252 // Tokenizes a substring of the unicode string, appending the resulting tokens
253 // to the output vector. The resulting tokens have bounds relative to the full
254 // string. Does nothing if the start of the span is negative.
255 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
256 std::vector<Token>* result) const;
257
Lukas Zilka726b4d22017-12-13 16:37:03 +0100258 // Removes all tokens from tokens that are not on a line (defined by calling
259 // SplitContext on the context) to which span points.
260 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
261 std::vector<Token>* tokens) const;
262
Lukas Zilka21d8c982018-01-24 11:11:20 +0100263 private:
264 std::unique_ptr<UniLib> owned_unilib_;
265 UniLib* unilib_;
266
267 protected:
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200268 const TokenFeatureExtractor feature_extractor_;
269
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200270 // Codepoint ranges that define what codepoints are supported by the model.
271 // NOTE: Must be sorted.
272 std::vector<CodepointRange> supported_codepoint_ranges_;
273
274 // Codepoint ranges that define which tokens (consisting of which codepoints)
275 // should be re-tokenized with the internal tokenizer in the mixed
276 // tokenization mode.
277 // NOTE: Must be sorted.
278 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
279
Matt Sharifibda09f12017-03-10 12:29:15 +0100280 private:
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200281 // Set of codepoints that will be stripped from beginning and end of
282 // predicted spans.
283 std::set<int32> ignored_span_boundary_codepoints_;
284
Lukas Zilka21d8c982018-01-24 11:11:20 +0100285 const FeatureProcessorOptions* const options_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100286
287 // Mapping between token selection spans and labels ids.
288 std::map<TokenSpan, int> selection_to_label_;
289 std::vector<TokenSpan> label_to_selection_;
290
291 // Mapping between collections and labels.
292 std::map<std::string, int> collection_to_label_;
293
294 Tokenizer tokenizer_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100295};
296
Lukas Zilka21d8c982018-01-24 11:11:20 +0100297} // namespace libtextclassifier2
Matt Sharifibda09f12017-03-10 12:29:15 +0100298
Lukas Zilka21d8c982018-01-24 11:11:20 +0100299#endif // KNOWLEDGE_CEREBRA_SENSE_TEXT_CLASSIFIER_LIB2_FEATURE_PROCESSOR_H_