blob: 277549ec606873862cb82a4fcb159821210c608e [file] [log] [blame]
Matt Sharifid40f9762017-03-14 21:24:23 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "smartselect/token-feature-extractor.h"
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21
22namespace libtextclassifier {
23namespace {
24
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020025class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
26 public:
27 using TokenFeatureExtractor::TokenFeatureExtractor;
28 using TokenFeatureExtractor::HashToken;
29};
30
31TEST(TokenFeatureExtractorTest, ExtractAscii) {
Matt Sharifid40f9762017-03-14 21:24:23 +010032 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020033 options.num_buckets = 1000;
34 options.chargram_orders = std::vector<int>{1, 2, 3};
Matt Sharifid40f9762017-03-14 21:24:23 +010035 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020036 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +010037 options.extract_selection_mask_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020038 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +010039
40 std::vector<int> sparse_features;
41 std::vector<float> dense_features;
42
Lukas Zilka6bb39a82017-04-07 19:55:11 +020043 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
Matt Sharifid40f9762017-03-14 21:24:23 +010044 &dense_features);
45
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020046 EXPECT_THAT(sparse_features,
47 testing::ElementsAreArray({
48 // clang-format off
49 extractor.HashToken("H"),
50 extractor.HashToken("e"),
51 extractor.HashToken("l"),
52 extractor.HashToken("l"),
53 extractor.HashToken("o"),
54 extractor.HashToken("^H"),
55 extractor.HashToken("He"),
56 extractor.HashToken("el"),
57 extractor.HashToken("ll"),
58 extractor.HashToken("lo"),
59 extractor.HashToken("o$"),
60 extractor.HashToken("^He"),
61 extractor.HashToken("Hel"),
62 extractor.HashToken("ell"),
63 extractor.HashToken("llo"),
64 extractor.HashToken("lo$")
65 // clang-format on
66 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010067 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
68
69 sparse_features.clear();
70 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +020071 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Matt Sharifid40f9762017-03-14 21:24:23 +010072 &dense_features);
73
Lukas Zilkad3bc59a2017-04-03 17:32:27 +020074 EXPECT_THAT(sparse_features,
75 testing::ElementsAreArray({
76 // clang-format off
77 extractor.HashToken("w"),
78 extractor.HashToken("o"),
79 extractor.HashToken("r"),
80 extractor.HashToken("l"),
81 extractor.HashToken("d"),
82 extractor.HashToken("!"),
83 extractor.HashToken("^w"),
84 extractor.HashToken("wo"),
85 extractor.HashToken("or"),
86 extractor.HashToken("rl"),
87 extractor.HashToken("ld"),
88 extractor.HashToken("d!"),
89 extractor.HashToken("!$"),
90 extractor.HashToken("^wo"),
91 extractor.HashToken("wor"),
92 extractor.HashToken("orl"),
93 extractor.HashToken("rld"),
94 extractor.HashToken("ld!"),
95 extractor.HashToken("d!$"),
96 // clang-format on
97 }));
Matt Sharifid40f9762017-03-14 21:24:23 +010098 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
99}
100
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200101TEST(TokenFeatureExtractorTest, ExtractUnicode) {
102 TokenFeatureExtractorOptions options;
103 options.num_buckets = 1000;
104 options.chargram_orders = std::vector<int>{1, 2, 3};
105 options.extract_case_feature = true;
106 options.unicode_aware_features = true;
107 options.extract_selection_mask_feature = true;
108 TestingTokenFeatureExtractor extractor(options);
109
110 std::vector<int> sparse_features;
111 std::vector<float> dense_features;
112
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200113 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200114 &dense_features);
115
116 EXPECT_THAT(sparse_features,
117 testing::ElementsAreArray({
118 // clang-format off
119 extractor.HashToken("H"),
120 extractor.HashToken("ě"),
121 extractor.HashToken("l"),
122 extractor.HashToken("l"),
123 extractor.HashToken("ó"),
124 extractor.HashToken("^H"),
125 extractor.HashToken("Hě"),
126 extractor.HashToken("ěl"),
127 extractor.HashToken("ll"),
128 extractor.HashToken("ló"),
129 extractor.HashToken("ó$"),
130 extractor.HashToken("^Hě"),
131 extractor.HashToken("Hěl"),
132 extractor.HashToken("ěll"),
133 extractor.HashToken("lló"),
134 extractor.HashToken("ló$")
135 // clang-format on
136 }));
137 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
138
139 sparse_features.clear();
140 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200141 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200142 &dense_features);
143
144 EXPECT_THAT(sparse_features,
145 testing::ElementsAreArray({
146 // clang-format off
147 extractor.HashToken("w"),
148 extractor.HashToken("o"),
149 extractor.HashToken("r"),
150 extractor.HashToken("l"),
151 extractor.HashToken("d"),
152 extractor.HashToken("!"),
153 extractor.HashToken("^w"),
154 extractor.HashToken("wo"),
155 extractor.HashToken("or"),
156 extractor.HashToken("rl"),
157 extractor.HashToken("ld"),
158 extractor.HashToken("d!"),
159 extractor.HashToken("!$"),
160 extractor.HashToken("^wo"),
161 extractor.HashToken("wor"),
162 extractor.HashToken("orl"),
163 extractor.HashToken("rld"),
164 extractor.HashToken("ld!"),
165 extractor.HashToken("d!$"),
166 // clang-format on
167 }));
168 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
169}
170
171TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
172 TokenFeatureExtractorOptions options;
173 options.num_buckets = 1000;
174 options.chargram_orders = std::vector<int>{1, 2};
175 options.extract_case_feature = true;
176 options.unicode_aware_features = true;
177 options.extract_selection_mask_feature = false;
178 TokenFeatureExtractor extractor(options);
179
180 std::vector<int> sparse_features;
181 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200182 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200183 &dense_features);
184 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
185
186 sparse_features.clear();
187 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200188 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200189 &dense_features);
190 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
191
192 sparse_features.clear();
193 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200194 extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200195 &dense_features);
196 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
197
198 sparse_features.clear();
199 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200200 extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200201 &dense_features);
202 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
203}
204
205TEST(TokenFeatureExtractorTest, DigitRemapping) {
206 TokenFeatureExtractorOptions options;
207 options.num_buckets = 1000;
208 options.chargram_orders = std::vector<int>{1, 2};
209 options.remap_digits = true;
210 options.unicode_aware_features = false;
211 TokenFeatureExtractor extractor(options);
212
213 std::vector<int> sparse_features;
214 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200215 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200216 &dense_features);
217
218 std::vector<int> sparse_features2;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200219 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200220 &dense_features);
221 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
222
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200223 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200224 &dense_features);
225 EXPECT_THAT(sparse_features,
226 testing::Not(testing::ElementsAreArray(sparse_features2)));
227}
228
229TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
230 TokenFeatureExtractorOptions options;
231 options.num_buckets = 1000;
232 options.chargram_orders = std::vector<int>{1, 2};
233 options.remap_digits = true;
234 options.unicode_aware_features = true;
235 TokenFeatureExtractor extractor(options);
236
237 std::vector<int> sparse_features;
238 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200239 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200240 &dense_features);
241
242 std::vector<int> sparse_features2;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200243 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200244 &dense_features);
245 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
246
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200247 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200248 &dense_features);
249 EXPECT_THAT(sparse_features,
250 testing::Not(testing::ElementsAreArray(sparse_features2)));
251}
252
253TEST(TokenFeatureExtractorTest, RegexFeatures) {
254 TokenFeatureExtractorOptions options;
255 options.num_buckets = 1000;
256 options.chargram_orders = std::vector<int>{1, 2};
257 options.remap_digits = false;
258 options.unicode_aware_features = false;
259 options.regexp_features.push_back("^[a-z]+$"); // all lower case.
260 options.regexp_features.push_back("^[0-9]+$"); // all digits.
261 TokenFeatureExtractor extractor(options);
262
263 std::vector<int> sparse_features;
264 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200265 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200266 &dense_features);
267 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
268
269 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200270 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200271 &dense_features);
272 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
273
274 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200275 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200276 &dense_features);
277 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
278
279 dense_features.clear();
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200280 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200281 &dense_features);
282 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
283}
284
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200285TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
286 TokenFeatureExtractorOptions options;
287 options.num_buckets = 1000;
288 options.chargram_orders = std::vector<int>{22};
289 options.extract_case_feature = true;
290 options.unicode_aware_features = true;
291 options.extract_selection_mask_feature = true;
292 TestingTokenFeatureExtractor extractor(options);
293
294 // Test that this runs. ASAN should catch problems.
295 std::vector<int> sparse_features;
296 std::vector<float> dense_features;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200297 extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200298 &sparse_features, &dense_features);
299
300 EXPECT_THAT(sparse_features,
301 testing::ElementsAreArray({
302 // clang-format off
303 extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
304 extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
305 // clang-format on
306 }));
307}
308
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200309TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
310 TokenFeatureExtractorOptions options;
311 options.num_buckets = 1000;
312 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
313 options.extract_case_feature = true;
314 options.unicode_aware_features = true;
315 options.extract_selection_mask_feature = true;
316 TestingTokenFeatureExtractor extractor_unicode(options);
317
318 options.unicode_aware_features = false;
319 TestingTokenFeatureExtractor extractor_ascii(options);
320
321 for (const std::string& input :
322 {"https://www.abcdefgh.com/in/xxxkkkvayio",
323 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
324 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
325 "x", "Hello", "Hey,", "Hi", ""}) {
326 std::vector<int> sparse_features_unicode;
327 std::vector<float> dense_features_unicode;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200328 extractor_unicode.Extract(Token{input, 0, 0}, true,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200329 &sparse_features_unicode,
330 &dense_features_unicode);
331
332 std::vector<int> sparse_features_ascii;
333 std::vector<float> dense_features_ascii;
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200334 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
Lukas Zilka26e8c2e2017-04-06 15:54:24 +0200335 &dense_features_ascii);
336
337 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
338 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
339 }
340}
341
Matt Sharifid40f9762017-03-14 21:24:23 +0100342TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
343 TokenFeatureExtractorOptions options;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200344 options.num_buckets = 1000;
Matt Sharifid40f9762017-03-14 21:24:23 +0100345 options.chargram_orders = std::vector<int>{1, 2};
346 options.extract_case_feature = true;
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200347 options.unicode_aware_features = false;
Matt Sharifid40f9762017-03-14 21:24:23 +0100348 options.extract_selection_mask_feature = true;
349
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200350 TestingTokenFeatureExtractor extractor(options);
Matt Sharifid40f9762017-03-14 21:24:23 +0100351
352 std::vector<int> sparse_features;
353 std::vector<float> dense_features;
354
Lukas Zilka6bb39a82017-04-07 19:55:11 +0200355 extractor.Extract(Token(), false, &sparse_features, &dense_features);
Matt Sharifid40f9762017-03-14 21:24:23 +0100356
Lukas Zilkad3bc59a2017-04-03 17:32:27 +0200357 EXPECT_THAT(sparse_features,
358 testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
Matt Sharifid40f9762017-03-14 21:24:23 +0100359 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
360}
361
362} // namespace
363} // namespace libtextclassifier