blob: a39a7893be47d33ea1187e84e27dc986f93f031a [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
19#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
20#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
21
22#include <memory>
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020023#include <set>
Matt Sharifibda09f12017-03-10 12:29:15 +010024#include <string>
25#include <vector>
26
Lukas Zilka6bb39a82017-04-07 19:55:11 +020027#include "smartselect/cached-features.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010028#include "smartselect/text-classification-model.pb.h"
29#include "smartselect/token-feature-extractor.h"
30#include "smartselect/tokenizer.h"
31#include "smartselect/types.h"
Lukas Zilka6bb39a82017-04-07 19:55:11 +020032#include "util/base/logging.h"
Matt Sharifif95c3bd2017-04-25 18:41:11 +020033#include "util/utf8/unicodetext.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010034
35namespace libtextclassifier {
36
37constexpr int kInvalidLabel = -1;
38
Lukas Zilka6bb39a82017-04-07 19:55:11 +020039// Maps a vector of sparse features and a vector of dense features to a vector
40// of features that combines both.
41// The output is written to the memory location pointed to by the last float*
42// argument.
43// Returns true on success false on failure.
44using FeatureVectorFn = std::function<bool(const std::vector<int>&,
45 const std::vector<float>&, float*)>;
46
Matt Sharifibda09f12017-03-10 12:29:15 +010047namespace internal {
48
49// Parses the serialized protocol buffer.
50FeatureProcessorOptions ParseSerializedOptions(
51 const std::string& serialized_options);
52
53TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
54 const FeatureProcessorOptions& options);
55
Matt Sharifibe876dc2017-03-17 17:02:43 +010056// Removes tokens that are not part of a line of the context which contains
57// given span.
58void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
59 std::vector<Token>* tokens);
Matt Sharifibda09f12017-03-10 12:29:15 +010060
61// Splits tokens that contain the selection boundary inside them.
62// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
63void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
64 std::vector<Token>* tokens);
65
Matt Sharifibe876dc2017-03-17 17:02:43 +010066// Returns the index of token that corresponds to the codepoint span.
67int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
68
69// Returns the index of token that corresponds to the middle of the codepoint
70// span.
71int CenterTokenFromMiddleOfSelection(
72 CodepointSpan span, const std::vector<Token>& selectable_tokens);
73
Lukas Zilka6bb39a82017-04-07 19:55:11 +020074// Strips the tokens from the tokens vector that are not used for feature
75// extraction because they are out of scope, or pads them so that there is
76// enough tokens in the required context_size for all inferences with a click
77// in relative_click_span.
78void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
79 std::vector<Token>* tokens, int* click_pos);
80
Matt Sharifibda09f12017-03-10 12:29:15 +010081} // namespace internal
82
Lukas Zilka40c18de2017-04-10 17:22:22 +020083// Converts a codepoint span to a token span in the given list of tokens.
Matt Sharifibda09f12017-03-10 12:29:15 +010084TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
85 CodepointSpan codepoint_span);
86
Lukas Zilka40c18de2017-04-10 17:22:22 +020087// Converts a token span to a codepoint span in the given list of tokens.
88CodepointSpan TokenSpanToCodepointSpan(
89 const std::vector<Token>& selectable_tokens, TokenSpan token_span);
90
Lukas Zilka6bb39a82017-04-07 19:55:11 +020091// Takes care of preparing features for the span prediction model.
Matt Sharifibda09f12017-03-10 12:29:15 +010092class FeatureProcessor {
93 public:
94 explicit FeatureProcessor(const FeatureProcessorOptions& options)
Lukas Zilka6bb39a82017-04-07 19:55:11 +020095 : feature_extractor_(
Matt Sharifibda09f12017-03-10 12:29:15 +010096 internal::BuildTokenFeatureExtractorOptions(options)),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020097 options_(options),
Matt Sharifibda09f12017-03-10 12:29:15 +010098 tokenizer_({options.tokenization_codepoint_config().begin(),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020099 options.tokenization_codepoint_config().end()}) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100100 MakeLabelMaps();
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200101 PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
102 options.supported_codepoint_ranges().end()},
103 &supported_codepoint_ranges_);
104 PrepareCodepointRanges(
105 {options.internal_tokenizer_codepoint_ranges().begin(),
106 options.internal_tokenizer_codepoint_ranges().end()},
107 &internal_tokenizer_codepoint_ranges_);
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200108 PrepareIgnoredSpanBoundaryCodepoints();
Matt Sharifibda09f12017-03-10 12:29:15 +0100109 }
110
111 explicit FeatureProcessor(const std::string& serialized_options)
112 : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
113 }
114
Matt Sharifibda09f12017-03-10 12:29:15 +0100115 // Tokenizes the input string using the selected tokenization method.
116 std::vector<Token> Tokenize(const std::string& utf8_text) const;
117
Matt Sharifibda09f12017-03-10 12:29:15 +0100118 // Converts a label into a token span.
119 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
120
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200121 // Gets the total number of selection labels.
122 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
123
Matt Sharifibda09f12017-03-10 12:29:15 +0100124 // Gets the string value for given collection label.
125 std::string LabelToCollection(int label) const;
126
127 // Gets the total number of collections of the model.
128 int NumCollections() const { return collection_to_label_.size(); }
129
130 // Gets the name of the default collection.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200131 std::string GetDefaultCollection() const;
132
133 const FeatureProcessorOptions& GetOptions() const { return options_; }
134
135 // Tokenizes the context and input span, and finds the click position.
136 void TokenizeAndFindClick(const std::string& context,
137 CodepointSpan input_span,
138 std::vector<Token>* tokens, int* click_pos) const;
139
140 // Extracts features as a CachedFeatures object that can be used for repeated
141 // inference over token spans in the given context.
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200142 // When relative_click_span == {kInvalidIndex, kInvalidIndex} then all tokens
143 // extracted from context will be considered.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200144 bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
145 TokenSpan relative_click_span,
146 const FeatureVectorFn& feature_vector_fn,
147 int feature_vector_size, std::vector<Token>* tokens,
148 int* click_pos,
149 std::unique_ptr<CachedFeatures>* cached_features) const;
150
151 // Fills selection_label_spans with CodepointSpans that correspond to the
152 // selection labels. The CodepointSpans are based on the codepoint ranges of
153 // given tokens.
154 bool SelectionLabelSpans(
155 VectorSpan<Token> tokens,
156 std::vector<CodepointSpan>* selection_label_spans) const;
157
158 int DenseFeaturesCount() const {
159 return feature_extractor_.DenseFeaturesCount();
Matt Sharifibda09f12017-03-10 12:29:15 +0100160 }
161
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200162 // Strips boundary codepoints from the span in context and returns the new
163 // start and end indices. If the span comprises entirely of boundary
164 // codepoints, the first index of span is returned for both indices.
165 CodepointSpan StripBoundaryCodepoints(const std::string& context,
166 CodepointSpan span) const;
167
Matt Sharifibda09f12017-03-10 12:29:15 +0100168 protected:
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200169 // Represents a codepoint range [start, end).
170 struct CodepointRange {
171 int32 start;
172 int32 end;
173
174 CodepointRange(int32 arg_start, int32 arg_end)
175 : start(arg_start), end(arg_end) {}
176 };
177
Matt Sharifibda09f12017-03-10 12:29:15 +0100178 // Returns the class id corresponding to the given string collection
179 // identifier. There is a catch-all class id that the function returns for
180 // unknown collections.
181 int CollectionToLabel(const std::string& collection) const;
182
183 // Prepares mapping from collection names to labels.
184 void MakeLabelMaps();
185
186 // Gets the number of spannable tokens for the model.
187 //
188 // Spannable tokens are those tokens of context, which the model predicts
189 // selection spans over (i.e., there is 1:1 correspondence between the output
190 // classes of the model and each of the spannable tokens).
191 int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
192
193 // Converts a label into a span of codepoint indices corresponding to it
194 // given output_tokens.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200195 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
Matt Sharifibda09f12017-03-10 12:29:15 +0100196 CodepointSpan* span) const;
197
198 // Converts a span to the corresponding label given output_tokens.
199 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
200 const std::vector<Token>& output_tokens, int* label) const;
201
202 // Converts a token span to the corresponding label.
203 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
204
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200205 void PrepareCodepointRanges(
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200206 const std::vector<FeatureProcessorOptions::CodepointRange>&
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200207 codepoint_ranges,
208 std::vector<CodepointRange>* prepared_codepoint_ranges);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200209
210 // Returns the ratio of supported codepoints to total number of codepoints in
211 // the input context around given click position.
212 float SupportedCodepointsRatio(int click_pos,
213 const std::vector<Token>& tokens) const;
214
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200215 // Returns true if given codepoint is covered by the given sorted vector of
216 // codepoint ranges.
217 bool IsCodepointInRanges(
218 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200219
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200220 void PrepareIgnoredSpanBoundaryCodepoints();
221
222 // Counts the number of span boundary codepoints. If count_from_beginning is
223 // True, the counting will start at the span_start iterator (inclusive) and at
224 // maximum end at span_end (exclusive). If count_from_beginning is True, the
225 // counting will start from span_end (exclusive) and end at span_start
226 // (inclusive).
227 int CountIgnoredSpanBoundaryCodepoints(
228 const UnicodeText::const_iterator& span_start,
229 const UnicodeText::const_iterator& span_end,
230 bool count_from_beginning) const;
231
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200232 // Finds the center token index in tokens vector, using the method defined
233 // in options_.
234 int FindCenterToken(CodepointSpan span,
235 const std::vector<Token>& tokens) const;
236
Lukas Zilka40c18de2017-04-10 17:22:22 +0200237 // Tokenizes the input text using ICU tokenizer.
238 bool ICUTokenize(const std::string& context,
239 std::vector<Token>* result) const;
240
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200241 // Takes the result of ICU tokenization and retokenizes stretches of tokens
242 // made of a specific subset of characters using the internal tokenizer.
243 void InternalRetokenize(const std::string& context,
244 std::vector<Token>* tokens) const;
245
246 // Tokenizes a substring of the unicode string, appending the resulting tokens
247 // to the output vector. The resulting tokens have bounds relative to the full
248 // string. Does nothing if the start of the span is negative.
249 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
250 std::vector<Token>* result) const;
251
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200252 const TokenFeatureExtractor feature_extractor_;
253
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200254 // Codepoint ranges that define what codepoints are supported by the model.
255 // NOTE: Must be sorted.
256 std::vector<CodepointRange> supported_codepoint_ranges_;
257
258 // Codepoint ranges that define which tokens (consisting of which codepoints)
259 // should be re-tokenized with the internal tokenizer in the mixed
260 // tokenization mode.
261 // NOTE: Must be sorted.
262 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
263
Matt Sharifibda09f12017-03-10 12:29:15 +0100264 private:
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200265 // Set of codepoints that will be stripped from beginning and end of
266 // predicted spans.
267 std::set<int32> ignored_span_boundary_codepoints_;
268
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200269 const FeatureProcessorOptions options_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100270
271 // Mapping between token selection spans and labels ids.
272 std::map<TokenSpan, int> selection_to_label_;
273 std::vector<TokenSpan> label_to_selection_;
274
275 // Mapping between collections and labels.
276 std::map<std::string, int> collection_to_label_;
277
278 Tokenizer tokenizer_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100279};
280
281} // namespace libtextclassifier
282
283#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_