blob: 2c64b6751d595cdba06f7dc8706c66a7284b739d [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
19#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
20#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
21
22#include <memory>
Matt Sharifibda09f12017-03-10 12:29:15 +010023#include <string>
24#include <vector>
25
Lukas Zilka6bb39a82017-04-07 19:55:11 +020026#include "smartselect/cached-features.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010027#include "smartselect/text-classification-model.pb.h"
28#include "smartselect/token-feature-extractor.h"
29#include "smartselect/tokenizer.h"
30#include "smartselect/types.h"
Lukas Zilka6bb39a82017-04-07 19:55:11 +020031#include "util/base/logging.h"
Matt Sharifif95c3bd2017-04-25 18:41:11 +020032#include "util/utf8/unicodetext.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010033
34namespace libtextclassifier {
35
36constexpr int kInvalidLabel = -1;
37
Lukas Zilka6bb39a82017-04-07 19:55:11 +020038// Maps a vector of sparse features and a vector of dense features to a vector
39// of features that combines both.
40// The output is written to the memory location pointed to by the last float*
41// argument.
42// Returns true on success false on failure.
43using FeatureVectorFn = std::function<bool(const std::vector<int>&,
44 const std::vector<float>&, float*)>;
45
Matt Sharifibda09f12017-03-10 12:29:15 +010046namespace internal {
47
48// Parses the serialized protocol buffer.
49FeatureProcessorOptions ParseSerializedOptions(
50 const std::string& serialized_options);
51
52TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
53 const FeatureProcessorOptions& options);
54
Matt Sharifibe876dc2017-03-17 17:02:43 +010055// Removes tokens that are not part of a line of the context which contains
56// given span.
57void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
58 std::vector<Token>* tokens);
Matt Sharifibda09f12017-03-10 12:29:15 +010059
60// Splits tokens that contain the selection boundary inside them.
61// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
62void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
63 std::vector<Token>* tokens);
64
Matt Sharifibe876dc2017-03-17 17:02:43 +010065// Returns the index of token that corresponds to the codepoint span.
66int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
67
68// Returns the index of token that corresponds to the middle of the codepoint
69// span.
70int CenterTokenFromMiddleOfSelection(
71 CodepointSpan span, const std::vector<Token>& selectable_tokens);
72
Lukas Zilka6bb39a82017-04-07 19:55:11 +020073// Strips the tokens from the tokens vector that are not used for feature
74// extraction because they are out of scope, or pads them so that there is
75// enough tokens in the required context_size for all inferences with a click
76// in relative_click_span.
77void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
78 std::vector<Token>* tokens, int* click_pos);
79
Matt Sharifibda09f12017-03-10 12:29:15 +010080} // namespace internal
81
Lukas Zilka40c18de2017-04-10 17:22:22 +020082// Converts a codepoint span to a token span in the given list of tokens.
Matt Sharifibda09f12017-03-10 12:29:15 +010083TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
84 CodepointSpan codepoint_span);
85
Lukas Zilka40c18de2017-04-10 17:22:22 +020086// Converts a token span to a codepoint span in the given list of tokens.
87CodepointSpan TokenSpanToCodepointSpan(
88 const std::vector<Token>& selectable_tokens, TokenSpan token_span);
89
Lukas Zilka6bb39a82017-04-07 19:55:11 +020090// Takes care of preparing features for the span prediction model.
Matt Sharifibda09f12017-03-10 12:29:15 +010091class FeatureProcessor {
92 public:
93 explicit FeatureProcessor(const FeatureProcessorOptions& options)
Lukas Zilka6bb39a82017-04-07 19:55:11 +020094 : feature_extractor_(
Matt Sharifibda09f12017-03-10 12:29:15 +010095 internal::BuildTokenFeatureExtractorOptions(options)),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020096 options_(options),
Matt Sharifibda09f12017-03-10 12:29:15 +010097 tokenizer_({options.tokenization_codepoint_config().begin(),
Lukas Zilka6bb39a82017-04-07 19:55:11 +020098 options.tokenization_codepoint_config().end()}) {
Matt Sharifibda09f12017-03-10 12:29:15 +010099 MakeLabelMaps();
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200100 PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
101 options.supported_codepoint_ranges().end()},
102 &supported_codepoint_ranges_);
103 PrepareCodepointRanges(
104 {options.internal_tokenizer_codepoint_ranges().begin(),
105 options.internal_tokenizer_codepoint_ranges().end()},
106 &internal_tokenizer_codepoint_ranges_);
Matt Sharifibda09f12017-03-10 12:29:15 +0100107 }
108
109 explicit FeatureProcessor(const std::string& serialized_options)
110 : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
111 }
112
Matt Sharifibda09f12017-03-10 12:29:15 +0100113 // Tokenizes the input string using the selected tokenization method.
114 std::vector<Token> Tokenize(const std::string& utf8_text) const;
115
Matt Sharifibda09f12017-03-10 12:29:15 +0100116 // Converts a label into a token span.
117 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
118
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200119 // Gets the total number of selection labels.
120 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
121
Matt Sharifibda09f12017-03-10 12:29:15 +0100122 // Gets the string value for given collection label.
123 std::string LabelToCollection(int label) const;
124
125 // Gets the total number of collections of the model.
126 int NumCollections() const { return collection_to_label_.size(); }
127
128 // Gets the name of the default collection.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200129 std::string GetDefaultCollection() const;
130
131 const FeatureProcessorOptions& GetOptions() const { return options_; }
132
133 // Tokenizes the context and input span, and finds the click position.
134 void TokenizeAndFindClick(const std::string& context,
135 CodepointSpan input_span,
136 std::vector<Token>* tokens, int* click_pos) const;
137
138 // Extracts features as a CachedFeatures object that can be used for repeated
139 // inference over token spans in the given context.
140 bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
141 TokenSpan relative_click_span,
142 const FeatureVectorFn& feature_vector_fn,
143 int feature_vector_size, std::vector<Token>* tokens,
144 int* click_pos,
145 std::unique_ptr<CachedFeatures>* cached_features) const;
146
147 // Fills selection_label_spans with CodepointSpans that correspond to the
148 // selection labels. The CodepointSpans are based on the codepoint ranges of
149 // given tokens.
150 bool SelectionLabelSpans(
151 VectorSpan<Token> tokens,
152 std::vector<CodepointSpan>* selection_label_spans) const;
153
154 int DenseFeaturesCount() const {
155 return feature_extractor_.DenseFeaturesCount();
Matt Sharifibda09f12017-03-10 12:29:15 +0100156 }
157
Matt Sharifibda09f12017-03-10 12:29:15 +0100158 protected:
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200159 // Represents a codepoint range [start, end).
160 struct CodepointRange {
161 int32 start;
162 int32 end;
163
164 CodepointRange(int32 arg_start, int32 arg_end)
165 : start(arg_start), end(arg_end) {}
166 };
167
Matt Sharifibda09f12017-03-10 12:29:15 +0100168 // Returns the class id corresponding to the given string collection
169 // identifier. There is a catch-all class id that the function returns for
170 // unknown collections.
171 int CollectionToLabel(const std::string& collection) const;
172
173 // Prepares mapping from collection names to labels.
174 void MakeLabelMaps();
175
176 // Gets the number of spannable tokens for the model.
177 //
178 // Spannable tokens are those tokens of context, which the model predicts
179 // selection spans over (i.e., there is 1:1 correspondence between the output
180 // classes of the model and each of the spannable tokens).
181 int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
182
183 // Converts a label into a span of codepoint indices corresponding to it
184 // given output_tokens.
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200185 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
Matt Sharifibda09f12017-03-10 12:29:15 +0100186 CodepointSpan* span) const;
187
188 // Converts a span to the corresponding label given output_tokens.
189 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
190 const std::vector<Token>& output_tokens, int* label) const;
191
192 // Converts a token span to the corresponding label.
193 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
194
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200195 void PrepareCodepointRanges(
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200196 const std::vector<FeatureProcessorOptions::CodepointRange>&
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200197 codepoint_ranges,
198 std::vector<CodepointRange>* prepared_codepoint_ranges);
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200199
200 // Returns the ratio of supported codepoints to total number of codepoints in
201 // the input context around given click position.
202 float SupportedCodepointsRatio(int click_pos,
203 const std::vector<Token>& tokens) const;
204
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200205 // Returns true if given codepoint is covered by the given sorted vector of
206 // codepoint ranges.
207 bool IsCodepointInRanges(
208 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200209
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200210 // Finds the center token index in tokens vector, using the method defined
211 // in options_.
212 int FindCenterToken(CodepointSpan span,
213 const std::vector<Token>& tokens) const;
214
Lukas Zilka40c18de2017-04-10 17:22:22 +0200215 // Tokenizes the input text using ICU tokenizer.
216 bool ICUTokenize(const std::string& context,
217 std::vector<Token>* result) const;
218
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200219 // Takes the result of ICU tokenization and retokenizes stretches of tokens
220 // made of a specific subset of characters using the internal tokenizer.
221 void InternalRetokenize(const std::string& context,
222 std::vector<Token>* tokens) const;
223
224 // Tokenizes a substring of the unicode string, appending the resulting tokens
225 // to the output vector. The resulting tokens have bounds relative to the full
226 // string. Does nothing if the start of the span is negative.
227 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
228 std::vector<Token>* result) const;
229
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200230 const TokenFeatureExtractor feature_extractor_;
231
Matt Sharifif95c3bd2017-04-25 18:41:11 +0200232 // Codepoint ranges that define what codepoints are supported by the model.
233 // NOTE: Must be sorted.
234 std::vector<CodepointRange> supported_codepoint_ranges_;
235
236 // Codepoint ranges that define which tokens (consisting of which codepoints)
237 // should be re-tokenized with the internal tokenizer in the mixed
238 // tokenization mode.
239 // NOTE: Must be sorted.
240 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
241
Matt Sharifibda09f12017-03-10 12:29:15 +0100242 private:
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200243 const FeatureProcessorOptions options_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100244
245 // Mapping between token selection spans and labels ids.
246 std::map<TokenSpan, int> selection_to_label_;
247 std::vector<TokenSpan> label_to_selection_;
248
249 // Mapping between collections and labels.
250 std::map<std::string, int> collection_to_label_;
251
252 Tokenizer tokenizer_;
Matt Sharifibda09f12017-03-10 12:29:15 +0100253};
254
255} // namespace libtextclassifier
256
257#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_