blob: 311be3e8eb7c27ac730fbb7151adf16d0eebb58a [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Feature processing for FFModel (feed-forward SmartSelection model).
18
19#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
20#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
21
22#include <memory>
23#include <random>
24#include <string>
25#include <vector>
26
27#include "common/feature-extractor.h"
28#include "smartselect/text-classification-model.pb.h"
29#include "smartselect/token-feature-extractor.h"
30#include "smartselect/tokenizer.h"
31#include "smartselect/types.h"
32
33namespace libtextclassifier {
34
35constexpr int kInvalidLabel = -1;
36
37namespace internal {
38
39// Parses the serialized protocol buffer.
40FeatureProcessorOptions ParseSerializedOptions(
41 const std::string& serialized_options);
42
43TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
44 const FeatureProcessorOptions& options);
45
Matt Sharifibe876dc2017-03-17 17:02:43 +010046// Removes tokens that are not part of a line of the context which contains
47// given span.
48void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
49 std::vector<Token>* tokens);
Matt Sharifibda09f12017-03-10 12:29:15 +010050
51// Splits tokens that contain the selection boundary inside them.
52// E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
53void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
54 std::vector<Token>* tokens);
55
Matt Sharifibe876dc2017-03-17 17:02:43 +010056// Returns the index of token that corresponds to the codepoint span.
57int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
58
59// Returns the index of token that corresponds to the middle of the codepoint
60// span.
61int CenterTokenFromMiddleOfSelection(
62 CodepointSpan span, const std::vector<Token>& selectable_tokens);
63
Matt Sharifibda09f12017-03-10 12:29:15 +010064} // namespace internal
65
66TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
67 CodepointSpan codepoint_span);
68
Matt Sharifibda09f12017-03-10 12:29:15 +010069// Takes care of preparing features for the FFModel.
70class FeatureProcessor {
71 public:
72 explicit FeatureProcessor(const FeatureProcessorOptions& options)
73 : options_(options),
74 feature_extractor_(
75 internal::BuildTokenFeatureExtractorOptions(options)),
76 feature_type_(FeatureProcessor::kFeatureTypeName,
77 options.num_buckets()),
78 tokenizer_({options.tokenization_codepoint_config().begin(),
79 options.tokenization_codepoint_config().end()}),
80 random_(new std::mt19937(std::random_device()())) {
81 MakeLabelMaps();
82 }
83
84 explicit FeatureProcessor(const std::string& serialized_options)
85 : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
86 }
87
88 CodepointSpan ClickRandomTokenInSelection(
89 const SelectionWithContext& selection_with_context) const;
90
91 // Tokenizes the input string using the selected tokenization method.
92 std::vector<Token> Tokenize(const std::string& utf8_text) const;
93
Matt Sharifibe876dc2017-03-17 17:02:43 +010094 bool GetFeatures(const std::string& context, CodepointSpan input_span,
95 std::vector<nlp_core::FeatureVector>* features,
96 std::vector<float>* extra_features,
97 std::vector<CodepointSpan>* selection_label_spans) const;
98
Matt Sharifibda09f12017-03-10 12:29:15 +010099 // NOTE: If dropout is on, subsequent calls of this function with the same
100 // arguments might return different results.
Matt Sharifibe876dc2017-03-17 17:02:43 +0100101 bool GetFeaturesAndLabels(const std::string& context,
102 CodepointSpan input_span, CodepointSpan label_span,
103 const std::string& label_collection,
Matt Sharifibda09f12017-03-10 12:29:15 +0100104 std::vector<nlp_core::FeatureVector>* features,
105 std::vector<float>* extra_features,
106 std::vector<CodepointSpan>* selection_label_spans,
107 int* selection_label,
108 CodepointSpan* selection_codepoint_label,
109 int* classification_label) const;
110
111 // Same as above but uses std::vector instead of FeatureVector.
112 // NOTE: If dropout is on, subsequent calls of this function with the same
113 // arguments might return different results.
114 bool GetFeaturesAndLabels(
Matt Sharifibe876dc2017-03-17 17:02:43 +0100115 const std::string& context, CodepointSpan input_span,
116 CodepointSpan label_span, const std::string& label_collection,
Matt Sharifibda09f12017-03-10 12:29:15 +0100117 std::vector<std::vector<std::pair<int, float>>>* features,
118 std::vector<float>* extra_features,
119 std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
120 CodepointSpan* selection_codepoint_label,
121 int* classification_label) const;
122
123 // Converts a label into a token span.
124 bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
125
126 // Gets the string value for given collection label.
127 std::string LabelToCollection(int label) const;
128
129 // Gets the total number of collections of the model.
130 int NumCollections() const { return collection_to_label_.size(); }
131
132 // Gets the name of the default collection.
133 std::string GetDefaultCollection() const {
134 return options_.collections(options_.default_collection());
135 }
136
137 FeatureProcessorOptions GetOptions() const { return options_; }
138
139 int GetSelectionLabelCount() const { return label_to_selection_.size(); }
140
141 // Sets the source of randomness.
142 void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
143
144 protected:
145 // Extracts features for given word.
146 std::vector<int> GetWordFeatures(const std::string& word) const;
147
148 // NOTE: If dropout is on, subsequent calls of this function with the same
149 // arguments might return different results.
150 bool ComputeFeatures(int click_pos,
151 const std::vector<Token>& selectable_tokens,
152 CodepointSpan selected_span,
153 std::vector<nlp_core::FeatureVector>* features,
154 std::vector<float>* extra_features,
155 std::vector<Token>* output_tokens) const;
156
157 // Helper function that computes how much left context and how much right
158 // context should be dropped. Uses a mutable random_ member as a source of
159 // randomness.
160 bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
161
162 // Returns the class id corresponding to the given string collection
163 // identifier. There is a catch-all class id that the function returns for
164 // unknown collections.
165 int CollectionToLabel(const std::string& collection) const;
166
167 // Prepares mapping from collection names to labels.
168 void MakeLabelMaps();
169
170 // Gets the number of spannable tokens for the model.
171 //
172 // Spannable tokens are those tokens of context, which the model predicts
173 // selection spans over (i.e., there is 1:1 correspondence between the output
174 // classes of the model and each of the spannable tokens).
175 int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
176
177 // Converts a label into a span of codepoint indices corresponding to it
178 // given output_tokens.
179 bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
180 CodepointSpan* span) const;
181
182 // Converts a span to the corresponding label given output_tokens.
183 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
184 const std::vector<Token>& output_tokens, int* label) const;
185
186 // Converts a token span to the corresponding label.
187 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
188
Matt Sharifibe876dc2017-03-17 17:02:43 +0100189 // Finds tokens that are part of the selection.
Matt Sharifi32ebfba2017-03-13 20:57:28 +0100190 // NOTE: Will select all tokens that somehow overlap with the selection.
191 std::vector<Token> FindTokensInSelection(
192 const std::vector<Token>& selectable_tokens,
193 const SelectionWithContext& selection_with_context) const;
194
Matt Sharifibe876dc2017-03-17 17:02:43 +0100195 // Finds the center token index in tokens vector, using the method defined
196 // in options_.
197 int FindCenterToken(CodepointSpan span,
198 const std::vector<Token>& tokens) const;
199
Matt Sharifibda09f12017-03-10 12:29:15 +0100200 private:
201 FeatureProcessorOptions options_;
202
203 TokenFeatureExtractor feature_extractor_;
204
205 static const char* const kFeatureTypeName;
206
207 nlp_core::NumericFeatureType feature_type_;
208
209 // Mapping between token selection spans and labels ids.
210 std::map<TokenSpan, int> selection_to_label_;
211 std::vector<TokenSpan> label_to_selection_;
212
213 // Mapping between collections and labels.
214 std::map<std::string, int> collection_to_label_;
215
216 Tokenizer tokenizer_;
217
218 // Source of randomness.
219 mutable std::unique_ptr<std::mt19937> random_;
220};
221
222} // namespace libtextclassifier
223
224#endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_