blob: 33c4d75111494ace1b45acdeed9a999f3c369851 [file] [log] [blame]
Matt Sharifibda09f12017-03-10 12:29:15 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lukas Zilka21d8c982018-01-24 11:11:20 +010017#include "token-feature-extractor.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010018
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020019#include <cctype>
Matt Sharifideb722d2017-04-24 13:30:47 +020020#include <string>
21
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020022#include "util/base/logging.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010023#include "util/hash/farmhash.h"
Lukas Zilka26e8c2e2017-04-06 15:54:24 +020024#include "util/strings/stringpiece.h"
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020025#include "util/utf8/unicodetext.h"
Matt Sharifibda09f12017-03-10 12:29:15 +010026
Lukas Zilka21d8c982018-01-24 11:11:20 +010027namespace libtextclassifier2 {
Matt Sharifibda09f12017-03-10 12:29:15 +010028
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020029namespace {
30
Matt Sharifideb722d2017-04-24 13:30:47 +020031std::string RemapTokenAscii(const std::string& token,
32 const TokenFeatureExtractorOptions& options) {
33 if (!options.remap_digits && !options.lowercase_tokens) {
34 return token;
35 }
36
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020037 std::string copy = token;
38 for (int i = 0; i < token.size(); ++i) {
Matt Sharifideb722d2017-04-24 13:30:47 +020039 if (options.remap_digits && isdigit(copy[i])) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020040 copy[i] = '0';
41 }
Matt Sharifideb722d2017-04-24 13:30:47 +020042 if (options.lowercase_tokens) {
43 copy[i] = tolower(copy[i]);
44 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020045 }
46 return copy;
47}
48
Matt Sharifideb722d2017-04-24 13:30:47 +020049void RemapTokenUnicode(const std::string& token,
50 const TokenFeatureExtractorOptions& options,
Lukas Zilka21d8c982018-01-24 11:11:20 +010051 const UniLib& unilib, UnicodeText* remapped) {
Matt Sharifideb722d2017-04-24 13:30:47 +020052 if (!options.remap_digits && !options.lowercase_tokens) {
53 // Leave remapped untouched.
54 return;
55 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020056
57 UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
Lukas Zilka21d8c982018-01-24 11:11:20 +010058 remapped->clear();
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020059 for (auto it = word.begin(); it != word.end(); ++it) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010060 if (options.remap_digits && unilib.IsDigit(*it)) {
61 remapped->AppendCodepoint('0');
Matt Sharifideb722d2017-04-24 13:30:47 +020062 } else if (options.lowercase_tokens) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010063 remapped->AppendCodepoint(unilib.ToLower(*it));
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020064 } else {
Lukas Zilka21d8c982018-01-24 11:11:20 +010065 remapped->AppendCodepoint(*it);
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020066 }
67 }
68}
69
70} // namespace
71
72TokenFeatureExtractor::TokenFeatureExtractor(
Lukas Zilka21d8c982018-01-24 11:11:20 +010073 const TokenFeatureExtractorOptions& options, const UniLib& unilib)
74 : options_(options), unilib_(unilib) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020075 for (const std::string& pattern : options.regexp_features) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010076 regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
77 unilib_.CreateRegexPattern(pattern)));
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020078 }
79}
Matt Sharifibda09f12017-03-10 12:29:15 +010080
Lukas Zilka26e8c2e2017-04-06 15:54:24 +020081int TokenFeatureExtractor::HashToken(StringPiece token) const {
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020082 if (options_.allowed_chargrams.empty()) {
Lukas Zilka21d8c982018-01-24 11:11:20 +010083 return tc2farmhash::Fingerprint64(token) % options_.num_buckets;
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020084 } else {
85 // Padding and out-of-vocabulary tokens have extra buckets reserved because
86 // they are special and important tokens, and we don't want them to share
87 // embedding with other charactergrams.
88 // TODO(zilka): Experimentally verify.
89 const int kNumExtraBuckets = 2;
90 const std::string token_string = token.ToString();
91 if (token_string == "<PAD>") {
92 return 1;
93 } else if (options_.allowed_chargrams.find(token_string) ==
94 options_.allowed_chargrams.end()) {
95 return 0; // Out-of-vocabulary.
96 } else {
Lukas Zilka21d8c982018-01-24 11:11:20 +010097 return (tc2farmhash::Fingerprint64(token) %
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +020098 (options_.num_buckets - kNumExtraBuckets)) +
99 kNumExtraBuckets;
100 }
101 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100102}
103
104std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
105 const Token& token) const {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200106 if (options_.unicode_aware_features) {
107 return ExtractCharactergramFeaturesUnicode(token);
108 } else {
109 return ExtractCharactergramFeaturesAscii(token);
110 }
111}
112
113std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
114 const Token& token) const {
Matt Sharifibda09f12017-03-10 12:29:15 +0100115 std::vector<int> result;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200116 if (token.is_padding || token.value.empty()) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100117 result.push_back(HashToken("<PAD>"));
118 } else {
Matt Sharifideb722d2017-04-24 13:30:47 +0200119 const std::string word = RemapTokenAscii(token.value, options_);
Matt Sharifibda09f12017-03-10 12:29:15 +0100120
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200121 // Trim words that are over max_word_length characters.
122 const int max_word_length = options_.max_word_length;
123 std::string feature_word;
124 if (word.size() > max_word_length) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100125 feature_word =
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200126 "^" + word.substr(0, max_word_length / 2) + "\1" +
127 word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
Matt Sharifibda09f12017-03-10 12:29:15 +0100128 "$";
129 } else {
130 // Add a prefix and suffix to the word.
131 feature_word = "^" + word + "$";
132 }
133
134 // Upper-bound the number of charactergram extracted to avoid resizing.
135 result.reserve(options_.chargram_orders.size() * feature_word.size());
136
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200137 if (options_.chargram_orders.empty()) {
138 result.push_back(HashToken(feature_word));
139 } else {
140 // Generate the character-grams.
141 for (int chargram_order : options_.chargram_orders) {
142 if (chargram_order == 1) {
143 for (int i = 1; i < feature_word.size() - 1; ++i) {
144 result.push_back(
145 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
146 }
147 } else {
148 for (int i = 0;
149 i < static_cast<int>(feature_word.size()) - chargram_order + 1;
150 ++i) {
151 result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
152 /*len=*/chargram_order)));
153 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100154 }
155 }
156 }
157 }
158 return result;
159}
160
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200161std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
162 const Token& token) const {
163 std::vector<int> result;
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200164 if (token.is_padding || token.value.empty()) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200165 result.push_back(HashToken("<PAD>"));
166 } else {
167 UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100168 RemapTokenUnicode(token.value, options_, unilib_, &word);
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200169
170 // Trim the word if needed by finding a left-cut point and right-cut point.
171 auto left_cut = word.begin();
172 auto right_cut = word.end();
173 for (int i = 0; i < options_.max_word_length / 2; i++) {
174 if (left_cut < right_cut) {
175 ++left_cut;
176 }
177 if (left_cut < right_cut) {
178 --right_cut;
179 }
180 }
181
182 std::string feature_word;
183 if (left_cut == right_cut) {
184 feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
185 } else {
186 // clang-format off
187 feature_word = "^" +
188 word.UTF8Substring(word.begin(), left_cut) +
189 "\1" +
190 word.UTF8Substring(right_cut, word.end()) +
191 "$";
192 // clang-format on
193 }
194
195 const UnicodeText feature_word_unicode =
196 UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
197
198 // Upper-bound the number of charactergram extracted to avoid resizing.
199 result.reserve(options_.chargram_orders.size() * feature_word.size());
200
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200201 if (options_.chargram_orders.empty()) {
202 result.push_back(HashToken(feature_word));
203 } else {
204 // Generate the character-grams.
205 for (int chargram_order : options_.chargram_orders) {
206 UnicodeText::const_iterator it_start = feature_word_unicode.begin();
207 UnicodeText::const_iterator it_end = feature_word_unicode.end();
208 if (chargram_order == 1) {
209 ++it_start;
210 --it_end;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200211 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200212
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200213 UnicodeText::const_iterator it_chargram_start = it_start;
214 UnicodeText::const_iterator it_chargram_end = it_start;
215 bool chargram_is_complete = true;
216 for (int i = 0; i < chargram_order; ++i) {
217 if (it_chargram_end == it_end) {
218 chargram_is_complete = false;
219 break;
220 }
221 ++it_chargram_end;
222 }
223 if (!chargram_is_complete) {
224 continue;
225 }
226
227 for (; it_chargram_end <= it_end;
228 ++it_chargram_start, ++it_chargram_end) {
229 const int length_bytes =
230 it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
231 result.push_back(HashToken(
232 StringPiece(it_chargram_start.utf8_data(), length_bytes)));
233 }
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200234 }
235 }
236 }
237 return result;
238}
239
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200240bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
Matt Sharifibda09f12017-03-10 12:29:15 +0100241 std::vector<int>* sparse_features,
242 std::vector<float>* dense_features) const {
243 if (sparse_features == nullptr || dense_features == nullptr) {
244 return false;
245 }
246
247 *sparse_features = ExtractCharactergramFeatures(token);
248
249 if (options_.extract_case_feature) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200250 if (options_.unicode_aware_features) {
251 UnicodeText token_unicode =
252 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100253 const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200254 if (!token.value.empty() && is_upper) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200255 dense_features->push_back(1.0);
256 } else {
257 dense_features->push_back(-1.0);
258 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100259 } else {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200260 if (!token.value.empty() && isupper(*token.value.begin())) {
261 dense_features->push_back(1.0);
262 } else {
263 dense_features->push_back(-1.0);
264 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100265 }
266 }
267
268 if (options_.extract_selection_mask_feature) {
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200269 if (is_in_span) {
Matt Sharifibda09f12017-03-10 12:29:15 +0100270 dense_features->push_back(1.0);
271 } else {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200272 if (options_.unicode_aware_features) {
273 dense_features->push_back(-1.0);
274 } else {
275 dense_features->push_back(0.0);
276 }
Matt Sharifibda09f12017-03-10 12:29:15 +0100277 }
278 }
279
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200280 // Add regexp features.
281 if (!regex_patterns_.empty()) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200282 for (int i = 0; i < regex_patterns_.size(); ++i) {
283 if (!regex_patterns_[i].get()) {
284 dense_features->push_back(-1.0);
285 continue;
286 }
287
Lukas Zilka21d8c982018-01-24 11:11:20 +0100288 if (regex_patterns_[i]->Matches(token.value)) {
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200289 dense_features->push_back(1.0);
290 } else {
291 dense_features->push_back(-1.0);
292 }
293 }
294 }
Lukas Zilkae5ea2ab2017-10-11 10:50:05 +0200295
Matt Sharifibda09f12017-03-10 12:29:15 +0100296 return true;
297}
298
Lukas Zilka21d8c982018-01-24 11:11:20 +0100299} // namespace libtextclassifier2