blob: 85043e3995502e1bebb237f7b565311cbda7f02d [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
19#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
20#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
21
22#include <memory>
23#include <random>
24#include <string>
25#include <vector>
26
27#include "common/feature-extractor.h"
28#include "smartselect/text-classification-model.pb.h"
29#include "smartselect/token-feature-extractor.h"
30#include "smartselect/tokenizer.h"
31#include "smartselect/types.h"
32
33namespace libtextclassifier {
34
35constexpr int kInvalidLabel = -1;
36
37namespace internal {
38
39// Parses the serialized protocol buffer.
40FeatureProcessorOptions ParseSerializedOptions(
41 const std::string& serialized_options);
42
43TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
44 const FeatureProcessorOptions& options);
45
Matt Sharifibe876dc2017-03-17 17:02:43 +010046// Removes tokens that are not part of a line of the context which contains
47// given span.
48void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
49 std::vector<Token>* tokens);
Matt Sharifibda09f12017-03-10 12:29:15 +010050
51// Splits tokens that contain the selection boundary inside them.
52// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
53void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
54 std::vector<Token>* tokens);
55
Matt Sharifibe876dc2017-03-17 17:02:43 +010056// Returns the index of token that corresponds to the codepoint span.
57int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
58
59// Returns the index of token that corresponds to the middle of the codepoint
60// span.
61int CenterTokenFromMiddleOfSelection(
62 CodepointSpan span, const std::vector<Token>& selectable_tokens);
63
Matt Sharifibda09f12017-03-10 12:29:15 +010064} // namespace internal
65
66TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
67 CodepointSpan codepoint_span);
68
Matt Sharifibda09f12017-03-10 12:29:15 +010069// Takes care of preparing features for the FFModel.
70class FeatureProcessor {
71 public:
72 explicit FeatureProcessor(const FeatureProcessorOptions& options)
73 : options_(options),
74 feature_extractor_(
75 internal::BuildTokenFeatureExtractorOptions(options)),
76 feature_type_(FeatureProcessor::kFeatureTypeName,
77 options.num_buckets()),
78 tokenizer_({options.tokenization_codepoint_config().begin(),
79 options.tokenization_codepoint_config().end()}),
80 random_(new std::mt19937(std::random_device()())) {
81 MakeLabelMaps();
Lukas Zilka26e8c2e2017-04-06 15:54:24 +020082 PrepareSupportedCodepointRanges(
83 {options.supported_codepoint_ranges().begin(),
84 options.supported_codepoint_ranges().end()});
Matt Sharifibda09f12017-03-10 12:29:15 +010085 }
86
87 explicit FeatureProcessor(const std::string& serialized_options)
88 : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
89 }
90
Matt Sharifibda09f12017-03-10 12:29:15 +010091 // Tokenizes the input string using the selected tokenization method.
92 std::vector<Token> Tokenize(const std::string& utf8_text) const;
93
Matt Sharifibe876dc2017-03-17 17:02:43 +010094 bool GetFeatures(const std::string& context, CodepointSpan input_span,
95 std::vector<nlp_core::FeatureVector>* features,
96 std::vector<float>* extra_features,
97 std::vector<CodepointSpan>* selection_label_spans) const;
98
Matt Sharifibda09f12017-03-10 12:29:15 +010099 // NOTE: If dropout is on, subsequent calls of this function with the same
100 // arguments might return different results.
Matt Sharifibe876dc2017-03-17 17:02:43 +0100101 bool GetFeaturesAndLabels(const std::string& context,
102 CodepointSpan input_span, CodepointSpan label_span,
103 const std::string& label_collection,
Matt Sharifibda09f12017-03-10 12:29:15 +0100104 std::vector<nlp_core::FeatureVector>* features,
105 std::vector<float>* extra_features,
106 std::vector<CodepointSpan>* selection_label_spans,
107 int* selection_label,
108 CodepointSpan* selection_codepoint_label,
109 int* classification_label) const;
110
111 // Same as above but uses std::vector instead of FeatureVector.
112 // NOTE: If dropout is on, subsequent calls of this function with the same
113 // arguments might return different results.
114 bool GetFeaturesAndLabels(
Matt Sharifibe876dc2017-03-17 17:02:43 +0100115 const std::string& context, CodepointSpan input_span,
116 CodepointSpan label_span, const std::string& label_collection,
Matt Sharifibda09f12017-03-10 12:29:15 +0100117 std::vector<std::vector<std::pair<int, float>>>* features,
118 std::vector<float>* extra_features,
119 std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
120 CodepointSpan* selection_codepoint_label,
121 int* classification_label) const;
122
123 // Converts a label into a token span.
124 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
125
126 // Gets the string value for given collection label.
127 std::string LabelToCollection(int label) const;
128
129 // Gets the total number of collections of the model.
130 int NumCollections() const { return collection_to_label_.size(); }
131
132 // Gets the name of the default collection.
133 std::string GetDefaultCollection() const {
134 return options_.collections(options_.default_collection());
135 }
136
137 FeatureProcessorOptions GetOptions() const { return options_; }
138
139 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
140
141 // Sets the source of randomness.
142 void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
143
144 protected:
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200145 // Represents a codepoint range [start, end).
146 struct CodepointRange {
147 int32 start;
148 int32 end;
149
150 CodepointRange(int32 arg_start, int32 arg_end)
151 : start(arg_start), end(arg_end) {}
152 };
153
Matt Sharifibda09f12017-03-10 12:29:15 +0100154 // Extracts features for given word.
155 std::vector<int> GetWordFeatures(const std::string& word) const;
156
157 // NOTE: If dropout is on, subsequent calls of this function with the same
158 // arguments might return different results.
159 bool ComputeFeatures(int click_pos,
160 const std::vector<Token>& selectable_tokens,
161 CodepointSpan selected_span,
162 std::vector<nlp_core::FeatureVector>* features,
163 std::vector<float>* extra_features,
164 std::vector<Token>* output_tokens) const;
165
166 // Helper function that computes how much left context and how much right
167 // context should be dropped. Uses a mutable random_ member as a source of
168 // randomness.
169 bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
170
171 // Returns the class id corresponding to the given string collection
172 // identifier. There is a catch-all class id that the function returns for
173 // unknown collections.
174 int CollectionToLabel(const std::string& collection) const;
175
176 // Prepares mapping from collection names to labels.
177 void MakeLabelMaps();
178
179 // Gets the number of spannable tokens for the model.
180 //
181 // Spannable tokens are those tokens of context, which the model predicts
182 // selection spans over (i.e., there is 1:1 correspondence between the output
183 // classes of the model and each of the spannable tokens).
184 int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
185
186 // Converts a label into a span of codepoint indices corresponding to it
187 // given output_tokens.
188 bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
189 CodepointSpan* span) const;
190
191 // Converts a span to the corresponding label given output_tokens.
192 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
193 const std::vector<Token>& output_tokens, int* label) const;
194
195 // Converts a token span to the corresponding label.
196 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
197
Matt Sharifibe876dc2017-03-17 17:02:43 +0100198 // Finds the center token index in tokens vector, using the method defined
199 // in options_.
200 int FindCenterToken(CodepointSpan span,
201 const std::vector<Token>& tokens) const;
202
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200203 void PrepareSupportedCodepointRanges(
204 const std::vector<FeatureProcessorOptions::CodepointRange>&
205 codepoint_range_configs);
206
207 // Returns the ratio of supported codepoints to total number of codepoints in
208 // the input context around given click position.
209 float SupportedCodepointsRatio(int click_pos,
210 const std::vector<Token>& tokens) const;
211
212 // Returns true if given codepoint is supported.
213 bool IsCodepointSupported(int codepoint) const;
214
Matt Sharifibda09f12017-03-10 12:29:15 +0100215 private:
216 FeatureProcessorOptions options_;
217
218 TokenFeatureExtractor feature_extractor_;
219
220 static const char* const kFeatureTypeName;
221
222 nlp_core::NumericFeatureType feature_type_;
223
224 // Mapping between token selection spans and labels ids.
225 std::map<TokenSpan, int> selection_to_label_;
226 std::vector<TokenSpan> label_to_selection_;
227
228 // Mapping between collections and labels.
229 std::map<std::string, int> collection_to_label_;
230
231 Tokenizer tokenizer_;
232
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200233 // Codepoint ranges that define what codepoints are supported by the model.
234 // NOTE: Must be sorted.
235 std::vector<CodepointRange> supported_codepoint_ranges_;
236
Matt Sharifibda09f12017-03-10 12:29:15 +0100237 // Source of randomness.
238 mutable std::unique_ptr<std::mt19937> random_;
239};
240
241} // namespace libtextclassifier
242
243#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_