blob: 77ad7a4b58dd5f980d926ae8ca82040da21eea19 [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
Tony Mak6c4cc672018-09-17 11:48:50 +01002 * Copyright (C) 2018 The Android Open Source Project
Matt Sharifibda09f12017-03-10 12:29:15 +01003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Tony Mak6c4cc672018-09-17 11:48:50 +010017#include "annotator/token-feature-extractor.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010018
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020019#include <cctype>
Matt Sharifideb722d2017-04-24 13:30:47 +020020#include <string>
21
Tony Mak6c4cc672018-09-17 11:48:50 +010022#include "utils/base/logging.h"
23#include "utils/hash/farmhash.h"
24#include "utils/strings/stringpiece.h"
25#include "utils/utf8/unicodetext.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010026
Tony Mak6c4cc672018-09-17 11:48:50 +010027namespace libtextclassifier3 {
Matt Sharifibda09f12017-03-10 12:29:15 +010028
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020029namespace {
30
Matt Sharifideb722d2017-04-24 13:30:47 +020031std::string RemapTokenAscii(const std::string& token,
32 const TokenFeatureExtractorOptions& options) {
33 if (!options.remap_digits && !options.lowercase_tokens) {
34 return token;
35 }
36
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020037 std::string copy = token;
38 for (int i = 0; i < token.size(); ++i) {
Matt Sharifideb722d2017-04-24 13:30:47 +020039 if (options.remap_digits && isdigit(copy[i])) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020040 copy[i] = '0';
41 }
Matt Sharifideb722d2017-04-24 13:30:47 +020042 if (options.lowercase_tokens) {
43 copy[i] = tolower(copy[i]);
44 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020045 }
46 return copy;
47}
48
Matt Sharifideb722d2017-04-24 13:30:47 +020049void RemapTokenUnicode(const std::string& token,
50 const TokenFeatureExtractorOptions& options,
Lukas Zilka21d8c982018-01-24 11:11:20 +010051 const UniLib& unilib, UnicodeText* remapped) {
Matt Sharifideb722d2017-04-24 13:30:47 +020052 if (!options.remap_digits && !options.lowercase_tokens) {
53 // Leave remapped untouched.
54 return;
55 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020056
57 UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
Lukas Zilka21d8c982018-01-24 11:11:20 +010058 remapped->clear();
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020059 for (auto it = word.begin(); it != word.end(); ++it) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010060 if (options.remap_digits && unilib.IsDigit(*it)) {
Tony Mak51a9e542018-11-02 13:36:22 +000061 remapped->push_back('0');
Matt Sharifideb722d2017-04-24 13:30:47 +020062 } else if (options.lowercase_tokens) {
Tony Mak51a9e542018-11-02 13:36:22 +000063 remapped->push_back(unilib.ToLower(*it));
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020064 } else {
Tony Mak51a9e542018-11-02 13:36:22 +000065 remapped->push_back(*it);
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020066 }
67 }
68}
69
70} // namespace
71
72TokenFeatureExtractor::TokenFeatureExtractor(
Lukas Zilka21d8c982018-01-24 11:11:20 +010073 const TokenFeatureExtractorOptions& options, const UniLib& unilib)
74 : options_(options), unilib_(unilib) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020075 for (const std::string& pattern : options.regexp_features) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010076 regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
Lukas Zilkab23e2122018-02-09 10:25:19 +010077 unilib_.CreateRegexPattern(UTF8ToUnicodeText(
78 pattern.c_str(), pattern.size(), /*do_copy=*/false))));
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020079 }
80}
Matt Sharifibda09f12017-03-10 12:29:15 +010081
Lukas Zilkab23e2122018-02-09 10:25:19 +010082bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
83 std::vector<int>* sparse_features,
84 std::vector<float>* dense_features) const {
Lukas Zilka21915472018-03-08 14:48:21 +010085 if (!dense_features) {
Lukas Zilkab23e2122018-02-09 10:25:19 +010086 return false;
87 }
Lukas Zilka21915472018-03-08 14:48:21 +010088 if (sparse_features) {
89 *sparse_features = ExtractCharactergramFeatures(token);
90 }
Lukas Zilkab23e2122018-02-09 10:25:19 +010091 *dense_features = ExtractDenseFeatures(token, is_in_span);
92 return true;
93}
94
95std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
96 const Token& token) const {
97 if (options_.unicode_aware_features) {
98 return ExtractCharactergramFeaturesUnicode(token);
99 } else {
100 return ExtractCharactergramFeaturesAscii(token);
101 }
102}
103
104std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
105 const Token& token, bool is_in_span) const {
106 std::vector<float> dense_features;
107
108 if (options_.extract_case_feature) {
109 if (options_.unicode_aware_features) {
110 UnicodeText token_unicode =
111 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
112 const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
113 if (!token.value.empty() && is_upper) {
114 dense_features.push_back(1.0);
115 } else {
116 dense_features.push_back(-1.0);
117 }
118 } else {
119 if (!token.value.empty() && isupper(*token.value.begin())) {
120 dense_features.push_back(1.0);
121 } else {
122 dense_features.push_back(-1.0);
123 }
124 }
125 }
126
127 if (options_.extract_selection_mask_feature) {
128 if (is_in_span) {
129 dense_features.push_back(1.0);
130 } else {
131 if (options_.unicode_aware_features) {
132 dense_features.push_back(-1.0);
133 } else {
134 dense_features.push_back(0.0);
135 }
136 }
137 }
138
139 // Add regexp features.
140 if (!regex_patterns_.empty()) {
141 UnicodeText token_unicode =
142 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
143 for (int i = 0; i < regex_patterns_.size(); ++i) {
144 if (!regex_patterns_[i].get()) {
145 dense_features.push_back(-1.0);
146 continue;
147 }
148 auto matcher = regex_patterns_[i]->Matcher(token_unicode);
149 int status;
150 if (matcher->Matches(&status)) {
151 dense_features.push_back(1.0);
152 } else {
153 dense_features.push_back(-1.0);
154 }
155 }
156 }
157
158 return dense_features;
159}
160
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200161int TokenFeatureExtractor::HashToken(StringPiece token) const {
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200162 if (options_.allowed_chargrams.empty()) {
Tony Mak51a9e542018-11-02 13:36:22 +0000163 return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200164 } else {
165 // Padding and out-of-vocabulary tokens have extra buckets reserved because
166 // they are special and important tokens, and we don't want them to share
167 // embedding with other charactergrams.
168 // TODO(zilka): Experimentally verify.
169 const int kNumExtraBuckets = 2;
170 const std::string token_string = token.ToString();
171 if (token_string == "<PAD>") {
172 return 1;
173 } else if (options_.allowed_chargrams.find(token_string) ==
174 options_.allowed_chargrams.end()) {
175 return 0; // Out-of-vocabulary.
176 } else {
Tony Mak51a9e542018-11-02 13:36:22 +0000177 return (tc3farmhash::Fingerprint64(token) %
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200178 (options_.num_buckets - kNumExtraBuckets)) +
179 kNumExtraBuckets;
180 }
181 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100182}
183
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200184std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
185 const Token& token) const {
Matt Sharifibda09f12017-03-10 12:29:15 +0100186 std::vector<int> result;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200187 if (token.is_padding || token.value.empty()) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100188 result.push_back(HashToken("<PAD>"));
189 } else {
Matt Sharifideb722d2017-04-24 13:30:47 +0200190 const std::string word = RemapTokenAscii(token.value, options_);
Matt Sharifibda09f12017-03-10 12:29:15 +0100191
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200192 // Trim words that are over max_word_length characters.
193 const int max_word_length = options_.max_word_length;
194 std::string feature_word;
195 if (word.size() > max_word_length) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100196 feature_word =
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200197 "^" + word.substr(0, max_word_length / 2) + "\1" +
198 word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
Matt Sharifibda09f12017-03-10 12:29:15 +0100199 "$";
200 } else {
201 // Add a prefix and suffix to the word.
202 feature_word = "^" + word + "$";
203 }
204
205 // Upper-bound the number of charactergram extracted to avoid resizing.
206 result.reserve(options_.chargram_orders.size() * feature_word.size());
207
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200208 if (options_.chargram_orders.empty()) {
209 result.push_back(HashToken(feature_word));
210 } else {
211 // Generate the character-grams.
212 for (int chargram_order : options_.chargram_orders) {
213 if (chargram_order == 1) {
214 for (int i = 1; i < feature_word.size() - 1; ++i) {
215 result.push_back(
216 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
217 }
218 } else {
219 for (int i = 0;
220 i < static_cast<int>(feature_word.size()) - chargram_order + 1;
221 ++i) {
222 result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
223 /*len=*/chargram_order)));
224 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100225 }
226 }
227 }
228 }
229 return result;
230}
231
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200232std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
233 const Token& token) const {
234 std::vector<int> result;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200235 if (token.is_padding || token.value.empty()) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200236 result.push_back(HashToken("<PAD>"));
237 } else {
238 UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100239 RemapTokenUnicode(token.value, options_, unilib_, &word);
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200240
241 // Trim the word if needed by finding a left-cut point and right-cut point.
242 auto left_cut = word.begin();
243 auto right_cut = word.end();
244 for (int i = 0; i < options_.max_word_length / 2; i++) {
245 if (left_cut < right_cut) {
246 ++left_cut;
247 }
248 if (left_cut < right_cut) {
249 --right_cut;
250 }
251 }
252
253 std::string feature_word;
254 if (left_cut == right_cut) {
255 feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
256 } else {
257 // clang-format off
258 feature_word = "^" +
259 word.UTF8Substring(word.begin(), left_cut) +
260 "\1" +
261 word.UTF8Substring(right_cut, word.end()) +
262 "$";
263 // clang-format on
264 }
265
266 const UnicodeText feature_word_unicode =
267 UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
268
269 // Upper-bound the number of charactergram extracted to avoid resizing.
270 result.reserve(options_.chargram_orders.size() * feature_word.size());
271
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200272 if (options_.chargram_orders.empty()) {
273 result.push_back(HashToken(feature_word));
274 } else {
275 // Generate the character-grams.
276 for (int chargram_order : options_.chargram_orders) {
277 UnicodeText::const_iterator it_start = feature_word_unicode.begin();
278 UnicodeText::const_iterator it_end = feature_word_unicode.end();
279 if (chargram_order == 1) {
280 ++it_start;
281 --it_end;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200282 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200283
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200284 UnicodeText::const_iterator it_chargram_start = it_start;
285 UnicodeText::const_iterator it_chargram_end = it_start;
286 bool chargram_is_complete = true;
287 for (int i = 0; i < chargram_order; ++i) {
288 if (it_chargram_end == it_end) {
289 chargram_is_complete = false;
290 break;
291 }
292 ++it_chargram_end;
293 }
294 if (!chargram_is_complete) {
295 continue;
296 }
297
298 for (; it_chargram_end <= it_end;
299 ++it_chargram_start, ++it_chargram_end) {
300 const int length_bytes =
301 it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
302 result.push_back(HashToken(
303 StringPiece(it_chargram_start.utf8_data(), length_bytes)));
304 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200305 }
306 }
307 }
308 return result;
309}
310
Tony Mak6c4cc672018-09-17 11:48:50 +0100311} // namespace libtextclassifier3